You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
409 lines
13 KiB
Python
409 lines
13 KiB
Python
# Copyright (C) 2022 CVAT.ai Corporation
|
|
#
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
from __future__ import annotations
|
|
|
|
import io
|
|
import json
|
|
import mimetypes
|
|
import os
|
|
import os.path as osp
|
|
import shutil
|
|
from enum import Enum
|
|
from time import sleep
|
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence
|
|
|
|
from PIL import Image
|
|
|
|
from cvat_sdk.api_client import apis, exceptions, models
|
|
from cvat_sdk.core import git
|
|
from cvat_sdk.core.downloading import Downloader
|
|
from cvat_sdk.core.progress import ProgressReporter
|
|
from cvat_sdk.core.proxies.annotations import AnnotationCrudMixin
|
|
from cvat_sdk.core.proxies.jobs import Job
|
|
from cvat_sdk.core.proxies.model_proxy import (
|
|
ModelCreateMixin,
|
|
ModelDeleteMixin,
|
|
ModelListMixin,
|
|
ModelRetrieveMixin,
|
|
ModelUpdateMixin,
|
|
build_model_bases,
|
|
)
|
|
from cvat_sdk.core.uploading import AnnotationUploader, DataUploader, Uploader
|
|
from cvat_sdk.core.utils import filter_dict
|
|
|
|
if TYPE_CHECKING:
|
|
from _typeshed import SupportsWrite
|
|
|
|
|
|
class ResourceType(Enum):
|
|
LOCAL = 0
|
|
SHARE = 1
|
|
REMOTE = 2
|
|
|
|
def __str__(self):
|
|
return self.name.lower()
|
|
|
|
def __repr__(self):
|
|
return str(self)
|
|
|
|
|
|
_TaskEntityBase, _TaskRepoBase = build_model_bases(
|
|
models.TaskRead, apis.TasksApi, api_member_name="tasks_api"
|
|
)
|
|
|
|
|
|
class Task(
|
|
_TaskEntityBase,
|
|
models.ITaskRead,
|
|
ModelUpdateMixin[models.IPatchedTaskWriteRequest],
|
|
ModelDeleteMixin,
|
|
AnnotationCrudMixin,
|
|
):
|
|
_model_partial_update_arg = "patched_task_write_request"
|
|
_put_annotations_data_param = "task_annotations_update_request"
|
|
|
|
def upload_data(
|
|
self,
|
|
resource_type: ResourceType,
|
|
resources: Sequence[str],
|
|
*,
|
|
pbar: Optional[ProgressReporter] = None,
|
|
params: Optional[Dict[str, Any]] = None,
|
|
) -> None:
|
|
"""
|
|
Add local, remote, or shared files to an existing task.
|
|
"""
|
|
params = params or {}
|
|
|
|
data = {}
|
|
if resource_type is ResourceType.LOCAL:
|
|
pass # handled later
|
|
elif resource_type is ResourceType.REMOTE:
|
|
data["remote_files"] = resources
|
|
elif resource_type is ResourceType.SHARE:
|
|
data["server_files"] = resources
|
|
|
|
data["image_quality"] = 70
|
|
data.update(
|
|
filter_dict(
|
|
params,
|
|
keep=[
|
|
"chunk_size",
|
|
"copy_data",
|
|
"image_quality",
|
|
"sorting_method",
|
|
"start_frame",
|
|
"stop_frame",
|
|
"use_cache",
|
|
"use_zip_chunks",
|
|
],
|
|
)
|
|
)
|
|
if params.get("frame_step") is not None:
|
|
data["frame_filter"] = f"step={params.get('frame_step')}"
|
|
|
|
if resource_type in [ResourceType.REMOTE, ResourceType.SHARE]:
|
|
self.api.create_data(
|
|
self.id,
|
|
data_request=models.DataRequest(**data),
|
|
)
|
|
elif resource_type == ResourceType.LOCAL:
|
|
url = self._client.api_map.make_endpoint_url(
|
|
self.api.create_data_endpoint.path, kwsub={"id": self.id}
|
|
)
|
|
|
|
DataUploader(self._client).upload_files(url, resources, pbar=pbar, **data)
|
|
|
|
def import_annotations(
|
|
self,
|
|
format_name: str,
|
|
filename: str,
|
|
*,
|
|
status_check_period: Optional[int] = None,
|
|
pbar: Optional[ProgressReporter] = None,
|
|
):
|
|
"""
|
|
Upload annotations for a task in the specified format (e.g. 'YOLO ZIP 1.0').
|
|
"""
|
|
|
|
AnnotationUploader(self._client).upload_file_and_wait(
|
|
self.api.create_annotations_endpoint,
|
|
filename,
|
|
format_name,
|
|
url_params={"id": self.id},
|
|
pbar=pbar,
|
|
status_check_period=status_check_period,
|
|
)
|
|
|
|
self._client.logger.info(f"Annotation file '{filename}' for task #{self.id} uploaded")
|
|
|
|
def get_frame(
|
|
self,
|
|
frame_id: int,
|
|
*,
|
|
quality: Optional[str] = None,
|
|
) -> io.RawIOBase:
|
|
params = {}
|
|
if quality:
|
|
params["quality"] = quality
|
|
(_, response) = self.api.retrieve_data(self.id, number=frame_id, **params, type="frame")
|
|
return io.BytesIO(response.data)
|
|
|
|
def get_preview(
|
|
self,
|
|
) -> io.RawIOBase:
|
|
(_, response) = self.api.retrieve_data(self.id, type="preview")
|
|
return io.BytesIO(response.data)
|
|
|
|
def download_chunk(
|
|
self,
|
|
chunk_id: int,
|
|
output_file: SupportsWrite[bytes],
|
|
*,
|
|
quality: Optional[str] = None,
|
|
) -> None:
|
|
params = {}
|
|
if quality:
|
|
params["quality"] = quality
|
|
(_, response) = self.api.retrieve_data(
|
|
self.id, number=chunk_id, **params, type="chunk", _parse_response=False
|
|
)
|
|
|
|
with response:
|
|
shutil.copyfileobj(response, output_file)
|
|
|
|
def download_frames(
|
|
self,
|
|
frame_ids: Sequence[int],
|
|
*,
|
|
outdir: str = "",
|
|
quality: str = "original",
|
|
filename_pattern: str = "frame_{frame_id:06d}{frame_ext}",
|
|
) -> Optional[List[Image.Image]]:
|
|
"""
|
|
Download the requested frame numbers for a task and save images as outdir/filename_pattern
|
|
"""
|
|
# TODO: add arg descriptions in schema
|
|
os.makedirs(outdir, exist_ok=True)
|
|
|
|
for frame_id in frame_ids:
|
|
frame_bytes = self.get_frame(frame_id, quality=quality)
|
|
|
|
im = Image.open(frame_bytes)
|
|
mime_type = im.get_format_mimetype() or "image/jpg"
|
|
im_ext = mimetypes.guess_extension(mime_type)
|
|
|
|
# FIXME It is better to use meta information from the server
|
|
# to determine the extension
|
|
# replace '.jpe' or '.jpeg' with a more used '.jpg'
|
|
if im_ext in (".jpe", ".jpeg", None):
|
|
im_ext = ".jpg"
|
|
|
|
outfile = filename_pattern.format(frame_id=frame_id, frame_ext=im_ext)
|
|
im.save(osp.join(outdir, outfile))
|
|
|
|
def export_dataset(
|
|
self,
|
|
format_name: str,
|
|
filename: str,
|
|
*,
|
|
pbar: Optional[ProgressReporter] = None,
|
|
status_check_period: Optional[int] = None,
|
|
include_images: bool = True,
|
|
) -> None:
|
|
"""
|
|
Download annotations for a task in the specified format (e.g. 'YOLO ZIP 1.0').
|
|
"""
|
|
if include_images:
|
|
endpoint = self.api.retrieve_dataset_endpoint
|
|
else:
|
|
endpoint = self.api.retrieve_annotations_endpoint
|
|
|
|
Downloader(self._client).prepare_and_download_file_from_endpoint(
|
|
endpoint=endpoint,
|
|
filename=filename,
|
|
url_params={"id": self.id},
|
|
query_params={"format": format_name},
|
|
pbar=pbar,
|
|
status_check_period=status_check_period,
|
|
)
|
|
|
|
self._client.logger.info(f"Dataset for task {self.id} has been downloaded to {filename}")
|
|
|
|
def download_backup(
|
|
self,
|
|
filename: str,
|
|
*,
|
|
status_check_period: int = None,
|
|
pbar: Optional[ProgressReporter] = None,
|
|
) -> None:
|
|
"""
|
|
Download a task backup
|
|
"""
|
|
|
|
Downloader(self._client).prepare_and_download_file_from_endpoint(
|
|
self.api.retrieve_backup_endpoint,
|
|
filename=filename,
|
|
pbar=pbar,
|
|
status_check_period=status_check_period,
|
|
url_params={"id": self.id},
|
|
)
|
|
|
|
self._client.logger.info(f"Backup for task {self.id} has been downloaded to {filename}")
|
|
|
|
def get_jobs(self) -> List[Job]:
|
|
return [Job(self._client, m) for m in self.api.list_jobs(id=self.id)[0]]
|
|
|
|
def get_meta(self) -> models.IDataMetaRead:
|
|
(meta, _) = self.api.retrieve_data_meta(self.id)
|
|
return meta
|
|
|
|
def get_frames_info(self) -> List[models.IFrameMeta]:
|
|
return self.get_meta().frames
|
|
|
|
def remove_frames_by_ids(self, ids: Sequence[int]) -> None:
|
|
self.api.partial_update_data_meta(
|
|
self.id,
|
|
patched_data_meta_write_request=models.PatchedDataMetaWriteRequest(deleted_frames=ids),
|
|
)
|
|
|
|
|
|
class TasksRepo(
|
|
_TaskRepoBase,
|
|
ModelCreateMixin[Task, models.ITaskWriteRequest],
|
|
ModelRetrieveMixin[Task],
|
|
ModelListMixin[Task],
|
|
ModelDeleteMixin,
|
|
):
|
|
_entity_type = Task
|
|
|
|
def create_from_data(
|
|
self,
|
|
spec: models.ITaskWriteRequest,
|
|
resource_type: ResourceType,
|
|
resources: Sequence[str],
|
|
*,
|
|
data_params: Optional[Dict[str, Any]] = None,
|
|
annotation_path: str = "",
|
|
annotation_format: str = "CVAT XML 1.1",
|
|
status_check_period: int = None,
|
|
dataset_repository_url: str = "",
|
|
use_lfs: bool = False,
|
|
pbar: Optional[ProgressReporter] = None,
|
|
) -> Task:
|
|
"""
|
|
Create a new task with the given name and labels JSON and
|
|
add the files to it.
|
|
|
|
Returns: id of the created task
|
|
"""
|
|
if status_check_period is None:
|
|
status_check_period = self._client.config.status_check_period
|
|
|
|
if getattr(spec, "project_id", None) and getattr(spec, "labels", None):
|
|
raise exceptions.ApiValueError(
|
|
"Can't set labels to a task inside a project. "
|
|
"Tasks inside a project use project's labels.",
|
|
["labels"],
|
|
)
|
|
|
|
task = self.create(spec=spec)
|
|
self._client.logger.info("Created task ID: %s NAME: %s", task.id, task.name)
|
|
|
|
task.upload_data(resource_type, resources, pbar=pbar, params=data_params)
|
|
|
|
self._client.logger.info("Awaiting for task %s creation...", task.id)
|
|
status: models.RqStatus = None
|
|
while status != models.RqStatusStateEnum.allowed_values[("value",)]["FINISHED"]:
|
|
sleep(status_check_period)
|
|
(status, response) = self.api.retrieve_status(task.id)
|
|
|
|
self._client.logger.info(
|
|
"Task %s creation status=%s, message=%s",
|
|
task.id,
|
|
status.state.value,
|
|
status.message,
|
|
)
|
|
|
|
if status.state.value == models.RqStatusStateEnum.allowed_values[("value",)]["FAILED"]:
|
|
raise exceptions.ApiException(
|
|
status=status.state.value, reason=status.message, http_resp=response
|
|
)
|
|
|
|
status = status.state.value
|
|
|
|
if annotation_path:
|
|
task.import_annotations(annotation_format, annotation_path, pbar=pbar)
|
|
|
|
if dataset_repository_url:
|
|
git.create_git_repo(
|
|
self._client,
|
|
task_id=task.id,
|
|
repo_url=dataset_repository_url,
|
|
status_check_period=status_check_period,
|
|
use_lfs=use_lfs,
|
|
)
|
|
|
|
task.fetch()
|
|
|
|
return task
|
|
|
|
def remove_by_ids(self, task_ids: Sequence[int]) -> None:
|
|
"""
|
|
Delete a list of tasks, ignoring those which don't exist.
|
|
"""
|
|
|
|
for task_id in task_ids:
|
|
(_, response) = self.api.destroy(task_id, _check_status=False)
|
|
|
|
if 200 <= response.status <= 299:
|
|
self._client.logger.info(f"Task ID {task_id} deleted")
|
|
elif response.status == 404:
|
|
self._client.logger.info(f"Task ID {task_id} not found")
|
|
else:
|
|
self._client.logger.warning(
|
|
f"Failed to delete task ID {task_id}: "
|
|
f"{response.msg} (status {response.status})"
|
|
)
|
|
|
|
def create_from_backup(
|
|
self,
|
|
filename: str,
|
|
*,
|
|
status_check_period: int = None,
|
|
pbar: Optional[ProgressReporter] = None,
|
|
) -> Task:
|
|
"""
|
|
Import a task from a backup file
|
|
"""
|
|
if status_check_period is None:
|
|
status_check_period = self._client.config.status_check_period
|
|
|
|
params = {"filename": osp.basename(filename)}
|
|
url = self._client.api_map.make_endpoint_url(self.api.create_backup_endpoint.path)
|
|
uploader = Uploader(self._client)
|
|
response = uploader.upload_file(
|
|
url,
|
|
filename,
|
|
meta=params,
|
|
query_params=params,
|
|
pbar=pbar,
|
|
logger=self._client.logger.debug,
|
|
)
|
|
|
|
rq_id = json.loads(response.data)["rq_id"]
|
|
response = self._client.wait_for_completion(
|
|
url,
|
|
success_status=201,
|
|
positive_statuses=[202],
|
|
post_params={"rq_id": rq_id},
|
|
status_check_period=status_check_period,
|
|
)
|
|
|
|
task_id = json.loads(response.data)["id"]
|
|
self._client.logger.info(f"Task has been imported sucessfully. Task ID: {task_id}")
|
|
|
|
return self.retrieve(task_id)
|