SDK layer 2 - cover RC1 usecases (#4813)
parent
b60d3b481a
commit
53697ecac5
@ -0,0 +1,66 @@
|
|||||||
|
# Copyright (C) 2022 CVAT.ai Corporation
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
from abc import ABC
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Optional, Sequence
|
||||||
|
|
||||||
|
from cvat_sdk import models
|
||||||
|
from cvat_sdk.core.proxies.model_proxy import _EntityT
|
||||||
|
|
||||||
|
|
||||||
|
class AnnotationUpdateAction(Enum):
|
||||||
|
CREATE = "create"
|
||||||
|
UPDATE = "update"
|
||||||
|
DELETE = "delete"
|
||||||
|
|
||||||
|
|
||||||
|
class AnnotationCrudMixin(ABC):
|
||||||
|
# TODO: refactor
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _put_annotations_data_param(self) -> str:
|
||||||
|
...
|
||||||
|
|
||||||
|
def get_annotations(self: _EntityT) -> models.ILabeledData:
|
||||||
|
(annotations, _) = self.api.retrieve_annotations(getattr(self, self._model_id_field))
|
||||||
|
return annotations
|
||||||
|
|
||||||
|
def set_annotations(self: _EntityT, data: models.ILabeledDataRequest):
|
||||||
|
self.api.update_annotations(
|
||||||
|
getattr(self, self._model_id_field), **{self._put_annotations_data_param: data}
|
||||||
|
)
|
||||||
|
|
||||||
|
def update_annotations(
|
||||||
|
self: _EntityT,
|
||||||
|
data: models.IPatchedLabeledDataRequest,
|
||||||
|
*,
|
||||||
|
action: AnnotationUpdateAction = AnnotationUpdateAction.UPDATE,
|
||||||
|
):
|
||||||
|
self.api.partial_update_annotations(
|
||||||
|
action=action.value,
|
||||||
|
id=getattr(self, self._model_id_field),
|
||||||
|
patched_labeled_data_request=data,
|
||||||
|
)
|
||||||
|
|
||||||
|
def remove_annotations(self: _EntityT, *, ids: Optional[Sequence[int]] = None):
|
||||||
|
if ids:
|
||||||
|
anns = self.get_annotations()
|
||||||
|
|
||||||
|
if not isinstance(ids, set):
|
||||||
|
ids = set(ids)
|
||||||
|
|
||||||
|
anns_to_remove = models.PatchedLabeledDataRequest(
|
||||||
|
tags=[models.LabeledImageRequest(**a.to_dict()) for a in anns.tags if a.id in ids],
|
||||||
|
tracks=[
|
||||||
|
models.LabeledTrackRequest(**a.to_dict()) for a in anns.tracks if a.id in ids
|
||||||
|
],
|
||||||
|
shapes=[
|
||||||
|
models.LabeledShapeRequest(**a.to_dict()) for a in anns.shapes if a.id in ids
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
self.update_annotations(anns_to_remove, action=AnnotationUpdateAction.DELETE)
|
||||||
|
else:
|
||||||
|
self.api.destroy_annotations(getattr(self, self._model_id_field))
|
||||||
@ -0,0 +1,60 @@
|
|||||||
|
# Copyright (C) 2022 CVAT.ai Corporation
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from cvat_sdk.api_client import apis, models
|
||||||
|
from cvat_sdk.core.proxies.model_proxy import (
|
||||||
|
ModelCreateMixin,
|
||||||
|
ModelDeleteMixin,
|
||||||
|
ModelListMixin,
|
||||||
|
ModelRetrieveMixin,
|
||||||
|
ModelUpdateMixin,
|
||||||
|
build_model_bases,
|
||||||
|
)
|
||||||
|
|
||||||
|
_CommentEntityBase, _CommentRepoBase = build_model_bases(
|
||||||
|
models.CommentRead, apis.CommentsApi, api_member_name="comments_api"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Comment(
|
||||||
|
models.ICommentRead,
|
||||||
|
_CommentEntityBase,
|
||||||
|
ModelUpdateMixin[models.IPatchedCommentWriteRequest],
|
||||||
|
ModelDeleteMixin,
|
||||||
|
):
|
||||||
|
_model_partial_update_arg = "patched_comment_write_request"
|
||||||
|
|
||||||
|
|
||||||
|
class CommentsRepo(
|
||||||
|
_CommentRepoBase,
|
||||||
|
ModelListMixin[Comment],
|
||||||
|
ModelCreateMixin[Comment, models.ICommentWriteRequest],
|
||||||
|
ModelRetrieveMixin[Comment],
|
||||||
|
):
|
||||||
|
_entity_type = Comment
|
||||||
|
|
||||||
|
|
||||||
|
_IssueEntityBase, _IssueRepoBase = build_model_bases(
|
||||||
|
models.IssueRead, apis.IssuesApi, api_member_name="issues_api"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Issue(
|
||||||
|
models.IIssueRead,
|
||||||
|
_IssueEntityBase,
|
||||||
|
ModelUpdateMixin[models.IPatchedIssueWriteRequest],
|
||||||
|
ModelDeleteMixin,
|
||||||
|
):
|
||||||
|
_model_partial_update_arg = "patched_issue_write_request"
|
||||||
|
|
||||||
|
|
||||||
|
class IssuesRepo(
|
||||||
|
_IssueRepoBase,
|
||||||
|
ModelListMixin[Issue],
|
||||||
|
ModelCreateMixin[Issue, models.IIssueWriteRequest],
|
||||||
|
ModelRetrieveMixin[Issue],
|
||||||
|
):
|
||||||
|
_entity_type = Issue
|
||||||
@ -0,0 +1,166 @@
|
|||||||
|
# Copyright (C) 2022 CVAT.ai Corporation
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import io
|
||||||
|
import mimetypes
|
||||||
|
import os
|
||||||
|
import os.path as osp
|
||||||
|
from typing import List, Optional, Sequence
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from cvat_sdk.api_client import apis, models
|
||||||
|
from cvat_sdk.core.downloading import Downloader
|
||||||
|
from cvat_sdk.core.helpers import get_paginated_collection
|
||||||
|
from cvat_sdk.core.progress import ProgressReporter
|
||||||
|
from cvat_sdk.core.proxies.annotations import AnnotationCrudMixin
|
||||||
|
from cvat_sdk.core.proxies.issues import Issue
|
||||||
|
from cvat_sdk.core.proxies.model_proxy import (
|
||||||
|
ModelListMixin,
|
||||||
|
ModelRetrieveMixin,
|
||||||
|
ModelUpdateMixin,
|
||||||
|
build_model_bases,
|
||||||
|
)
|
||||||
|
from cvat_sdk.core.uploading import AnnotationUploader
|
||||||
|
|
||||||
|
_JobEntityBase, _JobRepoBase = build_model_bases(
|
||||||
|
models.JobRead, apis.JobsApi, api_member_name="jobs_api"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Job(
|
||||||
|
models.IJobRead,
|
||||||
|
_JobEntityBase,
|
||||||
|
ModelUpdateMixin[models.IPatchedJobWriteRequest],
|
||||||
|
AnnotationCrudMixin,
|
||||||
|
):
|
||||||
|
_model_partial_update_arg = "patched_job_write_request"
|
||||||
|
_put_annotations_data_param = "job_annotations_update_request"
|
||||||
|
|
||||||
|
def import_annotations(
|
||||||
|
self,
|
||||||
|
format_name: str,
|
||||||
|
filename: str,
|
||||||
|
*,
|
||||||
|
status_check_period: Optional[int] = None,
|
||||||
|
pbar: Optional[ProgressReporter] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Upload annotations for a job in the specified format (e.g. 'YOLO ZIP 1.0').
|
||||||
|
"""
|
||||||
|
|
||||||
|
AnnotationUploader(self._client).upload_file_and_wait(
|
||||||
|
self.api.create_annotations_endpoint,
|
||||||
|
filename,
|
||||||
|
format_name,
|
||||||
|
url_params={"id": self.id},
|
||||||
|
pbar=pbar,
|
||||||
|
status_check_period=status_check_period,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._client.logger.info(f"Annotation file '{filename}' for job #{self.id} uploaded")
|
||||||
|
|
||||||
|
def export_dataset(
|
||||||
|
self,
|
||||||
|
format_name: str,
|
||||||
|
filename: str,
|
||||||
|
*,
|
||||||
|
pbar: Optional[ProgressReporter] = None,
|
||||||
|
status_check_period: Optional[int] = None,
|
||||||
|
include_images: bool = True,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Download annotations for a job in the specified format (e.g. 'YOLO ZIP 1.0').
|
||||||
|
"""
|
||||||
|
if include_images:
|
||||||
|
endpoint = self.api.retrieve_dataset_endpoint
|
||||||
|
else:
|
||||||
|
endpoint = self.api.retrieve_annotations_endpoint
|
||||||
|
|
||||||
|
Downloader(self._client).prepare_and_download_file_from_endpoint(
|
||||||
|
endpoint=endpoint,
|
||||||
|
filename=filename,
|
||||||
|
url_params={"id": self.id},
|
||||||
|
query_params={"format": format_name},
|
||||||
|
pbar=pbar,
|
||||||
|
status_check_period=status_check_period,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._client.logger.info(f"Dataset for job {self.id} has been downloaded to {filename}")
|
||||||
|
|
||||||
|
def get_frame(
|
||||||
|
self,
|
||||||
|
frame_id: int,
|
||||||
|
*,
|
||||||
|
quality: Optional[str] = None,
|
||||||
|
) -> io.RawIOBase:
|
||||||
|
(_, response) = self.api.retrieve_data(
|
||||||
|
self.id, number=frame_id, quality=quality, type="frame"
|
||||||
|
)
|
||||||
|
return io.BytesIO(response.data)
|
||||||
|
|
||||||
|
def get_preview(
|
||||||
|
self,
|
||||||
|
) -> io.RawIOBase:
|
||||||
|
(_, response) = self.api.retrieve_data(self.id, type="preview")
|
||||||
|
return io.BytesIO(response.data)
|
||||||
|
|
||||||
|
def download_frames(
|
||||||
|
self,
|
||||||
|
frame_ids: Sequence[int],
|
||||||
|
*,
|
||||||
|
outdir: str = "",
|
||||||
|
quality: str = "original",
|
||||||
|
filename_pattern: str = "frame_{frame_id:06d}{frame_ext}",
|
||||||
|
) -> Optional[List[Image.Image]]:
|
||||||
|
"""
|
||||||
|
Download the requested frame numbers for a job and save images as outdir/filename_pattern
|
||||||
|
"""
|
||||||
|
# TODO: add arg descriptions in schema
|
||||||
|
os.makedirs(outdir, exist_ok=True)
|
||||||
|
|
||||||
|
for frame_id in frame_ids:
|
||||||
|
frame_bytes = self.get_frame(frame_id, quality=quality)
|
||||||
|
|
||||||
|
im = Image.open(frame_bytes)
|
||||||
|
mime_type = im.get_format_mimetype() or "image/jpg"
|
||||||
|
im_ext = mimetypes.guess_extension(mime_type)
|
||||||
|
|
||||||
|
# FIXME It is better to use meta information from the server
|
||||||
|
# to determine the extension
|
||||||
|
# replace '.jpe' or '.jpeg' with a more used '.jpg'
|
||||||
|
if im_ext in (".jpe", ".jpeg", None):
|
||||||
|
im_ext = ".jpg"
|
||||||
|
|
||||||
|
outfile = filename_pattern.format(frame_id=frame_id, frame_ext=im_ext)
|
||||||
|
im.save(osp.join(outdir, outfile))
|
||||||
|
|
||||||
|
def get_meta(self) -> models.IDataMetaRead:
|
||||||
|
(meta, _) = self.api.retrieve_data_meta(self.id)
|
||||||
|
return meta
|
||||||
|
|
||||||
|
def get_frames_info(self) -> List[models.IFrameMeta]:
|
||||||
|
return self.get_meta().frames
|
||||||
|
|
||||||
|
def remove_frames_by_ids(self, ids: Sequence[int]) -> None:
|
||||||
|
self._client.api.tasks_api.jobs_partial_update_data_meta(
|
||||||
|
self.id,
|
||||||
|
patched_data_meta_write_request=models.PatchedDataMetaWriteRequest(deleted_frames=ids),
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_issues(self) -> List[Issue]:
|
||||||
|
return [Issue(self._client, m) for m in self.api.list_issues(id=self.id)[0]]
|
||||||
|
|
||||||
|
def get_commits(self) -> List[models.IJobCommit]:
|
||||||
|
return get_paginated_collection(self.api.list_commits_endpoint, id=self.id)
|
||||||
|
|
||||||
|
|
||||||
|
class JobsRepo(
|
||||||
|
_JobRepoBase,
|
||||||
|
ModelListMixin[Job],
|
||||||
|
ModelRetrieveMixin[Job],
|
||||||
|
):
|
||||||
|
_entity_type = Job
|
||||||
@ -0,0 +1,213 @@
|
|||||||
|
# Copyright (C) 2022 CVAT.ai Corporation
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from abc import ABC
|
||||||
|
from copy import deepcopy
|
||||||
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
Dict,
|
||||||
|
Generic,
|
||||||
|
List,
|
||||||
|
Literal,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
Type,
|
||||||
|
TypeVar,
|
||||||
|
Union,
|
||||||
|
overload,
|
||||||
|
)
|
||||||
|
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
from cvat_sdk.api_client.model_utils import IModelData, ModelNormal, to_json
|
||||||
|
from cvat_sdk.core.helpers import get_paginated_collection
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from cvat_sdk.core.client import Client
|
||||||
|
|
||||||
|
IModel = TypeVar("IModel", bound=IModelData)
|
||||||
|
ModelType = TypeVar("ModelType", bound=ModelNormal)
|
||||||
|
ApiType = TypeVar("ApiType")
|
||||||
|
|
||||||
|
|
||||||
|
class ModelProxy(ABC, Generic[ModelType, ApiType]):
|
||||||
|
_client: Client
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _api_member_name(self) -> str:
|
||||||
|
...
|
||||||
|
|
||||||
|
def __init__(self, client: Client) -> None:
|
||||||
|
self.__dict__["_client"] = client
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_api(cls, client: Client) -> ApiType:
|
||||||
|
return getattr(client.api, cls._api_member_name)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def api(self) -> ApiType:
|
||||||
|
return self.get_api(self._client)
|
||||||
|
|
||||||
|
|
||||||
|
class Entity(ModelProxy[ModelType, ApiType]):
|
||||||
|
"""
|
||||||
|
Represents a single object. Implements related operations and provides access to data members.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_model: ModelType
|
||||||
|
|
||||||
|
def __init__(self, client: Client, model: ModelType) -> None:
|
||||||
|
super().__init__(client)
|
||||||
|
self.__dict__["_model"] = model
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _model_id_field(self) -> str:
|
||||||
|
return "id"
|
||||||
|
|
||||||
|
def __getattr__(self, __name: str) -> Any:
|
||||||
|
# NOTE: be aware of potential problems with throwing AttributeError from @property
|
||||||
|
# in derived classes!
|
||||||
|
# https://medium.com/@ceshine/python-debugging-pitfall-mixed-use-of-property-and-getattr-f89e0ede13f1
|
||||||
|
return self._model[__name]
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return str(self._model)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"<{self.__class__.__name__}: id={getattr(self, self._model_id_field)}>"
|
||||||
|
|
||||||
|
|
||||||
|
class Repo(ModelProxy[ModelType, ApiType]):
|
||||||
|
"""
|
||||||
|
Represents a collection of corresponding Entity objects.
|
||||||
|
Implements group and management operations for entities.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_entity_type: Type[Entity[ModelType, ApiType]]
|
||||||
|
|
||||||
|
|
||||||
|
### Utilities
|
||||||
|
|
||||||
|
|
||||||
|
def build_model_bases(
|
||||||
|
mt: Type[ModelType], at: Type[ApiType], *, api_member_name: Optional[str] = None
|
||||||
|
) -> Tuple[Type[Entity[ModelType, ApiType]], Type[Repo[ModelType, ApiType]]]:
|
||||||
|
"""
|
||||||
|
Helps to remove code duplication in declarations of derived classes
|
||||||
|
"""
|
||||||
|
|
||||||
|
class _EntityBase(Entity[ModelType, ApiType]):
|
||||||
|
if api_member_name:
|
||||||
|
_api_member_name = api_member_name
|
||||||
|
|
||||||
|
class _RepoBase(Repo[ModelType, ApiType]):
|
||||||
|
if api_member_name:
|
||||||
|
_api_member_name = api_member_name
|
||||||
|
|
||||||
|
return _EntityBase, _RepoBase
|
||||||
|
|
||||||
|
|
||||||
|
### CRUD mixins
|
||||||
|
|
||||||
|
_EntityT = TypeVar("_EntityT", bound=Entity)
|
||||||
|
|
||||||
|
#### Repo mixins
|
||||||
|
|
||||||
|
|
||||||
|
class ModelCreateMixin(Generic[_EntityT, IModel]):
|
||||||
|
def create(self: Repo, spec: Union[Dict[str, Any], IModel]) -> _EntityT:
|
||||||
|
"""
|
||||||
|
Creates a new object on the server and returns corresponding local object
|
||||||
|
"""
|
||||||
|
|
||||||
|
(model, _) = self.api.create(spec)
|
||||||
|
return self._entity_type(self._client, model)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelRetrieveMixin(Generic[_EntityT]):
|
||||||
|
def retrieve(self: Repo, obj_id: int) -> _EntityT:
|
||||||
|
"""
|
||||||
|
Retrieves an object from server by ID
|
||||||
|
"""
|
||||||
|
|
||||||
|
(model, _) = self.api.retrieve(id=obj_id)
|
||||||
|
return self._entity_type(self._client, model)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelListMixin(Generic[_EntityT]):
|
||||||
|
@overload
|
||||||
|
def list(self: Repo, *, return_json: Literal[False] = False) -> List[_EntityT]:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def list(self: Repo, *, return_json: Literal[True] = False) -> List[Any]:
|
||||||
|
...
|
||||||
|
|
||||||
|
def list(self: Repo, *, return_json: bool = False) -> List[Union[_EntityT, Any]]:
|
||||||
|
"""
|
||||||
|
Retrieves all objects from the server and returns them in basic or JSON format.
|
||||||
|
"""
|
||||||
|
|
||||||
|
results = get_paginated_collection(endpoint=self.api.list_endpoint, return_json=return_json)
|
||||||
|
|
||||||
|
if return_json:
|
||||||
|
return json.dumps(results)
|
||||||
|
return [self._entity_type(self._client, model) for model in results]
|
||||||
|
|
||||||
|
|
||||||
|
#### Entity mixins
|
||||||
|
|
||||||
|
|
||||||
|
class ModelUpdateMixin(ABC, Generic[IModel]):
|
||||||
|
@property
|
||||||
|
def _model_partial_update_arg(self: Entity) -> str:
|
||||||
|
...
|
||||||
|
|
||||||
|
def _export_update_fields(
|
||||||
|
self: Entity, overrides: Optional[Union[Dict[str, Any], IModel]] = None
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
# TODO: support field conversion and assignment updating
|
||||||
|
# fields = to_json(self._model)
|
||||||
|
|
||||||
|
if isinstance(overrides, ModelNormal):
|
||||||
|
overrides = to_json(overrides)
|
||||||
|
fields = deepcopy(overrides)
|
||||||
|
|
||||||
|
return fields
|
||||||
|
|
||||||
|
def fetch(self: Entity) -> Self:
|
||||||
|
"""
|
||||||
|
Updates current object from the server
|
||||||
|
"""
|
||||||
|
|
||||||
|
# TODO: implement revision checking
|
||||||
|
(self._model, _) = self.api.retrieve(id=getattr(self, self._model_id_field))
|
||||||
|
return self
|
||||||
|
|
||||||
|
def update(self: Entity, values: Union[Dict[str, Any], IModel]) -> Self:
|
||||||
|
"""
|
||||||
|
Commits local model changes to the server
|
||||||
|
"""
|
||||||
|
|
||||||
|
# TODO: implement revision checking
|
||||||
|
self.api.partial_update(
|
||||||
|
id=getattr(self, self._model_id_field),
|
||||||
|
**{self._model_partial_update_arg: self._export_update_fields(values)},
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: use the response model, once input and output models are same
|
||||||
|
return self.fetch()
|
||||||
|
|
||||||
|
|
||||||
|
class ModelDeleteMixin:
|
||||||
|
def remove(self: Entity) -> None:
|
||||||
|
"""
|
||||||
|
Removes current object on the server
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.api.destroy(id=getattr(self, self._model_id_field))
|
||||||
@ -0,0 +1,187 @@
|
|||||||
|
# Copyright (C) 2022 CVAT.ai Corporation
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os.path as osp
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from cvat_sdk.api_client import apis, models
|
||||||
|
from cvat_sdk.core.downloading import Downloader
|
||||||
|
from cvat_sdk.core.progress import ProgressReporter
|
||||||
|
from cvat_sdk.core.proxies.model_proxy import (
|
||||||
|
ModelCreateMixin,
|
||||||
|
ModelDeleteMixin,
|
||||||
|
ModelListMixin,
|
||||||
|
ModelRetrieveMixin,
|
||||||
|
ModelUpdateMixin,
|
||||||
|
build_model_bases,
|
||||||
|
)
|
||||||
|
from cvat_sdk.core.uploading import DatasetUploader, Uploader
|
||||||
|
|
||||||
|
_ProjectEntityBase, _ProjectRepoBase = build_model_bases(
|
||||||
|
models.ProjectRead, apis.ProjectsApi, api_member_name="projects_api"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Project(
|
||||||
|
_ProjectEntityBase, models.IProjectRead, ModelUpdateMixin[models.IPatchedProjectWriteRequest]
|
||||||
|
):
|
||||||
|
_model_partial_update_arg = "patched_project_write_request"
|
||||||
|
|
||||||
|
def import_dataset(
|
||||||
|
self,
|
||||||
|
format_name: str,
|
||||||
|
filename: str,
|
||||||
|
*,
|
||||||
|
status_check_period: Optional[int] = None,
|
||||||
|
pbar: Optional[ProgressReporter] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Import dataset for a project in the specified format (e.g. 'YOLO ZIP 1.0').
|
||||||
|
"""
|
||||||
|
|
||||||
|
DatasetUploader(self._client).upload_file_and_wait(
|
||||||
|
self.api.create_dataset_endpoint,
|
||||||
|
filename,
|
||||||
|
format_name,
|
||||||
|
url_params={"id": self.id},
|
||||||
|
pbar=pbar,
|
||||||
|
status_check_period=status_check_period,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._client.logger.info(f"Annotation file '{filename}' for project #{self.id} uploaded")
|
||||||
|
|
||||||
|
def export_dataset(
|
||||||
|
self,
|
||||||
|
format_name: str,
|
||||||
|
filename: str,
|
||||||
|
*,
|
||||||
|
pbar: Optional[ProgressReporter] = None,
|
||||||
|
status_check_period: Optional[int] = None,
|
||||||
|
include_images: bool = True,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Download annotations for a project in the specified format (e.g. 'YOLO ZIP 1.0').
|
||||||
|
"""
|
||||||
|
if include_images:
|
||||||
|
endpoint = self.api.retrieve_dataset_endpoint
|
||||||
|
else:
|
||||||
|
endpoint = self.api.retrieve_annotations_endpoint
|
||||||
|
|
||||||
|
Downloader(self._client).prepare_and_download_file_from_endpoint(
|
||||||
|
endpoint=endpoint,
|
||||||
|
filename=filename,
|
||||||
|
url_params={"id": self.id},
|
||||||
|
query_params={"format": format_name},
|
||||||
|
pbar=pbar,
|
||||||
|
status_check_period=status_check_period,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._client.logger.info(f"Dataset for project {self.id} has been downloaded to {filename}")
|
||||||
|
|
||||||
|
def download_backup(
|
||||||
|
self,
|
||||||
|
filename: str,
|
||||||
|
*,
|
||||||
|
status_check_period: int = None,
|
||||||
|
pbar: Optional[ProgressReporter] = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Download a project backup
|
||||||
|
"""
|
||||||
|
|
||||||
|
Downloader(self._client).prepare_and_download_file_from_endpoint(
|
||||||
|
self.api.retrieve_backup_endpoint,
|
||||||
|
filename=filename,
|
||||||
|
pbar=pbar,
|
||||||
|
status_check_period=status_check_period,
|
||||||
|
url_params={"id": self.id},
|
||||||
|
)
|
||||||
|
|
||||||
|
self._client.logger.info(f"Backup for project {self.id} has been downloaded to {filename}")
|
||||||
|
|
||||||
|
def get_annotations(self) -> models.ILabeledData:
|
||||||
|
(annotations, _) = self.api.retrieve_annotations(self.id)
|
||||||
|
return annotations
|
||||||
|
|
||||||
|
|
||||||
|
class ProjectsRepo(
|
||||||
|
_ProjectRepoBase,
|
||||||
|
ModelCreateMixin[Project, models.IProjectWriteRequest],
|
||||||
|
ModelListMixin[Project],
|
||||||
|
ModelRetrieveMixin[Project],
|
||||||
|
ModelDeleteMixin,
|
||||||
|
):
|
||||||
|
_entity_type = Project
|
||||||
|
|
||||||
|
def create_from_dataset(
|
||||||
|
self,
|
||||||
|
spec: models.IProjectWriteRequest,
|
||||||
|
*,
|
||||||
|
dataset_path: str = "",
|
||||||
|
dataset_format: str = "CVAT XML 1.1",
|
||||||
|
status_check_period: int = None,
|
||||||
|
pbar: Optional[ProgressReporter] = None,
|
||||||
|
) -> Project:
|
||||||
|
"""
|
||||||
|
Create a new project with the given name and labels JSON and
|
||||||
|
add the files to it.
|
||||||
|
|
||||||
|
Returns: id of the created project
|
||||||
|
"""
|
||||||
|
project = self.create(spec=spec)
|
||||||
|
self._client.logger.info("Created project ID: %s NAME: %s", project.id, project.name)
|
||||||
|
|
||||||
|
if dataset_path:
|
||||||
|
project.import_dataset(
|
||||||
|
format_name=dataset_format,
|
||||||
|
filename=dataset_path,
|
||||||
|
pbar=pbar,
|
||||||
|
status_check_period=status_check_period,
|
||||||
|
)
|
||||||
|
|
||||||
|
project.fetch()
|
||||||
|
return project
|
||||||
|
|
||||||
|
def create_from_backup(
|
||||||
|
self,
|
||||||
|
filename: str,
|
||||||
|
*,
|
||||||
|
status_check_period: int = None,
|
||||||
|
pbar: Optional[ProgressReporter] = None,
|
||||||
|
) -> Project:
|
||||||
|
"""
|
||||||
|
Import a project from a backup file
|
||||||
|
"""
|
||||||
|
if status_check_period is None:
|
||||||
|
status_check_period = self.config.status_check_period
|
||||||
|
|
||||||
|
params = {"filename": osp.basename(filename)}
|
||||||
|
url = self.api_map.make_endpoint_url(self.api.create_backup_endpoint.path)
|
||||||
|
|
||||||
|
uploader = Uploader(self)
|
||||||
|
response = uploader.upload_file(
|
||||||
|
url,
|
||||||
|
filename,
|
||||||
|
meta=params,
|
||||||
|
query_params=params,
|
||||||
|
pbar=pbar,
|
||||||
|
logger=self._client.logger.debug,
|
||||||
|
)
|
||||||
|
|
||||||
|
rq_id = json.loads(response.data)["rq_id"]
|
||||||
|
response = self._client.wait_for_completion(
|
||||||
|
url,
|
||||||
|
success_status=201,
|
||||||
|
positive_statuses=[202],
|
||||||
|
post_params={"rq_id": rq_id},
|
||||||
|
status_check_period=status_check_period,
|
||||||
|
)
|
||||||
|
|
||||||
|
project_id = json.loads(response.data)["id"]
|
||||||
|
self._client.logger.info(f"Project has been imported sucessfully. Project ID: {project_id}")
|
||||||
|
|
||||||
|
return self.retrieve(project_id)
|
||||||
@ -0,0 +1,388 @@
|
|||||||
|
# Copyright (C) 2022 CVAT.ai Corporation
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import io
|
||||||
|
import json
|
||||||
|
import mimetypes
|
||||||
|
import os
|
||||||
|
import os.path as osp
|
||||||
|
from enum import Enum
|
||||||
|
from time import sleep
|
||||||
|
from typing import Any, Dict, List, Optional, Sequence
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from cvat_sdk.api_client import apis, exceptions, models
|
||||||
|
from cvat_sdk.core import git
|
||||||
|
from cvat_sdk.core.downloading import Downloader
|
||||||
|
from cvat_sdk.core.progress import ProgressReporter
|
||||||
|
from cvat_sdk.core.proxies.annotations import AnnotationCrudMixin
|
||||||
|
from cvat_sdk.core.proxies.jobs import Job
|
||||||
|
from cvat_sdk.core.proxies.model_proxy import (
|
||||||
|
ModelCreateMixin,
|
||||||
|
ModelDeleteMixin,
|
||||||
|
ModelListMixin,
|
||||||
|
ModelRetrieveMixin,
|
||||||
|
ModelUpdateMixin,
|
||||||
|
build_model_bases,
|
||||||
|
)
|
||||||
|
from cvat_sdk.core.uploading import AnnotationUploader, DataUploader, Uploader
|
||||||
|
from cvat_sdk.core.utils import filter_dict
|
||||||
|
|
||||||
|
|
||||||
|
class ResourceType(Enum):
|
||||||
|
LOCAL = 0
|
||||||
|
SHARE = 1
|
||||||
|
REMOTE = 2
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.name.lower()
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return str(self)
|
||||||
|
|
||||||
|
|
||||||
|
_TaskEntityBase, _TaskRepoBase = build_model_bases(
|
||||||
|
models.TaskRead, apis.TasksApi, api_member_name="tasks_api"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Task(
|
||||||
|
_TaskEntityBase,
|
||||||
|
models.ITaskRead,
|
||||||
|
ModelUpdateMixin[models.IPatchedTaskWriteRequest],
|
||||||
|
ModelDeleteMixin,
|
||||||
|
AnnotationCrudMixin,
|
||||||
|
):
|
||||||
|
_model_partial_update_arg = "patched_task_write_request"
|
||||||
|
_put_annotations_data_param = "task_annotations_update_request"
|
||||||
|
|
||||||
|
def upload_data(
|
||||||
|
self,
|
||||||
|
resource_type: ResourceType,
|
||||||
|
resources: Sequence[str],
|
||||||
|
*,
|
||||||
|
pbar: Optional[ProgressReporter] = None,
|
||||||
|
params: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Add local, remote, or shared files to an existing task.
|
||||||
|
"""
|
||||||
|
params = params or {}
|
||||||
|
|
||||||
|
data = {}
|
||||||
|
if resource_type is ResourceType.LOCAL:
|
||||||
|
pass # handled later
|
||||||
|
elif resource_type is ResourceType.REMOTE:
|
||||||
|
data = {f"remote_files[{i}]": f for i, f in enumerate(resources)}
|
||||||
|
elif resource_type is ResourceType.SHARE:
|
||||||
|
data = {f"server_files[{i}]": f for i, f in enumerate(resources)}
|
||||||
|
|
||||||
|
data["image_quality"] = 70
|
||||||
|
data.update(
|
||||||
|
filter_dict(
|
||||||
|
params,
|
||||||
|
keep=[
|
||||||
|
"chunk_size",
|
||||||
|
"copy_data",
|
||||||
|
"image_quality",
|
||||||
|
"sorting_method",
|
||||||
|
"start_frame",
|
||||||
|
"stop_frame",
|
||||||
|
"use_cache",
|
||||||
|
"use_zip_chunks",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if params.get("frame_step") is not None:
|
||||||
|
data["frame_filter"] = f"step={params.get('frame_step')}"
|
||||||
|
|
||||||
|
if resource_type in [ResourceType.REMOTE, ResourceType.SHARE]:
|
||||||
|
self.api.create_data(
|
||||||
|
self.id,
|
||||||
|
data_request=models.DataRequest(**data),
|
||||||
|
_content_type="multipart/form-data",
|
||||||
|
)
|
||||||
|
elif resource_type == ResourceType.LOCAL:
|
||||||
|
url = self._client.api_map.make_endpoint_url(
|
||||||
|
self.api.create_data_endpoint.path, kwsub={"id": self.id}
|
||||||
|
)
|
||||||
|
|
||||||
|
DataUploader(self._client).upload_files(url, resources, pbar=pbar, **data)
|
||||||
|
|
||||||
|
def import_annotations(
|
||||||
|
self,
|
||||||
|
format_name: str,
|
||||||
|
filename: str,
|
||||||
|
*,
|
||||||
|
status_check_period: Optional[int] = None,
|
||||||
|
pbar: Optional[ProgressReporter] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Upload annotations for a task in the specified format (e.g. 'YOLO ZIP 1.0').
|
||||||
|
"""
|
||||||
|
|
||||||
|
AnnotationUploader(self._client).upload_file_and_wait(
|
||||||
|
self.api.create_annotations_endpoint,
|
||||||
|
filename,
|
||||||
|
format_name,
|
||||||
|
url_params={"id": self.id},
|
||||||
|
pbar=pbar,
|
||||||
|
status_check_period=status_check_period,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._client.logger.info(f"Annotation file '{filename}' for task #{self.id} uploaded")
|
||||||
|
|
||||||
|
def get_frame(
|
||||||
|
self,
|
||||||
|
frame_id: int,
|
||||||
|
*,
|
||||||
|
quality: Optional[str] = None,
|
||||||
|
) -> io.RawIOBase:
|
||||||
|
params = {}
|
||||||
|
if quality:
|
||||||
|
params["quality"] = quality
|
||||||
|
(_, response) = self.api.retrieve_data(self.id, number=frame_id, **params, type="frame")
|
||||||
|
return io.BytesIO(response.data)
|
||||||
|
|
||||||
|
def get_preview(
|
||||||
|
self,
|
||||||
|
) -> io.RawIOBase:
|
||||||
|
(_, response) = self.api.retrieve_data(self.id, type="preview")
|
||||||
|
return io.BytesIO(response.data)
|
||||||
|
|
||||||
|
def download_frames(
|
||||||
|
self,
|
||||||
|
frame_ids: Sequence[int],
|
||||||
|
*,
|
||||||
|
outdir: str = "",
|
||||||
|
quality: str = "original",
|
||||||
|
filename_pattern: str = "frame_{frame_id:06d}{frame_ext}",
|
||||||
|
) -> Optional[List[Image.Image]]:
|
||||||
|
"""
|
||||||
|
Download the requested frame numbers for a task and save images as outdir/filename_pattern
|
||||||
|
"""
|
||||||
|
# TODO: add arg descriptions in schema
|
||||||
|
os.makedirs(outdir, exist_ok=True)
|
||||||
|
|
||||||
|
for frame_id in frame_ids:
|
||||||
|
frame_bytes = self.get_frame(frame_id, quality=quality)
|
||||||
|
|
||||||
|
im = Image.open(frame_bytes)
|
||||||
|
mime_type = im.get_format_mimetype() or "image/jpg"
|
||||||
|
im_ext = mimetypes.guess_extension(mime_type)
|
||||||
|
|
||||||
|
# FIXME It is better to use meta information from the server
|
||||||
|
# to determine the extension
|
||||||
|
# replace '.jpe' or '.jpeg' with a more used '.jpg'
|
||||||
|
if im_ext in (".jpe", ".jpeg", None):
|
||||||
|
im_ext = ".jpg"
|
||||||
|
|
||||||
|
outfile = filename_pattern.format(frame_id=frame_id, frame_ext=im_ext)
|
||||||
|
im.save(osp.join(outdir, outfile))
|
||||||
|
|
||||||
|
def export_dataset(
|
||||||
|
self,
|
||||||
|
format_name: str,
|
||||||
|
filename: str,
|
||||||
|
*,
|
||||||
|
pbar: Optional[ProgressReporter] = None,
|
||||||
|
status_check_period: Optional[int] = None,
|
||||||
|
include_images: bool = True,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Download annotations for a task in the specified format (e.g. 'YOLO ZIP 1.0').
|
||||||
|
"""
|
||||||
|
if include_images:
|
||||||
|
endpoint = self.api.retrieve_dataset_endpoint
|
||||||
|
else:
|
||||||
|
endpoint = self.api.retrieve_annotations_endpoint
|
||||||
|
|
||||||
|
Downloader(self._client).prepare_and_download_file_from_endpoint(
|
||||||
|
endpoint=endpoint,
|
||||||
|
filename=filename,
|
||||||
|
url_params={"id": self.id},
|
||||||
|
query_params={"format": format_name},
|
||||||
|
pbar=pbar,
|
||||||
|
status_check_period=status_check_period,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._client.logger.info(f"Dataset for task {self.id} has been downloaded to {filename}")
|
||||||
|
|
||||||
|
def download_backup(
|
||||||
|
self,
|
||||||
|
filename: str,
|
||||||
|
*,
|
||||||
|
status_check_period: int = None,
|
||||||
|
pbar: Optional[ProgressReporter] = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Download a task backup
|
||||||
|
"""
|
||||||
|
|
||||||
|
Downloader(self._client).prepare_and_download_file_from_endpoint(
|
||||||
|
self.api.retrieve_backup_endpoint,
|
||||||
|
filename=filename,
|
||||||
|
pbar=pbar,
|
||||||
|
status_check_period=status_check_period,
|
||||||
|
url_params={"id": self.id},
|
||||||
|
)
|
||||||
|
|
||||||
|
self._client.logger.info(f"Backup for task {self.id} has been downloaded to {filename}")
|
||||||
|
|
||||||
|
def get_jobs(self) -> List[Job]:
|
||||||
|
return [Job(self._client, m) for m in self.api.list_jobs(id=self.id)[0]]
|
||||||
|
|
||||||
|
def get_meta(self) -> models.IDataMetaRead:
|
||||||
|
(meta, _) = self.api.retrieve_data_meta(self.id)
|
||||||
|
return meta
|
||||||
|
|
||||||
|
def get_frames_info(self) -> List[models.IFrameMeta]:
|
||||||
|
return self.get_meta().frames
|
||||||
|
|
||||||
|
def remove_frames_by_ids(self, ids: Sequence[int]) -> None:
|
||||||
|
self.api.partial_update_data_meta(
|
||||||
|
self.id,
|
||||||
|
patched_data_meta_write_request=models.PatchedDataMetaWriteRequest(deleted_frames=ids),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TasksRepo(
|
||||||
|
_TaskRepoBase,
|
||||||
|
ModelCreateMixin[Task, models.ITaskWriteRequest],
|
||||||
|
ModelRetrieveMixin[Task],
|
||||||
|
ModelListMixin[Task],
|
||||||
|
ModelDeleteMixin,
|
||||||
|
):
|
||||||
|
_entity_type = Task
|
||||||
|
|
||||||
|
def create_from_data(
|
||||||
|
self,
|
||||||
|
spec: models.ITaskWriteRequest,
|
||||||
|
resource_type: ResourceType,
|
||||||
|
resources: Sequence[str],
|
||||||
|
*,
|
||||||
|
data_params: Optional[Dict[str, Any]] = None,
|
||||||
|
annotation_path: str = "",
|
||||||
|
annotation_format: str = "CVAT XML 1.1",
|
||||||
|
status_check_period: int = None,
|
||||||
|
dataset_repository_url: str = "",
|
||||||
|
use_lfs: bool = False,
|
||||||
|
pbar: Optional[ProgressReporter] = None,
|
||||||
|
) -> Task:
|
||||||
|
"""
|
||||||
|
Create a new task with the given name and labels JSON and
|
||||||
|
add the files to it.
|
||||||
|
|
||||||
|
Returns: id of the created task
|
||||||
|
"""
|
||||||
|
if status_check_period is None:
|
||||||
|
status_check_period = self._client.config.status_check_period
|
||||||
|
|
||||||
|
if getattr(spec, "project_id", None) and getattr(spec, "labels", None):
|
||||||
|
raise exceptions.ApiValueError(
|
||||||
|
"Can't set labels to a task inside a project. "
|
||||||
|
"Tasks inside a project use project's labels.",
|
||||||
|
["labels"],
|
||||||
|
)
|
||||||
|
|
||||||
|
task = self.create(spec=spec)
|
||||||
|
self._client.logger.info("Created task ID: %s NAME: %s", task.id, task.name)
|
||||||
|
|
||||||
|
task.upload_data(resource_type, resources, pbar=pbar, params=data_params)
|
||||||
|
|
||||||
|
self._client.logger.info("Awaiting for task %s creation...", task.id)
|
||||||
|
status: models.RqStatus = None
|
||||||
|
while status != models.RqStatusStateEnum.allowed_values[("value",)]["FINISHED"]:
|
||||||
|
sleep(status_check_period)
|
||||||
|
(status, response) = self.api.retrieve_status(task.id)
|
||||||
|
|
||||||
|
self._client.logger.info(
|
||||||
|
"Task %s creation status=%s, message=%s",
|
||||||
|
task.id,
|
||||||
|
status.state.value,
|
||||||
|
status.message,
|
||||||
|
)
|
||||||
|
|
||||||
|
if status.state.value == models.RqStatusStateEnum.allowed_values[("value",)]["FAILED"]:
|
||||||
|
raise exceptions.ApiException(
|
||||||
|
status=status.state.value, reason=status.message, http_resp=response
|
||||||
|
)
|
||||||
|
|
||||||
|
status = status.state.value
|
||||||
|
|
||||||
|
if annotation_path:
|
||||||
|
task.import_annotations(annotation_format, annotation_path, pbar=pbar)
|
||||||
|
|
||||||
|
if dataset_repository_url:
|
||||||
|
git.create_git_repo(
|
||||||
|
self,
|
||||||
|
task_id=task.id,
|
||||||
|
repo_url=dataset_repository_url,
|
||||||
|
status_check_period=status_check_period,
|
||||||
|
use_lfs=use_lfs,
|
||||||
|
)
|
||||||
|
|
||||||
|
task.fetch()
|
||||||
|
|
||||||
|
return task
|
||||||
|
|
||||||
|
def remove_by_ids(self, task_ids: Sequence[int]) -> None:
|
||||||
|
"""
|
||||||
|
Delete a list of tasks, ignoring those which don't exist.
|
||||||
|
"""
|
||||||
|
|
||||||
|
for task_id in task_ids:
|
||||||
|
(_, response) = self.api.destroy(task_id, _check_status=False)
|
||||||
|
|
||||||
|
if 200 <= response.status <= 299:
|
||||||
|
self._client.logger.info(f"Task ID {task_id} deleted")
|
||||||
|
elif response.status == 404:
|
||||||
|
self._client.logger.info(f"Task ID {task_id} not found")
|
||||||
|
else:
|
||||||
|
self._client.logger.warning(
|
||||||
|
f"Failed to delete task ID {task_id}: "
|
||||||
|
f"{response.msg} (status {response.status})"
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_from_backup(
|
||||||
|
self,
|
||||||
|
filename: str,
|
||||||
|
*,
|
||||||
|
status_check_period: int = None,
|
||||||
|
pbar: Optional[ProgressReporter] = None,
|
||||||
|
) -> Task:
|
||||||
|
"""
|
||||||
|
Import a task from a backup file
|
||||||
|
"""
|
||||||
|
if status_check_period is None:
|
||||||
|
status_check_period = self._client.config.status_check_period
|
||||||
|
|
||||||
|
params = {"filename": osp.basename(filename)}
|
||||||
|
url = self._client.api_map.make_endpoint_url(self.api.create_backup_endpoint.path)
|
||||||
|
uploader = Uploader(self._client)
|
||||||
|
response = uploader.upload_file(
|
||||||
|
url,
|
||||||
|
filename,
|
||||||
|
meta=params,
|
||||||
|
query_params=params,
|
||||||
|
pbar=pbar,
|
||||||
|
logger=self._client.logger.debug,
|
||||||
|
)
|
||||||
|
|
||||||
|
rq_id = json.loads(response.data)["rq_id"]
|
||||||
|
response = self._client.wait_for_completion(
|
||||||
|
url,
|
||||||
|
success_status=201,
|
||||||
|
positive_statuses=[202],
|
||||||
|
post_params={"rq_id": rq_id},
|
||||||
|
status_check_period=status_check_period,
|
||||||
|
)
|
||||||
|
|
||||||
|
task_id = json.loads(response.data)["id"]
|
||||||
|
self._client.logger.info(f"Task has been imported sucessfully. Task ID: {task_id}")
|
||||||
|
|
||||||
|
return self.retrieve(task_id)
|
||||||
@ -0,0 +1,35 @@
|
|||||||
|
# Copyright (C) 2022 CVAT.ai Corporation
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from cvat_sdk.api_client import apis, models
|
||||||
|
from cvat_sdk.core.proxies.model_proxy import (
|
||||||
|
ModelDeleteMixin,
|
||||||
|
ModelListMixin,
|
||||||
|
ModelRetrieveMixin,
|
||||||
|
ModelUpdateMixin,
|
||||||
|
build_model_bases,
|
||||||
|
)
|
||||||
|
|
||||||
|
_UserEntityBase, _UserRepoBase = build_model_bases(
|
||||||
|
models.User, apis.UsersApi, api_member_name="users_api"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class User(
|
||||||
|
models.IUser, _UserEntityBase, ModelUpdateMixin[models.IPatchedUserRequest], ModelDeleteMixin
|
||||||
|
):
|
||||||
|
_model_partial_update_arg = "patched_user_request"
|
||||||
|
|
||||||
|
|
||||||
|
class UsersRepo(
|
||||||
|
_UserRepoBase,
|
||||||
|
ModelListMixin[User],
|
||||||
|
ModelRetrieveMixin[User],
|
||||||
|
):
|
||||||
|
_entity_type = User
|
||||||
|
|
||||||
|
def retrieve_current_user(self) -> User:
|
||||||
|
return User(self._client, self.api.retrieve_self()[0])
|
||||||
@ -1,308 +0,0 @@
|
|||||||
# Copyright (C) 2020-2022 Intel Corporation
|
|
||||||
# Copyright (C) 2022 CVAT.ai Corporation
|
|
||||||
#
|
|
||||||
# SPDX-License-Identifier: MIT
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import io
|
|
||||||
import mimetypes
|
|
||||||
import os
|
|
||||||
import os.path as osp
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from io import BytesIO
|
|
||||||
from time import sleep
|
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence
|
|
||||||
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from cvat_sdk import models
|
|
||||||
from cvat_sdk.api_client.model_utils import OpenApiModel
|
|
||||||
from cvat_sdk.core.downloading import Downloader
|
|
||||||
from cvat_sdk.core.progress import ProgressReporter
|
|
||||||
from cvat_sdk.core.types import ResourceType
|
|
||||||
from cvat_sdk.core.uploading import Uploader
|
|
||||||
from cvat_sdk.core.utils import filter_dict
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from cvat_sdk.core.client import Client
|
|
||||||
|
|
||||||
|
|
||||||
class ModelProxy(ABC):
|
|
||||||
_client: Client
|
|
||||||
_model: OpenApiModel
|
|
||||||
|
|
||||||
def __init__(self, client: Client, model: OpenApiModel) -> None:
|
|
||||||
self.__dict__["_client"] = client
|
|
||||||
self.__dict__["_model"] = model
|
|
||||||
|
|
||||||
def __getattr__(self, __name: str) -> Any:
|
|
||||||
return self._model[__name]
|
|
||||||
|
|
||||||
def __setattr__(self, __name: str, __value: Any) -> None:
|
|
||||||
if __name in self.__dict__:
|
|
||||||
self.__dict__[__name] = __value
|
|
||||||
else:
|
|
||||||
self._model[__name] = __value
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def fetch(self, force: bool = False):
|
|
||||||
"""Fetches model data from the server"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def commit(self, force: bool = False):
|
|
||||||
"""Commits local changes to the server"""
|
|
||||||
...
|
|
||||||
|
|
||||||
def sync(self):
|
|
||||||
"""Pulls server state and commits local model changes"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def update(self, **kwargs):
|
|
||||||
"""Updates multiple fields at once"""
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
class TaskProxy(ModelProxy, models.ITaskRead):
|
|
||||||
def __init__(self, client: Client, task: models.TaskRead):
|
|
||||||
ModelProxy.__init__(self, client=client, model=task)
|
|
||||||
|
|
||||||
def remove(self):
|
|
||||||
self._client.api.tasks_api.destroy(self.id)
|
|
||||||
|
|
||||||
def upload_data(
|
|
||||||
self,
|
|
||||||
resource_type: ResourceType,
|
|
||||||
resources: Sequence[str],
|
|
||||||
*,
|
|
||||||
pbar: Optional[ProgressReporter] = None,
|
|
||||||
params: Optional[Dict[str, Any]] = None,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Add local, remote, or shared files to an existing task.
|
|
||||||
"""
|
|
||||||
client = self._client
|
|
||||||
task_id = self.id
|
|
||||||
|
|
||||||
params = params or {}
|
|
||||||
|
|
||||||
data = {}
|
|
||||||
if resource_type is ResourceType.LOCAL:
|
|
||||||
pass # handled later
|
|
||||||
elif resource_type is ResourceType.REMOTE:
|
|
||||||
data = {f"remote_files[{i}]": f for i, f in enumerate(resources)}
|
|
||||||
elif resource_type is ResourceType.SHARE:
|
|
||||||
data = {f"server_files[{i}]": f for i, f in enumerate(resources)}
|
|
||||||
|
|
||||||
data["image_quality"] = 70
|
|
||||||
data.update(
|
|
||||||
filter_dict(
|
|
||||||
params,
|
|
||||||
keep=[
|
|
||||||
"chunk_size",
|
|
||||||
"copy_data",
|
|
||||||
"image_quality",
|
|
||||||
"sorting_method",
|
|
||||||
"start_frame",
|
|
||||||
"stop_frame",
|
|
||||||
"use_cache",
|
|
||||||
"use_zip_chunks",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if params.get("frame_step") is not None:
|
|
||||||
data["frame_filter"] = f"step={params.get('frame_step')}"
|
|
||||||
|
|
||||||
if resource_type in [ResourceType.REMOTE, ResourceType.SHARE]:
|
|
||||||
client.api.tasks_api.create_data(
|
|
||||||
task_id,
|
|
||||||
data_request=models.DataRequest(**data),
|
|
||||||
_content_type="multipart/form-data",
|
|
||||||
)
|
|
||||||
elif resource_type == ResourceType.LOCAL:
|
|
||||||
url = client._api_map.make_endpoint_url(
|
|
||||||
client.api.tasks_api.create_data_endpoint.path, kwsub={"id": task_id}
|
|
||||||
)
|
|
||||||
|
|
||||||
uploader = Uploader(client)
|
|
||||||
uploader.upload_files(url, resources, pbar=pbar, **data)
|
|
||||||
|
|
||||||
def import_annotations(
|
|
||||||
self,
|
|
||||||
format_name: str,
|
|
||||||
filename: str,
|
|
||||||
*,
|
|
||||||
status_check_period: int = None,
|
|
||||||
pbar: Optional[ProgressReporter] = None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Upload annotations for a task in the specified format
|
|
||||||
(e.g. 'YOLO ZIP 1.0').
|
|
||||||
"""
|
|
||||||
client = self._client
|
|
||||||
if status_check_period is None:
|
|
||||||
status_check_period = client.config.status_check_period
|
|
||||||
|
|
||||||
task_id = self.id
|
|
||||||
|
|
||||||
url = client._api_map.make_endpoint_url(
|
|
||||||
client.api.tasks_api.create_annotations_endpoint.path,
|
|
||||||
kwsub={"id": task_id},
|
|
||||||
)
|
|
||||||
params = {"format": format_name, "filename": osp.basename(filename)}
|
|
||||||
|
|
||||||
uploader = Uploader(client)
|
|
||||||
uploader.upload_file(
|
|
||||||
url, filename, pbar=pbar, query_params=params, meta={"filename": params["filename"]}
|
|
||||||
)
|
|
||||||
|
|
||||||
while True:
|
|
||||||
response = client.api.rest_client.POST(
|
|
||||||
url, headers=client.api.get_common_headers(), query_params=params
|
|
||||||
)
|
|
||||||
if response.status == 201:
|
|
||||||
break
|
|
||||||
|
|
||||||
sleep(status_check_period)
|
|
||||||
|
|
||||||
client.logger.info(
|
|
||||||
f"Upload job for Task ID {task_id} with annotation file {filename} finished"
|
|
||||||
)
|
|
||||||
|
|
||||||
def retrieve_frame(
|
|
||||||
self,
|
|
||||||
frame_id: int,
|
|
||||||
*,
|
|
||||||
quality: Optional[str] = None,
|
|
||||||
) -> io.RawIOBase:
|
|
||||||
client = self._client
|
|
||||||
task_id = self.id
|
|
||||||
|
|
||||||
(_, response) = client.api.tasks_api.retrieve_data(task_id, frame_id, quality, type="frame")
|
|
||||||
|
|
||||||
return BytesIO(response.data)
|
|
||||||
|
|
||||||
def download_frames(
|
|
||||||
self,
|
|
||||||
frame_ids: Sequence[int],
|
|
||||||
*,
|
|
||||||
outdir: str = "",
|
|
||||||
quality: str = "original",
|
|
||||||
filename_pattern: str = "task_{task_id}_frame_{frame_id:06d}{frame_ext}",
|
|
||||||
) -> Optional[List[Image.Image]]:
|
|
||||||
"""
|
|
||||||
Download the requested frame numbers for a task and save images as
|
|
||||||
outdir/filename_pattern
|
|
||||||
"""
|
|
||||||
# TODO: add arg descriptions in schema
|
|
||||||
task_id = self.id
|
|
||||||
|
|
||||||
os.makedirs(outdir, exist_ok=True)
|
|
||||||
|
|
||||||
for frame_id in frame_ids:
|
|
||||||
frame_bytes = self.retrieve_frame(frame_id, quality=quality)
|
|
||||||
|
|
||||||
im = Image.open(frame_bytes)
|
|
||||||
mime_type = im.get_format_mimetype() or "image/jpg"
|
|
||||||
im_ext = mimetypes.guess_extension(mime_type)
|
|
||||||
|
|
||||||
# FIXME It is better to use meta information from the server
|
|
||||||
# to determine the extension
|
|
||||||
# replace '.jpe' or '.jpeg' with a more used '.jpg'
|
|
||||||
if im_ext in (".jpe", ".jpeg", None):
|
|
||||||
im_ext = ".jpg"
|
|
||||||
|
|
||||||
outfile = filename_pattern.format(task_id=task_id, frame_id=frame_id, frame_ext=im_ext)
|
|
||||||
im.save(osp.join(outdir, outfile))
|
|
||||||
|
|
||||||
def export_dataset(
|
|
||||||
self,
|
|
||||||
format_name: str,
|
|
||||||
filename: str,
|
|
||||||
*,
|
|
||||||
pbar: Optional[ProgressReporter] = None,
|
|
||||||
status_check_period: int = None,
|
|
||||||
include_images: bool = True,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Download annotations for a task in the specified format (e.g. 'YOLO ZIP 1.0').
|
|
||||||
"""
|
|
||||||
client = self._client
|
|
||||||
if status_check_period is None:
|
|
||||||
status_check_period = client.config.status_check_period
|
|
||||||
|
|
||||||
task_id = self.id
|
|
||||||
|
|
||||||
params = {"filename": self.name, "format": format_name}
|
|
||||||
if include_images:
|
|
||||||
endpoint = client.api.tasks_api.retrieve_dataset_endpoint
|
|
||||||
else:
|
|
||||||
endpoint = client.api.tasks_api.retrieve_annotations_endpoint
|
|
||||||
|
|
||||||
client.logger.info("Waiting for the server to prepare the file...")
|
|
||||||
while True:
|
|
||||||
(_, response) = endpoint.call_with_http_info(id=task_id, **params)
|
|
||||||
client.logger.debug("STATUS {}".format(response.status))
|
|
||||||
if response.status == 201:
|
|
||||||
break
|
|
||||||
sleep(status_check_period)
|
|
||||||
|
|
||||||
params["action"] = "download"
|
|
||||||
url = client._api_map.make_endpoint_url(
|
|
||||||
endpoint.path, kwsub={"id": task_id}, query_params=params
|
|
||||||
)
|
|
||||||
downloader = Downloader(client)
|
|
||||||
downloader.download_file(url, output_path=filename, pbar=pbar)
|
|
||||||
|
|
||||||
client.logger.info(f"Dataset has been exported to {filename}")
|
|
||||||
|
|
||||||
def download_backup(
|
|
||||||
self,
|
|
||||||
filename: str,
|
|
||||||
*,
|
|
||||||
status_check_period: int = None,
|
|
||||||
pbar: Optional[ProgressReporter] = None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Download a task backup
|
|
||||||
"""
|
|
||||||
client = self._client
|
|
||||||
if status_check_period is None:
|
|
||||||
status_check_period = client.config.status_check_period
|
|
||||||
|
|
||||||
task_id = self.id
|
|
||||||
|
|
||||||
endpoint = client.api.tasks_api.retrieve_backup_endpoint
|
|
||||||
client.logger.info("Waiting for the server to prepare the file...")
|
|
||||||
while True:
|
|
||||||
(_, response) = endpoint.call_with_http_info(id=task_id)
|
|
||||||
client.logger.debug("STATUS {}".format(response.status))
|
|
||||||
if response.status == 201:
|
|
||||||
break
|
|
||||||
sleep(status_check_period)
|
|
||||||
|
|
||||||
url = client._api_map.make_endpoint_url(
|
|
||||||
endpoint.path, kwsub={"id": task_id}, query_params={"action": "download"}
|
|
||||||
)
|
|
||||||
downloader = Downloader(client)
|
|
||||||
downloader.download_file(url, output_path=filename, pbar=pbar)
|
|
||||||
|
|
||||||
client.logger.info(
|
|
||||||
f"Task {task_id} has been exported sucessfully to {osp.abspath(filename)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def fetch(self, force: bool = False):
|
|
||||||
# TODO: implement revision checking
|
|
||||||
model, _ = self._client.api.tasks_api.retrieve(self.id)
|
|
||||||
self._model = model
|
|
||||||
|
|
||||||
def commit(self, force: bool = False):
|
|
||||||
return super().commit(force)
|
|
||||||
|
|
||||||
def update(self, **kwargs):
|
|
||||||
return super().update(**kwargs)
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
return str(self._model)
|
|
||||||
@ -1,18 +0,0 @@
|
|||||||
# Copyright (C) 2022 Intel Corporation
|
|
||||||
# Copyright (C) 2022 CVAT.ai Corporation
|
|
||||||
#
|
|
||||||
# SPDX-License-Identifier: MIT
|
|
||||||
|
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
|
|
||||||
class ResourceType(Enum):
|
|
||||||
LOCAL = 0
|
|
||||||
SHARE = 1
|
|
||||||
REMOTE = 2
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return self.name.lower()
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return str(self)
|
|
||||||
@ -0,0 +1,26 @@
|
|||||||
|
# Copyright (C) 2022 CVAT.ai Corporation
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
from http import HTTPStatus
|
||||||
|
from time import sleep
|
||||||
|
|
||||||
|
from cvat_sdk.api_client.api_client import Endpoint
|
||||||
|
from urllib3 import HTTPResponse
|
||||||
|
|
||||||
|
|
||||||
|
def export_dataset(
|
||||||
|
endpoint: Endpoint, *, max_retries: int = 20, interval: float = 0.1, **kwargs
|
||||||
|
) -> HTTPResponse:
|
||||||
|
for _ in range(max_retries):
|
||||||
|
(_, response) = endpoint.call_with_http_info(**kwargs, _parse_response=False)
|
||||||
|
if response.status == HTTPStatus.CREATED:
|
||||||
|
break
|
||||||
|
assert response.status == HTTPStatus.ACCEPTED
|
||||||
|
sleep(interval)
|
||||||
|
assert response.status == HTTPStatus.CREATED
|
||||||
|
|
||||||
|
(_, response) = endpoint.call_with_http_info(**kwargs, action="download", _parse_response=False)
|
||||||
|
assert response.status == HTTPStatus.OK
|
||||||
|
|
||||||
|
return response
|
||||||
@ -0,0 +1,236 @@
|
|||||||
|
# Copyright (C) 2022 CVAT.ai Corporation
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
import io
|
||||||
|
from logging import Logger
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from cvat_sdk import Client
|
||||||
|
from cvat_sdk.api_client import exceptions, models
|
||||||
|
from cvat_sdk.core.proxies.tasks import ResourceType, Task
|
||||||
|
|
||||||
|
from shared.utils.config import USER_PASS
|
||||||
|
|
||||||
|
|
||||||
|
class TestIssuesUsecases:
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def setup(
|
||||||
|
self,
|
||||||
|
changedb, # force fixture call order to allow DB setup
|
||||||
|
tmp_path: Path,
|
||||||
|
fxt_logger: Tuple[Logger, io.StringIO],
|
||||||
|
fxt_client: Client,
|
||||||
|
fxt_stdout: io.StringIO,
|
||||||
|
admin_user: str,
|
||||||
|
):
|
||||||
|
self.tmp_path = tmp_path
|
||||||
|
_, self.logger_stream = fxt_logger
|
||||||
|
self.client = fxt_client
|
||||||
|
self.stdout = fxt_stdout
|
||||||
|
self.user = admin_user
|
||||||
|
self.client.login((self.user, USER_PASS))
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def fxt_new_task(self, fxt_image_file: Path):
|
||||||
|
task = self.client.tasks.create_from_data(
|
||||||
|
spec={
|
||||||
|
"name": "test_task",
|
||||||
|
"labels": [{"name": "car"}, {"name": "person"}],
|
||||||
|
},
|
||||||
|
resource_type=ResourceType.LOCAL,
|
||||||
|
resources=[str(fxt_image_file)],
|
||||||
|
data_params={"image_quality": 80},
|
||||||
|
)
|
||||||
|
|
||||||
|
return task
|
||||||
|
|
||||||
|
def test_can_retrieve_issue(self, fxt_new_task: Task):
|
||||||
|
issue = self.client.issues.create(
|
||||||
|
models.IssueWriteRequest(
|
||||||
|
frame=0,
|
||||||
|
position=[2.0, 4.0],
|
||||||
|
job=fxt_new_task.get_jobs()[0].id,
|
||||||
|
message="hello",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
retrieved_issue = self.client.issues.retrieve(issue.id)
|
||||||
|
|
||||||
|
assert issue.id == retrieved_issue.id
|
||||||
|
assert self.stdout.getvalue() == ""
|
||||||
|
|
||||||
|
def test_can_list_issues(self, fxt_new_task: Task):
|
||||||
|
issue = self.client.issues.create(
|
||||||
|
models.IssueWriteRequest(
|
||||||
|
frame=0,
|
||||||
|
position=[2.0, 4.0],
|
||||||
|
job=fxt_new_task.get_jobs()[0].id,
|
||||||
|
message="hello",
|
||||||
|
assignee=self.client.users.list()[0].id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
issues = self.client.issues.list()
|
||||||
|
|
||||||
|
assert any(issue.id == j.id for j in issues)
|
||||||
|
assert self.stdout.getvalue() == ""
|
||||||
|
|
||||||
|
def test_can_list_comments(self, fxt_new_task: Task):
|
||||||
|
issue = self.client.issues.create(
|
||||||
|
models.IssueWriteRequest(
|
||||||
|
frame=0,
|
||||||
|
position=[2.0, 4.0],
|
||||||
|
job=fxt_new_task.get_jobs()[0].id,
|
||||||
|
message="hello",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
comment = self.client.comments.create(models.CommentWriteRequest(issue.id, message="hi!"))
|
||||||
|
issue.fetch()
|
||||||
|
|
||||||
|
comment_ids = {c.id for c in issue.comments}
|
||||||
|
|
||||||
|
assert len(comment_ids) == 2
|
||||||
|
assert comment.id in comment_ids
|
||||||
|
assert self.stdout.getvalue() == ""
|
||||||
|
|
||||||
|
def test_can_modify_issue(self, fxt_new_task: Task):
|
||||||
|
issue = self.client.issues.create(
|
||||||
|
models.IssueWriteRequest(
|
||||||
|
frame=0,
|
||||||
|
position=[2.0, 4.0],
|
||||||
|
job=fxt_new_task.get_jobs()[0].id,
|
||||||
|
message="hello",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
issue.update(models.PatchedIssueWriteRequest(resolved=True))
|
||||||
|
|
||||||
|
retrieved_issue = self.client.issues.retrieve(issue.id)
|
||||||
|
assert retrieved_issue.resolved is True
|
||||||
|
assert issue.resolved == retrieved_issue.resolved
|
||||||
|
assert self.stdout.getvalue() == ""
|
||||||
|
|
||||||
|
def test_can_remove_issue(self, fxt_new_task: Task):
|
||||||
|
issue = self.client.issues.create(
|
||||||
|
models.IssueWriteRequest(
|
||||||
|
frame=0,
|
||||||
|
position=[2.0, 4.0],
|
||||||
|
job=fxt_new_task.get_jobs()[0].id,
|
||||||
|
message="hello",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
issue.remove()
|
||||||
|
|
||||||
|
with pytest.raises(exceptions.NotFoundException):
|
||||||
|
issue.fetch()
|
||||||
|
with pytest.raises(exceptions.NotFoundException):
|
||||||
|
self.client.comments.retrieve(issue.comments[0].id)
|
||||||
|
assert self.stdout.getvalue() == ""
|
||||||
|
|
||||||
|
|
||||||
|
class TestCommentsUsecases:
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def setup(
|
||||||
|
self,
|
||||||
|
changedb, # force fixture call order to allow DB setup
|
||||||
|
tmp_path: Path,
|
||||||
|
fxt_logger: Tuple[Logger, io.StringIO],
|
||||||
|
fxt_client: Client,
|
||||||
|
fxt_stdout: io.StringIO,
|
||||||
|
admin_user: str,
|
||||||
|
):
|
||||||
|
self.tmp_path = tmp_path
|
||||||
|
_, self.logger_stream = fxt_logger
|
||||||
|
self.client = fxt_client
|
||||||
|
self.stdout = fxt_stdout
|
||||||
|
self.user = admin_user
|
||||||
|
self.client.login((self.user, USER_PASS))
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def fxt_new_task(self, fxt_image_file: Path):
|
||||||
|
task = self.client.tasks.create_from_data(
|
||||||
|
spec={
|
||||||
|
"name": "test_task",
|
||||||
|
"labels": [{"name": "car"}, {"name": "person"}],
|
||||||
|
},
|
||||||
|
resource_type=ResourceType.LOCAL,
|
||||||
|
resources=[str(fxt_image_file)],
|
||||||
|
data_params={"image_quality": 80},
|
||||||
|
)
|
||||||
|
|
||||||
|
return task
|
||||||
|
|
||||||
|
def test_can_retrieve_comment(self, fxt_new_task: Task):
|
||||||
|
issue = self.client.issues.create(
|
||||||
|
models.IssueWriteRequest(
|
||||||
|
frame=0,
|
||||||
|
position=[2.0, 4.0],
|
||||||
|
job=fxt_new_task.get_jobs()[0].id,
|
||||||
|
message="hello",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
comment = self.client.comments.create(models.CommentWriteRequest(issue.id, message="hi!"))
|
||||||
|
|
||||||
|
retrieved_comment = self.client.comments.retrieve(comment.id)
|
||||||
|
|
||||||
|
assert comment.id == retrieved_comment.id
|
||||||
|
assert self.stdout.getvalue() == ""
|
||||||
|
|
||||||
|
def test_can_list_comments(self, fxt_new_task: Task):
|
||||||
|
issue = self.client.issues.create(
|
||||||
|
models.IssueWriteRequest(
|
||||||
|
frame=0,
|
||||||
|
position=[2.0, 4.0],
|
||||||
|
job=fxt_new_task.get_jobs()[0].id,
|
||||||
|
message="hello",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
comment = self.client.comments.create(models.CommentWriteRequest(issue.id, message="hi!"))
|
||||||
|
|
||||||
|
comments = self.client.comments.list()
|
||||||
|
|
||||||
|
assert any(comment.id == c.id for c in comments)
|
||||||
|
assert self.stdout.getvalue() == ""
|
||||||
|
|
||||||
|
def test_can_modify_comment(self, fxt_new_task: Task):
|
||||||
|
issue = self.client.issues.create(
|
||||||
|
models.IssueWriteRequest(
|
||||||
|
frame=0,
|
||||||
|
position=[2.0, 4.0],
|
||||||
|
job=fxt_new_task.get_jobs()[0].id,
|
||||||
|
message="hello",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
comment = self.client.comments.create(models.CommentWriteRequest(issue.id, message="hi!"))
|
||||||
|
|
||||||
|
comment.update(models.PatchedCommentWriteRequest(message="bar"))
|
||||||
|
|
||||||
|
retrieved_comment = self.client.comments.retrieve(comment.id)
|
||||||
|
assert retrieved_comment.message == "bar"
|
||||||
|
assert comment.message == retrieved_comment.message
|
||||||
|
assert self.stdout.getvalue() == ""
|
||||||
|
|
||||||
|
def test_can_remove_comment(self, fxt_new_task: Task):
|
||||||
|
issue = self.client.issues.create(
|
||||||
|
models.IssueWriteRequest(
|
||||||
|
frame=0,
|
||||||
|
position=[2.0, 4.0],
|
||||||
|
job=fxt_new_task.get_jobs()[0].id,
|
||||||
|
message="hello",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
comment = self.client.comments.create(models.CommentWriteRequest(issue.id, message="hi!"))
|
||||||
|
|
||||||
|
comment.remove()
|
||||||
|
|
||||||
|
with pytest.raises(exceptions.NotFoundException):
|
||||||
|
comment.fetch()
|
||||||
|
assert self.stdout.getvalue() == ""
|
||||||
@ -0,0 +1,280 @@
|
|||||||
|
# Copyright (C) 2022 CVAT.ai Corporation
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
import io
|
||||||
|
import os.path as osp
|
||||||
|
from logging import Logger
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from cvat_sdk import Client
|
||||||
|
from cvat_sdk.api_client import models
|
||||||
|
from cvat_sdk.core.proxies.tasks import ResourceType, Task
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from shared.utils.config import USER_PASS
|
||||||
|
|
||||||
|
from .util import make_pbar
|
||||||
|
|
||||||
|
|
||||||
|
class TestJobUsecases:
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def setup(
|
||||||
|
self,
|
||||||
|
changedb, # force fixture call order to allow DB setup
|
||||||
|
tmp_path: Path,
|
||||||
|
fxt_logger: Tuple[Logger, io.StringIO],
|
||||||
|
fxt_client: Client,
|
||||||
|
fxt_stdout: io.StringIO,
|
||||||
|
admin_user: str,
|
||||||
|
):
|
||||||
|
self.tmp_path = tmp_path
|
||||||
|
_, self.logger_stream = fxt_logger
|
||||||
|
self.client = fxt_client
|
||||||
|
self.stdout = fxt_stdout
|
||||||
|
self.user = admin_user
|
||||||
|
self.client.login((self.user, USER_PASS))
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def fxt_new_task(self, fxt_image_file: Path):
|
||||||
|
task = self.client.tasks.create_from_data(
|
||||||
|
spec={
|
||||||
|
"name": "test_task",
|
||||||
|
"labels": [{"name": "car"}, {"name": "person"}],
|
||||||
|
},
|
||||||
|
resource_type=ResourceType.LOCAL,
|
||||||
|
resources=[str(fxt_image_file)],
|
||||||
|
data_params={"image_quality": 80},
|
||||||
|
)
|
||||||
|
|
||||||
|
return task
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def fxt_task_with_shapes(self, fxt_new_task: Task):
|
||||||
|
fxt_new_task.set_annotations(
|
||||||
|
models.LabeledDataRequest(
|
||||||
|
shapes=[
|
||||||
|
models.LabeledShapeRequest(
|
||||||
|
frame=0,
|
||||||
|
label_id=fxt_new_task.labels[0].id,
|
||||||
|
type="rectangle",
|
||||||
|
points=[1, 1, 2, 2],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return fxt_new_task
|
||||||
|
|
||||||
|
def test_can_retrieve_job(self, fxt_new_task: Task):
|
||||||
|
job_id = fxt_new_task.get_jobs()[0].id
|
||||||
|
|
||||||
|
job = self.client.jobs.retrieve(job_id)
|
||||||
|
|
||||||
|
assert job.id == job_id
|
||||||
|
assert self.stdout.getvalue() == ""
|
||||||
|
|
||||||
|
def test_can_list_jobs(self, fxt_new_task: Task):
|
||||||
|
task_job_ids = set(j.id for j in fxt_new_task.get_jobs())
|
||||||
|
|
||||||
|
jobs = self.client.jobs.list()
|
||||||
|
|
||||||
|
assert len(task_job_ids) != 0
|
||||||
|
assert task_job_ids.issubset(j.id for j in jobs)
|
||||||
|
assert self.stdout.getvalue() == ""
|
||||||
|
|
||||||
|
def test_can_update_job_field_directly(self, fxt_new_task: Task):
|
||||||
|
job = self.client.jobs.list()[0]
|
||||||
|
assert not job.assignee
|
||||||
|
new_assignee = self.client.users.list()[0]
|
||||||
|
|
||||||
|
job.update({"assignee": new_assignee.id})
|
||||||
|
|
||||||
|
updated_job = self.client.jobs.retrieve(job.id)
|
||||||
|
assert updated_job.assignee.id == new_assignee.id
|
||||||
|
assert self.stdout.getvalue() == ""
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("include_images", (True, False))
|
||||||
|
def test_can_download_dataset(self, fxt_new_task: Task, include_images: bool):
|
||||||
|
pbar_out = io.StringIO()
|
||||||
|
pbar = make_pbar(file=pbar_out)
|
||||||
|
|
||||||
|
task_id = fxt_new_task.id
|
||||||
|
path = str(self.tmp_path / f"task_{task_id}-cvat.zip")
|
||||||
|
job = self.client.jobs.retrieve(task_id)
|
||||||
|
job.export_dataset(
|
||||||
|
format_name="CVAT for images 1.1",
|
||||||
|
filename=path,
|
||||||
|
pbar=pbar,
|
||||||
|
include_images=include_images,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "100%" in pbar_out.getvalue().strip("\r").split("\r")[-1]
|
||||||
|
assert osp.isfile(path)
|
||||||
|
assert self.stdout.getvalue() == ""
|
||||||
|
|
||||||
|
def test_can_download_preview(self, fxt_new_task: Task):
|
||||||
|
frame_encoded = fxt_new_task.get_jobs()[0].get_preview()
|
||||||
|
|
||||||
|
assert Image.open(frame_encoded).size != 0
|
||||||
|
assert self.stdout.getvalue() == ""
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("quality", ("compressed", "original"))
|
||||||
|
def test_can_download_frame(self, fxt_new_task: Task, quality: str):
|
||||||
|
frame_encoded = fxt_new_task.get_jobs()[0].get_frame(0, quality=quality)
|
||||||
|
|
||||||
|
assert Image.open(frame_encoded).size != 0
|
||||||
|
assert self.stdout.getvalue() == ""
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("quality", ("compressed", "original"))
|
||||||
|
def test_can_download_frames(self, fxt_new_task: Task, quality: str):
|
||||||
|
fxt_new_task.get_jobs()[0].download_frames(
|
||||||
|
[0],
|
||||||
|
quality=quality,
|
||||||
|
outdir=str(self.tmp_path),
|
||||||
|
filename_pattern="frame-{frame_id}{frame_ext}",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert osp.isfile(self.tmp_path / "frame-0.jpg")
|
||||||
|
assert self.stdout.getvalue() == ""
|
||||||
|
|
||||||
|
def test_can_upload_annotations(self, fxt_new_task: Task, fxt_coco_file: Path):
|
||||||
|
pbar_out = io.StringIO()
|
||||||
|
pbar = make_pbar(file=pbar_out)
|
||||||
|
|
||||||
|
fxt_new_task.get_jobs()[0].import_annotations(
|
||||||
|
format_name="COCO 1.0", filename=str(fxt_coco_file), pbar=pbar
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "uploaded" in self.logger_stream.getvalue()
|
||||||
|
assert "100%" in pbar_out.getvalue().strip("\r").split("\r")[-1]
|
||||||
|
assert self.stdout.getvalue() == ""
|
||||||
|
|
||||||
|
def test_can_get_meta(self, fxt_new_task: Task):
|
||||||
|
meta = fxt_new_task.get_jobs()[0].get_meta()
|
||||||
|
|
||||||
|
assert meta.image_quality == 80
|
||||||
|
assert meta.size == 1
|
||||||
|
assert len(meta.frames) == meta.size
|
||||||
|
assert meta.frames[0].name == "img.png"
|
||||||
|
assert meta.frames[0].width == 5
|
||||||
|
assert meta.frames[0].height == 10
|
||||||
|
assert not meta.deleted_frames
|
||||||
|
assert self.stdout.getvalue() == ""
|
||||||
|
|
||||||
|
def test_can_remove_frames(self, fxt_new_task: Task):
|
||||||
|
fxt_new_task.get_jobs()[0].remove_frames_by_ids([0])
|
||||||
|
|
||||||
|
meta = fxt_new_task.get_jobs()[0].get_meta()
|
||||||
|
assert meta.deleted_frames == [0]
|
||||||
|
assert self.stdout.getvalue() == ""
|
||||||
|
|
||||||
|
def test_can_get_issues(self, fxt_new_task: Task):
|
||||||
|
issue = self.client.issues.create(
|
||||||
|
models.IssueWriteRequest(
|
||||||
|
frame=0,
|
||||||
|
position=[2.0, 4.0],
|
||||||
|
job=fxt_new_task.get_jobs()[0].id,
|
||||||
|
message="hello",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
job_issue_ids = set(j.id for j in fxt_new_task.get_jobs()[0].get_issues())
|
||||||
|
|
||||||
|
assert {issue.id} == job_issue_ids
|
||||||
|
assert self.stdout.getvalue() == ""
|
||||||
|
|
||||||
|
def test_can_get_annotations(self, fxt_task_with_shapes: Task):
|
||||||
|
anns = fxt_task_with_shapes.get_jobs()[0].get_annotations()
|
||||||
|
|
||||||
|
assert len(anns.shapes) == 1
|
||||||
|
assert anns.shapes[0].type.value == "rectangle"
|
||||||
|
assert self.stdout.getvalue() == ""
|
||||||
|
|
||||||
|
def test_can_set_annotations(self, fxt_new_task: Task):
|
||||||
|
fxt_new_task.get_jobs()[0].set_annotations(
|
||||||
|
models.LabeledDataRequest(
|
||||||
|
tags=[models.LabeledImageRequest(frame=0, label_id=fxt_new_task.labels[0].id)],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
anns = fxt_new_task.get_jobs()[0].get_annotations()
|
||||||
|
|
||||||
|
assert len(anns.tags) == 1
|
||||||
|
assert self.stdout.getvalue() == ""
|
||||||
|
|
||||||
|
def test_can_clear_annotations(self, fxt_task_with_shapes: Task):
|
||||||
|
fxt_task_with_shapes.get_jobs()[0].remove_annotations()
|
||||||
|
|
||||||
|
anns = fxt_task_with_shapes.get_jobs()[0].get_annotations()
|
||||||
|
assert len(anns.tags) == 0
|
||||||
|
assert len(anns.tracks) == 0
|
||||||
|
assert len(anns.shapes) == 0
|
||||||
|
assert self.stdout.getvalue() == ""
|
||||||
|
|
||||||
|
def test_can_remove_annotations(self, fxt_new_task: Task):
|
||||||
|
fxt_new_task.get_jobs()[0].set_annotations(
|
||||||
|
models.LabeledDataRequest(
|
||||||
|
shapes=[
|
||||||
|
models.LabeledShapeRequest(
|
||||||
|
frame=0,
|
||||||
|
label_id=fxt_new_task.labels[0].id,
|
||||||
|
type="rectangle",
|
||||||
|
points=[1, 1, 2, 2],
|
||||||
|
),
|
||||||
|
models.LabeledShapeRequest(
|
||||||
|
frame=0,
|
||||||
|
label_id=fxt_new_task.labels[0].id,
|
||||||
|
type="rectangle",
|
||||||
|
points=[2, 2, 3, 3],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
anns = fxt_new_task.get_jobs()[0].get_annotations()
|
||||||
|
|
||||||
|
fxt_new_task.get_jobs()[0].remove_annotations(ids=[anns.shapes[0].id])
|
||||||
|
|
||||||
|
anns = fxt_new_task.get_jobs()[0].get_annotations()
|
||||||
|
assert len(anns.tags) == 0
|
||||||
|
assert len(anns.tracks) == 0
|
||||||
|
assert len(anns.shapes) == 1
|
||||||
|
assert self.stdout.getvalue() == ""
|
||||||
|
|
||||||
|
def test_can_update_annotations(self, fxt_task_with_shapes: Task):
|
||||||
|
fxt_task_with_shapes.get_jobs()[0].update_annotations(
|
||||||
|
models.PatchedLabeledDataRequest(
|
||||||
|
shapes=[
|
||||||
|
models.LabeledShapeRequest(
|
||||||
|
frame=0,
|
||||||
|
label_id=fxt_task_with_shapes.labels[0].id,
|
||||||
|
type="rectangle",
|
||||||
|
points=[0, 1, 2, 3],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
tracks=[
|
||||||
|
models.LabeledTrackRequest(
|
||||||
|
frame=0,
|
||||||
|
label_id=fxt_task_with_shapes.labels[0].id,
|
||||||
|
shapes=[
|
||||||
|
models.TrackedShapeRequest(
|
||||||
|
frame=0, type="polygon", points=[3, 2, 2, 3, 3, 4]
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
],
|
||||||
|
tags=[
|
||||||
|
models.LabeledImageRequest(frame=0, label_id=fxt_task_with_shapes.labels[0].id)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
anns = fxt_task_with_shapes.get_jobs()[0].get_annotations()
|
||||||
|
assert len(anns.shapes) == 2
|
||||||
|
assert len(anns.tracks) == 1
|
||||||
|
assert len(anns.tags) == 1
|
||||||
|
assert self.stdout.getvalue() == ""
|
||||||
@ -0,0 +1,73 @@
|
|||||||
|
# Copyright (C) 2022 CVAT.ai Corporation
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
import io
|
||||||
|
from logging import Logger
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from cvat_sdk import Client, models
|
||||||
|
from cvat_sdk.api_client import exceptions
|
||||||
|
|
||||||
|
from shared.utils.config import USER_PASS
|
||||||
|
|
||||||
|
|
||||||
|
class TestUserUsecases:
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def setup(
|
||||||
|
self,
|
||||||
|
changedb, # force fixture call order to allow DB setup
|
||||||
|
tmp_path: Path,
|
||||||
|
fxt_logger: Tuple[Logger, io.StringIO],
|
||||||
|
fxt_client: Client,
|
||||||
|
fxt_stdout: io.StringIO,
|
||||||
|
admin_user: str,
|
||||||
|
):
|
||||||
|
self.tmp_path = tmp_path
|
||||||
|
_, self.logger_stream = fxt_logger
|
||||||
|
self.client = fxt_client
|
||||||
|
self.stdout = fxt_stdout
|
||||||
|
self.user = admin_user
|
||||||
|
self.client.login((self.user, USER_PASS))
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
def test_can_retrieve_user(self):
|
||||||
|
me = self.client.users.retrieve_current_user()
|
||||||
|
|
||||||
|
user = self.client.users.retrieve(me.id)
|
||||||
|
|
||||||
|
assert user.id == me.id
|
||||||
|
assert user.username == self.user
|
||||||
|
assert self.stdout.getvalue() == ""
|
||||||
|
|
||||||
|
def test_can_list_users(self):
|
||||||
|
users = self.client.users.list()
|
||||||
|
|
||||||
|
assert self.user in set(u.username for u in users)
|
||||||
|
assert self.stdout.getvalue() == ""
|
||||||
|
|
||||||
|
def test_can_update_user(self):
|
||||||
|
user = self.client.users.retrieve_current_user()
|
||||||
|
|
||||||
|
user.update(models.PatchedUserRequest(first_name="foo", last_name="bar"))
|
||||||
|
|
||||||
|
retrieved_user = self.client.users.retrieve(user.id)
|
||||||
|
assert retrieved_user.first_name == "foo"
|
||||||
|
assert retrieved_user.last_name == "bar"
|
||||||
|
assert user.first_name == retrieved_user.first_name
|
||||||
|
assert user.last_name == retrieved_user.last_name
|
||||||
|
assert self.stdout.getvalue() == ""
|
||||||
|
|
||||||
|
def test_can_remove_user(self):
|
||||||
|
users = self.client.users.list()
|
||||||
|
removed_user = next(u for u in users if u.username != self.user)
|
||||||
|
|
||||||
|
removed_user.remove()
|
||||||
|
|
||||||
|
with pytest.raises(exceptions.NotFoundException):
|
||||||
|
removed_user.fetch()
|
||||||
|
|
||||||
|
assert self.stdout.getvalue() == ""
|
||||||
Loading…
Reference in New Issue