SDK layer 2 - cover RC1 usecases (#4813)
parent
b60d3b481a
commit
53697ecac5
@ -0,0 +1,66 @@
|
||||
# Copyright (C) 2022 CVAT.ai Corporation
|
||||
#
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from abc import ABC
|
||||
from enum import Enum
|
||||
from typing import Optional, Sequence
|
||||
|
||||
from cvat_sdk import models
|
||||
from cvat_sdk.core.proxies.model_proxy import _EntityT
|
||||
|
||||
|
||||
class AnnotationUpdateAction(Enum):
|
||||
CREATE = "create"
|
||||
UPDATE = "update"
|
||||
DELETE = "delete"
|
||||
|
||||
|
||||
class AnnotationCrudMixin(ABC):
|
||||
# TODO: refactor
|
||||
|
||||
@property
|
||||
def _put_annotations_data_param(self) -> str:
|
||||
...
|
||||
|
||||
def get_annotations(self: _EntityT) -> models.ILabeledData:
|
||||
(annotations, _) = self.api.retrieve_annotations(getattr(self, self._model_id_field))
|
||||
return annotations
|
||||
|
||||
def set_annotations(self: _EntityT, data: models.ILabeledDataRequest):
|
||||
self.api.update_annotations(
|
||||
getattr(self, self._model_id_field), **{self._put_annotations_data_param: data}
|
||||
)
|
||||
|
||||
def update_annotations(
|
||||
self: _EntityT,
|
||||
data: models.IPatchedLabeledDataRequest,
|
||||
*,
|
||||
action: AnnotationUpdateAction = AnnotationUpdateAction.UPDATE,
|
||||
):
|
||||
self.api.partial_update_annotations(
|
||||
action=action.value,
|
||||
id=getattr(self, self._model_id_field),
|
||||
patched_labeled_data_request=data,
|
||||
)
|
||||
|
||||
def remove_annotations(self: _EntityT, *, ids: Optional[Sequence[int]] = None):
|
||||
if ids:
|
||||
anns = self.get_annotations()
|
||||
|
||||
if not isinstance(ids, set):
|
||||
ids = set(ids)
|
||||
|
||||
anns_to_remove = models.PatchedLabeledDataRequest(
|
||||
tags=[models.LabeledImageRequest(**a.to_dict()) for a in anns.tags if a.id in ids],
|
||||
tracks=[
|
||||
models.LabeledTrackRequest(**a.to_dict()) for a in anns.tracks if a.id in ids
|
||||
],
|
||||
shapes=[
|
||||
models.LabeledShapeRequest(**a.to_dict()) for a in anns.shapes if a.id in ids
|
||||
],
|
||||
)
|
||||
|
||||
self.update_annotations(anns_to_remove, action=AnnotationUpdateAction.DELETE)
|
||||
else:
|
||||
self.api.destroy_annotations(getattr(self, self._model_id_field))
|
||||
@ -0,0 +1,60 @@
|
||||
# Copyright (C) 2022 CVAT.ai Corporation
|
||||
#
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from cvat_sdk.api_client import apis, models
|
||||
from cvat_sdk.core.proxies.model_proxy import (
|
||||
ModelCreateMixin,
|
||||
ModelDeleteMixin,
|
||||
ModelListMixin,
|
||||
ModelRetrieveMixin,
|
||||
ModelUpdateMixin,
|
||||
build_model_bases,
|
||||
)
|
||||
|
||||
_CommentEntityBase, _CommentRepoBase = build_model_bases(
|
||||
models.CommentRead, apis.CommentsApi, api_member_name="comments_api"
|
||||
)
|
||||
|
||||
|
||||
class Comment(
|
||||
models.ICommentRead,
|
||||
_CommentEntityBase,
|
||||
ModelUpdateMixin[models.IPatchedCommentWriteRequest],
|
||||
ModelDeleteMixin,
|
||||
):
|
||||
_model_partial_update_arg = "patched_comment_write_request"
|
||||
|
||||
|
||||
class CommentsRepo(
|
||||
_CommentRepoBase,
|
||||
ModelListMixin[Comment],
|
||||
ModelCreateMixin[Comment, models.ICommentWriteRequest],
|
||||
ModelRetrieveMixin[Comment],
|
||||
):
|
||||
_entity_type = Comment
|
||||
|
||||
|
||||
_IssueEntityBase, _IssueRepoBase = build_model_bases(
|
||||
models.IssueRead, apis.IssuesApi, api_member_name="issues_api"
|
||||
)
|
||||
|
||||
|
||||
class Issue(
|
||||
models.IIssueRead,
|
||||
_IssueEntityBase,
|
||||
ModelUpdateMixin[models.IPatchedIssueWriteRequest],
|
||||
ModelDeleteMixin,
|
||||
):
|
||||
_model_partial_update_arg = "patched_issue_write_request"
|
||||
|
||||
|
||||
class IssuesRepo(
|
||||
_IssueRepoBase,
|
||||
ModelListMixin[Issue],
|
||||
ModelCreateMixin[Issue, models.IIssueWriteRequest],
|
||||
ModelRetrieveMixin[Issue],
|
||||
):
|
||||
_entity_type = Issue
|
||||
@ -0,0 +1,166 @@
|
||||
# Copyright (C) 2022 CVAT.ai Corporation
|
||||
#
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import mimetypes
|
||||
import os
|
||||
import os.path as osp
|
||||
from typing import List, Optional, Sequence
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from cvat_sdk.api_client import apis, models
|
||||
from cvat_sdk.core.downloading import Downloader
|
||||
from cvat_sdk.core.helpers import get_paginated_collection
|
||||
from cvat_sdk.core.progress import ProgressReporter
|
||||
from cvat_sdk.core.proxies.annotations import AnnotationCrudMixin
|
||||
from cvat_sdk.core.proxies.issues import Issue
|
||||
from cvat_sdk.core.proxies.model_proxy import (
|
||||
ModelListMixin,
|
||||
ModelRetrieveMixin,
|
||||
ModelUpdateMixin,
|
||||
build_model_bases,
|
||||
)
|
||||
from cvat_sdk.core.uploading import AnnotationUploader
|
||||
|
||||
_JobEntityBase, _JobRepoBase = build_model_bases(
|
||||
models.JobRead, apis.JobsApi, api_member_name="jobs_api"
|
||||
)
|
||||
|
||||
|
||||
class Job(
|
||||
models.IJobRead,
|
||||
_JobEntityBase,
|
||||
ModelUpdateMixin[models.IPatchedJobWriteRequest],
|
||||
AnnotationCrudMixin,
|
||||
):
|
||||
_model_partial_update_arg = "patched_job_write_request"
|
||||
_put_annotations_data_param = "job_annotations_update_request"
|
||||
|
||||
def import_annotations(
|
||||
self,
|
||||
format_name: str,
|
||||
filename: str,
|
||||
*,
|
||||
status_check_period: Optional[int] = None,
|
||||
pbar: Optional[ProgressReporter] = None,
|
||||
):
|
||||
"""
|
||||
Upload annotations for a job in the specified format (e.g. 'YOLO ZIP 1.0').
|
||||
"""
|
||||
|
||||
AnnotationUploader(self._client).upload_file_and_wait(
|
||||
self.api.create_annotations_endpoint,
|
||||
filename,
|
||||
format_name,
|
||||
url_params={"id": self.id},
|
||||
pbar=pbar,
|
||||
status_check_period=status_check_period,
|
||||
)
|
||||
|
||||
self._client.logger.info(f"Annotation file '{filename}' for job #{self.id} uploaded")
|
||||
|
||||
def export_dataset(
|
||||
self,
|
||||
format_name: str,
|
||||
filename: str,
|
||||
*,
|
||||
pbar: Optional[ProgressReporter] = None,
|
||||
status_check_period: Optional[int] = None,
|
||||
include_images: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Download annotations for a job in the specified format (e.g. 'YOLO ZIP 1.0').
|
||||
"""
|
||||
if include_images:
|
||||
endpoint = self.api.retrieve_dataset_endpoint
|
||||
else:
|
||||
endpoint = self.api.retrieve_annotations_endpoint
|
||||
|
||||
Downloader(self._client).prepare_and_download_file_from_endpoint(
|
||||
endpoint=endpoint,
|
||||
filename=filename,
|
||||
url_params={"id": self.id},
|
||||
query_params={"format": format_name},
|
||||
pbar=pbar,
|
||||
status_check_period=status_check_period,
|
||||
)
|
||||
|
||||
self._client.logger.info(f"Dataset for job {self.id} has been downloaded to {filename}")
|
||||
|
||||
def get_frame(
|
||||
self,
|
||||
frame_id: int,
|
||||
*,
|
||||
quality: Optional[str] = None,
|
||||
) -> io.RawIOBase:
|
||||
(_, response) = self.api.retrieve_data(
|
||||
self.id, number=frame_id, quality=quality, type="frame"
|
||||
)
|
||||
return io.BytesIO(response.data)
|
||||
|
||||
def get_preview(
|
||||
self,
|
||||
) -> io.RawIOBase:
|
||||
(_, response) = self.api.retrieve_data(self.id, type="preview")
|
||||
return io.BytesIO(response.data)
|
||||
|
||||
def download_frames(
|
||||
self,
|
||||
frame_ids: Sequence[int],
|
||||
*,
|
||||
outdir: str = "",
|
||||
quality: str = "original",
|
||||
filename_pattern: str = "frame_{frame_id:06d}{frame_ext}",
|
||||
) -> Optional[List[Image.Image]]:
|
||||
"""
|
||||
Download the requested frame numbers for a job and save images as outdir/filename_pattern
|
||||
"""
|
||||
# TODO: add arg descriptions in schema
|
||||
os.makedirs(outdir, exist_ok=True)
|
||||
|
||||
for frame_id in frame_ids:
|
||||
frame_bytes = self.get_frame(frame_id, quality=quality)
|
||||
|
||||
im = Image.open(frame_bytes)
|
||||
mime_type = im.get_format_mimetype() or "image/jpg"
|
||||
im_ext = mimetypes.guess_extension(mime_type)
|
||||
|
||||
# FIXME It is better to use meta information from the server
|
||||
# to determine the extension
|
||||
# replace '.jpe' or '.jpeg' with a more used '.jpg'
|
||||
if im_ext in (".jpe", ".jpeg", None):
|
||||
im_ext = ".jpg"
|
||||
|
||||
outfile = filename_pattern.format(frame_id=frame_id, frame_ext=im_ext)
|
||||
im.save(osp.join(outdir, outfile))
|
||||
|
||||
def get_meta(self) -> models.IDataMetaRead:
|
||||
(meta, _) = self.api.retrieve_data_meta(self.id)
|
||||
return meta
|
||||
|
||||
def get_frames_info(self) -> List[models.IFrameMeta]:
|
||||
return self.get_meta().frames
|
||||
|
||||
def remove_frames_by_ids(self, ids: Sequence[int]) -> None:
|
||||
self._client.api.tasks_api.jobs_partial_update_data_meta(
|
||||
self.id,
|
||||
patched_data_meta_write_request=models.PatchedDataMetaWriteRequest(deleted_frames=ids),
|
||||
)
|
||||
|
||||
def get_issues(self) -> List[Issue]:
|
||||
return [Issue(self._client, m) for m in self.api.list_issues(id=self.id)[0]]
|
||||
|
||||
def get_commits(self) -> List[models.IJobCommit]:
|
||||
return get_paginated_collection(self.api.list_commits_endpoint, id=self.id)
|
||||
|
||||
|
||||
class JobsRepo(
|
||||
_JobRepoBase,
|
||||
ModelListMixin[Job],
|
||||
ModelRetrieveMixin[Job],
|
||||
):
|
||||
_entity_type = Job
|
||||
@ -0,0 +1,213 @@
|
||||
# Copyright (C) 2022 CVAT.ai Corporation
|
||||
#
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from abc import ABC
|
||||
from copy import deepcopy
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Dict,
|
||||
Generic,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
overload,
|
||||
)
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
from cvat_sdk.api_client.model_utils import IModelData, ModelNormal, to_json
|
||||
from cvat_sdk.core.helpers import get_paginated_collection
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from cvat_sdk.core.client import Client
|
||||
|
||||
IModel = TypeVar("IModel", bound=IModelData)
|
||||
ModelType = TypeVar("ModelType", bound=ModelNormal)
|
||||
ApiType = TypeVar("ApiType")
|
||||
|
||||
|
||||
class ModelProxy(ABC, Generic[ModelType, ApiType]):
|
||||
_client: Client
|
||||
|
||||
@property
|
||||
def _api_member_name(self) -> str:
|
||||
...
|
||||
|
||||
def __init__(self, client: Client) -> None:
|
||||
self.__dict__["_client"] = client
|
||||
|
||||
@classmethod
|
||||
def get_api(cls, client: Client) -> ApiType:
|
||||
return getattr(client.api, cls._api_member_name)
|
||||
|
||||
@property
|
||||
def api(self) -> ApiType:
|
||||
return self.get_api(self._client)
|
||||
|
||||
|
||||
class Entity(ModelProxy[ModelType, ApiType]):
|
||||
"""
|
||||
Represents a single object. Implements related operations and provides access to data members.
|
||||
"""
|
||||
|
||||
_model: ModelType
|
||||
|
||||
def __init__(self, client: Client, model: ModelType) -> None:
|
||||
super().__init__(client)
|
||||
self.__dict__["_model"] = model
|
||||
|
||||
@property
|
||||
def _model_id_field(self) -> str:
|
||||
return "id"
|
||||
|
||||
def __getattr__(self, __name: str) -> Any:
|
||||
# NOTE: be aware of potential problems with throwing AttributeError from @property
|
||||
# in derived classes!
|
||||
# https://medium.com/@ceshine/python-debugging-pitfall-mixed-use-of-property-and-getattr-f89e0ede13f1
|
||||
return self._model[__name]
|
||||
|
||||
def __str__(self) -> str:
|
||||
return str(self._model)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__}: id={getattr(self, self._model_id_field)}>"
|
||||
|
||||
|
||||
class Repo(ModelProxy[ModelType, ApiType]):
|
||||
"""
|
||||
Represents a collection of corresponding Entity objects.
|
||||
Implements group and management operations for entities.
|
||||
"""
|
||||
|
||||
_entity_type: Type[Entity[ModelType, ApiType]]
|
||||
|
||||
|
||||
### Utilities
|
||||
|
||||
|
||||
def build_model_bases(
|
||||
mt: Type[ModelType], at: Type[ApiType], *, api_member_name: Optional[str] = None
|
||||
) -> Tuple[Type[Entity[ModelType, ApiType]], Type[Repo[ModelType, ApiType]]]:
|
||||
"""
|
||||
Helps to remove code duplication in declarations of derived classes
|
||||
"""
|
||||
|
||||
class _EntityBase(Entity[ModelType, ApiType]):
|
||||
if api_member_name:
|
||||
_api_member_name = api_member_name
|
||||
|
||||
class _RepoBase(Repo[ModelType, ApiType]):
|
||||
if api_member_name:
|
||||
_api_member_name = api_member_name
|
||||
|
||||
return _EntityBase, _RepoBase
|
||||
|
||||
|
||||
### CRUD mixins
|
||||
|
||||
_EntityT = TypeVar("_EntityT", bound=Entity)
|
||||
|
||||
#### Repo mixins
|
||||
|
||||
|
||||
class ModelCreateMixin(Generic[_EntityT, IModel]):
|
||||
def create(self: Repo, spec: Union[Dict[str, Any], IModel]) -> _EntityT:
|
||||
"""
|
||||
Creates a new object on the server and returns corresponding local object
|
||||
"""
|
||||
|
||||
(model, _) = self.api.create(spec)
|
||||
return self._entity_type(self._client, model)
|
||||
|
||||
|
||||
class ModelRetrieveMixin(Generic[_EntityT]):
|
||||
def retrieve(self: Repo, obj_id: int) -> _EntityT:
|
||||
"""
|
||||
Retrieves an object from server by ID
|
||||
"""
|
||||
|
||||
(model, _) = self.api.retrieve(id=obj_id)
|
||||
return self._entity_type(self._client, model)
|
||||
|
||||
|
||||
class ModelListMixin(Generic[_EntityT]):
|
||||
@overload
|
||||
def list(self: Repo, *, return_json: Literal[False] = False) -> List[_EntityT]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def list(self: Repo, *, return_json: Literal[True] = False) -> List[Any]:
|
||||
...
|
||||
|
||||
def list(self: Repo, *, return_json: bool = False) -> List[Union[_EntityT, Any]]:
|
||||
"""
|
||||
Retrieves all objects from the server and returns them in basic or JSON format.
|
||||
"""
|
||||
|
||||
results = get_paginated_collection(endpoint=self.api.list_endpoint, return_json=return_json)
|
||||
|
||||
if return_json:
|
||||
return json.dumps(results)
|
||||
return [self._entity_type(self._client, model) for model in results]
|
||||
|
||||
|
||||
#### Entity mixins
|
||||
|
||||
|
||||
class ModelUpdateMixin(ABC, Generic[IModel]):
|
||||
@property
|
||||
def _model_partial_update_arg(self: Entity) -> str:
|
||||
...
|
||||
|
||||
def _export_update_fields(
|
||||
self: Entity, overrides: Optional[Union[Dict[str, Any], IModel]] = None
|
||||
) -> Dict[str, Any]:
|
||||
# TODO: support field conversion and assignment updating
|
||||
# fields = to_json(self._model)
|
||||
|
||||
if isinstance(overrides, ModelNormal):
|
||||
overrides = to_json(overrides)
|
||||
fields = deepcopy(overrides)
|
||||
|
||||
return fields
|
||||
|
||||
def fetch(self: Entity) -> Self:
|
||||
"""
|
||||
Updates current object from the server
|
||||
"""
|
||||
|
||||
# TODO: implement revision checking
|
||||
(self._model, _) = self.api.retrieve(id=getattr(self, self._model_id_field))
|
||||
return self
|
||||
|
||||
def update(self: Entity, values: Union[Dict[str, Any], IModel]) -> Self:
|
||||
"""
|
||||
Commits local model changes to the server
|
||||
"""
|
||||
|
||||
# TODO: implement revision checking
|
||||
self.api.partial_update(
|
||||
id=getattr(self, self._model_id_field),
|
||||
**{self._model_partial_update_arg: self._export_update_fields(values)},
|
||||
)
|
||||
|
||||
# TODO: use the response model, once input and output models are same
|
||||
return self.fetch()
|
||||
|
||||
|
||||
class ModelDeleteMixin:
|
||||
def remove(self: Entity) -> None:
|
||||
"""
|
||||
Removes current object on the server
|
||||
"""
|
||||
|
||||
self.api.destroy(id=getattr(self, self._model_id_field))
|
||||
@ -0,0 +1,187 @@
|
||||
# Copyright (C) 2022 CVAT.ai Corporation
|
||||
#
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os.path as osp
|
||||
from typing import Optional
|
||||
|
||||
from cvat_sdk.api_client import apis, models
|
||||
from cvat_sdk.core.downloading import Downloader
|
||||
from cvat_sdk.core.progress import ProgressReporter
|
||||
from cvat_sdk.core.proxies.model_proxy import (
|
||||
ModelCreateMixin,
|
||||
ModelDeleteMixin,
|
||||
ModelListMixin,
|
||||
ModelRetrieveMixin,
|
||||
ModelUpdateMixin,
|
||||
build_model_bases,
|
||||
)
|
||||
from cvat_sdk.core.uploading import DatasetUploader, Uploader
|
||||
|
||||
_ProjectEntityBase, _ProjectRepoBase = build_model_bases(
|
||||
models.ProjectRead, apis.ProjectsApi, api_member_name="projects_api"
|
||||
)
|
||||
|
||||
|
||||
class Project(
|
||||
_ProjectEntityBase, models.IProjectRead, ModelUpdateMixin[models.IPatchedProjectWriteRequest]
|
||||
):
|
||||
_model_partial_update_arg = "patched_project_write_request"
|
||||
|
||||
def import_dataset(
|
||||
self,
|
||||
format_name: str,
|
||||
filename: str,
|
||||
*,
|
||||
status_check_period: Optional[int] = None,
|
||||
pbar: Optional[ProgressReporter] = None,
|
||||
):
|
||||
"""
|
||||
Import dataset for a project in the specified format (e.g. 'YOLO ZIP 1.0').
|
||||
"""
|
||||
|
||||
DatasetUploader(self._client).upload_file_and_wait(
|
||||
self.api.create_dataset_endpoint,
|
||||
filename,
|
||||
format_name,
|
||||
url_params={"id": self.id},
|
||||
pbar=pbar,
|
||||
status_check_period=status_check_period,
|
||||
)
|
||||
|
||||
self._client.logger.info(f"Annotation file '{filename}' for project #{self.id} uploaded")
|
||||
|
||||
def export_dataset(
|
||||
self,
|
||||
format_name: str,
|
||||
filename: str,
|
||||
*,
|
||||
pbar: Optional[ProgressReporter] = None,
|
||||
status_check_period: Optional[int] = None,
|
||||
include_images: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Download annotations for a project in the specified format (e.g. 'YOLO ZIP 1.0').
|
||||
"""
|
||||
if include_images:
|
||||
endpoint = self.api.retrieve_dataset_endpoint
|
||||
else:
|
||||
endpoint = self.api.retrieve_annotations_endpoint
|
||||
|
||||
Downloader(self._client).prepare_and_download_file_from_endpoint(
|
||||
endpoint=endpoint,
|
||||
filename=filename,
|
||||
url_params={"id": self.id},
|
||||
query_params={"format": format_name},
|
||||
pbar=pbar,
|
||||
status_check_period=status_check_period,
|
||||
)
|
||||
|
||||
self._client.logger.info(f"Dataset for project {self.id} has been downloaded to {filename}")
|
||||
|
||||
def download_backup(
|
||||
self,
|
||||
filename: str,
|
||||
*,
|
||||
status_check_period: int = None,
|
||||
pbar: Optional[ProgressReporter] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Download a project backup
|
||||
"""
|
||||
|
||||
Downloader(self._client).prepare_and_download_file_from_endpoint(
|
||||
self.api.retrieve_backup_endpoint,
|
||||
filename=filename,
|
||||
pbar=pbar,
|
||||
status_check_period=status_check_period,
|
||||
url_params={"id": self.id},
|
||||
)
|
||||
|
||||
self._client.logger.info(f"Backup for project {self.id} has been downloaded to {filename}")
|
||||
|
||||
def get_annotations(self) -> models.ILabeledData:
|
||||
(annotations, _) = self.api.retrieve_annotations(self.id)
|
||||
return annotations
|
||||
|
||||
|
||||
class ProjectsRepo(
|
||||
_ProjectRepoBase,
|
||||
ModelCreateMixin[Project, models.IProjectWriteRequest],
|
||||
ModelListMixin[Project],
|
||||
ModelRetrieveMixin[Project],
|
||||
ModelDeleteMixin,
|
||||
):
|
||||
_entity_type = Project
|
||||
|
||||
def create_from_dataset(
|
||||
self,
|
||||
spec: models.IProjectWriteRequest,
|
||||
*,
|
||||
dataset_path: str = "",
|
||||
dataset_format: str = "CVAT XML 1.1",
|
||||
status_check_period: int = None,
|
||||
pbar: Optional[ProgressReporter] = None,
|
||||
) -> Project:
|
||||
"""
|
||||
Create a new project with the given name and labels JSON and
|
||||
add the files to it.
|
||||
|
||||
Returns: id of the created project
|
||||
"""
|
||||
project = self.create(spec=spec)
|
||||
self._client.logger.info("Created project ID: %s NAME: %s", project.id, project.name)
|
||||
|
||||
if dataset_path:
|
||||
project.import_dataset(
|
||||
format_name=dataset_format,
|
||||
filename=dataset_path,
|
||||
pbar=pbar,
|
||||
status_check_period=status_check_period,
|
||||
)
|
||||
|
||||
project.fetch()
|
||||
return project
|
||||
|
||||
def create_from_backup(
|
||||
self,
|
||||
filename: str,
|
||||
*,
|
||||
status_check_period: int = None,
|
||||
pbar: Optional[ProgressReporter] = None,
|
||||
) -> Project:
|
||||
"""
|
||||
Import a project from a backup file
|
||||
"""
|
||||
if status_check_period is None:
|
||||
status_check_period = self.config.status_check_period
|
||||
|
||||
params = {"filename": osp.basename(filename)}
|
||||
url = self.api_map.make_endpoint_url(self.api.create_backup_endpoint.path)
|
||||
|
||||
uploader = Uploader(self)
|
||||
response = uploader.upload_file(
|
||||
url,
|
||||
filename,
|
||||
meta=params,
|
||||
query_params=params,
|
||||
pbar=pbar,
|
||||
logger=self._client.logger.debug,
|
||||
)
|
||||
|
||||
rq_id = json.loads(response.data)["rq_id"]
|
||||
response = self._client.wait_for_completion(
|
||||
url,
|
||||
success_status=201,
|
||||
positive_statuses=[202],
|
||||
post_params={"rq_id": rq_id},
|
||||
status_check_period=status_check_period,
|
||||
)
|
||||
|
||||
project_id = json.loads(response.data)["id"]
|
||||
self._client.logger.info(f"Project has been imported sucessfully. Project ID: {project_id}")
|
||||
|
||||
return self.retrieve(project_id)
|
||||
@ -0,0 +1,388 @@
|
||||
# Copyright (C) 2022 CVAT.ai Corporation
|
||||
#
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import json
|
||||
import mimetypes
|
||||
import os
|
||||
import os.path as osp
|
||||
from enum import Enum
|
||||
from time import sleep
|
||||
from typing import Any, Dict, List, Optional, Sequence
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from cvat_sdk.api_client import apis, exceptions, models
|
||||
from cvat_sdk.core import git
|
||||
from cvat_sdk.core.downloading import Downloader
|
||||
from cvat_sdk.core.progress import ProgressReporter
|
||||
from cvat_sdk.core.proxies.annotations import AnnotationCrudMixin
|
||||
from cvat_sdk.core.proxies.jobs import Job
|
||||
from cvat_sdk.core.proxies.model_proxy import (
|
||||
ModelCreateMixin,
|
||||
ModelDeleteMixin,
|
||||
ModelListMixin,
|
||||
ModelRetrieveMixin,
|
||||
ModelUpdateMixin,
|
||||
build_model_bases,
|
||||
)
|
||||
from cvat_sdk.core.uploading import AnnotationUploader, DataUploader, Uploader
|
||||
from cvat_sdk.core.utils import filter_dict
|
||||
|
||||
|
||||
class ResourceType(Enum):
|
||||
LOCAL = 0
|
||||
SHARE = 1
|
||||
REMOTE = 2
|
||||
|
||||
def __str__(self):
|
||||
return self.name.lower()
|
||||
|
||||
def __repr__(self):
|
||||
return str(self)
|
||||
|
||||
|
||||
_TaskEntityBase, _TaskRepoBase = build_model_bases(
|
||||
models.TaskRead, apis.TasksApi, api_member_name="tasks_api"
|
||||
)
|
||||
|
||||
|
||||
class Task(
|
||||
_TaskEntityBase,
|
||||
models.ITaskRead,
|
||||
ModelUpdateMixin[models.IPatchedTaskWriteRequest],
|
||||
ModelDeleteMixin,
|
||||
AnnotationCrudMixin,
|
||||
):
|
||||
_model_partial_update_arg = "patched_task_write_request"
|
||||
_put_annotations_data_param = "task_annotations_update_request"
|
||||
|
||||
def upload_data(
|
||||
self,
|
||||
resource_type: ResourceType,
|
||||
resources: Sequence[str],
|
||||
*,
|
||||
pbar: Optional[ProgressReporter] = None,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Add local, remote, or shared files to an existing task.
|
||||
"""
|
||||
params = params or {}
|
||||
|
||||
data = {}
|
||||
if resource_type is ResourceType.LOCAL:
|
||||
pass # handled later
|
||||
elif resource_type is ResourceType.REMOTE:
|
||||
data = {f"remote_files[{i}]": f for i, f in enumerate(resources)}
|
||||
elif resource_type is ResourceType.SHARE:
|
||||
data = {f"server_files[{i}]": f for i, f in enumerate(resources)}
|
||||
|
||||
data["image_quality"] = 70
|
||||
data.update(
|
||||
filter_dict(
|
||||
params,
|
||||
keep=[
|
||||
"chunk_size",
|
||||
"copy_data",
|
||||
"image_quality",
|
||||
"sorting_method",
|
||||
"start_frame",
|
||||
"stop_frame",
|
||||
"use_cache",
|
||||
"use_zip_chunks",
|
||||
],
|
||||
)
|
||||
)
|
||||
if params.get("frame_step") is not None:
|
||||
data["frame_filter"] = f"step={params.get('frame_step')}"
|
||||
|
||||
if resource_type in [ResourceType.REMOTE, ResourceType.SHARE]:
|
||||
self.api.create_data(
|
||||
self.id,
|
||||
data_request=models.DataRequest(**data),
|
||||
_content_type="multipart/form-data",
|
||||
)
|
||||
elif resource_type == ResourceType.LOCAL:
|
||||
url = self._client.api_map.make_endpoint_url(
|
||||
self.api.create_data_endpoint.path, kwsub={"id": self.id}
|
||||
)
|
||||
|
||||
DataUploader(self._client).upload_files(url, resources, pbar=pbar, **data)
|
||||
|
||||
def import_annotations(
|
||||
self,
|
||||
format_name: str,
|
||||
filename: str,
|
||||
*,
|
||||
status_check_period: Optional[int] = None,
|
||||
pbar: Optional[ProgressReporter] = None,
|
||||
):
|
||||
"""
|
||||
Upload annotations for a task in the specified format (e.g. 'YOLO ZIP 1.0').
|
||||
"""
|
||||
|
||||
AnnotationUploader(self._client).upload_file_and_wait(
|
||||
self.api.create_annotations_endpoint,
|
||||
filename,
|
||||
format_name,
|
||||
url_params={"id": self.id},
|
||||
pbar=pbar,
|
||||
status_check_period=status_check_period,
|
||||
)
|
||||
|
||||
self._client.logger.info(f"Annotation file '{filename}' for task #{self.id} uploaded")
|
||||
|
||||
def get_frame(
|
||||
self,
|
||||
frame_id: int,
|
||||
*,
|
||||
quality: Optional[str] = None,
|
||||
) -> io.RawIOBase:
|
||||
params = {}
|
||||
if quality:
|
||||
params["quality"] = quality
|
||||
(_, response) = self.api.retrieve_data(self.id, number=frame_id, **params, type="frame")
|
||||
return io.BytesIO(response.data)
|
||||
|
||||
def get_preview(
|
||||
self,
|
||||
) -> io.RawIOBase:
|
||||
(_, response) = self.api.retrieve_data(self.id, type="preview")
|
||||
return io.BytesIO(response.data)
|
||||
|
||||
def download_frames(
|
||||
self,
|
||||
frame_ids: Sequence[int],
|
||||
*,
|
||||
outdir: str = "",
|
||||
quality: str = "original",
|
||||
filename_pattern: str = "frame_{frame_id:06d}{frame_ext}",
|
||||
) -> Optional[List[Image.Image]]:
|
||||
"""
|
||||
Download the requested frame numbers for a task and save images as outdir/filename_pattern
|
||||
"""
|
||||
# TODO: add arg descriptions in schema
|
||||
os.makedirs(outdir, exist_ok=True)
|
||||
|
||||
for frame_id in frame_ids:
|
||||
frame_bytes = self.get_frame(frame_id, quality=quality)
|
||||
|
||||
im = Image.open(frame_bytes)
|
||||
mime_type = im.get_format_mimetype() or "image/jpg"
|
||||
im_ext = mimetypes.guess_extension(mime_type)
|
||||
|
||||
# FIXME It is better to use meta information from the server
|
||||
# to determine the extension
|
||||
# replace '.jpe' or '.jpeg' with a more used '.jpg'
|
||||
if im_ext in (".jpe", ".jpeg", None):
|
||||
im_ext = ".jpg"
|
||||
|
||||
outfile = filename_pattern.format(frame_id=frame_id, frame_ext=im_ext)
|
||||
im.save(osp.join(outdir, outfile))
|
||||
|
||||
def export_dataset(
|
||||
self,
|
||||
format_name: str,
|
||||
filename: str,
|
||||
*,
|
||||
pbar: Optional[ProgressReporter] = None,
|
||||
status_check_period: Optional[int] = None,
|
||||
include_images: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Download annotations for a task in the specified format (e.g. 'YOLO ZIP 1.0').
|
||||
"""
|
||||
if include_images:
|
||||
endpoint = self.api.retrieve_dataset_endpoint
|
||||
else:
|
||||
endpoint = self.api.retrieve_annotations_endpoint
|
||||
|
||||
Downloader(self._client).prepare_and_download_file_from_endpoint(
|
||||
endpoint=endpoint,
|
||||
filename=filename,
|
||||
url_params={"id": self.id},
|
||||
query_params={"format": format_name},
|
||||
pbar=pbar,
|
||||
status_check_period=status_check_period,
|
||||
)
|
||||
|
||||
self._client.logger.info(f"Dataset for task {self.id} has been downloaded to {filename}")
|
||||
|
||||
def download_backup(
|
||||
self,
|
||||
filename: str,
|
||||
*,
|
||||
status_check_period: int = None,
|
||||
pbar: Optional[ProgressReporter] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Download a task backup
|
||||
"""
|
||||
|
||||
Downloader(self._client).prepare_and_download_file_from_endpoint(
|
||||
self.api.retrieve_backup_endpoint,
|
||||
filename=filename,
|
||||
pbar=pbar,
|
||||
status_check_period=status_check_period,
|
||||
url_params={"id": self.id},
|
||||
)
|
||||
|
||||
self._client.logger.info(f"Backup for task {self.id} has been downloaded to {filename}")
|
||||
|
||||
def get_jobs(self) -> List[Job]:
|
||||
return [Job(self._client, m) for m in self.api.list_jobs(id=self.id)[0]]
|
||||
|
||||
def get_meta(self) -> models.IDataMetaRead:
|
||||
(meta, _) = self.api.retrieve_data_meta(self.id)
|
||||
return meta
|
||||
|
||||
def get_frames_info(self) -> List[models.IFrameMeta]:
|
||||
return self.get_meta().frames
|
||||
|
||||
def remove_frames_by_ids(self, ids: Sequence[int]) -> None:
|
||||
self.api.partial_update_data_meta(
|
||||
self.id,
|
||||
patched_data_meta_write_request=models.PatchedDataMetaWriteRequest(deleted_frames=ids),
|
||||
)
|
||||
|
||||
|
||||
class TasksRepo(
|
||||
_TaskRepoBase,
|
||||
ModelCreateMixin[Task, models.ITaskWriteRequest],
|
||||
ModelRetrieveMixin[Task],
|
||||
ModelListMixin[Task],
|
||||
ModelDeleteMixin,
|
||||
):
|
||||
_entity_type = Task
|
||||
|
||||
def create_from_data(
|
||||
self,
|
||||
spec: models.ITaskWriteRequest,
|
||||
resource_type: ResourceType,
|
||||
resources: Sequence[str],
|
||||
*,
|
||||
data_params: Optional[Dict[str, Any]] = None,
|
||||
annotation_path: str = "",
|
||||
annotation_format: str = "CVAT XML 1.1",
|
||||
status_check_period: int = None,
|
||||
dataset_repository_url: str = "",
|
||||
use_lfs: bool = False,
|
||||
pbar: Optional[ProgressReporter] = None,
|
||||
) -> Task:
|
||||
"""
|
||||
Create a new task with the given name and labels JSON and
|
||||
add the files to it.
|
||||
|
||||
Returns: id of the created task
|
||||
"""
|
||||
if status_check_period is None:
|
||||
status_check_period = self._client.config.status_check_period
|
||||
|
||||
if getattr(spec, "project_id", None) and getattr(spec, "labels", None):
|
||||
raise exceptions.ApiValueError(
|
||||
"Can't set labels to a task inside a project. "
|
||||
"Tasks inside a project use project's labels.",
|
||||
["labels"],
|
||||
)
|
||||
|
||||
task = self.create(spec=spec)
|
||||
self._client.logger.info("Created task ID: %s NAME: %s", task.id, task.name)
|
||||
|
||||
task.upload_data(resource_type, resources, pbar=pbar, params=data_params)
|
||||
|
||||
self._client.logger.info("Awaiting for task %s creation...", task.id)
|
||||
status: models.RqStatus = None
|
||||
while status != models.RqStatusStateEnum.allowed_values[("value",)]["FINISHED"]:
|
||||
sleep(status_check_period)
|
||||
(status, response) = self.api.retrieve_status(task.id)
|
||||
|
||||
self._client.logger.info(
|
||||
"Task %s creation status=%s, message=%s",
|
||||
task.id,
|
||||
status.state.value,
|
||||
status.message,
|
||||
)
|
||||
|
||||
if status.state.value == models.RqStatusStateEnum.allowed_values[("value",)]["FAILED"]:
|
||||
raise exceptions.ApiException(
|
||||
status=status.state.value, reason=status.message, http_resp=response
|
||||
)
|
||||
|
||||
status = status.state.value
|
||||
|
||||
if annotation_path:
|
||||
task.import_annotations(annotation_format, annotation_path, pbar=pbar)
|
||||
|
||||
if dataset_repository_url:
|
||||
git.create_git_repo(
|
||||
self,
|
||||
task_id=task.id,
|
||||
repo_url=dataset_repository_url,
|
||||
status_check_period=status_check_period,
|
||||
use_lfs=use_lfs,
|
||||
)
|
||||
|
||||
task.fetch()
|
||||
|
||||
return task
|
||||
|
||||
def remove_by_ids(self, task_ids: Sequence[int]) -> None:
|
||||
"""
|
||||
Delete a list of tasks, ignoring those which don't exist.
|
||||
"""
|
||||
|
||||
for task_id in task_ids:
|
||||
(_, response) = self.api.destroy(task_id, _check_status=False)
|
||||
|
||||
if 200 <= response.status <= 299:
|
||||
self._client.logger.info(f"Task ID {task_id} deleted")
|
||||
elif response.status == 404:
|
||||
self._client.logger.info(f"Task ID {task_id} not found")
|
||||
else:
|
||||
self._client.logger.warning(
|
||||
f"Failed to delete task ID {task_id}: "
|
||||
f"{response.msg} (status {response.status})"
|
||||
)
|
||||
|
||||
def create_from_backup(
|
||||
self,
|
||||
filename: str,
|
||||
*,
|
||||
status_check_period: int = None,
|
||||
pbar: Optional[ProgressReporter] = None,
|
||||
) -> Task:
|
||||
"""
|
||||
Import a task from a backup file
|
||||
"""
|
||||
if status_check_period is None:
|
||||
status_check_period = self._client.config.status_check_period
|
||||
|
||||
params = {"filename": osp.basename(filename)}
|
||||
url = self._client.api_map.make_endpoint_url(self.api.create_backup_endpoint.path)
|
||||
uploader = Uploader(self._client)
|
||||
response = uploader.upload_file(
|
||||
url,
|
||||
filename,
|
||||
meta=params,
|
||||
query_params=params,
|
||||
pbar=pbar,
|
||||
logger=self._client.logger.debug,
|
||||
)
|
||||
|
||||
rq_id = json.loads(response.data)["rq_id"]
|
||||
response = self._client.wait_for_completion(
|
||||
url,
|
||||
success_status=201,
|
||||
positive_statuses=[202],
|
||||
post_params={"rq_id": rq_id},
|
||||
status_check_period=status_check_period,
|
||||
)
|
||||
|
||||
task_id = json.loads(response.data)["id"]
|
||||
self._client.logger.info(f"Task has been imported sucessfully. Task ID: {task_id}")
|
||||
|
||||
return self.retrieve(task_id)
|
||||
@ -0,0 +1,35 @@
|
||||
# Copyright (C) 2022 CVAT.ai Corporation
|
||||
#
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from cvat_sdk.api_client import apis, models
|
||||
from cvat_sdk.core.proxies.model_proxy import (
|
||||
ModelDeleteMixin,
|
||||
ModelListMixin,
|
||||
ModelRetrieveMixin,
|
||||
ModelUpdateMixin,
|
||||
build_model_bases,
|
||||
)
|
||||
|
||||
_UserEntityBase, _UserRepoBase = build_model_bases(
|
||||
models.User, apis.UsersApi, api_member_name="users_api"
|
||||
)
|
||||
|
||||
|
||||
class User(
|
||||
models.IUser, _UserEntityBase, ModelUpdateMixin[models.IPatchedUserRequest], ModelDeleteMixin
|
||||
):
|
||||
_model_partial_update_arg = "patched_user_request"
|
||||
|
||||
|
||||
class UsersRepo(
|
||||
_UserRepoBase,
|
||||
ModelListMixin[User],
|
||||
ModelRetrieveMixin[User],
|
||||
):
|
||||
_entity_type = User
|
||||
|
||||
def retrieve_current_user(self) -> User:
|
||||
return User(self._client, self.api.retrieve_self()[0])
|
||||
@ -1,308 +0,0 @@
|
||||
# Copyright (C) 2020-2022 Intel Corporation
|
||||
# Copyright (C) 2022 CVAT.ai Corporation
|
||||
#
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import mimetypes
|
||||
import os
|
||||
import os.path as osp
|
||||
from abc import ABC, abstractmethod
|
||||
from io import BytesIO
|
||||
from time import sleep
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from cvat_sdk import models
|
||||
from cvat_sdk.api_client.model_utils import OpenApiModel
|
||||
from cvat_sdk.core.downloading import Downloader
|
||||
from cvat_sdk.core.progress import ProgressReporter
|
||||
from cvat_sdk.core.types import ResourceType
|
||||
from cvat_sdk.core.uploading import Uploader
|
||||
from cvat_sdk.core.utils import filter_dict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from cvat_sdk.core.client import Client
|
||||
|
||||
|
||||
class ModelProxy(ABC):
|
||||
_client: Client
|
||||
_model: OpenApiModel
|
||||
|
||||
def __init__(self, client: Client, model: OpenApiModel) -> None:
|
||||
self.__dict__["_client"] = client
|
||||
self.__dict__["_model"] = model
|
||||
|
||||
def __getattr__(self, __name: str) -> Any:
|
||||
return self._model[__name]
|
||||
|
||||
def __setattr__(self, __name: str, __value: Any) -> None:
|
||||
if __name in self.__dict__:
|
||||
self.__dict__[__name] = __value
|
||||
else:
|
||||
self._model[__name] = __value
|
||||
|
||||
@abstractmethod
|
||||
def fetch(self, force: bool = False):
|
||||
"""Fetches model data from the server"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def commit(self, force: bool = False):
|
||||
"""Commits local changes to the server"""
|
||||
...
|
||||
|
||||
def sync(self):
|
||||
"""Pulls server state and commits local model changes"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def update(self, **kwargs):
|
||||
"""Updates multiple fields at once"""
|
||||
...
|
||||
|
||||
|
||||
class TaskProxy(ModelProxy, models.ITaskRead):
|
||||
def __init__(self, client: Client, task: models.TaskRead):
|
||||
ModelProxy.__init__(self, client=client, model=task)
|
||||
|
||||
def remove(self):
|
||||
self._client.api.tasks_api.destroy(self.id)
|
||||
|
||||
def upload_data(
|
||||
self,
|
||||
resource_type: ResourceType,
|
||||
resources: Sequence[str],
|
||||
*,
|
||||
pbar: Optional[ProgressReporter] = None,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Add local, remote, or shared files to an existing task.
|
||||
"""
|
||||
client = self._client
|
||||
task_id = self.id
|
||||
|
||||
params = params or {}
|
||||
|
||||
data = {}
|
||||
if resource_type is ResourceType.LOCAL:
|
||||
pass # handled later
|
||||
elif resource_type is ResourceType.REMOTE:
|
||||
data = {f"remote_files[{i}]": f for i, f in enumerate(resources)}
|
||||
elif resource_type is ResourceType.SHARE:
|
||||
data = {f"server_files[{i}]": f for i, f in enumerate(resources)}
|
||||
|
||||
data["image_quality"] = 70
|
||||
data.update(
|
||||
filter_dict(
|
||||
params,
|
||||
keep=[
|
||||
"chunk_size",
|
||||
"copy_data",
|
||||
"image_quality",
|
||||
"sorting_method",
|
||||
"start_frame",
|
||||
"stop_frame",
|
||||
"use_cache",
|
||||
"use_zip_chunks",
|
||||
],
|
||||
)
|
||||
)
|
||||
if params.get("frame_step") is not None:
|
||||
data["frame_filter"] = f"step={params.get('frame_step')}"
|
||||
|
||||
if resource_type in [ResourceType.REMOTE, ResourceType.SHARE]:
|
||||
client.api.tasks_api.create_data(
|
||||
task_id,
|
||||
data_request=models.DataRequest(**data),
|
||||
_content_type="multipart/form-data",
|
||||
)
|
||||
elif resource_type == ResourceType.LOCAL:
|
||||
url = client._api_map.make_endpoint_url(
|
||||
client.api.tasks_api.create_data_endpoint.path, kwsub={"id": task_id}
|
||||
)
|
||||
|
||||
uploader = Uploader(client)
|
||||
uploader.upload_files(url, resources, pbar=pbar, **data)
|
||||
|
||||
def import_annotations(
|
||||
self,
|
||||
format_name: str,
|
||||
filename: str,
|
||||
*,
|
||||
status_check_period: int = None,
|
||||
pbar: Optional[ProgressReporter] = None,
|
||||
):
|
||||
"""
|
||||
Upload annotations for a task in the specified format
|
||||
(e.g. 'YOLO ZIP 1.0').
|
||||
"""
|
||||
client = self._client
|
||||
if status_check_period is None:
|
||||
status_check_period = client.config.status_check_period
|
||||
|
||||
task_id = self.id
|
||||
|
||||
url = client._api_map.make_endpoint_url(
|
||||
client.api.tasks_api.create_annotations_endpoint.path,
|
||||
kwsub={"id": task_id},
|
||||
)
|
||||
params = {"format": format_name, "filename": osp.basename(filename)}
|
||||
|
||||
uploader = Uploader(client)
|
||||
uploader.upload_file(
|
||||
url, filename, pbar=pbar, query_params=params, meta={"filename": params["filename"]}
|
||||
)
|
||||
|
||||
while True:
|
||||
response = client.api.rest_client.POST(
|
||||
url, headers=client.api.get_common_headers(), query_params=params
|
||||
)
|
||||
if response.status == 201:
|
||||
break
|
||||
|
||||
sleep(status_check_period)
|
||||
|
||||
client.logger.info(
|
||||
f"Upload job for Task ID {task_id} with annotation file {filename} finished"
|
||||
)
|
||||
|
||||
def retrieve_frame(
|
||||
self,
|
||||
frame_id: int,
|
||||
*,
|
||||
quality: Optional[str] = None,
|
||||
) -> io.RawIOBase:
|
||||
client = self._client
|
||||
task_id = self.id
|
||||
|
||||
(_, response) = client.api.tasks_api.retrieve_data(task_id, frame_id, quality, type="frame")
|
||||
|
||||
return BytesIO(response.data)
|
||||
|
||||
def download_frames(
|
||||
self,
|
||||
frame_ids: Sequence[int],
|
||||
*,
|
||||
outdir: str = "",
|
||||
quality: str = "original",
|
||||
filename_pattern: str = "task_{task_id}_frame_{frame_id:06d}{frame_ext}",
|
||||
) -> Optional[List[Image.Image]]:
|
||||
"""
|
||||
Download the requested frame numbers for a task and save images as
|
||||
outdir/filename_pattern
|
||||
"""
|
||||
# TODO: add arg descriptions in schema
|
||||
task_id = self.id
|
||||
|
||||
os.makedirs(outdir, exist_ok=True)
|
||||
|
||||
for frame_id in frame_ids:
|
||||
frame_bytes = self.retrieve_frame(frame_id, quality=quality)
|
||||
|
||||
im = Image.open(frame_bytes)
|
||||
mime_type = im.get_format_mimetype() or "image/jpg"
|
||||
im_ext = mimetypes.guess_extension(mime_type)
|
||||
|
||||
# FIXME It is better to use meta information from the server
|
||||
# to determine the extension
|
||||
# replace '.jpe' or '.jpeg' with a more used '.jpg'
|
||||
if im_ext in (".jpe", ".jpeg", None):
|
||||
im_ext = ".jpg"
|
||||
|
||||
outfile = filename_pattern.format(task_id=task_id, frame_id=frame_id, frame_ext=im_ext)
|
||||
im.save(osp.join(outdir, outfile))
|
||||
|
||||
def export_dataset(
|
||||
self,
|
||||
format_name: str,
|
||||
filename: str,
|
||||
*,
|
||||
pbar: Optional[ProgressReporter] = None,
|
||||
status_check_period: int = None,
|
||||
include_images: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Download annotations for a task in the specified format (e.g. 'YOLO ZIP 1.0').
|
||||
"""
|
||||
client = self._client
|
||||
if status_check_period is None:
|
||||
status_check_period = client.config.status_check_period
|
||||
|
||||
task_id = self.id
|
||||
|
||||
params = {"filename": self.name, "format": format_name}
|
||||
if include_images:
|
||||
endpoint = client.api.tasks_api.retrieve_dataset_endpoint
|
||||
else:
|
||||
endpoint = client.api.tasks_api.retrieve_annotations_endpoint
|
||||
|
||||
client.logger.info("Waiting for the server to prepare the file...")
|
||||
while True:
|
||||
(_, response) = endpoint.call_with_http_info(id=task_id, **params)
|
||||
client.logger.debug("STATUS {}".format(response.status))
|
||||
if response.status == 201:
|
||||
break
|
||||
sleep(status_check_period)
|
||||
|
||||
params["action"] = "download"
|
||||
url = client._api_map.make_endpoint_url(
|
||||
endpoint.path, kwsub={"id": task_id}, query_params=params
|
||||
)
|
||||
downloader = Downloader(client)
|
||||
downloader.download_file(url, output_path=filename, pbar=pbar)
|
||||
|
||||
client.logger.info(f"Dataset has been exported to {filename}")
|
||||
|
||||
def download_backup(
|
||||
self,
|
||||
filename: str,
|
||||
*,
|
||||
status_check_period: int = None,
|
||||
pbar: Optional[ProgressReporter] = None,
|
||||
):
|
||||
"""
|
||||
Download a task backup
|
||||
"""
|
||||
client = self._client
|
||||
if status_check_period is None:
|
||||
status_check_period = client.config.status_check_period
|
||||
|
||||
task_id = self.id
|
||||
|
||||
endpoint = client.api.tasks_api.retrieve_backup_endpoint
|
||||
client.logger.info("Waiting for the server to prepare the file...")
|
||||
while True:
|
||||
(_, response) = endpoint.call_with_http_info(id=task_id)
|
||||
client.logger.debug("STATUS {}".format(response.status))
|
||||
if response.status == 201:
|
||||
break
|
||||
sleep(status_check_period)
|
||||
|
||||
url = client._api_map.make_endpoint_url(
|
||||
endpoint.path, kwsub={"id": task_id}, query_params={"action": "download"}
|
||||
)
|
||||
downloader = Downloader(client)
|
||||
downloader.download_file(url, output_path=filename, pbar=pbar)
|
||||
|
||||
client.logger.info(
|
||||
f"Task {task_id} has been exported sucessfully to {osp.abspath(filename)}"
|
||||
)
|
||||
|
||||
def fetch(self, force: bool = False):
|
||||
# TODO: implement revision checking
|
||||
model, _ = self._client.api.tasks_api.retrieve(self.id)
|
||||
self._model = model
|
||||
|
||||
def commit(self, force: bool = False):
|
||||
return super().commit(force)
|
||||
|
||||
def update(self, **kwargs):
|
||||
return super().update(**kwargs)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return str(self._model)
|
||||
@ -1,18 +0,0 @@
|
||||
# Copyright (C) 2022 Intel Corporation
|
||||
# Copyright (C) 2022 CVAT.ai Corporation
|
||||
#
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class ResourceType(Enum):
|
||||
LOCAL = 0
|
||||
SHARE = 1
|
||||
REMOTE = 2
|
||||
|
||||
def __str__(self):
|
||||
return self.name.lower()
|
||||
|
||||
def __repr__(self):
|
||||
return str(self)
|
||||
@ -0,0 +1,26 @@
|
||||
# Copyright (C) 2022 CVAT.ai Corporation
|
||||
#
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from http import HTTPStatus
|
||||
from time import sleep
|
||||
|
||||
from cvat_sdk.api_client.api_client import Endpoint
|
||||
from urllib3 import HTTPResponse
|
||||
|
||||
|
||||
def export_dataset(
|
||||
endpoint: Endpoint, *, max_retries: int = 20, interval: float = 0.1, **kwargs
|
||||
) -> HTTPResponse:
|
||||
for _ in range(max_retries):
|
||||
(_, response) = endpoint.call_with_http_info(**kwargs, _parse_response=False)
|
||||
if response.status == HTTPStatus.CREATED:
|
||||
break
|
||||
assert response.status == HTTPStatus.ACCEPTED
|
||||
sleep(interval)
|
||||
assert response.status == HTTPStatus.CREATED
|
||||
|
||||
(_, response) = endpoint.call_with_http_info(**kwargs, action="download", _parse_response=False)
|
||||
assert response.status == HTTPStatus.OK
|
||||
|
||||
return response
|
||||
@ -0,0 +1,236 @@
|
||||
# Copyright (C) 2022 CVAT.ai Corporation
|
||||
#
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import io
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Tuple
|
||||
|
||||
import pytest
|
||||
from cvat_sdk import Client
|
||||
from cvat_sdk.api_client import exceptions, models
|
||||
from cvat_sdk.core.proxies.tasks import ResourceType, Task
|
||||
|
||||
from shared.utils.config import USER_PASS
|
||||
|
||||
|
||||
class TestIssuesUsecases:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup(
|
||||
self,
|
||||
changedb, # force fixture call order to allow DB setup
|
||||
tmp_path: Path,
|
||||
fxt_logger: Tuple[Logger, io.StringIO],
|
||||
fxt_client: Client,
|
||||
fxt_stdout: io.StringIO,
|
||||
admin_user: str,
|
||||
):
|
||||
self.tmp_path = tmp_path
|
||||
_, self.logger_stream = fxt_logger
|
||||
self.client = fxt_client
|
||||
self.stdout = fxt_stdout
|
||||
self.user = admin_user
|
||||
self.client.login((self.user, USER_PASS))
|
||||
|
||||
yield
|
||||
|
||||
@pytest.fixture
|
||||
def fxt_new_task(self, fxt_image_file: Path):
|
||||
task = self.client.tasks.create_from_data(
|
||||
spec={
|
||||
"name": "test_task",
|
||||
"labels": [{"name": "car"}, {"name": "person"}],
|
||||
},
|
||||
resource_type=ResourceType.LOCAL,
|
||||
resources=[str(fxt_image_file)],
|
||||
data_params={"image_quality": 80},
|
||||
)
|
||||
|
||||
return task
|
||||
|
||||
def test_can_retrieve_issue(self, fxt_new_task: Task):
|
||||
issue = self.client.issues.create(
|
||||
models.IssueWriteRequest(
|
||||
frame=0,
|
||||
position=[2.0, 4.0],
|
||||
job=fxt_new_task.get_jobs()[0].id,
|
||||
message="hello",
|
||||
)
|
||||
)
|
||||
|
||||
retrieved_issue = self.client.issues.retrieve(issue.id)
|
||||
|
||||
assert issue.id == retrieved_issue.id
|
||||
assert self.stdout.getvalue() == ""
|
||||
|
||||
def test_can_list_issues(self, fxt_new_task: Task):
|
||||
issue = self.client.issues.create(
|
||||
models.IssueWriteRequest(
|
||||
frame=0,
|
||||
position=[2.0, 4.0],
|
||||
job=fxt_new_task.get_jobs()[0].id,
|
||||
message="hello",
|
||||
assignee=self.client.users.list()[0].id,
|
||||
)
|
||||
)
|
||||
|
||||
issues = self.client.issues.list()
|
||||
|
||||
assert any(issue.id == j.id for j in issues)
|
||||
assert self.stdout.getvalue() == ""
|
||||
|
||||
def test_can_list_comments(self, fxt_new_task: Task):
|
||||
issue = self.client.issues.create(
|
||||
models.IssueWriteRequest(
|
||||
frame=0,
|
||||
position=[2.0, 4.0],
|
||||
job=fxt_new_task.get_jobs()[0].id,
|
||||
message="hello",
|
||||
)
|
||||
)
|
||||
comment = self.client.comments.create(models.CommentWriteRequest(issue.id, message="hi!"))
|
||||
issue.fetch()
|
||||
|
||||
comment_ids = {c.id for c in issue.comments}
|
||||
|
||||
assert len(comment_ids) == 2
|
||||
assert comment.id in comment_ids
|
||||
assert self.stdout.getvalue() == ""
|
||||
|
||||
def test_can_modify_issue(self, fxt_new_task: Task):
|
||||
issue = self.client.issues.create(
|
||||
models.IssueWriteRequest(
|
||||
frame=0,
|
||||
position=[2.0, 4.0],
|
||||
job=fxt_new_task.get_jobs()[0].id,
|
||||
message="hello",
|
||||
)
|
||||
)
|
||||
|
||||
issue.update(models.PatchedIssueWriteRequest(resolved=True))
|
||||
|
||||
retrieved_issue = self.client.issues.retrieve(issue.id)
|
||||
assert retrieved_issue.resolved is True
|
||||
assert issue.resolved == retrieved_issue.resolved
|
||||
assert self.stdout.getvalue() == ""
|
||||
|
||||
def test_can_remove_issue(self, fxt_new_task: Task):
|
||||
issue = self.client.issues.create(
|
||||
models.IssueWriteRequest(
|
||||
frame=0,
|
||||
position=[2.0, 4.0],
|
||||
job=fxt_new_task.get_jobs()[0].id,
|
||||
message="hello",
|
||||
)
|
||||
)
|
||||
|
||||
issue.remove()
|
||||
|
||||
with pytest.raises(exceptions.NotFoundException):
|
||||
issue.fetch()
|
||||
with pytest.raises(exceptions.NotFoundException):
|
||||
self.client.comments.retrieve(issue.comments[0].id)
|
||||
assert self.stdout.getvalue() == ""
|
||||
|
||||
|
||||
class TestCommentsUsecases:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup(
|
||||
self,
|
||||
changedb, # force fixture call order to allow DB setup
|
||||
tmp_path: Path,
|
||||
fxt_logger: Tuple[Logger, io.StringIO],
|
||||
fxt_client: Client,
|
||||
fxt_stdout: io.StringIO,
|
||||
admin_user: str,
|
||||
):
|
||||
self.tmp_path = tmp_path
|
||||
_, self.logger_stream = fxt_logger
|
||||
self.client = fxt_client
|
||||
self.stdout = fxt_stdout
|
||||
self.user = admin_user
|
||||
self.client.login((self.user, USER_PASS))
|
||||
|
||||
yield
|
||||
|
||||
@pytest.fixture
|
||||
def fxt_new_task(self, fxt_image_file: Path):
|
||||
task = self.client.tasks.create_from_data(
|
||||
spec={
|
||||
"name": "test_task",
|
||||
"labels": [{"name": "car"}, {"name": "person"}],
|
||||
},
|
||||
resource_type=ResourceType.LOCAL,
|
||||
resources=[str(fxt_image_file)],
|
||||
data_params={"image_quality": 80},
|
||||
)
|
||||
|
||||
return task
|
||||
|
||||
def test_can_retrieve_comment(self, fxt_new_task: Task):
|
||||
issue = self.client.issues.create(
|
||||
models.IssueWriteRequest(
|
||||
frame=0,
|
||||
position=[2.0, 4.0],
|
||||
job=fxt_new_task.get_jobs()[0].id,
|
||||
message="hello",
|
||||
)
|
||||
)
|
||||
comment = self.client.comments.create(models.CommentWriteRequest(issue.id, message="hi!"))
|
||||
|
||||
retrieved_comment = self.client.comments.retrieve(comment.id)
|
||||
|
||||
assert comment.id == retrieved_comment.id
|
||||
assert self.stdout.getvalue() == ""
|
||||
|
||||
def test_can_list_comments(self, fxt_new_task: Task):
|
||||
issue = self.client.issues.create(
|
||||
models.IssueWriteRequest(
|
||||
frame=0,
|
||||
position=[2.0, 4.0],
|
||||
job=fxt_new_task.get_jobs()[0].id,
|
||||
message="hello",
|
||||
)
|
||||
)
|
||||
comment = self.client.comments.create(models.CommentWriteRequest(issue.id, message="hi!"))
|
||||
|
||||
comments = self.client.comments.list()
|
||||
|
||||
assert any(comment.id == c.id for c in comments)
|
||||
assert self.stdout.getvalue() == ""
|
||||
|
||||
def test_can_modify_comment(self, fxt_new_task: Task):
|
||||
issue = self.client.issues.create(
|
||||
models.IssueWriteRequest(
|
||||
frame=0,
|
||||
position=[2.0, 4.0],
|
||||
job=fxt_new_task.get_jobs()[0].id,
|
||||
message="hello",
|
||||
)
|
||||
)
|
||||
comment = self.client.comments.create(models.CommentWriteRequest(issue.id, message="hi!"))
|
||||
|
||||
comment.update(models.PatchedCommentWriteRequest(message="bar"))
|
||||
|
||||
retrieved_comment = self.client.comments.retrieve(comment.id)
|
||||
assert retrieved_comment.message == "bar"
|
||||
assert comment.message == retrieved_comment.message
|
||||
assert self.stdout.getvalue() == ""
|
||||
|
||||
def test_can_remove_comment(self, fxt_new_task: Task):
|
||||
issue = self.client.issues.create(
|
||||
models.IssueWriteRequest(
|
||||
frame=0,
|
||||
position=[2.0, 4.0],
|
||||
job=fxt_new_task.get_jobs()[0].id,
|
||||
message="hello",
|
||||
)
|
||||
)
|
||||
comment = self.client.comments.create(models.CommentWriteRequest(issue.id, message="hi!"))
|
||||
|
||||
comment.remove()
|
||||
|
||||
with pytest.raises(exceptions.NotFoundException):
|
||||
comment.fetch()
|
||||
assert self.stdout.getvalue() == ""
|
||||
@ -0,0 +1,280 @@
|
||||
# Copyright (C) 2022 CVAT.ai Corporation
|
||||
#
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import io
|
||||
import os.path as osp
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Tuple
|
||||
|
||||
import pytest
|
||||
from cvat_sdk import Client
|
||||
from cvat_sdk.api_client import models
|
||||
from cvat_sdk.core.proxies.tasks import ResourceType, Task
|
||||
from PIL import Image
|
||||
|
||||
from shared.utils.config import USER_PASS
|
||||
|
||||
from .util import make_pbar
|
||||
|
||||
|
||||
class TestJobUsecases:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup(
|
||||
self,
|
||||
changedb, # force fixture call order to allow DB setup
|
||||
tmp_path: Path,
|
||||
fxt_logger: Tuple[Logger, io.StringIO],
|
||||
fxt_client: Client,
|
||||
fxt_stdout: io.StringIO,
|
||||
admin_user: str,
|
||||
):
|
||||
self.tmp_path = tmp_path
|
||||
_, self.logger_stream = fxt_logger
|
||||
self.client = fxt_client
|
||||
self.stdout = fxt_stdout
|
||||
self.user = admin_user
|
||||
self.client.login((self.user, USER_PASS))
|
||||
|
||||
yield
|
||||
|
||||
@pytest.fixture
|
||||
def fxt_new_task(self, fxt_image_file: Path):
|
||||
task = self.client.tasks.create_from_data(
|
||||
spec={
|
||||
"name": "test_task",
|
||||
"labels": [{"name": "car"}, {"name": "person"}],
|
||||
},
|
||||
resource_type=ResourceType.LOCAL,
|
||||
resources=[str(fxt_image_file)],
|
||||
data_params={"image_quality": 80},
|
||||
)
|
||||
|
||||
return task
|
||||
|
||||
@pytest.fixture
|
||||
def fxt_task_with_shapes(self, fxt_new_task: Task):
|
||||
fxt_new_task.set_annotations(
|
||||
models.LabeledDataRequest(
|
||||
shapes=[
|
||||
models.LabeledShapeRequest(
|
||||
frame=0,
|
||||
label_id=fxt_new_task.labels[0].id,
|
||||
type="rectangle",
|
||||
points=[1, 1, 2, 2],
|
||||
),
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
return fxt_new_task
|
||||
|
||||
def test_can_retrieve_job(self, fxt_new_task: Task):
|
||||
job_id = fxt_new_task.get_jobs()[0].id
|
||||
|
||||
job = self.client.jobs.retrieve(job_id)
|
||||
|
||||
assert job.id == job_id
|
||||
assert self.stdout.getvalue() == ""
|
||||
|
||||
def test_can_list_jobs(self, fxt_new_task: Task):
|
||||
task_job_ids = set(j.id for j in fxt_new_task.get_jobs())
|
||||
|
||||
jobs = self.client.jobs.list()
|
||||
|
||||
assert len(task_job_ids) != 0
|
||||
assert task_job_ids.issubset(j.id for j in jobs)
|
||||
assert self.stdout.getvalue() == ""
|
||||
|
||||
def test_can_update_job_field_directly(self, fxt_new_task: Task):
|
||||
job = self.client.jobs.list()[0]
|
||||
assert not job.assignee
|
||||
new_assignee = self.client.users.list()[0]
|
||||
|
||||
job.update({"assignee": new_assignee.id})
|
||||
|
||||
updated_job = self.client.jobs.retrieve(job.id)
|
||||
assert updated_job.assignee.id == new_assignee.id
|
||||
assert self.stdout.getvalue() == ""
|
||||
|
||||
@pytest.mark.parametrize("include_images", (True, False))
|
||||
def test_can_download_dataset(self, fxt_new_task: Task, include_images: bool):
|
||||
pbar_out = io.StringIO()
|
||||
pbar = make_pbar(file=pbar_out)
|
||||
|
||||
task_id = fxt_new_task.id
|
||||
path = str(self.tmp_path / f"task_{task_id}-cvat.zip")
|
||||
job = self.client.jobs.retrieve(task_id)
|
||||
job.export_dataset(
|
||||
format_name="CVAT for images 1.1",
|
||||
filename=path,
|
||||
pbar=pbar,
|
||||
include_images=include_images,
|
||||
)
|
||||
|
||||
assert "100%" in pbar_out.getvalue().strip("\r").split("\r")[-1]
|
||||
assert osp.isfile(path)
|
||||
assert self.stdout.getvalue() == ""
|
||||
|
||||
def test_can_download_preview(self, fxt_new_task: Task):
|
||||
frame_encoded = fxt_new_task.get_jobs()[0].get_preview()
|
||||
|
||||
assert Image.open(frame_encoded).size != 0
|
||||
assert self.stdout.getvalue() == ""
|
||||
|
||||
@pytest.mark.parametrize("quality", ("compressed", "original"))
|
||||
def test_can_download_frame(self, fxt_new_task: Task, quality: str):
|
||||
frame_encoded = fxt_new_task.get_jobs()[0].get_frame(0, quality=quality)
|
||||
|
||||
assert Image.open(frame_encoded).size != 0
|
||||
assert self.stdout.getvalue() == ""
|
||||
|
||||
@pytest.mark.parametrize("quality", ("compressed", "original"))
|
||||
def test_can_download_frames(self, fxt_new_task: Task, quality: str):
|
||||
fxt_new_task.get_jobs()[0].download_frames(
|
||||
[0],
|
||||
quality=quality,
|
||||
outdir=str(self.tmp_path),
|
||||
filename_pattern="frame-{frame_id}{frame_ext}",
|
||||
)
|
||||
|
||||
assert osp.isfile(self.tmp_path / "frame-0.jpg")
|
||||
assert self.stdout.getvalue() == ""
|
||||
|
||||
def test_can_upload_annotations(self, fxt_new_task: Task, fxt_coco_file: Path):
|
||||
pbar_out = io.StringIO()
|
||||
pbar = make_pbar(file=pbar_out)
|
||||
|
||||
fxt_new_task.get_jobs()[0].import_annotations(
|
||||
format_name="COCO 1.0", filename=str(fxt_coco_file), pbar=pbar
|
||||
)
|
||||
|
||||
assert "uploaded" in self.logger_stream.getvalue()
|
||||
assert "100%" in pbar_out.getvalue().strip("\r").split("\r")[-1]
|
||||
assert self.stdout.getvalue() == ""
|
||||
|
||||
def test_can_get_meta(self, fxt_new_task: Task):
|
||||
meta = fxt_new_task.get_jobs()[0].get_meta()
|
||||
|
||||
assert meta.image_quality == 80
|
||||
assert meta.size == 1
|
||||
assert len(meta.frames) == meta.size
|
||||
assert meta.frames[0].name == "img.png"
|
||||
assert meta.frames[0].width == 5
|
||||
assert meta.frames[0].height == 10
|
||||
assert not meta.deleted_frames
|
||||
assert self.stdout.getvalue() == ""
|
||||
|
||||
def test_can_remove_frames(self, fxt_new_task: Task):
|
||||
fxt_new_task.get_jobs()[0].remove_frames_by_ids([0])
|
||||
|
||||
meta = fxt_new_task.get_jobs()[0].get_meta()
|
||||
assert meta.deleted_frames == [0]
|
||||
assert self.stdout.getvalue() == ""
|
||||
|
||||
def test_can_get_issues(self, fxt_new_task: Task):
|
||||
issue = self.client.issues.create(
|
||||
models.IssueWriteRequest(
|
||||
frame=0,
|
||||
position=[2.0, 4.0],
|
||||
job=fxt_new_task.get_jobs()[0].id,
|
||||
message="hello",
|
||||
)
|
||||
)
|
||||
|
||||
job_issue_ids = set(j.id for j in fxt_new_task.get_jobs()[0].get_issues())
|
||||
|
||||
assert {issue.id} == job_issue_ids
|
||||
assert self.stdout.getvalue() == ""
|
||||
|
||||
def test_can_get_annotations(self, fxt_task_with_shapes: Task):
|
||||
anns = fxt_task_with_shapes.get_jobs()[0].get_annotations()
|
||||
|
||||
assert len(anns.shapes) == 1
|
||||
assert anns.shapes[0].type.value == "rectangle"
|
||||
assert self.stdout.getvalue() == ""
|
||||
|
||||
def test_can_set_annotations(self, fxt_new_task: Task):
|
||||
fxt_new_task.get_jobs()[0].set_annotations(
|
||||
models.LabeledDataRequest(
|
||||
tags=[models.LabeledImageRequest(frame=0, label_id=fxt_new_task.labels[0].id)],
|
||||
)
|
||||
)
|
||||
|
||||
anns = fxt_new_task.get_jobs()[0].get_annotations()
|
||||
|
||||
assert len(anns.tags) == 1
|
||||
assert self.stdout.getvalue() == ""
|
||||
|
||||
def test_can_clear_annotations(self, fxt_task_with_shapes: Task):
|
||||
fxt_task_with_shapes.get_jobs()[0].remove_annotations()
|
||||
|
||||
anns = fxt_task_with_shapes.get_jobs()[0].get_annotations()
|
||||
assert len(anns.tags) == 0
|
||||
assert len(anns.tracks) == 0
|
||||
assert len(anns.shapes) == 0
|
||||
assert self.stdout.getvalue() == ""
|
||||
|
||||
def test_can_remove_annotations(self, fxt_new_task: Task):
|
||||
fxt_new_task.get_jobs()[0].set_annotations(
|
||||
models.LabeledDataRequest(
|
||||
shapes=[
|
||||
models.LabeledShapeRequest(
|
||||
frame=0,
|
||||
label_id=fxt_new_task.labels[0].id,
|
||||
type="rectangle",
|
||||
points=[1, 1, 2, 2],
|
||||
),
|
||||
models.LabeledShapeRequest(
|
||||
frame=0,
|
||||
label_id=fxt_new_task.labels[0].id,
|
||||
type="rectangle",
|
||||
points=[2, 2, 3, 3],
|
||||
),
|
||||
],
|
||||
)
|
||||
)
|
||||
anns = fxt_new_task.get_jobs()[0].get_annotations()
|
||||
|
||||
fxt_new_task.get_jobs()[0].remove_annotations(ids=[anns.shapes[0].id])
|
||||
|
||||
anns = fxt_new_task.get_jobs()[0].get_annotations()
|
||||
assert len(anns.tags) == 0
|
||||
assert len(anns.tracks) == 0
|
||||
assert len(anns.shapes) == 1
|
||||
assert self.stdout.getvalue() == ""
|
||||
|
||||
def test_can_update_annotations(self, fxt_task_with_shapes: Task):
|
||||
fxt_task_with_shapes.get_jobs()[0].update_annotations(
|
||||
models.PatchedLabeledDataRequest(
|
||||
shapes=[
|
||||
models.LabeledShapeRequest(
|
||||
frame=0,
|
||||
label_id=fxt_task_with_shapes.labels[0].id,
|
||||
type="rectangle",
|
||||
points=[0, 1, 2, 3],
|
||||
),
|
||||
],
|
||||
tracks=[
|
||||
models.LabeledTrackRequest(
|
||||
frame=0,
|
||||
label_id=fxt_task_with_shapes.labels[0].id,
|
||||
shapes=[
|
||||
models.TrackedShapeRequest(
|
||||
frame=0, type="polygon", points=[3, 2, 2, 3, 3, 4]
|
||||
),
|
||||
],
|
||||
)
|
||||
],
|
||||
tags=[
|
||||
models.LabeledImageRequest(frame=0, label_id=fxt_task_with_shapes.labels[0].id)
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
anns = fxt_task_with_shapes.get_jobs()[0].get_annotations()
|
||||
assert len(anns.shapes) == 2
|
||||
assert len(anns.tracks) == 1
|
||||
assert len(anns.tags) == 1
|
||||
assert self.stdout.getvalue() == ""
|
||||
@ -0,0 +1,73 @@
|
||||
# Copyright (C) 2022 CVAT.ai Corporation
|
||||
#
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import io
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Tuple
|
||||
|
||||
import pytest
|
||||
from cvat_sdk import Client, models
|
||||
from cvat_sdk.api_client import exceptions
|
||||
|
||||
from shared.utils.config import USER_PASS
|
||||
|
||||
|
||||
class TestUserUsecases:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup(
|
||||
self,
|
||||
changedb, # force fixture call order to allow DB setup
|
||||
tmp_path: Path,
|
||||
fxt_logger: Tuple[Logger, io.StringIO],
|
||||
fxt_client: Client,
|
||||
fxt_stdout: io.StringIO,
|
||||
admin_user: str,
|
||||
):
|
||||
self.tmp_path = tmp_path
|
||||
_, self.logger_stream = fxt_logger
|
||||
self.client = fxt_client
|
||||
self.stdout = fxt_stdout
|
||||
self.user = admin_user
|
||||
self.client.login((self.user, USER_PASS))
|
||||
|
||||
yield
|
||||
|
||||
def test_can_retrieve_user(self):
|
||||
me = self.client.users.retrieve_current_user()
|
||||
|
||||
user = self.client.users.retrieve(me.id)
|
||||
|
||||
assert user.id == me.id
|
||||
assert user.username == self.user
|
||||
assert self.stdout.getvalue() == ""
|
||||
|
||||
def test_can_list_users(self):
|
||||
users = self.client.users.list()
|
||||
|
||||
assert self.user in set(u.username for u in users)
|
||||
assert self.stdout.getvalue() == ""
|
||||
|
||||
def test_can_update_user(self):
|
||||
user = self.client.users.retrieve_current_user()
|
||||
|
||||
user.update(models.PatchedUserRequest(first_name="foo", last_name="bar"))
|
||||
|
||||
retrieved_user = self.client.users.retrieve(user.id)
|
||||
assert retrieved_user.first_name == "foo"
|
||||
assert retrieved_user.last_name == "bar"
|
||||
assert user.first_name == retrieved_user.first_name
|
||||
assert user.last_name == retrieved_user.last_name
|
||||
assert self.stdout.getvalue() == ""
|
||||
|
||||
def test_can_remove_user(self):
|
||||
users = self.client.users.list()
|
||||
removed_user = next(u for u in users if u.username != self.user)
|
||||
|
||||
removed_user.remove()
|
||||
|
||||
with pytest.raises(exceptions.NotFoundException):
|
||||
removed_user.fetch()
|
||||
|
||||
assert self.stdout.getvalue() == ""
|
||||
Loading…
Reference in New Issue