You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

92 lines
2.9 KiB
Python

# Copyright (C) 2022-2023 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
from abc import ABCMeta, abstractmethod
from http import HTTPStatus
from time import sleep
from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Union
from cvat_sdk.api_client.api_client import ApiClient, Endpoint
from cvat_sdk.core.helpers import get_paginated_collection
from urllib3 import HTTPResponse
from shared.utils.config import make_api_client
def export_dataset(
endpoint: Endpoint, *, max_retries: int = 20, interval: float = 0.1, **kwargs
) -> HTTPResponse:
for _ in range(max_retries):
(_, response) = endpoint.call_with_http_info(**kwargs, _parse_response=False)
if response.status == HTTPStatus.CREATED:
break
assert response.status == HTTPStatus.ACCEPTED
sleep(interval)
assert response.status == HTTPStatus.CREATED
(_, response) = endpoint.call_with_http_info(**kwargs, action="download", _parse_response=False)
assert response.status == HTTPStatus.OK
return response
FieldPath = Sequence[str]
class CollectionSimpleFilterTestBase(metaclass=ABCMeta):
# These fields need to be defined in the subclass
user: str
samples: List[Dict[str, Any]]
field_lookups: Dict[str, FieldPath] = None
@abstractmethod
def _get_endpoint(self, api_client: ApiClient) -> Endpoint:
...
def _retrieve_collection(self, **kwargs) -> List:
with make_api_client(self.user) as api_client:
return get_paginated_collection(self._get_endpoint(api_client), **kwargs)
@classmethod
def _get_field(cls, d: Dict[str, Any], path: Union[str, FieldPath]) -> Optional[Any]:
assert path
for key in path:
if isinstance(d, dict):
d = d.get(key)
else:
d = None
return d
def _map_field(self, name: str) -> FieldPath:
return (self.field_lookups or {}).get(name, [name])
@classmethod
def _find_valid_field_value(
cls, samples: Iterator[Dict[str, Any]], field_path: FieldPath
) -> Any:
value = None
for sample in samples:
value = cls._get_field(sample, field_path)
if value:
break
assert value, f"Failed to find a sample for the '{'.'.join(field_path)}' field"
return value
def _get_field_samples(self, field: str) -> Tuple[Any, List[Dict[str, Any]]]:
field_path = self._map_field(field)
field_value = self._find_valid_field_value(self.samples, field_path)
gt_objects = filter(lambda p: field_value == self._get_field(p, field_path), self.samples)
return field_value, gt_objects
def test_can_use_simple_filter_for_object_list(self, field):
value, gt_objects = self._get_field_samples(field)
received_items = self._retrieve_collection(**{field: str(value)})
assert set(p["id"] for p in gt_objects) == set(p.id for p in received_items)