You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
199 lines
5.7 KiB
Python
199 lines
5.7 KiB
Python
# Copyright (C) 2022 CVAT.ai Corporation
|
|
#
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import urllib.parse
|
|
from time import sleep
|
|
from typing import Any, Dict, Optional, Sequence, Tuple
|
|
|
|
import attrs
|
|
import urllib3
|
|
|
|
from cvat_sdk.api_client import ApiClient, Configuration, models
|
|
from cvat_sdk.core.helpers import expect_status
|
|
from cvat_sdk.core.proxies.issues import CommentsRepo, IssuesRepo
|
|
from cvat_sdk.core.proxies.jobs import JobsRepo
|
|
from cvat_sdk.core.proxies.model_proxy import Repo
|
|
from cvat_sdk.core.proxies.projects import ProjectsRepo
|
|
from cvat_sdk.core.proxies.tasks import TasksRepo
|
|
from cvat_sdk.core.proxies.users import UsersRepo
|
|
|
|
|
|
@attrs.define
|
|
class Config:
|
|
status_check_period: float = 5
|
|
"""In seconds"""
|
|
|
|
|
|
class Client:
|
|
"""
|
|
Manages session and configuration.
|
|
"""
|
|
|
|
# TODO: Locates resources and APIs.
|
|
|
|
def __init__(
|
|
self, url: str, *, logger: Optional[logging.Logger] = None, config: Optional[Config] = None
|
|
):
|
|
# TODO: use requests instead of urllib3 in ApiClient
|
|
# TODO: try to autodetect schema
|
|
self.api_map = CVAT_API_V2(url)
|
|
self.api = ApiClient(Configuration(host=url))
|
|
self.logger = logger or logging.getLogger(__name__)
|
|
self.config = config or Config()
|
|
|
|
self._repos: Dict[str, Repo] = {}
|
|
|
|
def __enter__(self):
|
|
self.api.__enter__()
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback):
|
|
return self.api.__exit__(exc_type, exc_value, traceback)
|
|
|
|
def close(self):
|
|
return self.__exit__(None, None, None)
|
|
|
|
def login(self, credentials: Tuple[str, str]):
|
|
(auth, _) = self.api.auth_api.create_login(
|
|
models.LoginRequest(username=credentials[0], password=credentials[1])
|
|
)
|
|
|
|
assert "sessionid" in self.api.cookies
|
|
assert "csrftoken" in self.api.cookies
|
|
self.api.set_default_header("Authorization", "Token " + auth.key)
|
|
|
|
def _has_credentials(self):
|
|
return (
|
|
("sessionid" in self.api.cookies)
|
|
or ("csrftoken" in self.api.cookies)
|
|
or (self.api.get_common_headers().get("Authorization", ""))
|
|
)
|
|
|
|
def logout(self):
|
|
if self._has_credentials():
|
|
self.api.auth_api.create_logout()
|
|
|
|
def wait_for_completion(
|
|
self: Client,
|
|
url: str,
|
|
*,
|
|
success_status: int,
|
|
status_check_period: Optional[int] = None,
|
|
query_params: Optional[Dict[str, Any]] = None,
|
|
post_params: Optional[Dict[str, Any]] = None,
|
|
method: str = "POST",
|
|
positive_statuses: Optional[Sequence[int]] = None,
|
|
) -> urllib3.HTTPResponse:
|
|
if status_check_period is None:
|
|
status_check_period = self.config.status_check_period
|
|
|
|
positive_statuses = set(positive_statuses) | {success_status}
|
|
|
|
while True:
|
|
sleep(status_check_period)
|
|
|
|
response = self.api.rest_client.request(
|
|
method=method,
|
|
url=url,
|
|
headers=self.api.get_common_headers(),
|
|
query_params=query_params,
|
|
post_params=post_params,
|
|
)
|
|
|
|
self.logger.debug("STATUS %s", response.status)
|
|
expect_status(positive_statuses, response)
|
|
if response.status == success_status:
|
|
break
|
|
|
|
return response
|
|
|
|
def _get_repo(self, key: str) -> Repo:
|
|
_repo_map = {
|
|
"tasks": TasksRepo,
|
|
"projects": ProjectsRepo,
|
|
"jobs": JobsRepo,
|
|
"users": UsersRepo,
|
|
"issues": IssuesRepo,
|
|
"comments": CommentsRepo,
|
|
}
|
|
|
|
repo = self._repos.get(key, None)
|
|
if repo is None:
|
|
repo = _repo_map[key](self)
|
|
self._repos[key] = repo
|
|
return repo
|
|
|
|
@property
|
|
def tasks(self) -> TasksRepo:
|
|
return self._get_repo("tasks")
|
|
|
|
@property
|
|
def projects(self) -> ProjectsRepo:
|
|
return self._get_repo("projects")
|
|
|
|
@property
|
|
def jobs(self) -> JobsRepo:
|
|
return self._get_repo("jobs")
|
|
|
|
@property
|
|
def users(self) -> UsersRepo:
|
|
return self._get_repo("users")
|
|
|
|
@property
|
|
def issues(self) -> IssuesRepo:
|
|
return self._get_repo("issues")
|
|
|
|
@property
|
|
def comments(self) -> CommentsRepo:
|
|
return self._get_repo("comments")
|
|
|
|
|
|
class CVAT_API_V2:
|
|
"""Build parameterized API URLs"""
|
|
|
|
def __init__(self, host, https=False):
|
|
if host.startswith("https://"):
|
|
https = True
|
|
if host.startswith("http://") or host.startswith("https://"):
|
|
host = host.replace("http://", "")
|
|
host = host.replace("https://", "")
|
|
scheme = "https" if https else "http"
|
|
self.host = "{}://{}".format(scheme, host)
|
|
self.base = self.host + "/api/"
|
|
self.git = f"{scheme}://{host}/git/repository/"
|
|
|
|
def git_create(self, task_id):
|
|
return self.git + f"create/{task_id}"
|
|
|
|
def git_check(self, rq_id):
|
|
return self.git + f"check/{rq_id}"
|
|
|
|
def make_endpoint_url(
|
|
self,
|
|
path: str,
|
|
*,
|
|
psub: Optional[Sequence[Any]] = None,
|
|
kwsub: Optional[Dict[str, Any]] = None,
|
|
query_params: Optional[Dict[str, Any]] = None,
|
|
) -> str:
|
|
url = self.host + path
|
|
if psub or kwsub:
|
|
url = url.format(*(psub or []), **(kwsub or {}))
|
|
if query_params:
|
|
url += "?" + urllib.parse.urlencode(query_params)
|
|
return url
|
|
|
|
|
|
def make_client(
|
|
host: str, *, port: int = 8080, credentials: Optional[Tuple[int, int]] = None
|
|
) -> Client:
|
|
client = Client(url=f"{host}:{port}")
|
|
if credentials is not None:
|
|
client.login(credentials)
|
|
return client
|