# Copyright (C) 2022 CVAT.ai Corporation # # SPDX-License-Identifier: MIT from __future__ import annotations import io import json import mimetypes import os import os.path as osp from enum import Enum from time import sleep from typing import Any, Dict, List, Optional, Sequence from PIL import Image from cvat_sdk.api_client import apis, exceptions, models from cvat_sdk.core import git from cvat_sdk.core.downloading import Downloader from cvat_sdk.core.progress import ProgressReporter from cvat_sdk.core.proxies.annotations import AnnotationCrudMixin from cvat_sdk.core.proxies.jobs import Job from cvat_sdk.core.proxies.model_proxy import ( ModelCreateMixin, ModelDeleteMixin, ModelListMixin, ModelRetrieveMixin, ModelUpdateMixin, build_model_bases, ) from cvat_sdk.core.uploading import AnnotationUploader, DataUploader, Uploader from cvat_sdk.core.utils import filter_dict class ResourceType(Enum): LOCAL = 0 SHARE = 1 REMOTE = 2 def __str__(self): return self.name.lower() def __repr__(self): return str(self) _TaskEntityBase, _TaskRepoBase = build_model_bases( models.TaskRead, apis.TasksApi, api_member_name="tasks_api" ) class Task( _TaskEntityBase, models.ITaskRead, ModelUpdateMixin[models.IPatchedTaskWriteRequest], ModelDeleteMixin, AnnotationCrudMixin, ): _model_partial_update_arg = "patched_task_write_request" _put_annotations_data_param = "task_annotations_update_request" def upload_data( self, resource_type: ResourceType, resources: Sequence[str], *, pbar: Optional[ProgressReporter] = None, params: Optional[Dict[str, Any]] = None, ) -> None: """ Add local, remote, or shared files to an existing task. """ params = params or {} data = {} if resource_type is ResourceType.LOCAL: pass # handled later elif resource_type is ResourceType.REMOTE: data = {f"remote_files[{i}]": f for i, f in enumerate(resources)} elif resource_type is ResourceType.SHARE: data = {f"server_files[{i}]": f for i, f in enumerate(resources)} data["image_quality"] = 70 data.update( filter_dict( params, keep=[ "chunk_size", "copy_data", "image_quality", "sorting_method", "start_frame", "stop_frame", "use_cache", "use_zip_chunks", ], ) ) if params.get("frame_step") is not None: data["frame_filter"] = f"step={params.get('frame_step')}" if resource_type in [ResourceType.REMOTE, ResourceType.SHARE]: self.api.create_data( self.id, data_request=models.DataRequest(**data), _content_type="multipart/form-data", ) elif resource_type == ResourceType.LOCAL: url = self._client.api_map.make_endpoint_url( self.api.create_data_endpoint.path, kwsub={"id": self.id} ) DataUploader(self._client).upload_files(url, resources, pbar=pbar, **data) def import_annotations( self, format_name: str, filename: str, *, status_check_period: Optional[int] = None, pbar: Optional[ProgressReporter] = None, ): """ Upload annotations for a task in the specified format (e.g. 'YOLO ZIP 1.0'). """ AnnotationUploader(self._client).upload_file_and_wait( self.api.create_annotations_endpoint, filename, format_name, url_params={"id": self.id}, pbar=pbar, status_check_period=status_check_period, ) self._client.logger.info(f"Annotation file '{filename}' for task #{self.id} uploaded") def get_frame( self, frame_id: int, *, quality: Optional[str] = None, ) -> io.RawIOBase: params = {} if quality: params["quality"] = quality (_, response) = self.api.retrieve_data(self.id, number=frame_id, **params, type="frame") return io.BytesIO(response.data) def get_preview( self, ) -> io.RawIOBase: (_, response) = self.api.retrieve_data(self.id, type="preview") return io.BytesIO(response.data) def download_frames( self, frame_ids: Sequence[int], *, outdir: str = "", quality: str = "original", filename_pattern: str = "frame_{frame_id:06d}{frame_ext}", ) -> Optional[List[Image.Image]]: """ Download the requested frame numbers for a task and save images as outdir/filename_pattern """ # TODO: add arg descriptions in schema os.makedirs(outdir, exist_ok=True) for frame_id in frame_ids: frame_bytes = self.get_frame(frame_id, quality=quality) im = Image.open(frame_bytes) mime_type = im.get_format_mimetype() or "image/jpg" im_ext = mimetypes.guess_extension(mime_type) # FIXME It is better to use meta information from the server # to determine the extension # replace '.jpe' or '.jpeg' with a more used '.jpg' if im_ext in (".jpe", ".jpeg", None): im_ext = ".jpg" outfile = filename_pattern.format(frame_id=frame_id, frame_ext=im_ext) im.save(osp.join(outdir, outfile)) def export_dataset( self, format_name: str, filename: str, *, pbar: Optional[ProgressReporter] = None, status_check_period: Optional[int] = None, include_images: bool = True, ) -> None: """ Download annotations for a task in the specified format (e.g. 'YOLO ZIP 1.0'). """ if include_images: endpoint = self.api.retrieve_dataset_endpoint else: endpoint = self.api.retrieve_annotations_endpoint Downloader(self._client).prepare_and_download_file_from_endpoint( endpoint=endpoint, filename=filename, url_params={"id": self.id}, query_params={"format": format_name}, pbar=pbar, status_check_period=status_check_period, ) self._client.logger.info(f"Dataset for task {self.id} has been downloaded to {filename}") def download_backup( self, filename: str, *, status_check_period: int = None, pbar: Optional[ProgressReporter] = None, ) -> None: """ Download a task backup """ Downloader(self._client).prepare_and_download_file_from_endpoint( self.api.retrieve_backup_endpoint, filename=filename, pbar=pbar, status_check_period=status_check_period, url_params={"id": self.id}, ) self._client.logger.info(f"Backup for task {self.id} has been downloaded to {filename}") def get_jobs(self) -> List[Job]: return [Job(self._client, m) for m in self.api.list_jobs(id=self.id)[0]] def get_meta(self) -> models.IDataMetaRead: (meta, _) = self.api.retrieve_data_meta(self.id) return meta def get_frames_info(self) -> List[models.IFrameMeta]: return self.get_meta().frames def remove_frames_by_ids(self, ids: Sequence[int]) -> None: self.api.partial_update_data_meta( self.id, patched_data_meta_write_request=models.PatchedDataMetaWriteRequest(deleted_frames=ids), ) class TasksRepo( _TaskRepoBase, ModelCreateMixin[Task, models.ITaskWriteRequest], ModelRetrieveMixin[Task], ModelListMixin[Task], ModelDeleteMixin, ): _entity_type = Task def create_from_data( self, spec: models.ITaskWriteRequest, resource_type: ResourceType, resources: Sequence[str], *, data_params: Optional[Dict[str, Any]] = None, annotation_path: str = "", annotation_format: str = "CVAT XML 1.1", status_check_period: int = None, dataset_repository_url: str = "", use_lfs: bool = False, pbar: Optional[ProgressReporter] = None, ) -> Task: """ Create a new task with the given name and labels JSON and add the files to it. Returns: id of the created task """ if status_check_period is None: status_check_period = self._client.config.status_check_period if getattr(spec, "project_id", None) and getattr(spec, "labels", None): raise exceptions.ApiValueError( "Can't set labels to a task inside a project. " "Tasks inside a project use project's labels.", ["labels"], ) task = self.create(spec=spec) self._client.logger.info("Created task ID: %s NAME: %s", task.id, task.name) task.upload_data(resource_type, resources, pbar=pbar, params=data_params) self._client.logger.info("Awaiting for task %s creation...", task.id) status: models.RqStatus = None while status != models.RqStatusStateEnum.allowed_values[("value",)]["FINISHED"]: sleep(status_check_period) (status, response) = self.api.retrieve_status(task.id) self._client.logger.info( "Task %s creation status=%s, message=%s", task.id, status.state.value, status.message, ) if status.state.value == models.RqStatusStateEnum.allowed_values[("value",)]["FAILED"]: raise exceptions.ApiException( status=status.state.value, reason=status.message, http_resp=response ) status = status.state.value if annotation_path: task.import_annotations(annotation_format, annotation_path, pbar=pbar) if dataset_repository_url: git.create_git_repo( self, task_id=task.id, repo_url=dataset_repository_url, status_check_period=status_check_period, use_lfs=use_lfs, ) task.fetch() return task def remove_by_ids(self, task_ids: Sequence[int]) -> None: """ Delete a list of tasks, ignoring those which don't exist. """ for task_id in task_ids: (_, response) = self.api.destroy(task_id, _check_status=False) if 200 <= response.status <= 299: self._client.logger.info(f"Task ID {task_id} deleted") elif response.status == 404: self._client.logger.info(f"Task ID {task_id} not found") else: self._client.logger.warning( f"Failed to delete task ID {task_id}: " f"{response.msg} (status {response.status})" ) def create_from_backup( self, filename: str, *, status_check_period: int = None, pbar: Optional[ProgressReporter] = None, ) -> Task: """ Import a task from a backup file """ if status_check_period is None: status_check_period = self._client.config.status_check_period params = {"filename": osp.basename(filename)} url = self._client.api_map.make_endpoint_url(self.api.create_backup_endpoint.path) uploader = Uploader(self._client) response = uploader.upload_file( url, filename, meta=params, query_params=params, pbar=pbar, logger=self._client.logger.debug, ) rq_id = json.loads(response.data)["rq_id"] response = self._client.wait_for_completion( url, success_status=201, positive_statuses=[202], post_params={"rq_id": rq_id}, status_check_period=status_check_period, ) task_id = json.loads(response.data)["id"] self._client.logger.info(f"Task has been imported sucessfully. Task ID: {task_id}") return self.retrieve(task_id)