You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
309 lines
9.5 KiB
Python
309 lines
9.5 KiB
Python
# Copyright (C) 2020-2022 Intel Corporation
|
|
# Copyright (C) 2022 CVAT.ai Corporation
|
|
#
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
from __future__ import annotations
|
|
|
|
import io
|
|
import mimetypes
|
|
import os
|
|
import os.path as osp
|
|
from abc import ABC, abstractmethod
|
|
from io import BytesIO
|
|
from time import sleep
|
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence
|
|
|
|
from PIL import Image
|
|
|
|
from cvat_sdk import models
|
|
from cvat_sdk.api_client.model_utils import OpenApiModel
|
|
from cvat_sdk.core.downloading import Downloader
|
|
from cvat_sdk.core.progress import ProgressReporter
|
|
from cvat_sdk.core.types import ResourceType
|
|
from cvat_sdk.core.uploading import Uploader
|
|
from cvat_sdk.core.utils import filter_dict
|
|
|
|
if TYPE_CHECKING:
|
|
from cvat_sdk.core.client import Client
|
|
|
|
|
|
class ModelProxy(ABC):
|
|
_client: Client
|
|
_model: OpenApiModel
|
|
|
|
def __init__(self, client: Client, model: OpenApiModel) -> None:
|
|
self.__dict__["_client"] = client
|
|
self.__dict__["_model"] = model
|
|
|
|
def __getattr__(self, __name: str) -> Any:
|
|
return self._model[__name]
|
|
|
|
def __setattr__(self, __name: str, __value: Any) -> None:
|
|
if __name in self.__dict__:
|
|
self.__dict__[__name] = __value
|
|
else:
|
|
self._model[__name] = __value
|
|
|
|
@abstractmethod
|
|
def fetch(self, force: bool = False):
|
|
"""Fetches model data from the server"""
|
|
...
|
|
|
|
@abstractmethod
|
|
def commit(self, force: bool = False):
|
|
"""Commits local changes to the server"""
|
|
...
|
|
|
|
def sync(self):
|
|
"""Pulls server state and commits local model changes"""
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def update(self, **kwargs):
|
|
"""Updates multiple fields at once"""
|
|
...
|
|
|
|
|
|
class TaskProxy(ModelProxy, models.ITaskRead):
|
|
def __init__(self, client: Client, task: models.TaskRead):
|
|
ModelProxy.__init__(self, client=client, model=task)
|
|
|
|
def remove(self):
|
|
self._client.api.tasks_api.destroy(self.id)
|
|
|
|
def upload_data(
|
|
self,
|
|
resource_type: ResourceType,
|
|
resources: Sequence[str],
|
|
*,
|
|
pbar: Optional[ProgressReporter] = None,
|
|
params: Optional[Dict[str, Any]] = None,
|
|
) -> None:
|
|
"""
|
|
Add local, remote, or shared files to an existing task.
|
|
"""
|
|
client = self._client
|
|
task_id = self.id
|
|
|
|
params = params or {}
|
|
|
|
data = {}
|
|
if resource_type is ResourceType.LOCAL:
|
|
pass # handled later
|
|
elif resource_type is ResourceType.REMOTE:
|
|
data = {f"remote_files[{i}]": f for i, f in enumerate(resources)}
|
|
elif resource_type is ResourceType.SHARE:
|
|
data = {f"server_files[{i}]": f for i, f in enumerate(resources)}
|
|
|
|
data["image_quality"] = 70
|
|
data.update(
|
|
filter_dict(
|
|
params,
|
|
keep=[
|
|
"chunk_size",
|
|
"copy_data",
|
|
"image_quality",
|
|
"sorting_method",
|
|
"start_frame",
|
|
"stop_frame",
|
|
"use_cache",
|
|
"use_zip_chunks",
|
|
],
|
|
)
|
|
)
|
|
if params.get("frame_step") is not None:
|
|
data["frame_filter"] = f"step={params.get('frame_step')}"
|
|
|
|
if resource_type in [ResourceType.REMOTE, ResourceType.SHARE]:
|
|
client.api.tasks_api.create_data(
|
|
task_id,
|
|
data_request=models.DataRequest(**data),
|
|
_content_type="multipart/form-data",
|
|
)
|
|
elif resource_type == ResourceType.LOCAL:
|
|
url = client._api_map.make_endpoint_url(
|
|
client.api.tasks_api.create_data_endpoint.path, kwsub={"id": task_id}
|
|
)
|
|
|
|
uploader = Uploader(client)
|
|
uploader.upload_files(url, resources, pbar=pbar, **data)
|
|
|
|
def import_annotations(
|
|
self,
|
|
format_name: str,
|
|
filename: str,
|
|
*,
|
|
status_check_period: int = None,
|
|
pbar: Optional[ProgressReporter] = None,
|
|
):
|
|
"""
|
|
Upload annotations for a task in the specified format
|
|
(e.g. 'YOLO ZIP 1.0').
|
|
"""
|
|
client = self._client
|
|
if status_check_period is None:
|
|
status_check_period = client.config.status_check_period
|
|
|
|
task_id = self.id
|
|
|
|
url = client._api_map.make_endpoint_url(
|
|
client.api.tasks_api.create_annotations_endpoint.path,
|
|
kwsub={"id": task_id},
|
|
)
|
|
params = {"format": format_name, "filename": osp.basename(filename)}
|
|
|
|
uploader = Uploader(client)
|
|
uploader.upload_file(
|
|
url, filename, pbar=pbar, query_params=params, meta={"filename": params["filename"]}
|
|
)
|
|
|
|
while True:
|
|
response = client.api.rest_client.POST(
|
|
url, headers=client.api.get_common_headers(), query_params=params
|
|
)
|
|
if response.status == 201:
|
|
break
|
|
|
|
sleep(status_check_period)
|
|
|
|
client.logger.info(
|
|
f"Upload job for Task ID {task_id} with annotation file {filename} finished"
|
|
)
|
|
|
|
def retrieve_frame(
|
|
self,
|
|
frame_id: int,
|
|
*,
|
|
quality: Optional[str] = None,
|
|
) -> io.RawIOBase:
|
|
client = self._client
|
|
task_id = self.id
|
|
|
|
(_, response) = client.api.tasks_api.retrieve_data(task_id, frame_id, quality, type="frame")
|
|
|
|
return BytesIO(response.data)
|
|
|
|
def download_frames(
|
|
self,
|
|
frame_ids: Sequence[int],
|
|
*,
|
|
outdir: str = "",
|
|
quality: str = "original",
|
|
filename_pattern: str = "task_{task_id}_frame_{frame_id:06d}{frame_ext}",
|
|
) -> Optional[List[Image.Image]]:
|
|
"""
|
|
Download the requested frame numbers for a task and save images as
|
|
outdir/filename_pattern
|
|
"""
|
|
# TODO: add arg descriptions in schema
|
|
task_id = self.id
|
|
|
|
os.makedirs(outdir, exist_ok=True)
|
|
|
|
for frame_id in frame_ids:
|
|
frame_bytes = self.retrieve_frame(frame_id, quality=quality)
|
|
|
|
im = Image.open(frame_bytes)
|
|
mime_type = im.get_format_mimetype() or "image/jpg"
|
|
im_ext = mimetypes.guess_extension(mime_type)
|
|
|
|
# FIXME It is better to use meta information from the server
|
|
# to determine the extension
|
|
# replace '.jpe' or '.jpeg' with a more used '.jpg'
|
|
if im_ext in (".jpe", ".jpeg", None):
|
|
im_ext = ".jpg"
|
|
|
|
outfile = filename_pattern.format(task_id=task_id, frame_id=frame_id, frame_ext=im_ext)
|
|
im.save(osp.join(outdir, outfile))
|
|
|
|
def export_dataset(
|
|
self,
|
|
format_name: str,
|
|
filename: str,
|
|
*,
|
|
pbar: Optional[ProgressReporter] = None,
|
|
status_check_period: int = None,
|
|
include_images: bool = True,
|
|
) -> None:
|
|
"""
|
|
Download annotations for a task in the specified format (e.g. 'YOLO ZIP 1.0').
|
|
"""
|
|
client = self._client
|
|
if status_check_period is None:
|
|
status_check_period = client.config.status_check_period
|
|
|
|
task_id = self.id
|
|
|
|
params = {"filename": self.name, "format": format_name}
|
|
if include_images:
|
|
endpoint = client.api.tasks_api.retrieve_dataset_endpoint
|
|
else:
|
|
endpoint = client.api.tasks_api.retrieve_annotations_endpoint
|
|
|
|
client.logger.info("Waiting for the server to prepare the file...")
|
|
while True:
|
|
(_, response) = endpoint.call_with_http_info(id=task_id, **params)
|
|
client.logger.debug("STATUS {}".format(response.status))
|
|
if response.status == 201:
|
|
break
|
|
sleep(status_check_period)
|
|
|
|
params["action"] = "download"
|
|
url = client._api_map.make_endpoint_url(
|
|
endpoint.path, kwsub={"id": task_id}, query_params=params
|
|
)
|
|
downloader = Downloader(client)
|
|
downloader.download_file(url, output_path=filename, pbar=pbar)
|
|
|
|
client.logger.info(f"Dataset has been exported to {filename}")
|
|
|
|
def download_backup(
|
|
self,
|
|
filename: str,
|
|
*,
|
|
status_check_period: int = None,
|
|
pbar: Optional[ProgressReporter] = None,
|
|
):
|
|
"""
|
|
Download a task backup
|
|
"""
|
|
client = self._client
|
|
if status_check_period is None:
|
|
status_check_period = client.config.status_check_period
|
|
|
|
task_id = self.id
|
|
|
|
endpoint = client.api.tasks_api.retrieve_backup_endpoint
|
|
client.logger.info("Waiting for the server to prepare the file...")
|
|
while True:
|
|
(_, response) = endpoint.call_with_http_info(id=task_id)
|
|
client.logger.debug("STATUS {}".format(response.status))
|
|
if response.status == 201:
|
|
break
|
|
sleep(status_check_period)
|
|
|
|
url = client._api_map.make_endpoint_url(
|
|
endpoint.path, kwsub={"id": task_id}, query_params={"action": "download"}
|
|
)
|
|
downloader = Downloader(client)
|
|
downloader.download_file(url, output_path=filename, pbar=pbar)
|
|
|
|
client.logger.info(
|
|
f"Task {task_id} has been exported sucessfully to {osp.abspath(filename)}"
|
|
)
|
|
|
|
def fetch(self, force: bool = False):
|
|
# TODO: implement revision checking
|
|
model, _ = self._client.api.tasks_api.retrieve(self.id)
|
|
self._model = model
|
|
|
|
def commit(self, force: bool = False):
|
|
return super().commit(force)
|
|
|
|
def update(self, **kwargs):
|
|
return super().update(**kwargs)
|
|
|
|
def __str__(self) -> str:
|
|
return str(self._model)
|