You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

139 lines
4.7 KiB
Python

# Copyright (C) 2022 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
from __future__ import annotations
import json
from typing import Dict, List, Sequence, Tuple
import tqdm
from cvat_sdk import Client, models
from cvat_sdk.core.helpers import TqdmProgressReporter
from cvat_sdk.core.proxies.tasks import ResourceType
class CLI:
def __init__(self, client: Client, credentials: Tuple[str, str]):
self.client = client
# allow arbitrary kwargs in models
# TODO: will silently ignore invalid args, so remove this ASAP
self.client.api_client.configuration.discard_unknown_keys = True
self.client.login(credentials)
self.client.check_server_version(fail_if_unsupported=False)
def tasks_list(self, *, use_json_output: bool = False, **kwargs):
"""List all tasks in either basic or JSON format."""
results = self.client.tasks.list(return_json=use_json_output, **kwargs)
if use_json_output:
print(json.dumps(json.loads(results), indent=2))
else:
for r in results:
print(r.id)
def tasks_create(
self,
name: str,
labels: List[Dict[str, str]],
resources: Sequence[str],
*,
resource_type: ResourceType = ResourceType.LOCAL,
annotation_path: str = "",
annotation_format: str = "CVAT XML 1.1",
status_check_period: int = 2,
dataset_repository_url: str = "",
lfs: bool = False,
**kwargs,
) -> None:
"""
Create a new task with the given name and labels JSON and add the files to it.
"""
task = self.client.tasks.create_from_data(
spec=models.TaskWriteRequest(name=name, labels=labels, **kwargs),
resource_type=resource_type,
resources=resources,
data_params=kwargs,
annotation_path=annotation_path,
annotation_format=annotation_format,
status_check_period=status_check_period,
dataset_repository_url=dataset_repository_url,
use_lfs=lfs,
pbar=self._make_pbar(),
)
print("Created task id", task.id)
def tasks_delete(self, task_ids: Sequence[int]) -> None:
"""Delete a list of tasks, ignoring those which don't exist."""
self.client.tasks.remove_by_ids(task_ids=task_ids)
def tasks_frames(
self,
task_id: int,
frame_ids: Sequence[int],
*,
outdir: str = "",
quality: str = "original",
) -> None:
"""
Download the requested frame numbers for a task and save images as
task_<ID>_frame_<FRAME>.jpg.
"""
self.client.tasks.retrieve(obj_id=task_id).download_frames(
frame_ids=frame_ids,
outdir=outdir,
quality=quality,
filename_pattern=f"task_{task_id}" + "_frame_{frame_id:06d}{frame_ext}",
)
def tasks_dump(
self,
task_id: int,
fileformat: str,
filename: str,
*,
status_check_period: int = 2,
include_images: bool = False,
) -> None:
"""
Download annotations for a task in the specified format (e.g. 'YOLO ZIP 1.0').
"""
self.client.tasks.retrieve(obj_id=task_id).export_dataset(
format_name=fileformat,
filename=filename,
pbar=self._make_pbar(),
status_check_period=status_check_period,
include_images=include_images,
)
def tasks_upload(
self, task_id: str, fileformat: str, filename: str, *, status_check_period: int = 2
) -> None:
"""Upload annotations for a task in the specified format
(e.g. 'YOLO ZIP 1.0')."""
self.client.tasks.retrieve(obj_id=task_id).import_annotations(
format_name=fileformat,
filename=filename,
status_check_period=status_check_period,
pbar=self._make_pbar(),
)
def tasks_export(self, task_id: str, filename: str, *, status_check_period: int = 2) -> None:
"""Download a task backup"""
self.client.tasks.retrieve(obj_id=task_id).download_backup(
filename=filename, status_check_period=status_check_period, pbar=self._make_pbar()
)
def tasks_import(self, filename: str, *, status_check_period: int = 2) -> None:
"""Import a task from a backup file"""
self.client.tasks.create_from_backup(
filename=filename, status_check_period=status_check_period, pbar=self._make_pbar()
)
def _make_pbar(self, title: str = None) -> TqdmProgressReporter:
return TqdmProgressReporter(
tqdm.tqdm(unit_scale=True, unit="B", unit_divisor=1024, desc=title)
)