You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
247 lines
7.3 KiB
Python
247 lines
7.3 KiB
Python
# Copyright (C) 2022 CVAT.ai Corporation
|
|
#
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import urllib.parse
|
|
from contextlib import suppress
|
|
from time import sleep
|
|
from typing import Any, Dict, Optional, Sequence, Tuple
|
|
|
|
import attrs
|
|
import urllib3
|
|
import urllib3.exceptions
|
|
|
|
from cvat_sdk.api_client import ApiClient, Configuration, models
|
|
from cvat_sdk.core.exceptions import InvalidHostException
|
|
from cvat_sdk.core.helpers import expect_status
|
|
from cvat_sdk.core.proxies.issues import CommentsRepo, IssuesRepo
|
|
from cvat_sdk.core.proxies.jobs import JobsRepo
|
|
from cvat_sdk.core.proxies.model_proxy import Repo
|
|
from cvat_sdk.core.proxies.projects import ProjectsRepo
|
|
from cvat_sdk.core.proxies.tasks import TasksRepo
|
|
from cvat_sdk.core.proxies.users import UsersRepo
|
|
|
|
|
|
@attrs.define
|
|
class Config:
|
|
status_check_period: float = 5
|
|
"""In seconds"""
|
|
|
|
verify_ssl: Optional[bool] = None
|
|
"""
|
|
Whether to verify host SSL certificate or not.
|
|
"""
|
|
|
|
|
|
class Client:
|
|
"""
|
|
Manages session and configuration.
|
|
"""
|
|
|
|
def __init__(
|
|
self, url: str, *, logger: Optional[logging.Logger] = None, config: Optional[Config] = None
|
|
):
|
|
url = self._validate_and_prepare_url(url)
|
|
self.logger = logger or logging.getLogger(__name__)
|
|
self.config = config or Config()
|
|
self.api_map = CVAT_API_V2(url)
|
|
self.api_client = ApiClient(
|
|
Configuration(host=self.api_map.host, verify_ssl=self.config.verify_ssl)
|
|
)
|
|
|
|
self._repos: Dict[str, Repo] = {}
|
|
|
|
ALLOWED_SCHEMAS = ("https", "http")
|
|
|
|
@classmethod
|
|
def _validate_and_prepare_url(cls, url: str) -> str:
|
|
url_parts = url.split("://", maxsplit=1)
|
|
if len(url_parts) == 2:
|
|
schema, base_url = url_parts
|
|
else:
|
|
schema = ""
|
|
base_url = url
|
|
|
|
if schema and schema not in cls.ALLOWED_SCHEMAS:
|
|
raise InvalidHostException(
|
|
f"Invalid url schema '{schema}', expected "
|
|
f"one of <none>, {', '.join(cls.ALLOWED_SCHEMAS)}"
|
|
)
|
|
|
|
if not schema:
|
|
schema = cls._detect_schema(base_url)
|
|
url = f"{schema}://{base_url}"
|
|
|
|
return url
|
|
|
|
@classmethod
|
|
def _detect_schema(cls, base_url: str) -> str:
|
|
for schema in cls.ALLOWED_SCHEMAS:
|
|
with ApiClient(Configuration(host=f"{schema}://{base_url}")) as api_client:
|
|
with suppress(urllib3.exceptions.RequestError):
|
|
(_, response) = api_client.server_api.retrieve_about(
|
|
_request_timeout=5, _parse_response=False, _check_status=False
|
|
)
|
|
|
|
if response.status == 401:
|
|
return schema
|
|
|
|
raise InvalidHostException(
|
|
"Failed to detect host schema automatically, please check "
|
|
"the server url and try to specify schema explicitly"
|
|
)
|
|
|
|
def __enter__(self):
|
|
self.api_client.__enter__()
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback) -> None:
|
|
return self.api_client.__exit__(exc_type, exc_value, traceback)
|
|
|
|
def close(self) -> None:
|
|
return self.__exit__(None, None, None)
|
|
|
|
def login(self, credentials: Tuple[str, str]) -> None:
|
|
(auth, _) = self.api_client.auth_api.create_login(
|
|
models.LoginRequest(username=credentials[0], password=credentials[1])
|
|
)
|
|
|
|
assert "sessionid" in self.api_client.cookies
|
|
assert "csrftoken" in self.api_client.cookies
|
|
self.api_client.set_default_header("Authorization", "Token " + auth.key)
|
|
|
|
def has_credentials(self) -> bool:
|
|
return (
|
|
("sessionid" in self.api_client.cookies)
|
|
or ("csrftoken" in self.api_client.cookies)
|
|
or bool(self.api_client.get_common_headers().get("Authorization", ""))
|
|
)
|
|
|
|
def logout(self) -> None:
|
|
if self.has_credentials():
|
|
self.api_client.auth_api.create_logout()
|
|
self.api_client.cookies.pop("sessionid", None)
|
|
self.api_client.cookies.pop("csrftoken", None)
|
|
self.api_client.default_headers.pop("Authorization", None)
|
|
|
|
def wait_for_completion(
|
|
self: Client,
|
|
url: str,
|
|
*,
|
|
success_status: int,
|
|
status_check_period: Optional[int] = None,
|
|
query_params: Optional[Dict[str, Any]] = None,
|
|
post_params: Optional[Dict[str, Any]] = None,
|
|
method: str = "POST",
|
|
positive_statuses: Optional[Sequence[int]] = None,
|
|
) -> urllib3.HTTPResponse:
|
|
if status_check_period is None:
|
|
status_check_period = self.config.status_check_period
|
|
|
|
positive_statuses = set(positive_statuses) | {success_status}
|
|
|
|
while True:
|
|
sleep(status_check_period)
|
|
|
|
response = self.api_client.rest_client.request(
|
|
method=method,
|
|
url=url,
|
|
headers=self.api_client.get_common_headers(),
|
|
query_params=query_params,
|
|
post_params=post_params,
|
|
)
|
|
|
|
self.logger.debug("STATUS %s", response.status)
|
|
expect_status(positive_statuses, response)
|
|
if response.status == success_status:
|
|
break
|
|
|
|
return response
|
|
|
|
def _get_repo(self, key: str) -> Repo:
|
|
_repo_map = {
|
|
"tasks": TasksRepo,
|
|
"projects": ProjectsRepo,
|
|
"jobs": JobsRepo,
|
|
"users": UsersRepo,
|
|
"issues": IssuesRepo,
|
|
"comments": CommentsRepo,
|
|
}
|
|
|
|
repo = self._repos.get(key, None)
|
|
if repo is None:
|
|
repo = _repo_map[key](self)
|
|
self._repos[key] = repo
|
|
return repo
|
|
|
|
@property
|
|
def tasks(self) -> TasksRepo:
|
|
return self._get_repo("tasks")
|
|
|
|
@property
|
|
def projects(self) -> ProjectsRepo:
|
|
return self._get_repo("projects")
|
|
|
|
@property
|
|
def jobs(self) -> JobsRepo:
|
|
return self._get_repo("jobs")
|
|
|
|
@property
|
|
def users(self) -> UsersRepo:
|
|
return self._get_repo("users")
|
|
|
|
@property
|
|
def issues(self) -> IssuesRepo:
|
|
return self._get_repo("issues")
|
|
|
|
@property
|
|
def comments(self) -> CommentsRepo:
|
|
return self._get_repo("comments")
|
|
|
|
|
|
class CVAT_API_V2:
|
|
"""Build parameterized API URLs"""
|
|
|
|
def __init__(self, host: str):
|
|
self.host = host
|
|
self.base = self.host + "/api/"
|
|
self.git = self.host + "/git/repository/"
|
|
|
|
def git_create(self, task_id: int) -> str:
|
|
return self.git + f"create/{task_id}"
|
|
|
|
def git_check(self, rq_id: int) -> str:
|
|
return self.git + f"check/{rq_id}"
|
|
|
|
def make_endpoint_url(
|
|
self,
|
|
path: str,
|
|
*,
|
|
psub: Optional[Sequence[Any]] = None,
|
|
kwsub: Optional[Dict[str, Any]] = None,
|
|
query_params: Optional[Dict[str, Any]] = None,
|
|
) -> str:
|
|
url = self.host + path
|
|
if psub or kwsub:
|
|
url = url.format(*(psub or []), **(kwsub or {}))
|
|
if query_params:
|
|
url += "?" + urllib.parse.urlencode(query_params)
|
|
return url
|
|
|
|
|
|
def make_client(
|
|
host: str, *, port: Optional[int] = None, credentials: Optional[Tuple[int, int]] = None
|
|
) -> Client:
|
|
url = host
|
|
if port:
|
|
url = f"{url}:{port}"
|
|
|
|
client = Client(url=url)
|
|
if credentials is not None:
|
|
client.login(credentials)
|
|
return client
|