You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
92 lines
2.9 KiB
Python
92 lines
2.9 KiB
Python
# Copyright (C) 2022-2023 CVAT.ai Corporation
|
|
#
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
from abc import ABCMeta, abstractmethod
|
|
from http import HTTPStatus
|
|
from time import sleep
|
|
from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Union
|
|
|
|
from cvat_sdk.api_client.api_client import ApiClient, Endpoint
|
|
from cvat_sdk.core.helpers import get_paginated_collection
|
|
from urllib3 import HTTPResponse
|
|
|
|
from shared.utils.config import make_api_client
|
|
|
|
|
|
def export_dataset(
|
|
endpoint: Endpoint, *, max_retries: int = 20, interval: float = 0.1, **kwargs
|
|
) -> HTTPResponse:
|
|
for _ in range(max_retries):
|
|
(_, response) = endpoint.call_with_http_info(**kwargs, _parse_response=False)
|
|
if response.status == HTTPStatus.CREATED:
|
|
break
|
|
assert response.status == HTTPStatus.ACCEPTED
|
|
sleep(interval)
|
|
assert response.status == HTTPStatus.CREATED
|
|
|
|
(_, response) = endpoint.call_with_http_info(**kwargs, action="download", _parse_response=False)
|
|
assert response.status == HTTPStatus.OK
|
|
|
|
return response
|
|
|
|
|
|
FieldPath = Sequence[str]
|
|
|
|
|
|
class CollectionSimpleFilterTestBase(metaclass=ABCMeta):
|
|
# These fields need to be defined in the subclass
|
|
user: str
|
|
samples: List[Dict[str, Any]]
|
|
field_lookups: Dict[str, FieldPath] = None
|
|
|
|
@abstractmethod
|
|
def _get_endpoint(self, api_client: ApiClient) -> Endpoint:
|
|
...
|
|
|
|
def _retrieve_collection(self, **kwargs) -> List:
|
|
with make_api_client(self.user) as api_client:
|
|
return get_paginated_collection(self._get_endpoint(api_client), **kwargs)
|
|
|
|
@classmethod
|
|
def _get_field(cls, d: Dict[str, Any], path: Union[str, FieldPath]) -> Optional[Any]:
|
|
assert path
|
|
for key in path:
|
|
if isinstance(d, dict):
|
|
d = d.get(key)
|
|
else:
|
|
d = None
|
|
|
|
return d
|
|
|
|
def _map_field(self, name: str) -> FieldPath:
|
|
return (self.field_lookups or {}).get(name, [name])
|
|
|
|
@classmethod
|
|
def _find_valid_field_value(
|
|
cls, samples: Iterator[Dict[str, Any]], field_path: FieldPath
|
|
) -> Any:
|
|
value = None
|
|
for sample in samples:
|
|
value = cls._get_field(sample, field_path)
|
|
if value:
|
|
break
|
|
|
|
assert value, f"Failed to find a sample for the '{'.'.join(field_path)}' field"
|
|
return value
|
|
|
|
def _get_field_samples(self, field: str) -> Tuple[Any, List[Dict[str, Any]]]:
|
|
field_path = self._map_field(field)
|
|
field_value = self._find_valid_field_value(self.samples, field_path)
|
|
|
|
gt_objects = filter(lambda p: field_value == self._get_field(p, field_path), self.samples)
|
|
|
|
return field_value, gt_objects
|
|
|
|
def test_can_use_simple_filter_for_object_list(self, field):
|
|
value, gt_objects = self._get_field_samples(field)
|
|
|
|
received_items = self._retrieve_collection(**{field: str(value)})
|
|
|
|
assert set(p["id"] for p in gt_objects) == set(p.id for p in received_items)
|