You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

389 lines
12 KiB
Python

# Copyright (C) 2022 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
from __future__ import annotations
import io
import json
import mimetypes
import os
import os.path as osp
from enum import Enum
from time import sleep
from typing import Any, Dict, List, Optional, Sequence
from PIL import Image
from cvat_sdk.api_client import apis, exceptions, models
from cvat_sdk.core import git
from cvat_sdk.core.downloading import Downloader
from cvat_sdk.core.progress import ProgressReporter
from cvat_sdk.core.proxies.annotations import AnnotationCrudMixin
from cvat_sdk.core.proxies.jobs import Job
from cvat_sdk.core.proxies.model_proxy import (
ModelCreateMixin,
ModelDeleteMixin,
ModelListMixin,
ModelRetrieveMixin,
ModelUpdateMixin,
build_model_bases,
)
from cvat_sdk.core.uploading import AnnotationUploader, DataUploader, Uploader
from cvat_sdk.core.utils import filter_dict
class ResourceType(Enum):
LOCAL = 0
SHARE = 1
REMOTE = 2
def __str__(self):
return self.name.lower()
def __repr__(self):
return str(self)
_TaskEntityBase, _TaskRepoBase = build_model_bases(
models.TaskRead, apis.TasksApi, api_member_name="tasks_api"
)
class Task(
_TaskEntityBase,
models.ITaskRead,
ModelUpdateMixin[models.IPatchedTaskWriteRequest],
ModelDeleteMixin,
AnnotationCrudMixin,
):
_model_partial_update_arg = "patched_task_write_request"
_put_annotations_data_param = "task_annotations_update_request"
def upload_data(
self,
resource_type: ResourceType,
resources: Sequence[str],
*,
pbar: Optional[ProgressReporter] = None,
params: Optional[Dict[str, Any]] = None,
) -> None:
"""
Add local, remote, or shared files to an existing task.
"""
params = params or {}
data = {}
if resource_type is ResourceType.LOCAL:
pass # handled later
elif resource_type is ResourceType.REMOTE:
data = {f"remote_files[{i}]": f for i, f in enumerate(resources)}
elif resource_type is ResourceType.SHARE:
data = {f"server_files[{i}]": f for i, f in enumerate(resources)}
data["image_quality"] = 70
data.update(
filter_dict(
params,
keep=[
"chunk_size",
"copy_data",
"image_quality",
"sorting_method",
"start_frame",
"stop_frame",
"use_cache",
"use_zip_chunks",
],
)
)
if params.get("frame_step") is not None:
data["frame_filter"] = f"step={params.get('frame_step')}"
if resource_type in [ResourceType.REMOTE, ResourceType.SHARE]:
self.api.create_data(
self.id,
data_request=models.DataRequest(**data),
_content_type="multipart/form-data",
)
elif resource_type == ResourceType.LOCAL:
url = self._client.api_map.make_endpoint_url(
self.api.create_data_endpoint.path, kwsub={"id": self.id}
)
DataUploader(self._client).upload_files(url, resources, pbar=pbar, **data)
def import_annotations(
self,
format_name: str,
filename: str,
*,
status_check_period: Optional[int] = None,
pbar: Optional[ProgressReporter] = None,
):
"""
Upload annotations for a task in the specified format (e.g. 'YOLO ZIP 1.0').
"""
AnnotationUploader(self._client).upload_file_and_wait(
self.api.create_annotations_endpoint,
filename,
format_name,
url_params={"id": self.id},
pbar=pbar,
status_check_period=status_check_period,
)
self._client.logger.info(f"Annotation file '{filename}' for task #{self.id} uploaded")
def get_frame(
self,
frame_id: int,
*,
quality: Optional[str] = None,
) -> io.RawIOBase:
params = {}
if quality:
params["quality"] = quality
(_, response) = self.api.retrieve_data(self.id, number=frame_id, **params, type="frame")
return io.BytesIO(response.data)
def get_preview(
self,
) -> io.RawIOBase:
(_, response) = self.api.retrieve_data(self.id, type="preview")
return io.BytesIO(response.data)
def download_frames(
self,
frame_ids: Sequence[int],
*,
outdir: str = "",
quality: str = "original",
filename_pattern: str = "frame_{frame_id:06d}{frame_ext}",
) -> Optional[List[Image.Image]]:
"""
Download the requested frame numbers for a task and save images as outdir/filename_pattern
"""
# TODO: add arg descriptions in schema
os.makedirs(outdir, exist_ok=True)
for frame_id in frame_ids:
frame_bytes = self.get_frame(frame_id, quality=quality)
im = Image.open(frame_bytes)
mime_type = im.get_format_mimetype() or "image/jpg"
im_ext = mimetypes.guess_extension(mime_type)
# FIXME It is better to use meta information from the server
# to determine the extension
# replace '.jpe' or '.jpeg' with a more used '.jpg'
if im_ext in (".jpe", ".jpeg", None):
im_ext = ".jpg"
outfile = filename_pattern.format(frame_id=frame_id, frame_ext=im_ext)
im.save(osp.join(outdir, outfile))
def export_dataset(
self,
format_name: str,
filename: str,
*,
pbar: Optional[ProgressReporter] = None,
status_check_period: Optional[int] = None,
include_images: bool = True,
) -> None:
"""
Download annotations for a task in the specified format (e.g. 'YOLO ZIP 1.0').
"""
if include_images:
endpoint = self.api.retrieve_dataset_endpoint
else:
endpoint = self.api.retrieve_annotations_endpoint
Downloader(self._client).prepare_and_download_file_from_endpoint(
endpoint=endpoint,
filename=filename,
url_params={"id": self.id},
query_params={"format": format_name},
pbar=pbar,
status_check_period=status_check_period,
)
self._client.logger.info(f"Dataset for task {self.id} has been downloaded to {filename}")
def download_backup(
self,
filename: str,
*,
status_check_period: int = None,
pbar: Optional[ProgressReporter] = None,
) -> None:
"""
Download a task backup
"""
Downloader(self._client).prepare_and_download_file_from_endpoint(
self.api.retrieve_backup_endpoint,
filename=filename,
pbar=pbar,
status_check_period=status_check_period,
url_params={"id": self.id},
)
self._client.logger.info(f"Backup for task {self.id} has been downloaded to {filename}")
def get_jobs(self) -> List[Job]:
return [Job(self._client, m) for m in self.api.list_jobs(id=self.id)[0]]
def get_meta(self) -> models.IDataMetaRead:
(meta, _) = self.api.retrieve_data_meta(self.id)
return meta
def get_frames_info(self) -> List[models.IFrameMeta]:
return self.get_meta().frames
def remove_frames_by_ids(self, ids: Sequence[int]) -> None:
self.api.partial_update_data_meta(
self.id,
patched_data_meta_write_request=models.PatchedDataMetaWriteRequest(deleted_frames=ids),
)
class TasksRepo(
_TaskRepoBase,
ModelCreateMixin[Task, models.ITaskWriteRequest],
ModelRetrieveMixin[Task],
ModelListMixin[Task],
ModelDeleteMixin,
):
_entity_type = Task
def create_from_data(
self,
spec: models.ITaskWriteRequest,
resource_type: ResourceType,
resources: Sequence[str],
*,
data_params: Optional[Dict[str, Any]] = None,
annotation_path: str = "",
annotation_format: str = "CVAT XML 1.1",
status_check_period: int = None,
dataset_repository_url: str = "",
use_lfs: bool = False,
pbar: Optional[ProgressReporter] = None,
) -> Task:
"""
Create a new task with the given name and labels JSON and
add the files to it.
Returns: id of the created task
"""
if status_check_period is None:
status_check_period = self._client.config.status_check_period
if getattr(spec, "project_id", None) and getattr(spec, "labels", None):
raise exceptions.ApiValueError(
"Can't set labels to a task inside a project. "
"Tasks inside a project use project's labels.",
["labels"],
)
task = self.create(spec=spec)
self._client.logger.info("Created task ID: %s NAME: %s", task.id, task.name)
task.upload_data(resource_type, resources, pbar=pbar, params=data_params)
self._client.logger.info("Awaiting for task %s creation...", task.id)
status: models.RqStatus = None
while status != models.RqStatusStateEnum.allowed_values[("value",)]["FINISHED"]:
sleep(status_check_period)
(status, response) = self.api.retrieve_status(task.id)
self._client.logger.info(
"Task %s creation status=%s, message=%s",
task.id,
status.state.value,
status.message,
)
if status.state.value == models.RqStatusStateEnum.allowed_values[("value",)]["FAILED"]:
raise exceptions.ApiException(
status=status.state.value, reason=status.message, http_resp=response
)
status = status.state.value
if annotation_path:
task.import_annotations(annotation_format, annotation_path, pbar=pbar)
if dataset_repository_url:
git.create_git_repo(
self,
task_id=task.id,
repo_url=dataset_repository_url,
status_check_period=status_check_period,
use_lfs=use_lfs,
)
task.fetch()
return task
def remove_by_ids(self, task_ids: Sequence[int]) -> None:
"""
Delete a list of tasks, ignoring those which don't exist.
"""
for task_id in task_ids:
(_, response) = self.api.destroy(task_id, _check_status=False)
if 200 <= response.status <= 299:
self._client.logger.info(f"Task ID {task_id} deleted")
elif response.status == 404:
self._client.logger.info(f"Task ID {task_id} not found")
else:
self._client.logger.warning(
f"Failed to delete task ID {task_id}: "
f"{response.msg} (status {response.status})"
)
def create_from_backup(
self,
filename: str,
*,
status_check_period: int = None,
pbar: Optional[ProgressReporter] = None,
) -> Task:
"""
Import a task from a backup file
"""
if status_check_period is None:
status_check_period = self._client.config.status_check_period
params = {"filename": osp.basename(filename)}
url = self._client.api_map.make_endpoint_url(self.api.create_backup_endpoint.path)
uploader = Uploader(self._client)
response = uploader.upload_file(
url,
filename,
meta=params,
query_params=params,
pbar=pbar,
logger=self._client.logger.debug,
)
rq_id = json.loads(response.data)["rq_id"]
response = self._client.wait_for_completion(
url,
success_status=201,
positive_statuses=[202],
post_params={"rq_id": rq_id},
status_check_period=status_check_period,
)
task_id = json.loads(response.data)["id"]
self._client.logger.info(f"Task has been imported sucessfully. Task ID: {task_id}")
return self.retrieve(task_id)