You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

167 lines
5.1 KiB
Python

# Copyright (C) 2022 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
from __future__ import annotations
import io
import mimetypes
import os
import os.path as osp
from typing import List, Optional, Sequence
from PIL import Image
from cvat_sdk.api_client import apis, models
from cvat_sdk.core.downloading import Downloader
from cvat_sdk.core.helpers import get_paginated_collection
from cvat_sdk.core.progress import ProgressReporter
from cvat_sdk.core.proxies.annotations import AnnotationCrudMixin
from cvat_sdk.core.proxies.issues import Issue
from cvat_sdk.core.proxies.model_proxy import (
ModelListMixin,
ModelRetrieveMixin,
ModelUpdateMixin,
build_model_bases,
)
from cvat_sdk.core.uploading import AnnotationUploader
_JobEntityBase, _JobRepoBase = build_model_bases(
models.JobRead, apis.JobsApi, api_member_name="jobs_api"
)
class Job(
models.IJobRead,
_JobEntityBase,
ModelUpdateMixin[models.IPatchedJobWriteRequest],
AnnotationCrudMixin,
):
_model_partial_update_arg = "patched_job_write_request"
_put_annotations_data_param = "job_annotations_update_request"
def import_annotations(
self,
format_name: str,
filename: str,
*,
status_check_period: Optional[int] = None,
pbar: Optional[ProgressReporter] = None,
):
"""
Upload annotations for a job in the specified format (e.g. 'YOLO ZIP 1.0').
"""
AnnotationUploader(self._client).upload_file_and_wait(
self.api.create_annotations_endpoint,
filename,
format_name,
url_params={"id": self.id},
pbar=pbar,
status_check_period=status_check_period,
)
self._client.logger.info(f"Annotation file '{filename}' for job #{self.id} uploaded")
def export_dataset(
self,
format_name: str,
filename: str,
*,
pbar: Optional[ProgressReporter] = None,
status_check_period: Optional[int] = None,
include_images: bool = True,
) -> None:
"""
Download annotations for a job in the specified format (e.g. 'YOLO ZIP 1.0').
"""
if include_images:
endpoint = self.api.retrieve_dataset_endpoint
else:
endpoint = self.api.retrieve_annotations_endpoint
Downloader(self._client).prepare_and_download_file_from_endpoint(
endpoint=endpoint,
filename=filename,
url_params={"id": self.id},
query_params={"format": format_name},
pbar=pbar,
status_check_period=status_check_period,
)
self._client.logger.info(f"Dataset for job {self.id} has been downloaded to {filename}")
def get_frame(
self,
frame_id: int,
*,
quality: Optional[str] = None,
) -> io.RawIOBase:
(_, response) = self.api.retrieve_data(
self.id, number=frame_id, quality=quality, type="frame"
)
return io.BytesIO(response.data)
def get_preview(
self,
) -> io.RawIOBase:
(_, response) = self.api.retrieve_data(self.id, type="preview")
return io.BytesIO(response.data)
def download_frames(
self,
frame_ids: Sequence[int],
*,
outdir: str = "",
quality: str = "original",
filename_pattern: str = "frame_{frame_id:06d}{frame_ext}",
) -> Optional[List[Image.Image]]:
"""
Download the requested frame numbers for a job and save images as outdir/filename_pattern
"""
# TODO: add arg descriptions in schema
os.makedirs(outdir, exist_ok=True)
for frame_id in frame_ids:
frame_bytes = self.get_frame(frame_id, quality=quality)
im = Image.open(frame_bytes)
mime_type = im.get_format_mimetype() or "image/jpg"
im_ext = mimetypes.guess_extension(mime_type)
# FIXME It is better to use meta information from the server
# to determine the extension
# replace '.jpe' or '.jpeg' with a more used '.jpg'
if im_ext in (".jpe", ".jpeg", None):
im_ext = ".jpg"
outfile = filename_pattern.format(frame_id=frame_id, frame_ext=im_ext)
im.save(osp.join(outdir, outfile))
def get_meta(self) -> models.IDataMetaRead:
(meta, _) = self.api.retrieve_data_meta(self.id)
return meta
def get_frames_info(self) -> List[models.IFrameMeta]:
return self.get_meta().frames
def remove_frames_by_ids(self, ids: Sequence[int]) -> None:
self._client.api_client.tasks_api.jobs_partial_update_data_meta(
self.id,
patched_data_meta_write_request=models.PatchedDataMetaWriteRequest(deleted_frames=ids),
)
def get_issues(self) -> List[Issue]:
return [Issue(self._client, m) for m in self.api.list_issues(id=self.id)[0]]
def get_commits(self) -> List[models.IJobCommit]:
return get_paginated_collection(self.api.list_commits_endpoint, id=self.id)
class JobsRepo(
_JobRepoBase,
ModelListMixin[Job],
ModelRetrieveMixin[Job],
):
_entity_type = Job