Add host schema auto-detection in SDK (#4910)

main
Maxim Zhiltsov 3 years ago committed by GitHub
parent 1dcba5a843
commit 7218f4e283
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -28,6 +28,7 @@ Skeleton (<https://github.com/cvat-ai/cvat/pull/1>), (<https://github.com/opencv
- Added confirmation when remove a track (<https://github.com/opencv/cvat/pull/4846>)
- [COCO Keypoints](https://cocodataset.org/#keypoints-2020) format support (<https://github.com/opencv/cvat/pull/4821>)
- Support for Oracle OCI Buckets (<https://github.com/opencv/cvat/pull/4876>)
- `cvat-sdk` and `cvat-cli` packages on PyPI (<https://github.com/opencv/cvat/pull/4903>)
### Changed
- Bumped nuclio version to 1.8.14

@ -1,5 +1,20 @@
# Command-line client for CVAT
A simple command line interface for working with CVAT tasks. At the moment it
implements a basic feature set but may serve as the starting point for a more
comprehensive CVAT administration tool in the future.
Overview of functionality:
- Create a new task (supports name, bug tracker, project, labels JSON, local/share/remote files)
- Delete tasks (supports deleting a list of task IDs)
- List all tasks (supports basic CSV or JSON output)
- Download JPEG frames (supports a list of frame IDs)
- Dump annotations (supports all formats via format string)
- Upload annotations for a task in the specified format (e.g. 'YOLO ZIP 1.0')
- Export and download a whole task
- Import a task
## Installation
`pip install cvat-cli`
@ -10,7 +25,7 @@
$ cvat-cli --help
usage: cvat-cli [-h] [--version] [--auth USER:[PASS]]
[--server-host SERVER_HOST] [--server-port SERVER_PORT] [--https] [--debug]
[--server-host SERVER_HOST] [--server-port SERVER_PORT] [--debug]
{create,delete,ls,frames,dump,upload,export,import} ...
Perform common operations related to CVAT tasks.
@ -28,7 +43,6 @@ optional arguments:
host (default: localhost)
--server-port SERVER_PORT
port (default: 8080)
--https force https connection (default: try to detect automatically)
--debug show debug output
```

@ -19,7 +19,7 @@ class CLI:
# allow arbitrary kwargs in models
# TODO: will silently ignore invalid args, so remove this ASAP
self.client.api.configuration.discard_unknown_keys = True
self.client.api_client.configuration.discard_unknown_keys = True
self.client.login(credentials)

@ -66,13 +66,10 @@ def make_cmdline_parser() -> argparse.ArgumentParser:
"--server-host", type=str, default="localhost", help="host (default: %(default)s)"
)
parser.add_argument(
"--server-port", type=int, default="8080", help="port (default: %(default)s)"
)
parser.add_argument(
"--https",
default=False,
action="store_true",
help="force https connection (default: try to detect automatically)",
"--server-port",
type=int,
default=None,
help="port (default: 80 for http and 443 for https connections)",
)
parser.add_argument(
"--debug",

@ -7,13 +7,16 @@ from __future__ import annotations
import logging
import urllib.parse
from contextlib import suppress
from time import sleep
from typing import Any, Dict, Optional, Sequence, Tuple
import attrs
import urllib3
import urllib3.exceptions
from cvat_sdk.api_client import ApiClient, Configuration, models
from cvat_sdk.core.exceptions import InvalidHostException
from cvat_sdk.core.helpers import expect_status
from cvat_sdk.core.proxies.issues import CommentsRepo, IssuesRepo
from cvat_sdk.core.proxies.jobs import JobsRepo
@ -34,49 +37,89 @@ class Client:
Manages session and configuration.
"""
# TODO: Locates resources and APIs.
def __init__(
self, url: str, *, logger: Optional[logging.Logger] = None, config: Optional[Config] = None
):
# TODO: use requests instead of urllib3 in ApiClient
# TODO: try to autodetect schema
url = self._validate_and_prepare_url(url)
self.api_map = CVAT_API_V2(url)
self.api = ApiClient(Configuration(host=url))
self.api_client = ApiClient(Configuration(host=self.api_map.host))
self.logger = logger or logging.getLogger(__name__)
self.config = config or Config()
self._repos: Dict[str, Repo] = {}
ALLOWED_SCHEMAS = ("https", "http")
@classmethod
def _validate_and_prepare_url(cls, url: str) -> str:
url_parts = url.split("://", maxsplit=1)
if len(url_parts) == 2:
schema, base_url = url_parts
else:
schema = ""
base_url = url
if schema and schema not in cls.ALLOWED_SCHEMAS:
raise InvalidHostException(
f"Invalid url schema '{schema}', expected "
f"one of <none>, {', '.join(cls.ALLOWED_SCHEMAS)}"
)
if not schema:
schema = cls._detect_schema(base_url)
url = f"{schema}://{base_url}"
return url
@classmethod
def _detect_schema(cls, base_url: str) -> str:
for schema in cls.ALLOWED_SCHEMAS:
with ApiClient(Configuration(host=f"{schema}://{base_url}")) as api_client:
with suppress(urllib3.exceptions.RequestError):
(_, response) = api_client.schema_api.retrieve(
_request_timeout=5, _parse_response=False, _check_status=False
)
if response.status == 401:
return schema
raise InvalidHostException(
"Failed to detect host schema automatically, please check "
"the server url and try to specify schema explicitly"
)
def __enter__(self):
self.api.__enter__()
self.api_client.__enter__()
return self
def __exit__(self, exc_type, exc_value, traceback):
return self.api.__exit__(exc_type, exc_value, traceback)
def __exit__(self, exc_type, exc_value, traceback) -> None:
return self.api_client.__exit__(exc_type, exc_value, traceback)
def close(self):
def close(self) -> None:
return self.__exit__(None, None, None)
def login(self, credentials: Tuple[str, str]):
(auth, _) = self.api.auth_api.create_login(
def login(self, credentials: Tuple[str, str]) -> None:
(auth, _) = self.api_client.auth_api.create_login(
models.LoginRequest(username=credentials[0], password=credentials[1])
)
assert "sessionid" in self.api.cookies
assert "csrftoken" in self.api.cookies
self.api.set_default_header("Authorization", "Token " + auth.key)
assert "sessionid" in self.api_client.cookies
assert "csrftoken" in self.api_client.cookies
self.api_client.set_default_header("Authorization", "Token " + auth.key)
def _has_credentials(self):
def has_credentials(self) -> bool:
return (
("sessionid" in self.api.cookies)
or ("csrftoken" in self.api.cookies)
or (self.api.get_common_headers().get("Authorization", ""))
("sessionid" in self.api_client.cookies)
or ("csrftoken" in self.api_client.cookies)
or bool(self.api_client.get_common_headers().get("Authorization", ""))
)
def logout(self):
if self._has_credentials():
self.api.auth_api.create_logout()
def logout(self) -> None:
if self.has_credentials():
self.api_client.auth_api.create_logout()
self.api_client.cookies.pop("sessionid", None)
self.api_client.cookies.pop("csrftoken", None)
self.api_client.default_headers.pop("Authorization", None)
def wait_for_completion(
self: Client,
@ -97,10 +140,10 @@ class Client:
while True:
sleep(status_check_period)
response = self.api.rest_client.request(
response = self.api_client.rest_client.request(
method=method,
url=url,
headers=self.api.get_common_headers(),
headers=self.api_client.get_common_headers(),
query_params=query_params,
post_params=post_params,
)
@ -156,21 +199,15 @@ class Client:
class CVAT_API_V2:
"""Build parameterized API URLs"""
def __init__(self, host, https=False):
if host.startswith("https://"):
https = True
if host.startswith("http://") or host.startswith("https://"):
host = host.replace("http://", "")
host = host.replace("https://", "")
scheme = "https" if https else "http"
self.host = "{}://{}".format(scheme, host)
def __init__(self, host: str):
self.host = host
self.base = self.host + "/api/"
self.git = f"{scheme}://{host}/git/repository/"
self.git = self.host + "/git/repository/"
def git_create(self, task_id):
def git_create(self, task_id: int) -> str:
return self.git + f"create/{task_id}"
def git_check(self, rq_id):
def git_check(self, rq_id: int) -> str:
return self.git + f"check/{rq_id}"
def make_endpoint_url(
@ -190,9 +227,13 @@ class CVAT_API_V2:
def make_client(
host: str, *, port: int = 8080, credentials: Optional[Tuple[int, int]] = None
host: str, *, port: Optional[int] = None, credentials: Optional[Tuple[int, int]] = None
) -> Client:
client = Client(url=f"{host}:{port}")
url = host
if port:
url = f"{url}:{port}"
client = Client(url=url)
if credentials is not None:
client.login(credentials)
return client

@ -45,10 +45,10 @@ class Downloader:
if osp.exists(tmp_path):
raise FileExistsError(f"Can't write temporary file '{tmp_path}' - file exists")
response = self._client.api.rest_client.GET(
response = self._client.api_client.rest_client.GET(
url,
_request_timeout=timeout,
headers=self._client.api.get_common_headers(),
headers=self._client.api_client.get_common_headers(),
_parse_response=False,
)
with closing(response):

@ -0,0 +1,11 @@
# Copyright (C) 2022 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
class CvatSdkException(Exception):
"""Base class for SDK exceptions"""
class InvalidHostException(CvatSdkException):
"""Indicates an invalid hostname error"""

@ -23,9 +23,9 @@ def create_git_repo(
if status_check_period is None:
status_check_period = client.config.status_check_period
common_headers = client.api.get_common_headers()
common_headers = client.api_client.get_common_headers()
response = client.api.rest_client.POST(
response = client.api_client.rest_client.POST(
client.api_map.git_create(task_id),
post_params={"path": repo_url, "lfs": use_lfs, "tid": task_id},
headers=common_headers,
@ -39,7 +39,7 @@ def create_git_repo(
status = None
while status != "finished":
sleep(status_check_period)
response = client.api.rest_client.GET(check_url, headers=common_headers)
response = client.api_client.rest_client.GET(check_url, headers=common_headers)
response_json = json.loads(response.data)
status = response_json["status"]
if status == "failed" or status == "unknown":

@ -146,7 +146,7 @@ class Job(
return self.get_meta().frames
def remove_frames_by_ids(self, ids: Sequence[int]) -> None:
self._client.api.tasks_api.jobs_partial_update_data_meta(
self._client.api_client.tasks_api.jobs_partial_update_data_meta(
self.id,
patched_data_meta_write_request=models.PatchedDataMetaWriteRequest(deleted_frames=ids),
)

@ -47,7 +47,7 @@ class ModelProxy(ABC, Generic[ModelType, ApiType]):
@classmethod
def get_api(cls, client: Client) -> ApiType:
return getattr(client.api, cls._api_member_name)
return getattr(client.api_client, cls._api_member_name)
@property
def api(self) -> ApiType:

@ -243,7 +243,7 @@ class Uploader:
input_file = StreamWithProgress(input_file, pbar, length=file_size)
tus_uploader = self._make_tus_uploader(
self._client.api,
self._client.api_client,
url=url.rstrip("/") + "/",
metadata=meta,
file_stream=input_file,
@ -253,23 +253,23 @@ class Uploader:
tus_uploader.upload()
def _tus_start_upload(self, url, *, query_params=None):
response = self._client.api.rest_client.POST(
response = self._client.api_client.rest_client.POST(
url,
query_params=query_params,
headers={
"Upload-Start": "",
**self._client.api.get_common_headers(),
**self._client.api_client.get_common_headers(),
},
)
expect_status(202, response)
return response
def _tus_finish_upload(self, url, *, query_params=None, fields=None):
response = self._client.api.rest_client.POST(
response = self._client.api_client.rest_client.POST(
url,
headers={
"Upload-Finish": "",
**self._client.api.get_common_headers(),
**self._client.api_client.get_common_headers(),
},
query_params=query_params,
post_params=fields,
@ -356,13 +356,13 @@ class DataUploader(Uploader):
filename,
es.enter_context(closing(open(filename, "rb"))).read(),
)
response = self._client.api.rest_client.POST(
response = self._client.api_client.rest_client.POST(
url,
post_params=dict(**kwargs, **files),
headers={
"Content-Type": "multipart/form-data",
"Upload-Multiple": "",
**self._client.api.get_common_headers(),
**self._client.api_client.get_common_headers(),
},
)
expect_status(200, response)

@ -11,3 +11,4 @@ from cvat_sdk.api_client.exceptions import (
ApiValueError,
OpenApiException,
)
from cvat_sdk.core.exceptions import CvatSdkException

@ -28,7 +28,7 @@ To access the CLI, you need to have python in environment,
as well as a clone of the CVAT repository and the necessary modules:
```bash
pip install 'git+https://github.com/cvat-ai/cvat#subdirectory=cvat-cli'
pip install cvat-cli
```
You can get help with `cvat-cli --help`.
@ -51,8 +51,6 @@ optional arguments:
host (default: localhost)
--server-port SERVER_PORT
port (default: 8080)
--https
using https connection (default: False)
--debug show debug output
```

@ -19,7 +19,7 @@ def fxt_client(fxt_logger):
logger, _ = fxt_logger
client = Client(BASE_URL, logger=logger)
api_client = client.api
api_client = client.api_client
for k in api_client.configuration.logger:
api_client.configuration.logger[k] = logger
client.config.status_check_period = 0.01

@ -0,0 +1,71 @@
# Copyright (C) 2022 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
import io
from logging import Logger
from typing import Tuple
import pytest
from cvat_sdk import Client
from cvat_sdk.core.client import make_client
from cvat_sdk.core.exceptions import InvalidHostException
from cvat_sdk.exceptions import ApiException
from shared.utils.config import BASE_URL, USER_PASS
class TestClientUsecases:
@pytest.fixture(autouse=True)
def setup(
self,
changedb, # force fixture call order to allow DB setup
fxt_logger: Tuple[Logger, io.StringIO],
fxt_client: Client,
fxt_stdout: io.StringIO,
admin_user: str,
):
_, self.logger_stream = fxt_logger
self.client = fxt_client
self.stdout = fxt_stdout
self.user = admin_user
yield
def test_can_login_with_basic_auth(self):
self.client.login((self.user, USER_PASS))
assert self.client.has_credentials()
def test_can_fail_to_login_with_basic_auth(self):
with pytest.raises(ApiException):
self.client.login((self.user, USER_PASS + "123"))
def test_can_logout(self):
self.client.login((self.user, USER_PASS))
self.client.logout()
assert not self.client.has_credentials()
def test_can_detect_server_schema_if_not_provided():
host, port = BASE_URL.split("://", maxsplit=1)[1].rsplit(":", maxsplit=1)
client = make_client(host=host, port=int(port))
assert client.api_map.host == "http://" + host + ":" + port
def test_can_fail_to_detect_server_schema_if_not_provided():
host, port = BASE_URL.split("://", maxsplit=1)[1].rsplit(":", maxsplit=1)
with pytest.raises(InvalidHostException) as capture:
make_client(host=host, port=int(port) + 1)
assert capture.match(r"Failed to detect host schema automatically")
def test_can_reject_invalid_server_schema():
host, port = BASE_URL.split("://", maxsplit=1)[1].rsplit(":", maxsplit=1)
with pytest.raises(InvalidHostException) as capture:
make_client(host="ftp://" + host, port=int(port) + 1)
assert capture.match(r"Invalid url schema 'ftp'")
Loading…
Cancel
Save