# Copyright (C) 2022 CVAT.ai Corporation # # SPDX-License-Identifier: MIT from __future__ import annotations import logging import urllib.parse from contextlib import suppress from time import sleep from typing import Any, Dict, Optional, Sequence, Tuple import attrs import packaging.version as pv import urllib3 import urllib3.exceptions from cvat_sdk.api_client import ApiClient, Configuration, exceptions, models from cvat_sdk.core.exceptions import IncompatibleVersionException, InvalidHostException from cvat_sdk.core.helpers import expect_status from cvat_sdk.core.proxies.issues import CommentsRepo, IssuesRepo from cvat_sdk.core.proxies.jobs import JobsRepo from cvat_sdk.core.proxies.model_proxy import Repo from cvat_sdk.core.proxies.projects import ProjectsRepo from cvat_sdk.core.proxies.tasks import TasksRepo from cvat_sdk.core.proxies.users import UsersRepo from cvat_sdk.version import VERSION @attrs.define class Config: status_check_period: float = 5 """Operation status check period, in seconds""" allow_unsupported_server: bool = True """Allow to use SDK with an unsupported server version. If disabled, raise an exception""" verify_ssl: Optional[bool] = None """Whether to verify host SSL certificate or not""" class Client: """ Manages session and configuration. """ SUPPORTED_SERVER_VERSIONS = ( pv.Version("2.0"), pv.Version("2.1"), pv.Version("2.2"), pv.Version("2.3"), ) def __init__( self, url: str, *, logger: Optional[logging.Logger] = None, config: Optional[Config] = None, check_server_version: bool = True, ) -> None: url = self._validate_and_prepare_url(url) self.logger = logger or logging.getLogger(__name__) self.config = config or Config() self.api_map = CVAT_API_V2(url) self.api_client = ApiClient( Configuration(host=self.api_map.host, verify_ssl=self.config.verify_ssl) ) if check_server_version: self.check_server_version() self._repos: Dict[str, Repo] = {} ALLOWED_SCHEMAS = ("https", "http") @classmethod def _validate_and_prepare_url(cls, url: str) -> str: url_parts = url.split("://", maxsplit=1) if len(url_parts) == 2: schema, base_url = url_parts else: schema = "" base_url = url if schema and schema not in cls.ALLOWED_SCHEMAS: raise InvalidHostException( f"Invalid url schema '{schema}', expected " f"one of , {', '.join(cls.ALLOWED_SCHEMAS)}" ) if not schema: schema = cls._detect_schema(base_url) url = f"{schema}://{base_url}" return url @classmethod def _detect_schema(cls, base_url: str) -> str: for schema in cls.ALLOWED_SCHEMAS: with ApiClient(Configuration(host=f"{schema}://{base_url}")) as api_client: with suppress(urllib3.exceptions.RequestError): (_, response) = api_client.server_api.retrieve_about( _request_timeout=5, _parse_response=False, _check_status=False ) if response.status in [200, 401]: # Server versions prior to 2.3.0 respond with unauthorized # 2.3.0 allows unauthorized access return schema raise InvalidHostException( "Failed to detect host schema automatically, please check " "the server url and try to specify 'https://' or 'http://' explicitly" ) def __enter__(self): self.api_client.__enter__() return self def __exit__(self, exc_type, exc_value, traceback) -> None: return self.api_client.__exit__(exc_type, exc_value, traceback) def close(self) -> None: return self.__exit__(None, None, None) def login(self, credentials: Tuple[str, str]) -> None: (auth, _) = self.api_client.auth_api.create_login( models.LoginRequest(username=credentials[0], password=credentials[1]) ) assert "sessionid" in self.api_client.cookies assert "csrftoken" in self.api_client.cookies self.api_client.set_default_header("Authorization", "Token " + auth.key) def has_credentials(self) -> bool: return ( ("sessionid" in self.api_client.cookies) or ("csrftoken" in self.api_client.cookies) or bool(self.api_client.get_common_headers().get("Authorization", "")) ) def logout(self) -> None: if self.has_credentials(): self.api_client.auth_api.create_logout() self.api_client.cookies.pop("sessionid", None) self.api_client.cookies.pop("csrftoken", None) self.api_client.default_headers.pop("Authorization", None) def wait_for_completion( self: Client, url: str, *, success_status: int, status_check_period: Optional[int] = None, query_params: Optional[Dict[str, Any]] = None, post_params: Optional[Dict[str, Any]] = None, method: str = "POST", positive_statuses: Optional[Sequence[int]] = None, ) -> urllib3.HTTPResponse: if status_check_period is None: status_check_period = self.config.status_check_period positive_statuses = set(positive_statuses) | {success_status} while True: sleep(status_check_period) response = self.api_client.rest_client.request( method=method, url=url, headers=self.api_client.get_common_headers(), query_params=query_params, post_params=post_params, ) self.logger.debug("STATUS %s", response.status) expect_status(positive_statuses, response) if response.status == success_status: break return response def check_server_version(self, fail_if_unsupported: Optional[bool] = None) -> None: if fail_if_unsupported is None: fail_if_unsupported = not self.config.allow_unsupported_server try: server_version = self.get_server_version() except exceptions.ApiException as e: msg = ( "Failed to retrieve server API version: %s. " "Some SDK functions may not work properly with this server." ) % (e,) self.logger.warning(msg) if fail_if_unsupported: raise IncompatibleVersionException(msg) return sdk_version = pv.Version(VERSION) # We only check base version match. Micro releases and fixes do not affect # API compatibility in general. if all( server_version.base_version != sv.base_version for sv in self.SUPPORTED_SERVER_VERSIONS ): msg = ( "Server version '%s' is not compatible with SDK version '%s'. " "Some SDK functions may not work properly with this server. " "You can continue using this SDK, or you can " "try to update with 'pip install cvat-sdk'." ) % (server_version, sdk_version) self.logger.warning(msg) if fail_if_unsupported: raise IncompatibleVersionException(msg) def get_server_version(self) -> pv.Version: # TODO: allow to use this endpoint unauthorized (about, _) = self.api_client.server_api.retrieve_about() return pv.Version(about.version) def _get_repo(self, key: str) -> Repo: _repo_map = { "tasks": TasksRepo, "projects": ProjectsRepo, "jobs": JobsRepo, "users": UsersRepo, "issues": IssuesRepo, "comments": CommentsRepo, } repo = self._repos.get(key, None) if repo is None: repo = _repo_map[key](self) self._repos[key] = repo return repo @property def tasks(self) -> TasksRepo: return self._get_repo("tasks") @property def projects(self) -> ProjectsRepo: return self._get_repo("projects") @property def jobs(self) -> JobsRepo: return self._get_repo("jobs") @property def users(self) -> UsersRepo: return self._get_repo("users") @property def issues(self) -> IssuesRepo: return self._get_repo("issues") @property def comments(self) -> CommentsRepo: return self._get_repo("comments") class CVAT_API_V2: """Build parameterized API URLs""" def __init__(self, host: str): self.host = host self.base = self.host + "/api/" self.git = self.host + "/git/repository/" def git_create(self, task_id: int) -> str: return self.git + f"create/{task_id}" def git_check(self, rq_id: int) -> str: return self.git + f"check/{rq_id}" def make_endpoint_url( self, path: str, *, psub: Optional[Sequence[Any]] = None, kwsub: Optional[Dict[str, Any]] = None, query_params: Optional[Dict[str, Any]] = None, ) -> str: url = self.host + path if psub or kwsub: url = url.format(*(psub or []), **(kwsub or {})) if query_params: url += "?" + urllib.parse.urlencode(query_params) return url def make_client( host: str, *, port: Optional[int] = None, credentials: Optional[Tuple[int, int]] = None ) -> Client: url = host if port: url = f"{url}:{port}" client = Client(url=url) if credentials is not None: client.login(credentials) return client