diff --git a/.bandit b/.bandit index 17c7fbe6..ac77d7fc 100644 --- a/.bandit +++ b/.bandit @@ -6,3 +6,4 @@ # B406 : import_xml_sax # B410 : import_lxml skips: B101,B102,B320,B404,B406,B410 +exclude: **/tests/**,tests diff --git a/.github/workflows/bandit.yml b/.github/workflows/bandit.yml index 015a5600..1e7a4a19 100644 --- a/.github/workflows/bandit.yml +++ b/.github/workflows/bandit.yml @@ -33,7 +33,7 @@ jobs: echo "Bandit version: "$(bandit --version | head -1) echo "The files will be checked: "$(echo $CHANGED_FILES) - bandit $CHANGED_FILES --exclude '**/tests/**' -a file --ini ./.bandit -f html -o ./bandit_report/bandit_checks.html + bandit -a file --ini .bandit -f html -o ./bandit_report/bandit_checks.html $CHANGED_FILES deactivate else echo "No files with the \"py\" extension found" diff --git a/CHANGELOG.md b/CHANGELOG.md index 0a46838f..bc170598 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,7 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Possibility to display tags on frame - Support source and target storages (server part) - Tests for import/export annotation, dataset, backup from/to cloud storage -- Added Python SDK package (`cvat-sdk`) +- Added Python SDK package (`cvat-sdk`) () - Previews for jobs - Documentation for LDAP authentication () - OpenCV.js caching and autoload () @@ -28,7 +28,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed - Bumped nuclio version to 1.8.14 - Simplified running REST API tests. Extended CI-nightly workflow -- REST API tests are partially moved to Python SDK (`users`, `projects`, `tasks`) +- REST API tests are partially moved to Python SDK (`users`, `projects`, `tasks`, `issues`) - cvat-ui: Improve UI/UX on label, create task and create project forms () - Removed link to OpenVINO documentation () - Clarified meaning of chunking for videos @@ -51,6 +51,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Image search in cloud storage () - Reset password functionality () - Creating task with cloud storage data () +- Show empty tasks () ### Security - TDB diff --git a/cvat-cli/src/cvat_cli/cli.py b/cvat-cli/src/cvat_cli/cli.py index 1a20ea3f..3a3dedb2 100644 --- a/cvat-cli/src/cvat_cli/cli.py +++ b/cvat-cli/src/cvat_cli/cli.py @@ -1,4 +1,3 @@ -# Copyright (C) 2020-2022 Intel Corporation # Copyright (C) 2022 CVAT.ai Corporation # # SPDX-License-Identifier: MIT @@ -11,7 +10,7 @@ from typing import Dict, List, Sequence, Tuple import tqdm from cvat_sdk import Client, models from cvat_sdk.core.helpers import TqdmProgressReporter -from cvat_sdk.core.types import ResourceType +from cvat_sdk.core.proxies.tasks import ResourceType class CLI: @@ -26,7 +25,7 @@ class CLI: def tasks_list(self, *, use_json_output: bool = False, **kwargs): """List all tasks in either basic or JSON format.""" - results = self.client.list_tasks(return_json=use_json_output, **kwargs) + results = self.client.tasks.list(return_json=use_json_output, **kwargs) if use_json_output: print(json.dumps(json.loads(results), indent=2)) else: @@ -50,7 +49,7 @@ class CLI: """ Create a new task with the given name and labels JSON and add the files to it. """ - task = self.client.create_task( + task = self.client.tasks.create_from_data( spec=models.TaskWriteRequest(name=name, labels=labels, **kwargs), resource_type=resource_type, resources=resources, @@ -66,7 +65,7 @@ class CLI: def tasks_delete(self, task_ids: Sequence[int]) -> None: """Delete a list of tasks, ignoring those which don't exist.""" - self.client.delete_tasks(task_ids=task_ids) + self.client.tasks.remove_by_ids(task_ids=task_ids) def tasks_frames( self, @@ -80,11 +79,11 @@ class CLI: Download the requested frame numbers for a task and save images as task__frame_.jpg. """ - self.client.retrieve_task(task_id=task_id).download_frames( + self.client.tasks.retrieve(obj_id=task_id).download_frames( frame_ids=frame_ids, outdir=outdir, quality=quality, - filename_pattern="task_{task_id}_frame_{frame_id:06d}{frame_ext}", + filename_pattern=f"task_{task_id}" + "_frame_{frame_id:06d}{frame_ext}", ) def tasks_dump( @@ -99,7 +98,7 @@ class CLI: """ Download annotations for a task in the specified format (e.g. 'YOLO ZIP 1.0'). """ - self.client.retrieve_task(task_id=task_id).export_dataset( + self.client.tasks.retrieve(obj_id=task_id).export_dataset( format_name=fileformat, filename=filename, pbar=self._make_pbar(), @@ -112,7 +111,7 @@ class CLI: ) -> None: """Upload annotations for a task in the specified format (e.g. 'YOLO ZIP 1.0').""" - self.client.retrieve_task(task_id=task_id).import_annotations( + self.client.tasks.retrieve(obj_id=task_id).import_annotations( format_name=fileformat, filename=filename, status_check_period=status_check_period, @@ -121,13 +120,13 @@ class CLI: def tasks_export(self, task_id: str, filename: str, *, status_check_period: int = 2) -> None: """Download a task backup""" - self.client.retrieve_task(task_id=task_id).download_backup( + self.client.tasks.retrieve(obj_id=task_id).download_backup( filename=filename, status_check_period=status_check_period, pbar=self._make_pbar() ) def tasks_import(self, filename: str, *, status_check_period: int = 2) -> None: """Import a task from a backup file""" - self.client.create_task_from_backup( + self.client.tasks.create_from_backup( filename=filename, status_check_period=status_check_period, pbar=self._make_pbar() ) diff --git a/cvat-cli/src/cvat_cli/parser.py b/cvat-cli/src/cvat_cli/parser.py index a4dc8179..89e2088d 100644 --- a/cvat-cli/src/cvat_cli/parser.py +++ b/cvat-cli/src/cvat_cli/parser.py @@ -10,7 +10,7 @@ import logging import os from distutils.util import strtobool -from cvat_sdk.core.types import ResourceType +from cvat_sdk.core.proxies.tasks import ResourceType from .version import VERSION diff --git a/cvat-sdk/.gitignore b/cvat-sdk/.gitignore index 7552f3ba..523f1997 100644 --- a/cvat-sdk/.gitignore +++ b/cvat-sdk/.gitignore @@ -74,4 +74,4 @@ cvat_sdk/api_client/ requirements/ docs/ setup.py -README.md \ No newline at end of file +README.md diff --git a/cvat-sdk/cvat_sdk/core/client.py b/cvat-sdk/cvat_sdk/core/client.py index 1518d0c5..db014def 100644 --- a/cvat-sdk/cvat_sdk/core/client.py +++ b/cvat-sdk/cvat_sdk/core/client.py @@ -1,4 +1,3 @@ -# Copyright (C) 2020-2022 Intel Corporation # Copyright (C) 2022 CVAT.ai Corporation # # SPDX-License-Identifier: MIT @@ -6,23 +5,22 @@ from __future__ import annotations -import json import logging -import os.path as osp import urllib.parse from time import sleep -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Dict, Optional, Sequence, Tuple import attrs +import urllib3 -from cvat_sdk.api_client import ApiClient, ApiException, ApiValueError, Configuration, models -from cvat_sdk.core.git import create_git_repo -from cvat_sdk.core.helpers import get_paginated_collection -from cvat_sdk.core.progress import ProgressReporter -from cvat_sdk.core.tasks import TaskProxy -from cvat_sdk.core.types import ResourceType -from cvat_sdk.core.uploading import Uploader -from cvat_sdk.core.utils import assert_status +from cvat_sdk.api_client import ApiClient, Configuration, models +from cvat_sdk.core.helpers import expect_status +from cvat_sdk.core.proxies.issues import CommentsRepo, IssuesRepo +from cvat_sdk.core.proxies.jobs import JobsRepo +from cvat_sdk.core.proxies.model_proxy import Repo +from cvat_sdk.core.proxies.projects import ProjectsRepo +from cvat_sdk.core.proxies.tasks import TasksRepo +from cvat_sdk.core.proxies.users import UsersRepo @attrs.define @@ -43,11 +41,13 @@ class Client: ): # TODO: use requests instead of urllib3 in ApiClient # TODO: try to autodetect schema - self._api_map = _CVAT_API_V2(url) + self.api_map = CVAT_API_V2(url) self.api = ApiClient(Configuration(host=url)) self.logger = logger or logging.getLogger(__name__) self.config = config or Config() + self._repos: Dict[str, Repo] = {} + def __enter__(self): self.api.__enter__() return self @@ -67,150 +67,93 @@ class Client: assert "csrftoken" in self.api.cookies self.api.set_default_header("Authorization", "Token " + auth.key) - def create_task( - self, - spec: models.ITaskWriteRequest, - resource_type: ResourceType, - resources: Sequence[str], + def _has_credentials(self): + return ( + ("sessionid" in self.api.cookies) + or ("csrftoken" in self.api.cookies) + or (self.api.get_common_headers().get("Authorization", "")) + ) + + def logout(self): + if self._has_credentials(): + self.api.auth_api.create_logout() + + def wait_for_completion( + self: Client, + url: str, *, - data_params: Optional[Dict[str, Any]] = None, - annotation_path: str = "", - annotation_format: str = "CVAT XML 1.1", - status_check_period: int = None, - dataset_repository_url: str = "", - use_lfs: bool = False, - pbar: Optional[ProgressReporter] = None, - ) -> TaskProxy: - """ - Create a new task with the given name and labels JSON and - add the files to it. - - Returns: id of the created task - """ + success_status: int, + status_check_period: Optional[int] = None, + query_params: Optional[Dict[str, Any]] = None, + post_params: Optional[Dict[str, Any]] = None, + method: str = "POST", + positive_statuses: Optional[Sequence[int]] = None, + ) -> urllib3.HTTPResponse: if status_check_period is None: status_check_period = self.config.status_check_period - if getattr(spec, "project_id", None) and getattr(spec, "labels", None): - raise ApiValueError( - "Can't set labels to a task inside a project. " - "Tasks inside a project use project's labels.", - ["labels"], - ) - (task, _) = self.api.tasks_api.create(spec) - self.logger.info("Created task ID: %s NAME: %s", task.id, task.name) + positive_statuses = set(positive_statuses) | {success_status} - task = TaskProxy(self, task) - task.upload_data(resource_type, resources, pbar=pbar, params=data_params) - - self.logger.info("Awaiting for task %s creation...", task.id) - status = None - while status != models.RqStatusStateEnum.allowed_values[("value",)]["FINISHED"]: + while True: sleep(status_check_period) - (status, _) = self.api.tasks_api.retrieve_status(task.id) - - self.logger.info( - "Task %s creation status=%s, message=%s", - task.id, - status.state.value, - status.message, - ) - - if status.state.value == models.RqStatusStateEnum.allowed_values[("value",)]["FAILED"]: - raise ApiException(status=status.state.value, reason=status.message) - - status = status.state.value - if annotation_path: - task.import_annotations(annotation_format, annotation_path, pbar=pbar) - - if dataset_repository_url: - create_git_repo( - self, - task_id=task.id, - repo_url=dataset_repository_url, - status_check_period=status_check_period, - use_lfs=use_lfs, + response = self.api.rest_client.request( + method=method, + url=url, + headers=self.api.get_common_headers(), + query_params=query_params, + post_params=post_params, ) - task.fetch() + self.logger.debug("STATUS %s", response.status) + expect_status(positive_statuses, response) + if response.status == success_status: + break - return task + return response - def list_tasks( - self, *, return_json: bool = False, **kwargs - ) -> Union[List[TaskProxy], List[Dict[str, Any]]]: - """List all tasks in either basic or JSON format.""" + def _get_repo(self, key: str) -> Repo: + _repo_map = { + "tasks": TasksRepo, + "projects": ProjectsRepo, + "jobs": JobsRepo, + "users": UsersRepo, + "issues": IssuesRepo, + "comments": CommentsRepo, + } - results = get_paginated_collection( - endpoint=self.api.tasks_api.list_endpoint, return_json=return_json, **kwargs - ) + repo = self._repos.get(key, None) + if repo is None: + repo = _repo_map[key](self) + self._repos[key] = repo + return repo - if return_json: - return json.dumps(results) - - return [TaskProxy(self, v) for v in results] - - def retrieve_task(self, task_id: int) -> TaskProxy: - (task, _) = self.api.tasks_api.retrieve(task_id) - return TaskProxy(self, task) - - def delete_tasks(self, task_ids: Sequence[int]): - """ - Delete a list of tasks, ignoring those which don't exist. - """ - - for task_id in task_ids: - (_, response) = self.api.tasks_api.destroy(task_id, _check_status=False) - if 200 <= response.status <= 299: - self.logger.info(f"Task ID {task_id} deleted") - elif response.status == 404: - self.logger.info(f"Task ID {task_id} not found") - else: - self.logger.warning( - f"Failed to delete task ID {task_id}: " - f"{response.msg} (status {response.status})" - ) - - def create_task_from_backup( - self, - filename: str, - *, - status_check_period: int = None, - pbar: Optional[ProgressReporter] = None, - ) -> TaskProxy: - """ - Import a task from a backup file - """ - if status_check_period is None: - status_check_period = self.config.status_check_period + @property + def tasks(self) -> TasksRepo: + return self._get_repo("tasks") - params = {"filename": osp.basename(filename)} - url = self._api_map.make_endpoint_url(self.api.tasks_api.create_backup_endpoint.path) - uploader = Uploader(self) - response = uploader.upload_file( - url, filename, meta=params, query_params=params, pbar=pbar, logger=self.logger.debug - ) - - rq_id = json.loads(response.data)["rq_id"] + @property + def projects(self) -> ProjectsRepo: + return self._get_repo("projects") - # check task status - while True: - sleep(status_check_period) + @property + def jobs(self) -> JobsRepo: + return self._get_repo("jobs") - response = self.api.rest_client.POST( - url, post_params={"rq_id": rq_id}, headers=self.api.get_common_headers() - ) - if response.status == 201: - break - assert_status(202, response) + @property + def users(self) -> UsersRepo: + return self._get_repo("users") - task_id = json.loads(response.data)["id"] - self.logger.info(f"Task has been imported sucessfully. Task ID: {task_id}") + @property + def issues(self) -> IssuesRepo: + return self._get_repo("issues") - return self.retrieve_task(task_id) + @property + def comments(self) -> CommentsRepo: + return self._get_repo("comments") -class _CVAT_API_V2: +class CVAT_API_V2: """Build parameterized API URLs""" def __init__(self, host, https=False): diff --git a/cvat-sdk/cvat_sdk/core/downloading.py b/cvat-sdk/cvat_sdk/core/downloading.py index 50f5a137..a270ffc4 100644 --- a/cvat-sdk/cvat_sdk/core/downloading.py +++ b/cvat-sdk/cvat_sdk/core/downloading.py @@ -8,8 +8,9 @@ from __future__ import annotations import os import os.path as osp from contextlib import closing -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Any, Dict, Optional +from cvat_sdk.api_client.api_client import Endpoint from cvat_sdk.core.progress import ProgressReporter if TYPE_CHECKING: @@ -17,8 +18,12 @@ if TYPE_CHECKING: class Downloader: + """ + Implements common downloading protocols + """ + def __init__(self, client: Client): - self.client = client + self._client = client def download_file( self, @@ -29,8 +34,7 @@ class Downloader: pbar: Optional[ProgressReporter] = None, ) -> None: """ - Downloads the file from url into a temporary file, then renames it - to the requested name. + Downloads the file from url into a temporary file, then renames it to the requested name. """ CHUNK_SIZE = 10 * 2**20 @@ -41,10 +45,10 @@ class Downloader: if osp.exists(tmp_path): raise FileExistsError(f"Can't write temporary file '{tmp_path}' - file exists") - response = self.client.api.rest_client.GET( + response = self._client.api.rest_client.GET( url, _request_timeout=timeout, - headers=self.client.api.get_common_headers(), + headers=self._client.api.get_common_headers(), _parse_response=False, ) with closing(response): @@ -72,3 +76,38 @@ class Downloader: except: os.unlink(tmp_path) raise + + def prepare_and_download_file_from_endpoint( + self, + endpoint: Endpoint, + filename: str, + *, + url_params: Optional[Dict[str, Any]] = None, + query_params: Optional[Dict[str, Any]] = None, + pbar: Optional[ProgressReporter] = None, + status_check_period: Optional[int] = None, + ): + client = self._client + if status_check_period is None: + status_check_period = client.config.status_check_period + + client.logger.info("Waiting for the server to prepare the file...") + + url = client.api_map.make_endpoint_url( + endpoint.path, kwsub=url_params, query_params=query_params + ) + client.wait_for_completion( + url, + method="GET", + positive_statuses=[202], + success_status=201, + status_check_period=status_check_period, + ) + + query_params = dict(query_params or {}) + query_params["action"] = "download" + url = client.api_map.make_endpoint_url( + endpoint.path, kwsub=url_params, query_params=query_params + ) + downloader = Downloader(client) + downloader.download_file(url, output_path=filename, pbar=pbar) diff --git a/cvat-sdk/cvat_sdk/core/git.py b/cvat-sdk/cvat_sdk/core/git.py index 828d3127..44e71ea9 100644 --- a/cvat-sdk/cvat_sdk/core/git.py +++ b/cvat-sdk/cvat_sdk/core/git.py @@ -1,4 +1,3 @@ -# Copyright (C) 2020-2022 Intel Corporation # Copyright (C) 2022 CVAT.ai Corporation # # SPDX-License-Identifier: MIT @@ -27,7 +26,7 @@ def create_git_repo( common_headers = client.api.get_common_headers() response = client.api.rest_client.POST( - client._api_map.git_create(task_id), + client.api_map.git_create(task_id), post_params={"path": repo_url, "lfs": use_lfs, "tid": task_id}, headers=common_headers, ) @@ -36,7 +35,7 @@ def create_git_repo( client.logger.info(f"Create RQ ID: {rq_id}") client.logger.debug("Awaiting a dataset repository to be created for the task %s...", task_id) - check_url = client._api_map.git_check(rq_id) + check_url = client.api_map.git_check(rq_id) status = None while status != "finished": sleep(status_check_period) diff --git a/cvat-sdk/cvat_sdk/core/helpers.py b/cvat-sdk/cvat_sdk/core/helpers.py index 8b8120c8..6dda3b68 100644 --- a/cvat-sdk/cvat_sdk/core/helpers.py +++ b/cvat-sdk/cvat_sdk/core/helpers.py @@ -6,13 +6,14 @@ from __future__ import annotations import io import json -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, Iterable, List, Optional, Union import tqdm +import urllib3 +from cvat_sdk import exceptions from cvat_sdk.api_client.api_client import Endpoint from cvat_sdk.core.progress import ProgressReporter -from cvat_sdk.core.utils import assert_status def get_paginated_collection( @@ -26,7 +27,7 @@ def get_paginated_collection( page = 1 while True: (page_contents, response) = endpoint.call_with_http_info(**kwargs, page=page) - assert_status(200, response) + expect_status(200, response) if return_json: results.extend(json.loads(response.data).get("results", [])) @@ -86,3 +87,18 @@ class StreamWithProgress: def tell(self): return self.stream.tell() + + +def expect_status(codes: Union[int, Iterable[int]], response: urllib3.HTTPResponse) -> None: + if not hasattr(codes, "__iter__"): + codes = [codes] + + if response.status in codes: + return + + if 300 <= response.status <= 500: + raise exceptions.ApiException(response.status, reason=response.msg, http_resp=response) + else: + raise exceptions.ApiException( + response.status, reason="Unexpected status code received", http_resp=response + ) diff --git a/cvat-sdk/cvat_sdk/core/proxies/__init__.py b/cvat-sdk/cvat_sdk/core/proxies/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/cvat-sdk/cvat_sdk/core/proxies/annotations.py b/cvat-sdk/cvat_sdk/core/proxies/annotations.py new file mode 100644 index 00000000..96c50b69 --- /dev/null +++ b/cvat-sdk/cvat_sdk/core/proxies/annotations.py @@ -0,0 +1,66 @@ +# Copyright (C) 2022 CVAT.ai Corporation +# +# SPDX-License-Identifier: MIT + +from abc import ABC +from enum import Enum +from typing import Optional, Sequence + +from cvat_sdk import models +from cvat_sdk.core.proxies.model_proxy import _EntityT + + +class AnnotationUpdateAction(Enum): + CREATE = "create" + UPDATE = "update" + DELETE = "delete" + + +class AnnotationCrudMixin(ABC): + # TODO: refactor + + @property + def _put_annotations_data_param(self) -> str: + ... + + def get_annotations(self: _EntityT) -> models.ILabeledData: + (annotations, _) = self.api.retrieve_annotations(getattr(self, self._model_id_field)) + return annotations + + def set_annotations(self: _EntityT, data: models.ILabeledDataRequest): + self.api.update_annotations( + getattr(self, self._model_id_field), **{self._put_annotations_data_param: data} + ) + + def update_annotations( + self: _EntityT, + data: models.IPatchedLabeledDataRequest, + *, + action: AnnotationUpdateAction = AnnotationUpdateAction.UPDATE, + ): + self.api.partial_update_annotations( + action=action.value, + id=getattr(self, self._model_id_field), + patched_labeled_data_request=data, + ) + + def remove_annotations(self: _EntityT, *, ids: Optional[Sequence[int]] = None): + if ids: + anns = self.get_annotations() + + if not isinstance(ids, set): + ids = set(ids) + + anns_to_remove = models.PatchedLabeledDataRequest( + tags=[models.LabeledImageRequest(**a.to_dict()) for a in anns.tags if a.id in ids], + tracks=[ + models.LabeledTrackRequest(**a.to_dict()) for a in anns.tracks if a.id in ids + ], + shapes=[ + models.LabeledShapeRequest(**a.to_dict()) for a in anns.shapes if a.id in ids + ], + ) + + self.update_annotations(anns_to_remove, action=AnnotationUpdateAction.DELETE) + else: + self.api.destroy_annotations(getattr(self, self._model_id_field)) diff --git a/cvat-sdk/cvat_sdk/core/proxies/issues.py b/cvat-sdk/cvat_sdk/core/proxies/issues.py new file mode 100644 index 00000000..5583fd08 --- /dev/null +++ b/cvat-sdk/cvat_sdk/core/proxies/issues.py @@ -0,0 +1,60 @@ +# Copyright (C) 2022 CVAT.ai Corporation +# +# SPDX-License-Identifier: MIT + +from __future__ import annotations + +from cvat_sdk.api_client import apis, models +from cvat_sdk.core.proxies.model_proxy import ( + ModelCreateMixin, + ModelDeleteMixin, + ModelListMixin, + ModelRetrieveMixin, + ModelUpdateMixin, + build_model_bases, +) + +_CommentEntityBase, _CommentRepoBase = build_model_bases( + models.CommentRead, apis.CommentsApi, api_member_name="comments_api" +) + + +class Comment( + models.ICommentRead, + _CommentEntityBase, + ModelUpdateMixin[models.IPatchedCommentWriteRequest], + ModelDeleteMixin, +): + _model_partial_update_arg = "patched_comment_write_request" + + +class CommentsRepo( + _CommentRepoBase, + ModelListMixin[Comment], + ModelCreateMixin[Comment, models.ICommentWriteRequest], + ModelRetrieveMixin[Comment], +): + _entity_type = Comment + + +_IssueEntityBase, _IssueRepoBase = build_model_bases( + models.IssueRead, apis.IssuesApi, api_member_name="issues_api" +) + + +class Issue( + models.IIssueRead, + _IssueEntityBase, + ModelUpdateMixin[models.IPatchedIssueWriteRequest], + ModelDeleteMixin, +): + _model_partial_update_arg = "patched_issue_write_request" + + +class IssuesRepo( + _IssueRepoBase, + ModelListMixin[Issue], + ModelCreateMixin[Issue, models.IIssueWriteRequest], + ModelRetrieveMixin[Issue], +): + _entity_type = Issue diff --git a/cvat-sdk/cvat_sdk/core/proxies/jobs.py b/cvat-sdk/cvat_sdk/core/proxies/jobs.py new file mode 100644 index 00000000..c1818870 --- /dev/null +++ b/cvat-sdk/cvat_sdk/core/proxies/jobs.py @@ -0,0 +1,166 @@ +# Copyright (C) 2022 CVAT.ai Corporation +# +# SPDX-License-Identifier: MIT + +from __future__ import annotations + +import io +import mimetypes +import os +import os.path as osp +from typing import List, Optional, Sequence + +from PIL import Image + +from cvat_sdk.api_client import apis, models +from cvat_sdk.core.downloading import Downloader +from cvat_sdk.core.helpers import get_paginated_collection +from cvat_sdk.core.progress import ProgressReporter +from cvat_sdk.core.proxies.annotations import AnnotationCrudMixin +from cvat_sdk.core.proxies.issues import Issue +from cvat_sdk.core.proxies.model_proxy import ( + ModelListMixin, + ModelRetrieveMixin, + ModelUpdateMixin, + build_model_bases, +) +from cvat_sdk.core.uploading import AnnotationUploader + +_JobEntityBase, _JobRepoBase = build_model_bases( + models.JobRead, apis.JobsApi, api_member_name="jobs_api" +) + + +class Job( + models.IJobRead, + _JobEntityBase, + ModelUpdateMixin[models.IPatchedJobWriteRequest], + AnnotationCrudMixin, +): + _model_partial_update_arg = "patched_job_write_request" + _put_annotations_data_param = "job_annotations_update_request" + + def import_annotations( + self, + format_name: str, + filename: str, + *, + status_check_period: Optional[int] = None, + pbar: Optional[ProgressReporter] = None, + ): + """ + Upload annotations for a job in the specified format (e.g. 'YOLO ZIP 1.0'). + """ + + AnnotationUploader(self._client).upload_file_and_wait( + self.api.create_annotations_endpoint, + filename, + format_name, + url_params={"id": self.id}, + pbar=pbar, + status_check_period=status_check_period, + ) + + self._client.logger.info(f"Annotation file '{filename}' for job #{self.id} uploaded") + + def export_dataset( + self, + format_name: str, + filename: str, + *, + pbar: Optional[ProgressReporter] = None, + status_check_period: Optional[int] = None, + include_images: bool = True, + ) -> None: + """ + Download annotations for a job in the specified format (e.g. 'YOLO ZIP 1.0'). + """ + if include_images: + endpoint = self.api.retrieve_dataset_endpoint + else: + endpoint = self.api.retrieve_annotations_endpoint + + Downloader(self._client).prepare_and_download_file_from_endpoint( + endpoint=endpoint, + filename=filename, + url_params={"id": self.id}, + query_params={"format": format_name}, + pbar=pbar, + status_check_period=status_check_period, + ) + + self._client.logger.info(f"Dataset for job {self.id} has been downloaded to {filename}") + + def get_frame( + self, + frame_id: int, + *, + quality: Optional[str] = None, + ) -> io.RawIOBase: + (_, response) = self.api.retrieve_data( + self.id, number=frame_id, quality=quality, type="frame" + ) + return io.BytesIO(response.data) + + def get_preview( + self, + ) -> io.RawIOBase: + (_, response) = self.api.retrieve_data(self.id, type="preview") + return io.BytesIO(response.data) + + def download_frames( + self, + frame_ids: Sequence[int], + *, + outdir: str = "", + quality: str = "original", + filename_pattern: str = "frame_{frame_id:06d}{frame_ext}", + ) -> Optional[List[Image.Image]]: + """ + Download the requested frame numbers for a job and save images as outdir/filename_pattern + """ + # TODO: add arg descriptions in schema + os.makedirs(outdir, exist_ok=True) + + for frame_id in frame_ids: + frame_bytes = self.get_frame(frame_id, quality=quality) + + im = Image.open(frame_bytes) + mime_type = im.get_format_mimetype() or "image/jpg" + im_ext = mimetypes.guess_extension(mime_type) + + # FIXME It is better to use meta information from the server + # to determine the extension + # replace '.jpe' or '.jpeg' with a more used '.jpg' + if im_ext in (".jpe", ".jpeg", None): + im_ext = ".jpg" + + outfile = filename_pattern.format(frame_id=frame_id, frame_ext=im_ext) + im.save(osp.join(outdir, outfile)) + + def get_meta(self) -> models.IDataMetaRead: + (meta, _) = self.api.retrieve_data_meta(self.id) + return meta + + def get_frames_info(self) -> List[models.IFrameMeta]: + return self.get_meta().frames + + def remove_frames_by_ids(self, ids: Sequence[int]) -> None: + self._client.api.tasks_api.jobs_partial_update_data_meta( + self.id, + patched_data_meta_write_request=models.PatchedDataMetaWriteRequest(deleted_frames=ids), + ) + + def get_issues(self) -> List[Issue]: + return [Issue(self._client, m) for m in self.api.list_issues(id=self.id)[0]] + + def get_commits(self) -> List[models.IJobCommit]: + return get_paginated_collection(self.api.list_commits_endpoint, id=self.id) + + +class JobsRepo( + _JobRepoBase, + ModelListMixin[Job], + ModelRetrieveMixin[Job], +): + _entity_type = Job diff --git a/cvat-sdk/cvat_sdk/core/proxies/model_proxy.py b/cvat-sdk/cvat_sdk/core/proxies/model_proxy.py new file mode 100644 index 00000000..04673481 --- /dev/null +++ b/cvat-sdk/cvat_sdk/core/proxies/model_proxy.py @@ -0,0 +1,213 @@ +# Copyright (C) 2022 CVAT.ai Corporation +# +# SPDX-License-Identifier: MIT + +from __future__ import annotations + +import json +from abc import ABC +from copy import deepcopy +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Generic, + List, + Literal, + Optional, + Tuple, + Type, + TypeVar, + Union, + overload, +) + +from typing_extensions import Self + +from cvat_sdk.api_client.model_utils import IModelData, ModelNormal, to_json +from cvat_sdk.core.helpers import get_paginated_collection + +if TYPE_CHECKING: + from cvat_sdk.core.client import Client + +IModel = TypeVar("IModel", bound=IModelData) +ModelType = TypeVar("ModelType", bound=ModelNormal) +ApiType = TypeVar("ApiType") + + +class ModelProxy(ABC, Generic[ModelType, ApiType]): + _client: Client + + @property + def _api_member_name(self) -> str: + ... + + def __init__(self, client: Client) -> None: + self.__dict__["_client"] = client + + @classmethod + def get_api(cls, client: Client) -> ApiType: + return getattr(client.api, cls._api_member_name) + + @property + def api(self) -> ApiType: + return self.get_api(self._client) + + +class Entity(ModelProxy[ModelType, ApiType]): + """ + Represents a single object. Implements related operations and provides access to data members. + """ + + _model: ModelType + + def __init__(self, client: Client, model: ModelType) -> None: + super().__init__(client) + self.__dict__["_model"] = model + + @property + def _model_id_field(self) -> str: + return "id" + + def __getattr__(self, __name: str) -> Any: + # NOTE: be aware of potential problems with throwing AttributeError from @property + # in derived classes! + # https://medium.com/@ceshine/python-debugging-pitfall-mixed-use-of-property-and-getattr-f89e0ede13f1 + return self._model[__name] + + def __str__(self) -> str: + return str(self._model) + + def __repr__(self) -> str: + return f"<{self.__class__.__name__}: id={getattr(self, self._model_id_field)}>" + + +class Repo(ModelProxy[ModelType, ApiType]): + """ + Represents a collection of corresponding Entity objects. + Implements group and management operations for entities. + """ + + _entity_type: Type[Entity[ModelType, ApiType]] + + +### Utilities + + +def build_model_bases( + mt: Type[ModelType], at: Type[ApiType], *, api_member_name: Optional[str] = None +) -> Tuple[Type[Entity[ModelType, ApiType]], Type[Repo[ModelType, ApiType]]]: + """ + Helps to remove code duplication in declarations of derived classes + """ + + class _EntityBase(Entity[ModelType, ApiType]): + if api_member_name: + _api_member_name = api_member_name + + class _RepoBase(Repo[ModelType, ApiType]): + if api_member_name: + _api_member_name = api_member_name + + return _EntityBase, _RepoBase + + +### CRUD mixins + +_EntityT = TypeVar("_EntityT", bound=Entity) + +#### Repo mixins + + +class ModelCreateMixin(Generic[_EntityT, IModel]): + def create(self: Repo, spec: Union[Dict[str, Any], IModel]) -> _EntityT: + """ + Creates a new object on the server and returns corresponding local object + """ + + (model, _) = self.api.create(spec) + return self._entity_type(self._client, model) + + +class ModelRetrieveMixin(Generic[_EntityT]): + def retrieve(self: Repo, obj_id: int) -> _EntityT: + """ + Retrieves an object from server by ID + """ + + (model, _) = self.api.retrieve(id=obj_id) + return self._entity_type(self._client, model) + + +class ModelListMixin(Generic[_EntityT]): + @overload + def list(self: Repo, *, return_json: Literal[False] = False) -> List[_EntityT]: + ... + + @overload + def list(self: Repo, *, return_json: Literal[True] = False) -> List[Any]: + ... + + def list(self: Repo, *, return_json: bool = False) -> List[Union[_EntityT, Any]]: + """ + Retrieves all objects from the server and returns them in basic or JSON format. + """ + + results = get_paginated_collection(endpoint=self.api.list_endpoint, return_json=return_json) + + if return_json: + return json.dumps(results) + return [self._entity_type(self._client, model) for model in results] + + +#### Entity mixins + + +class ModelUpdateMixin(ABC, Generic[IModel]): + @property + def _model_partial_update_arg(self: Entity) -> str: + ... + + def _export_update_fields( + self: Entity, overrides: Optional[Union[Dict[str, Any], IModel]] = None + ) -> Dict[str, Any]: + # TODO: support field conversion and assignment updating + # fields = to_json(self._model) + + if isinstance(overrides, ModelNormal): + overrides = to_json(overrides) + fields = deepcopy(overrides) + + return fields + + def fetch(self: Entity) -> Self: + """ + Updates current object from the server + """ + + # TODO: implement revision checking + (self._model, _) = self.api.retrieve(id=getattr(self, self._model_id_field)) + return self + + def update(self: Entity, values: Union[Dict[str, Any], IModel]) -> Self: + """ + Commits local model changes to the server + """ + + # TODO: implement revision checking + self.api.partial_update( + id=getattr(self, self._model_id_field), + **{self._model_partial_update_arg: self._export_update_fields(values)}, + ) + + # TODO: use the response model, once input and output models are same + return self.fetch() + + +class ModelDeleteMixin: + def remove(self: Entity) -> None: + """ + Removes current object on the server + """ + + self.api.destroy(id=getattr(self, self._model_id_field)) diff --git a/cvat-sdk/cvat_sdk/core/proxies/projects.py b/cvat-sdk/cvat_sdk/core/proxies/projects.py new file mode 100644 index 00000000..5906a80a --- /dev/null +++ b/cvat-sdk/cvat_sdk/core/proxies/projects.py @@ -0,0 +1,187 @@ +# Copyright (C) 2022 CVAT.ai Corporation +# +# SPDX-License-Identifier: MIT + +from __future__ import annotations + +import json +import os.path as osp +from typing import Optional + +from cvat_sdk.api_client import apis, models +from cvat_sdk.core.downloading import Downloader +from cvat_sdk.core.progress import ProgressReporter +from cvat_sdk.core.proxies.model_proxy import ( + ModelCreateMixin, + ModelDeleteMixin, + ModelListMixin, + ModelRetrieveMixin, + ModelUpdateMixin, + build_model_bases, +) +from cvat_sdk.core.uploading import DatasetUploader, Uploader + +_ProjectEntityBase, _ProjectRepoBase = build_model_bases( + models.ProjectRead, apis.ProjectsApi, api_member_name="projects_api" +) + + +class Project( + _ProjectEntityBase, models.IProjectRead, ModelUpdateMixin[models.IPatchedProjectWriteRequest] +): + _model_partial_update_arg = "patched_project_write_request" + + def import_dataset( + self, + format_name: str, + filename: str, + *, + status_check_period: Optional[int] = None, + pbar: Optional[ProgressReporter] = None, + ): + """ + Import dataset for a project in the specified format (e.g. 'YOLO ZIP 1.0'). + """ + + DatasetUploader(self._client).upload_file_and_wait( + self.api.create_dataset_endpoint, + filename, + format_name, + url_params={"id": self.id}, + pbar=pbar, + status_check_period=status_check_period, + ) + + self._client.logger.info(f"Annotation file '{filename}' for project #{self.id} uploaded") + + def export_dataset( + self, + format_name: str, + filename: str, + *, + pbar: Optional[ProgressReporter] = None, + status_check_period: Optional[int] = None, + include_images: bool = True, + ) -> None: + """ + Download annotations for a project in the specified format (e.g. 'YOLO ZIP 1.0'). + """ + if include_images: + endpoint = self.api.retrieve_dataset_endpoint + else: + endpoint = self.api.retrieve_annotations_endpoint + + Downloader(self._client).prepare_and_download_file_from_endpoint( + endpoint=endpoint, + filename=filename, + url_params={"id": self.id}, + query_params={"format": format_name}, + pbar=pbar, + status_check_period=status_check_period, + ) + + self._client.logger.info(f"Dataset for project {self.id} has been downloaded to {filename}") + + def download_backup( + self, + filename: str, + *, + status_check_period: int = None, + pbar: Optional[ProgressReporter] = None, + ) -> None: + """ + Download a project backup + """ + + Downloader(self._client).prepare_and_download_file_from_endpoint( + self.api.retrieve_backup_endpoint, + filename=filename, + pbar=pbar, + status_check_period=status_check_period, + url_params={"id": self.id}, + ) + + self._client.logger.info(f"Backup for project {self.id} has been downloaded to {filename}") + + def get_annotations(self) -> models.ILabeledData: + (annotations, _) = self.api.retrieve_annotations(self.id) + return annotations + + +class ProjectsRepo( + _ProjectRepoBase, + ModelCreateMixin[Project, models.IProjectWriteRequest], + ModelListMixin[Project], + ModelRetrieveMixin[Project], + ModelDeleteMixin, +): + _entity_type = Project + + def create_from_dataset( + self, + spec: models.IProjectWriteRequest, + *, + dataset_path: str = "", + dataset_format: str = "CVAT XML 1.1", + status_check_period: int = None, + pbar: Optional[ProgressReporter] = None, + ) -> Project: + """ + Create a new project with the given name and labels JSON and + add the files to it. + + Returns: id of the created project + """ + project = self.create(spec=spec) + self._client.logger.info("Created project ID: %s NAME: %s", project.id, project.name) + + if dataset_path: + project.import_dataset( + format_name=dataset_format, + filename=dataset_path, + pbar=pbar, + status_check_period=status_check_period, + ) + + project.fetch() + return project + + def create_from_backup( + self, + filename: str, + *, + status_check_period: int = None, + pbar: Optional[ProgressReporter] = None, + ) -> Project: + """ + Import a project from a backup file + """ + if status_check_period is None: + status_check_period = self.config.status_check_period + + params = {"filename": osp.basename(filename)} + url = self.api_map.make_endpoint_url(self.api.create_backup_endpoint.path) + + uploader = Uploader(self) + response = uploader.upload_file( + url, + filename, + meta=params, + query_params=params, + pbar=pbar, + logger=self._client.logger.debug, + ) + + rq_id = json.loads(response.data)["rq_id"] + response = self._client.wait_for_completion( + url, + success_status=201, + positive_statuses=[202], + post_params={"rq_id": rq_id}, + status_check_period=status_check_period, + ) + + project_id = json.loads(response.data)["id"] + self._client.logger.info(f"Project has been imported sucessfully. Project ID: {project_id}") + + return self.retrieve(project_id) diff --git a/cvat-sdk/cvat_sdk/core/proxies/tasks.py b/cvat-sdk/cvat_sdk/core/proxies/tasks.py new file mode 100644 index 00000000..b0510c78 --- /dev/null +++ b/cvat-sdk/cvat_sdk/core/proxies/tasks.py @@ -0,0 +1,388 @@ +# Copyright (C) 2022 CVAT.ai Corporation +# +# SPDX-License-Identifier: MIT + +from __future__ import annotations + +import io +import json +import mimetypes +import os +import os.path as osp +from enum import Enum +from time import sleep +from typing import Any, Dict, List, Optional, Sequence + +from PIL import Image + +from cvat_sdk.api_client import apis, exceptions, models +from cvat_sdk.core import git +from cvat_sdk.core.downloading import Downloader +from cvat_sdk.core.progress import ProgressReporter +from cvat_sdk.core.proxies.annotations import AnnotationCrudMixin +from cvat_sdk.core.proxies.jobs import Job +from cvat_sdk.core.proxies.model_proxy import ( + ModelCreateMixin, + ModelDeleteMixin, + ModelListMixin, + ModelRetrieveMixin, + ModelUpdateMixin, + build_model_bases, +) +from cvat_sdk.core.uploading import AnnotationUploader, DataUploader, Uploader +from cvat_sdk.core.utils import filter_dict + + +class ResourceType(Enum): + LOCAL = 0 + SHARE = 1 + REMOTE = 2 + + def __str__(self): + return self.name.lower() + + def __repr__(self): + return str(self) + + +_TaskEntityBase, _TaskRepoBase = build_model_bases( + models.TaskRead, apis.TasksApi, api_member_name="tasks_api" +) + + +class Task( + _TaskEntityBase, + models.ITaskRead, + ModelUpdateMixin[models.IPatchedTaskWriteRequest], + ModelDeleteMixin, + AnnotationCrudMixin, +): + _model_partial_update_arg = "patched_task_write_request" + _put_annotations_data_param = "task_annotations_update_request" + + def upload_data( + self, + resource_type: ResourceType, + resources: Sequence[str], + *, + pbar: Optional[ProgressReporter] = None, + params: Optional[Dict[str, Any]] = None, + ) -> None: + """ + Add local, remote, or shared files to an existing task. + """ + params = params or {} + + data = {} + if resource_type is ResourceType.LOCAL: + pass # handled later + elif resource_type is ResourceType.REMOTE: + data = {f"remote_files[{i}]": f for i, f in enumerate(resources)} + elif resource_type is ResourceType.SHARE: + data = {f"server_files[{i}]": f for i, f in enumerate(resources)} + + data["image_quality"] = 70 + data.update( + filter_dict( + params, + keep=[ + "chunk_size", + "copy_data", + "image_quality", + "sorting_method", + "start_frame", + "stop_frame", + "use_cache", + "use_zip_chunks", + ], + ) + ) + if params.get("frame_step") is not None: + data["frame_filter"] = f"step={params.get('frame_step')}" + + if resource_type in [ResourceType.REMOTE, ResourceType.SHARE]: + self.api.create_data( + self.id, + data_request=models.DataRequest(**data), + _content_type="multipart/form-data", + ) + elif resource_type == ResourceType.LOCAL: + url = self._client.api_map.make_endpoint_url( + self.api.create_data_endpoint.path, kwsub={"id": self.id} + ) + + DataUploader(self._client).upload_files(url, resources, pbar=pbar, **data) + + def import_annotations( + self, + format_name: str, + filename: str, + *, + status_check_period: Optional[int] = None, + pbar: Optional[ProgressReporter] = None, + ): + """ + Upload annotations for a task in the specified format (e.g. 'YOLO ZIP 1.0'). + """ + + AnnotationUploader(self._client).upload_file_and_wait( + self.api.create_annotations_endpoint, + filename, + format_name, + url_params={"id": self.id}, + pbar=pbar, + status_check_period=status_check_period, + ) + + self._client.logger.info(f"Annotation file '{filename}' for task #{self.id} uploaded") + + def get_frame( + self, + frame_id: int, + *, + quality: Optional[str] = None, + ) -> io.RawIOBase: + params = {} + if quality: + params["quality"] = quality + (_, response) = self.api.retrieve_data(self.id, number=frame_id, **params, type="frame") + return io.BytesIO(response.data) + + def get_preview( + self, + ) -> io.RawIOBase: + (_, response) = self.api.retrieve_data(self.id, type="preview") + return io.BytesIO(response.data) + + def download_frames( + self, + frame_ids: Sequence[int], + *, + outdir: str = "", + quality: str = "original", + filename_pattern: str = "frame_{frame_id:06d}{frame_ext}", + ) -> Optional[List[Image.Image]]: + """ + Download the requested frame numbers for a task and save images as outdir/filename_pattern + """ + # TODO: add arg descriptions in schema + os.makedirs(outdir, exist_ok=True) + + for frame_id in frame_ids: + frame_bytes = self.get_frame(frame_id, quality=quality) + + im = Image.open(frame_bytes) + mime_type = im.get_format_mimetype() or "image/jpg" + im_ext = mimetypes.guess_extension(mime_type) + + # FIXME It is better to use meta information from the server + # to determine the extension + # replace '.jpe' or '.jpeg' with a more used '.jpg' + if im_ext in (".jpe", ".jpeg", None): + im_ext = ".jpg" + + outfile = filename_pattern.format(frame_id=frame_id, frame_ext=im_ext) + im.save(osp.join(outdir, outfile)) + + def export_dataset( + self, + format_name: str, + filename: str, + *, + pbar: Optional[ProgressReporter] = None, + status_check_period: Optional[int] = None, + include_images: bool = True, + ) -> None: + """ + Download annotations for a task in the specified format (e.g. 'YOLO ZIP 1.0'). + """ + if include_images: + endpoint = self.api.retrieve_dataset_endpoint + else: + endpoint = self.api.retrieve_annotations_endpoint + + Downloader(self._client).prepare_and_download_file_from_endpoint( + endpoint=endpoint, + filename=filename, + url_params={"id": self.id}, + query_params={"format": format_name}, + pbar=pbar, + status_check_period=status_check_period, + ) + + self._client.logger.info(f"Dataset for task {self.id} has been downloaded to {filename}") + + def download_backup( + self, + filename: str, + *, + status_check_period: int = None, + pbar: Optional[ProgressReporter] = None, + ) -> None: + """ + Download a task backup + """ + + Downloader(self._client).prepare_and_download_file_from_endpoint( + self.api.retrieve_backup_endpoint, + filename=filename, + pbar=pbar, + status_check_period=status_check_period, + url_params={"id": self.id}, + ) + + self._client.logger.info(f"Backup for task {self.id} has been downloaded to {filename}") + + def get_jobs(self) -> List[Job]: + return [Job(self._client, m) for m in self.api.list_jobs(id=self.id)[0]] + + def get_meta(self) -> models.IDataMetaRead: + (meta, _) = self.api.retrieve_data_meta(self.id) + return meta + + def get_frames_info(self) -> List[models.IFrameMeta]: + return self.get_meta().frames + + def remove_frames_by_ids(self, ids: Sequence[int]) -> None: + self.api.partial_update_data_meta( + self.id, + patched_data_meta_write_request=models.PatchedDataMetaWriteRequest(deleted_frames=ids), + ) + + +class TasksRepo( + _TaskRepoBase, + ModelCreateMixin[Task, models.ITaskWriteRequest], + ModelRetrieveMixin[Task], + ModelListMixin[Task], + ModelDeleteMixin, +): + _entity_type = Task + + def create_from_data( + self, + spec: models.ITaskWriteRequest, + resource_type: ResourceType, + resources: Sequence[str], + *, + data_params: Optional[Dict[str, Any]] = None, + annotation_path: str = "", + annotation_format: str = "CVAT XML 1.1", + status_check_period: int = None, + dataset_repository_url: str = "", + use_lfs: bool = False, + pbar: Optional[ProgressReporter] = None, + ) -> Task: + """ + Create a new task with the given name and labels JSON and + add the files to it. + + Returns: id of the created task + """ + if status_check_period is None: + status_check_period = self._client.config.status_check_period + + if getattr(spec, "project_id", None) and getattr(spec, "labels", None): + raise exceptions.ApiValueError( + "Can't set labels to a task inside a project. " + "Tasks inside a project use project's labels.", + ["labels"], + ) + + task = self.create(spec=spec) + self._client.logger.info("Created task ID: %s NAME: %s", task.id, task.name) + + task.upload_data(resource_type, resources, pbar=pbar, params=data_params) + + self._client.logger.info("Awaiting for task %s creation...", task.id) + status: models.RqStatus = None + while status != models.RqStatusStateEnum.allowed_values[("value",)]["FINISHED"]: + sleep(status_check_period) + (status, response) = self.api.retrieve_status(task.id) + + self._client.logger.info( + "Task %s creation status=%s, message=%s", + task.id, + status.state.value, + status.message, + ) + + if status.state.value == models.RqStatusStateEnum.allowed_values[("value",)]["FAILED"]: + raise exceptions.ApiException( + status=status.state.value, reason=status.message, http_resp=response + ) + + status = status.state.value + + if annotation_path: + task.import_annotations(annotation_format, annotation_path, pbar=pbar) + + if dataset_repository_url: + git.create_git_repo( + self, + task_id=task.id, + repo_url=dataset_repository_url, + status_check_period=status_check_period, + use_lfs=use_lfs, + ) + + task.fetch() + + return task + + def remove_by_ids(self, task_ids: Sequence[int]) -> None: + """ + Delete a list of tasks, ignoring those which don't exist. + """ + + for task_id in task_ids: + (_, response) = self.api.destroy(task_id, _check_status=False) + + if 200 <= response.status <= 299: + self._client.logger.info(f"Task ID {task_id} deleted") + elif response.status == 404: + self._client.logger.info(f"Task ID {task_id} not found") + else: + self._client.logger.warning( + f"Failed to delete task ID {task_id}: " + f"{response.msg} (status {response.status})" + ) + + def create_from_backup( + self, + filename: str, + *, + status_check_period: int = None, + pbar: Optional[ProgressReporter] = None, + ) -> Task: + """ + Import a task from a backup file + """ + if status_check_period is None: + status_check_period = self._client.config.status_check_period + + params = {"filename": osp.basename(filename)} + url = self._client.api_map.make_endpoint_url(self.api.create_backup_endpoint.path) + uploader = Uploader(self._client) + response = uploader.upload_file( + url, + filename, + meta=params, + query_params=params, + pbar=pbar, + logger=self._client.logger.debug, + ) + + rq_id = json.loads(response.data)["rq_id"] + response = self._client.wait_for_completion( + url, + success_status=201, + positive_statuses=[202], + post_params={"rq_id": rq_id}, + status_check_period=status_check_period, + ) + + task_id = json.loads(response.data)["id"] + self._client.logger.info(f"Task has been imported sucessfully. Task ID: {task_id}") + + return self.retrieve(task_id) diff --git a/cvat-sdk/cvat_sdk/core/proxies/users.py b/cvat-sdk/cvat_sdk/core/proxies/users.py new file mode 100644 index 00000000..bc3a69bc --- /dev/null +++ b/cvat-sdk/cvat_sdk/core/proxies/users.py @@ -0,0 +1,35 @@ +# Copyright (C) 2022 CVAT.ai Corporation +# +# SPDX-License-Identifier: MIT + +from __future__ import annotations + +from cvat_sdk.api_client import apis, models +from cvat_sdk.core.proxies.model_proxy import ( + ModelDeleteMixin, + ModelListMixin, + ModelRetrieveMixin, + ModelUpdateMixin, + build_model_bases, +) + +_UserEntityBase, _UserRepoBase = build_model_bases( + models.User, apis.UsersApi, api_member_name="users_api" +) + + +class User( + models.IUser, _UserEntityBase, ModelUpdateMixin[models.IPatchedUserRequest], ModelDeleteMixin +): + _model_partial_update_arg = "patched_user_request" + + +class UsersRepo( + _UserRepoBase, + ModelListMixin[User], + ModelRetrieveMixin[User], +): + _entity_type = User + + def retrieve_current_user(self) -> User: + return User(self._client, self.api.retrieve_self()[0]) diff --git a/cvat-sdk/cvat_sdk/core/tasks.py b/cvat-sdk/cvat_sdk/core/tasks.py deleted file mode 100644 index 33e6b965..00000000 --- a/cvat-sdk/cvat_sdk/core/tasks.py +++ /dev/null @@ -1,308 +0,0 @@ -# Copyright (C) 2020-2022 Intel Corporation -# Copyright (C) 2022 CVAT.ai Corporation -# -# SPDX-License-Identifier: MIT - -from __future__ import annotations - -import io -import mimetypes -import os -import os.path as osp -from abc import ABC, abstractmethod -from io import BytesIO -from time import sleep -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence - -from PIL import Image - -from cvat_sdk import models -from cvat_sdk.api_client.model_utils import OpenApiModel -from cvat_sdk.core.downloading import Downloader -from cvat_sdk.core.progress import ProgressReporter -from cvat_sdk.core.types import ResourceType -from cvat_sdk.core.uploading import Uploader -from cvat_sdk.core.utils import filter_dict - -if TYPE_CHECKING: - from cvat_sdk.core.client import Client - - -class ModelProxy(ABC): - _client: Client - _model: OpenApiModel - - def __init__(self, client: Client, model: OpenApiModel) -> None: - self.__dict__["_client"] = client - self.__dict__["_model"] = model - - def __getattr__(self, __name: str) -> Any: - return self._model[__name] - - def __setattr__(self, __name: str, __value: Any) -> None: - if __name in self.__dict__: - self.__dict__[__name] = __value - else: - self._model[__name] = __value - - @abstractmethod - def fetch(self, force: bool = False): - """Fetches model data from the server""" - ... - - @abstractmethod - def commit(self, force: bool = False): - """Commits local changes to the server""" - ... - - def sync(self): - """Pulls server state and commits local model changes""" - raise NotImplementedError - - @abstractmethod - def update(self, **kwargs): - """Updates multiple fields at once""" - ... - - -class TaskProxy(ModelProxy, models.ITaskRead): - def __init__(self, client: Client, task: models.TaskRead): - ModelProxy.__init__(self, client=client, model=task) - - def remove(self): - self._client.api.tasks_api.destroy(self.id) - - def upload_data( - self, - resource_type: ResourceType, - resources: Sequence[str], - *, - pbar: Optional[ProgressReporter] = None, - params: Optional[Dict[str, Any]] = None, - ) -> None: - """ - Add local, remote, or shared files to an existing task. - """ - client = self._client - task_id = self.id - - params = params or {} - - data = {} - if resource_type is ResourceType.LOCAL: - pass # handled later - elif resource_type is ResourceType.REMOTE: - data = {f"remote_files[{i}]": f for i, f in enumerate(resources)} - elif resource_type is ResourceType.SHARE: - data = {f"server_files[{i}]": f for i, f in enumerate(resources)} - - data["image_quality"] = 70 - data.update( - filter_dict( - params, - keep=[ - "chunk_size", - "copy_data", - "image_quality", - "sorting_method", - "start_frame", - "stop_frame", - "use_cache", - "use_zip_chunks", - ], - ) - ) - if params.get("frame_step") is not None: - data["frame_filter"] = f"step={params.get('frame_step')}" - - if resource_type in [ResourceType.REMOTE, ResourceType.SHARE]: - client.api.tasks_api.create_data( - task_id, - data_request=models.DataRequest(**data), - _content_type="multipart/form-data", - ) - elif resource_type == ResourceType.LOCAL: - url = client._api_map.make_endpoint_url( - client.api.tasks_api.create_data_endpoint.path, kwsub={"id": task_id} - ) - - uploader = Uploader(client) - uploader.upload_files(url, resources, pbar=pbar, **data) - - def import_annotations( - self, - format_name: str, - filename: str, - *, - status_check_period: int = None, - pbar: Optional[ProgressReporter] = None, - ): - """ - Upload annotations for a task in the specified format - (e.g. 'YOLO ZIP 1.0'). - """ - client = self._client - if status_check_period is None: - status_check_period = client.config.status_check_period - - task_id = self.id - - url = client._api_map.make_endpoint_url( - client.api.tasks_api.create_annotations_endpoint.path, - kwsub={"id": task_id}, - ) - params = {"format": format_name, "filename": osp.basename(filename)} - - uploader = Uploader(client) - uploader.upload_file( - url, filename, pbar=pbar, query_params=params, meta={"filename": params["filename"]} - ) - - while True: - response = client.api.rest_client.POST( - url, headers=client.api.get_common_headers(), query_params=params - ) - if response.status == 201: - break - - sleep(status_check_period) - - client.logger.info( - f"Upload job for Task ID {task_id} with annotation file {filename} finished" - ) - - def retrieve_frame( - self, - frame_id: int, - *, - quality: Optional[str] = None, - ) -> io.RawIOBase: - client = self._client - task_id = self.id - - (_, response) = client.api.tasks_api.retrieve_data(task_id, frame_id, quality, type="frame") - - return BytesIO(response.data) - - def download_frames( - self, - frame_ids: Sequence[int], - *, - outdir: str = "", - quality: str = "original", - filename_pattern: str = "task_{task_id}_frame_{frame_id:06d}{frame_ext}", - ) -> Optional[List[Image.Image]]: - """ - Download the requested frame numbers for a task and save images as - outdir/filename_pattern - """ - # TODO: add arg descriptions in schema - task_id = self.id - - os.makedirs(outdir, exist_ok=True) - - for frame_id in frame_ids: - frame_bytes = self.retrieve_frame(frame_id, quality=quality) - - im = Image.open(frame_bytes) - mime_type = im.get_format_mimetype() or "image/jpg" - im_ext = mimetypes.guess_extension(mime_type) - - # FIXME It is better to use meta information from the server - # to determine the extension - # replace '.jpe' or '.jpeg' with a more used '.jpg' - if im_ext in (".jpe", ".jpeg", None): - im_ext = ".jpg" - - outfile = filename_pattern.format(task_id=task_id, frame_id=frame_id, frame_ext=im_ext) - im.save(osp.join(outdir, outfile)) - - def export_dataset( - self, - format_name: str, - filename: str, - *, - pbar: Optional[ProgressReporter] = None, - status_check_period: int = None, - include_images: bool = True, - ) -> None: - """ - Download annotations for a task in the specified format (e.g. 'YOLO ZIP 1.0'). - """ - client = self._client - if status_check_period is None: - status_check_period = client.config.status_check_period - - task_id = self.id - - params = {"filename": self.name, "format": format_name} - if include_images: - endpoint = client.api.tasks_api.retrieve_dataset_endpoint - else: - endpoint = client.api.tasks_api.retrieve_annotations_endpoint - - client.logger.info("Waiting for the server to prepare the file...") - while True: - (_, response) = endpoint.call_with_http_info(id=task_id, **params) - client.logger.debug("STATUS {}".format(response.status)) - if response.status == 201: - break - sleep(status_check_period) - - params["action"] = "download" - url = client._api_map.make_endpoint_url( - endpoint.path, kwsub={"id": task_id}, query_params=params - ) - downloader = Downloader(client) - downloader.download_file(url, output_path=filename, pbar=pbar) - - client.logger.info(f"Dataset has been exported to {filename}") - - def download_backup( - self, - filename: str, - *, - status_check_period: int = None, - pbar: Optional[ProgressReporter] = None, - ): - """ - Download a task backup - """ - client = self._client - if status_check_period is None: - status_check_period = client.config.status_check_period - - task_id = self.id - - endpoint = client.api.tasks_api.retrieve_backup_endpoint - client.logger.info("Waiting for the server to prepare the file...") - while True: - (_, response) = endpoint.call_with_http_info(id=task_id) - client.logger.debug("STATUS {}".format(response.status)) - if response.status == 201: - break - sleep(status_check_period) - - url = client._api_map.make_endpoint_url( - endpoint.path, kwsub={"id": task_id}, query_params={"action": "download"} - ) - downloader = Downloader(client) - downloader.download_file(url, output_path=filename, pbar=pbar) - - client.logger.info( - f"Task {task_id} has been exported sucessfully to {osp.abspath(filename)}" - ) - - def fetch(self, force: bool = False): - # TODO: implement revision checking - model, _ = self._client.api.tasks_api.retrieve(self.id) - self._model = model - - def commit(self, force: bool = False): - return super().commit(force) - - def update(self, **kwargs): - return super().update(**kwargs) - - def __str__(self) -> str: - return str(self._model) diff --git a/cvat-sdk/cvat_sdk/core/types.py b/cvat-sdk/cvat_sdk/core/types.py deleted file mode 100644 index f1d8d89e..00000000 --- a/cvat-sdk/cvat_sdk/core/types.py +++ /dev/null @@ -1,18 +0,0 @@ -# Copyright (C) 2022 Intel Corporation -# Copyright (C) 2022 CVAT.ai Corporation -# -# SPDX-License-Identifier: MIT - -from enum import Enum - - -class ResourceType(Enum): - LOCAL = 0 - SHARE = 1 - REMOTE = 2 - - def __str__(self): - return self.name.lower() - - def __repr__(self): - return str(self) diff --git a/cvat-sdk/cvat_sdk/core/uploading.py b/cvat-sdk/cvat_sdk/core/uploading.py index 9b8c0b72..93d4764e 100644 --- a/cvat-sdk/cvat_sdk/core/uploading.py +++ b/cvat-sdk/cvat_sdk/core/uploading.py @@ -7,16 +7,15 @@ from __future__ import annotations import os import os.path as osp from contextlib import ExitStack, closing -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple import requests import urllib3 -from cvat_sdk.api_client import ApiClient +from cvat_sdk.api_client.api_client import ApiClient, Endpoint from cvat_sdk.api_client.rest import RESTClientObject -from cvat_sdk.core.helpers import StreamWithProgress +from cvat_sdk.core.helpers import StreamWithProgress, expect_status from cvat_sdk.core.progress import ProgressReporter -from cvat_sdk.core.utils import assert_status if TYPE_CHECKING: from cvat_sdk.core.client import Client @@ -25,57 +24,12 @@ MAX_REQUEST_SIZE = 100 * 2**20 class Uploader: - def __init__(self, client: Client): - self.client = client - - def upload_files( - self, - url: str, - resources: List[str], - *, - pbar: Optional[ProgressReporter] = None, - **kwargs, - ): - bulk_file_groups, separate_files, total_size = self._split_files_by_requests(resources) - - if pbar is not None: - pbar.start(total_size, desc="Uploading data") - - self._tus_start_upload(url) + """ + Implements common uploading protocols + """ - for group, group_size in bulk_file_groups: - with ExitStack() as es: - files = {} - for i, filename in enumerate(group): - files[f"client_files[{i}]"] = ( - filename, - es.enter_context(closing(open(filename, "rb"))).read(), - ) - response = self.client.api.rest_client.POST( - url, - post_params=dict(**kwargs, **files), - headers={ - "Content-Type": "multipart/form-data", - "Upload-Multiple": "", - **self.client.api.get_common_headers(), - }, - ) - assert_status(200, response) - - if pbar is not None: - pbar.advance(group_size) - - for filename in separate_files: - # TODO: check if basename produces invalid paths here, can lead to overwriting - self._upload_file_data_with_tus( - url, - filename, - meta={"filename": osp.basename(filename)}, - pbar=pbar, - logger=self.client.logger.debug, - ) - - self._tus_finish_upload(url, fields=kwargs) + def __init__(self, client: Client): + self._client = client def upload_file( self, @@ -121,6 +75,27 @@ class Uploader: ) return self._tus_finish_upload(url, query_params=query_params, fields=fields) + def _wait_for_completion( + self, + url: str, + *, + success_status: int, + status_check_period: Optional[int] = None, + query_params: Optional[Dict[str, Any]] = None, + post_params: Optional[Dict[str, Any]] = None, + method: str = "POST", + positive_statuses: Optional[Sequence[int]] = None, + ) -> urllib3.HTTPResponse: + return self._client.wait_for_completion( + url, + success_status=success_status, + status_check_period=status_check_period, + query_params=query_params, + post_params=post_params, + method=method, + positive_statuses=positive_statuses, + ) + def _split_files_by_requests( self, filenames: List[str] ) -> Tuple[List[Tuple[List[str], int]], List[str], int]: @@ -268,7 +243,7 @@ class Uploader: input_file = StreamWithProgress(input_file, pbar, length=file_size) tus_uploader = self._make_tus_uploader( - self.client.api, + self._client.api, url=url.rstrip("/") + "/", metadata=meta, file_stream=input_file, @@ -278,26 +253,131 @@ class Uploader: tus_uploader.upload() def _tus_start_upload(self, url, *, query_params=None): - response = self.client.api.rest_client.POST( + response = self._client.api.rest_client.POST( url, query_params=query_params, headers={ "Upload-Start": "", - **self.client.api.get_common_headers(), + **self._client.api.get_common_headers(), }, ) - assert_status(202, response) + expect_status(202, response) return response def _tus_finish_upload(self, url, *, query_params=None, fields=None): - response = self.client.api.rest_client.POST( + response = self._client.api.rest_client.POST( url, headers={ "Upload-Finish": "", - **self.client.api.get_common_headers(), + **self._client.api.get_common_headers(), }, query_params=query_params, post_params=fields, ) - assert_status(202, response) + expect_status(202, response) return response + + +class AnnotationUploader(Uploader): + def upload_file_and_wait( + self, + endpoint: Endpoint, + filename: str, + format_name: str, + *, + url_params: Optional[Dict[str, Any]] = None, + pbar: Optional[ProgressReporter] = None, + status_check_period: Optional[int] = None, + ): + url = self._client.api_map.make_endpoint_url(endpoint.path, kwsub=url_params) + params = {"format": format_name, "filename": osp.basename(filename)} + self.upload_file( + url, filename, pbar=pbar, query_params=params, meta={"filename": params["filename"]} + ) + + self._wait_for_completion( + url, + success_status=201, + positive_statuses=[202], + status_check_period=status_check_period, + query_params=params, + method="POST", + ) + + +class DatasetUploader(Uploader): + def upload_file_and_wait( + self, + endpoint: Endpoint, + filename: str, + format_name: str, + *, + url_params: Optional[Dict[str, Any]] = None, + pbar: Optional[ProgressReporter] = None, + status_check_period: Optional[int] = None, + ): + url = self._client.api_map.make_endpoint_url(endpoint.path, kwsub=url_params) + params = {"format": format_name, "filename": osp.basename(filename)} + self.upload_file( + url, filename, pbar=pbar, query_params=params, meta={"filename": params["filename"]} + ) + + self._wait_for_completion( + url, + success_status=201, + positive_statuses=[202], + status_check_period=status_check_period, + query_params=params, + method="GET", + ) + + +class DataUploader(Uploader): + def upload_files( + self, + url: str, + resources: List[str], + *, + pbar: Optional[ProgressReporter] = None, + **kwargs, + ): + bulk_file_groups, separate_files, total_size = self._split_files_by_requests(resources) + + if pbar is not None: + pbar.start(total_size, desc="Uploading data") + + self._tus_start_upload(url) + + for group, group_size in bulk_file_groups: + with ExitStack() as es: + files = {} + for i, filename in enumerate(group): + files[f"client_files[{i}]"] = ( + filename, + es.enter_context(closing(open(filename, "rb"))).read(), + ) + response = self._client.api.rest_client.POST( + url, + post_params=dict(**kwargs, **files), + headers={ + "Content-Type": "multipart/form-data", + "Upload-Multiple": "", + **self._client.api.get_common_headers(), + }, + ) + expect_status(200, response) + + if pbar is not None: + pbar.advance(group_size) + + for filename in separate_files: + # TODO: check if basename produces invalid paths here, can lead to overwriting + self._upload_file_data_with_tus( + url, + filename, + meta={"filename": osp.basename(filename)}, + pbar=pbar, + logger=self._client.logger.debug, + ) + + self._tus_finish_upload(url, fields=kwargs) diff --git a/cvat-sdk/cvat_sdk/core/utils.py b/cvat-sdk/cvat_sdk/core/utils.py index d931e2f8..407b6d3e 100644 --- a/cvat-sdk/cvat_sdk/core/utils.py +++ b/cvat-sdk/cvat_sdk/core/utils.py @@ -1,4 +1,3 @@ -# Copyright (C) 2022 Intel Corporation # Copyright (C) 2022 CVAT.ai Corporation # # SPDX-License-Identifier: MIT @@ -7,13 +6,6 @@ from __future__ import annotations from typing import Any, Dict, Sequence -import urllib3 - - -def assert_status(code: int, response: urllib3.HTTPResponse) -> None: - if response.status != code: - raise Exception(f"Unexpected status code received {response.status}") - def filter_dict( d: Dict[str, Any], *, keep: Sequence[str] = None, drop: Sequence[str] = None diff --git a/cvat-sdk/gen/postprocess.py b/cvat-sdk/gen/postprocess.py index 51661520..8e2476a4 100755 --- a/cvat-sdk/gen/postprocess.py +++ b/cvat-sdk/gen/postprocess.py @@ -48,7 +48,7 @@ class Processor: tokenized_path = tokenized_path[2:] prefix = tokenized_path[0] + "_" - if new_name.startswith(prefix): + if new_name.startswith(prefix) and tokenized_path[0] in operation["tags"]: new_name = new_name[len(prefix) :] return new_name diff --git a/cvat-sdk/gen/templates/openapi-generator/api_client.mustache b/cvat-sdk/gen/templates/openapi-generator/api_client.mustache index 537cd1cc..0cf16488 100644 --- a/cvat-sdk/gen/templates/openapi-generator/api_client.mustache +++ b/cvat-sdk/gen/templates/openapi-generator/api_client.mustache @@ -345,6 +345,9 @@ class ApiClient(object): """ if response_schema == (file_type,): + # TODO: response schema can be "oneOf" with a file option, + # this implementation does not cover this. + # handle file downloading # save response body into a tmp file and return the instance content_disposition = response.getheader("Content-Disposition") diff --git a/cvat-sdk/gen/templates/openapi-generator/model.mustache b/cvat-sdk/gen/templates/openapi-generator/model.mustache index 75f75645..3a063f33 100644 --- a/cvat-sdk/gen/templates/openapi-generator/model.mustache +++ b/cvat-sdk/gen/templates/openapi-generator/model.mustache @@ -9,6 +9,7 @@ import sys # noqa: F401 from {{packageName}}.model_utils import ( # noqa: F401 ApiTypeError, + IModelData, ModelComposed, ModelNormal, ModelSimple, diff --git a/cvat-sdk/gen/templates/openapi-generator/model_templates/model_normal.mustache b/cvat-sdk/gen/templates/openapi-generator/model_templates/model_normal.mustache index 06eae9fa..73ef8d59 100644 --- a/cvat-sdk/gen/templates/openapi-generator/model_templates/model_normal.mustache +++ b/cvat-sdk/gen/templates/openapi-generator/model_templates/model_normal.mustache @@ -1,5 +1,5 @@ -class I{{classname}}: +class I{{classname}}(IModelData): """ NOTE: This class is auto generated by OpenAPI Generator. Ref: https://openapi-generator.tech diff --git a/cvat-sdk/gen/templates/openapi-generator/model_templates/model_simple.mustache b/cvat-sdk/gen/templates/openapi-generator/model_templates/model_simple.mustache index de03cd5b..66487bf4 100644 --- a/cvat-sdk/gen/templates/openapi-generator/model_templates/model_simple.mustache +++ b/cvat-sdk/gen/templates/openapi-generator/model_templates/model_simple.mustache @@ -1,5 +1,5 @@ -class I{{classname}}: +class I{{classname}}(IModelData): """ NOTE: This class is auto generated by OpenAPI Generator. Ref: https://openapi-generator.tech diff --git a/cvat-sdk/gen/templates/openapi-generator/model_utils.mustache b/cvat-sdk/gen/templates/openapi-generator/model_utils.mustache index 3b3cbd1c..09887e41 100644 --- a/cvat-sdk/gen/templates/openapi-generator/model_utils.mustache +++ b/cvat-sdk/gen/templates/openapi-generator/model_utils.mustache @@ -113,6 +113,11 @@ def composed_model_input_classes(cls): return [] +class IModelData: + """ + The base class for model data. Declares model fields and their types for better introspection + """ + class OpenApiModel(object): """The base class for all OpenAPIModels""" diff --git a/cvat-sdk/gen/templates/requirements/base.txt b/cvat-sdk/gen/templates/requirements/base.txt index 77c2adea..695db060 100644 --- a/cvat-sdk/gen/templates/requirements/base.txt +++ b/cvat-sdk/gen/templates/requirements/base.txt @@ -3,3 +3,4 @@ attrs >= 21.4.0 tqdm >= 4.64.0 tuspy == 0.2.5 # have it pinned, because SDK has lots of patched TUS code +typing_extensions >= 4.2.0 diff --git a/cvat-ui/package.json b/cvat-ui/package.json index 9d2803a5..9245ca36 100644 --- a/cvat-ui/package.json +++ b/cvat-ui/package.json @@ -1,6 +1,6 @@ { "name": "cvat-ui", - "version": "1.41.0", + "version": "1.41.1", "description": "CVAT single-page application", "main": "src/index.tsx", "scripts": { diff --git a/cvat-ui/src/components/tasks-page/tasks-page.tsx b/cvat-ui/src/components/tasks-page/tasks-page.tsx index d537b77c..c4d0e4c3 100644 --- a/cvat-ui/src/components/tasks-page/tasks-page.tsx +++ b/cvat-ui/src/components/tasks-page/tasks-page.tsx @@ -70,7 +70,7 @@ function TasksPageComponent(props: Props): JSX.Element {