You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

256 lines
8.2 KiB
Python

# Copyright (C) 2020-2022 Intel Corporation
# Copyright (C) 2022 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
from __future__ import annotations
import json
import logging
import os.path as osp
import urllib.parse
from time import sleep
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import attrs
from cvat_sdk.api_client import ApiClient, ApiException, ApiValueError, Configuration, models
from cvat_sdk.core.git import create_git_repo
from cvat_sdk.core.helpers import get_paginated_collection
from cvat_sdk.core.progress import ProgressReporter
from cvat_sdk.core.tasks import TaskProxy
from cvat_sdk.core.types import ResourceType
from cvat_sdk.core.uploading import Uploader
from cvat_sdk.core.utils import assert_status
@attrs.define
class Config:
status_check_period: float = 5
"""In seconds"""
class Client:
"""
Manages session and configuration.
"""
# TODO: Locates resources and APIs.
def __init__(
self, url: str, *, logger: Optional[logging.Logger] = None, config: Optional[Config] = None
):
# TODO: use requests instead of urllib3 in ApiClient
# TODO: try to autodetect schema
self._api_map = _CVAT_API_V2(url)
self.api = ApiClient(Configuration(host=url))
self.logger = logger or logging.getLogger(__name__)
self.config = config or Config()
def __enter__(self):
self.api.__enter__()
return self
def __exit__(self, exc_type, exc_value, traceback):
return self.api.__exit__(exc_type, exc_value, traceback)
def close(self):
return self.__exit__(None, None, None)
def login(self, credentials: Tuple[str, str]):
(auth, _) = self.api.auth_api.create_login(
models.LoginRequest(username=credentials[0], password=credentials[1])
)
assert "sessionid" in self.api.cookies
assert "csrftoken" in self.api.cookies
self.api.set_default_header("Authorization", "Token " + auth.key)
def create_task(
self,
spec: models.ITaskWriteRequest,
resource_type: ResourceType,
resources: Sequence[str],
*,
data_params: Optional[Dict[str, Any]] = None,
annotation_path: str = "",
annotation_format: str = "CVAT XML 1.1",
status_check_period: int = None,
dataset_repository_url: str = "",
use_lfs: bool = False,
pbar: Optional[ProgressReporter] = None,
) -> TaskProxy:
"""
Create a new task with the given name and labels JSON and
add the files to it.
Returns: id of the created task
"""
if status_check_period is None:
status_check_period = self.config.status_check_period
if getattr(spec, "project_id", None) and getattr(spec, "labels", None):
raise ApiValueError(
"Can't set labels to a task inside a project. "
"Tasks inside a project use project's labels.",
["labels"],
)
(task, _) = self.api.tasks_api.create(spec)
self.logger.info("Created task ID: %s NAME: %s", task.id, task.name)
task = TaskProxy(self, task)
task.upload_data(resource_type, resources, pbar=pbar, params=data_params)
self.logger.info("Awaiting for task %s creation...", task.id)
status = None
while status != models.RqStatusStateEnum.allowed_values[("value",)]["FINISHED"]:
sleep(status_check_period)
(status, _) = self.api.tasks_api.retrieve_status(task.id)
self.logger.info(
"Task %s creation status=%s, message=%s",
task.id,
status.state.value,
status.message,
)
if status.state.value == models.RqStatusStateEnum.allowed_values[("value",)]["FAILED"]:
raise ApiException(status=status.state.value, reason=status.message)
status = status.state.value
if annotation_path:
task.import_annotations(annotation_format, annotation_path, pbar=pbar)
if dataset_repository_url:
create_git_repo(
self,
task_id=task.id,
repo_url=dataset_repository_url,
status_check_period=status_check_period,
use_lfs=use_lfs,
)
task.fetch()
return task
def list_tasks(
self, *, return_json: bool = False, **kwargs
) -> Union[List[TaskProxy], List[Dict[str, Any]]]:
"""List all tasks in either basic or JSON format."""
results = get_paginated_collection(
endpoint=self.api.tasks_api.list_endpoint, return_json=return_json, **kwargs
)
if return_json:
return json.dumps(results)
return [TaskProxy(self, v) for v in results]
def retrieve_task(self, task_id: int) -> TaskProxy:
(task, _) = self.api.tasks_api.retrieve(task_id)
return TaskProxy(self, task)
def delete_tasks(self, task_ids: Sequence[int]):
"""
Delete a list of tasks, ignoring those which don't exist.
"""
for task_id in task_ids:
(_, response) = self.api.tasks_api.destroy(task_id, _check_status=False)
if 200 <= response.status <= 299:
self.logger.info(f"Task ID {task_id} deleted")
elif response.status == 404:
self.logger.info(f"Task ID {task_id} not found")
else:
self.logger.warning(
f"Failed to delete task ID {task_id}: "
f"{response.msg} (status {response.status})"
)
def create_task_from_backup(
self,
filename: str,
*,
status_check_period: int = None,
pbar: Optional[ProgressReporter] = None,
) -> TaskProxy:
"""
Import a task from a backup file
"""
if status_check_period is None:
status_check_period = self.config.status_check_period
params = {"filename": osp.basename(filename)}
url = self._api_map.make_endpoint_url(self.api.tasks_api.create_backup_endpoint.path)
uploader = Uploader(self)
response = uploader.upload_file(
url, filename, meta=params, query_params=params, pbar=pbar, logger=self.logger.debug
)
rq_id = json.loads(response.data)["rq_id"]
# check task status
while True:
sleep(status_check_period)
response = self.api.rest_client.POST(
url, post_params={"rq_id": rq_id}, headers=self.api.get_common_headers()
)
if response.status == 201:
break
assert_status(202, response)
task_id = json.loads(response.data)["id"]
self.logger.info(f"Task has been imported sucessfully. Task ID: {task_id}")
return self.retrieve_task(task_id)
class _CVAT_API_V2:
"""Build parameterized API URLs"""
def __init__(self, host, https=False):
if host.startswith("https://"):
https = True
if host.startswith("http://") or host.startswith("https://"):
host = host.replace("http://", "")
host = host.replace("https://", "")
scheme = "https" if https else "http"
self.host = "{}://{}".format(scheme, host)
self.base = self.host + "/api/"
self.git = f"{scheme}://{host}/git/repository/"
def git_create(self, task_id):
return self.git + f"create/{task_id}"
def git_check(self, rq_id):
return self.git + f"check/{rq_id}"
def make_endpoint_url(
self,
path: str,
*,
psub: Optional[Sequence[Any]] = None,
kwsub: Optional[Dict[str, Any]] = None,
query_params: Optional[Dict[str, Any]] = None,
) -> str:
url = self.host + path
if psub or kwsub:
url = url.format(*(psub or []), **(kwsub or {}))
if query_params:
url += "?" + urllib.parse.urlencode(query_params)
return url
def make_client(
host: str, *, port: int = 8080, credentials: Optional[Tuple[int, int]] = None
) -> Client:
client = Client(url=f"{host}:{port}")
if credentials is not None:
client.login(credentials)
return client