From 7218f4e2834722b4eccead8096cff7821348facd Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Tue, 6 Sep 2022 22:32:58 +0300 Subject: [PATCH] Add host schema auto-detection in SDK (#4910) --- CHANGELOG.md | 1 + cvat-cli/README.md | 18 ++- cvat-cli/src/cvat_cli/cli.py | 2 +- cvat-cli/src/cvat_cli/parser.py | 11 +- cvat-sdk/cvat_sdk/core/client.py | 113 ++++++++++++------ cvat-sdk/cvat_sdk/core/downloading.py | 4 +- cvat-sdk/cvat_sdk/core/exceptions.py | 11 ++ cvat-sdk/cvat_sdk/core/git.py | 6 +- cvat-sdk/cvat_sdk/core/proxies/jobs.py | 2 +- cvat-sdk/cvat_sdk/core/proxies/model_proxy.py | 2 +- cvat-sdk/cvat_sdk/core/uploading.py | 14 +-- cvat-sdk/cvat_sdk/exceptions.py | 1 + site/content/en/docs/manual/advanced/cli.md | 4 +- tests/python/sdk/fixtures.py | 2 +- tests/python/sdk/test_client.py | 71 +++++++++++ 15 files changed, 198 insertions(+), 64 deletions(-) create mode 100644 cvat-sdk/cvat_sdk/core/exceptions.py create mode 100644 tests/python/sdk/test_client.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 0858d96d..8c400474 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,7 @@ Skeleton (), () - [COCO Keypoints](https://cocodataset.org/#keypoints-2020) format support () - Support for Oracle OCI Buckets () +- `cvat-sdk` and `cvat-cli` packages on PyPI () ### Changed - Bumped nuclio version to 1.8.14 diff --git a/cvat-cli/README.md b/cvat-cli/README.md index aec0afdb..71c19b79 100644 --- a/cvat-cli/README.md +++ b/cvat-cli/README.md @@ -1,5 +1,20 @@ # Command-line client for CVAT +A simple command line interface for working with CVAT tasks. At the moment it +implements a basic feature set but may serve as the starting point for a more +comprehensive CVAT administration tool in the future. + +Overview of functionality: + +- Create a new task (supports name, bug tracker, project, labels JSON, local/share/remote files) +- Delete tasks (supports deleting a list of task IDs) +- List all tasks (supports basic CSV or JSON output) +- Download JPEG frames (supports a list of frame IDs) +- Dump annotations (supports all formats via format string) +- Upload annotations for a task in the specified format (e.g. 'YOLO ZIP 1.0') +- Export and download a whole task +- Import a task + ## Installation `pip install cvat-cli` @@ -10,7 +25,7 @@ $ cvat-cli --help usage: cvat-cli [-h] [--version] [--auth USER:[PASS]] - [--server-host SERVER_HOST] [--server-port SERVER_PORT] [--https] [--debug] + [--server-host SERVER_HOST] [--server-port SERVER_PORT] [--debug] {create,delete,ls,frames,dump,upload,export,import} ... Perform common operations related to CVAT tasks. @@ -28,7 +43,6 @@ optional arguments: host (default: localhost) --server-port SERVER_PORT port (default: 8080) - --https force https connection (default: try to detect automatically) --debug show debug output ``` diff --git a/cvat-cli/src/cvat_cli/cli.py b/cvat-cli/src/cvat_cli/cli.py index 3a3dedb2..4609f954 100644 --- a/cvat-cli/src/cvat_cli/cli.py +++ b/cvat-cli/src/cvat_cli/cli.py @@ -19,7 +19,7 @@ class CLI: # allow arbitrary kwargs in models # TODO: will silently ignore invalid args, so remove this ASAP - self.client.api.configuration.discard_unknown_keys = True + self.client.api_client.configuration.discard_unknown_keys = True self.client.login(credentials) diff --git a/cvat-cli/src/cvat_cli/parser.py b/cvat-cli/src/cvat_cli/parser.py index 8aebf5fc..43bca8eb 100644 --- a/cvat-cli/src/cvat_cli/parser.py +++ b/cvat-cli/src/cvat_cli/parser.py @@ -66,13 +66,10 @@ def make_cmdline_parser() -> argparse.ArgumentParser: "--server-host", type=str, default="localhost", help="host (default: %(default)s)" ) parser.add_argument( - "--server-port", type=int, default="8080", help="port (default: %(default)s)" - ) - parser.add_argument( - "--https", - default=False, - action="store_true", - help="force https connection (default: try to detect automatically)", + "--server-port", + type=int, + default=None, + help="port (default: 80 for http and 443 for https connections)", ) parser.add_argument( "--debug", diff --git a/cvat-sdk/cvat_sdk/core/client.py b/cvat-sdk/cvat_sdk/core/client.py index db014def..cb8c4205 100644 --- a/cvat-sdk/cvat_sdk/core/client.py +++ b/cvat-sdk/cvat_sdk/core/client.py @@ -7,13 +7,16 @@ from __future__ import annotations import logging import urllib.parse +from contextlib import suppress from time import sleep from typing import Any, Dict, Optional, Sequence, Tuple import attrs import urllib3 +import urllib3.exceptions from cvat_sdk.api_client import ApiClient, Configuration, models +from cvat_sdk.core.exceptions import InvalidHostException from cvat_sdk.core.helpers import expect_status from cvat_sdk.core.proxies.issues import CommentsRepo, IssuesRepo from cvat_sdk.core.proxies.jobs import JobsRepo @@ -34,49 +37,89 @@ class Client: Manages session and configuration. """ - # TODO: Locates resources and APIs. - def __init__( self, url: str, *, logger: Optional[logging.Logger] = None, config: Optional[Config] = None ): - # TODO: use requests instead of urllib3 in ApiClient - # TODO: try to autodetect schema + url = self._validate_and_prepare_url(url) self.api_map = CVAT_API_V2(url) - self.api = ApiClient(Configuration(host=url)) + self.api_client = ApiClient(Configuration(host=self.api_map.host)) self.logger = logger or logging.getLogger(__name__) self.config = config or Config() self._repos: Dict[str, Repo] = {} + ALLOWED_SCHEMAS = ("https", "http") + + @classmethod + def _validate_and_prepare_url(cls, url: str) -> str: + url_parts = url.split("://", maxsplit=1) + if len(url_parts) == 2: + schema, base_url = url_parts + else: + schema = "" + base_url = url + + if schema and schema not in cls.ALLOWED_SCHEMAS: + raise InvalidHostException( + f"Invalid url schema '{schema}', expected " + f"one of , {', '.join(cls.ALLOWED_SCHEMAS)}" + ) + + if not schema: + schema = cls._detect_schema(base_url) + url = f"{schema}://{base_url}" + + return url + + @classmethod + def _detect_schema(cls, base_url: str) -> str: + for schema in cls.ALLOWED_SCHEMAS: + with ApiClient(Configuration(host=f"{schema}://{base_url}")) as api_client: + with suppress(urllib3.exceptions.RequestError): + (_, response) = api_client.schema_api.retrieve( + _request_timeout=5, _parse_response=False, _check_status=False + ) + + if response.status == 401: + return schema + + raise InvalidHostException( + "Failed to detect host schema automatically, please check " + "the server url and try to specify schema explicitly" + ) + def __enter__(self): - self.api.__enter__() + self.api_client.__enter__() return self - def __exit__(self, exc_type, exc_value, traceback): - return self.api.__exit__(exc_type, exc_value, traceback) + def __exit__(self, exc_type, exc_value, traceback) -> None: + return self.api_client.__exit__(exc_type, exc_value, traceback) - def close(self): + def close(self) -> None: return self.__exit__(None, None, None) - def login(self, credentials: Tuple[str, str]): - (auth, _) = self.api.auth_api.create_login( + def login(self, credentials: Tuple[str, str]) -> None: + (auth, _) = self.api_client.auth_api.create_login( models.LoginRequest(username=credentials[0], password=credentials[1]) ) - assert "sessionid" in self.api.cookies - assert "csrftoken" in self.api.cookies - self.api.set_default_header("Authorization", "Token " + auth.key) + assert "sessionid" in self.api_client.cookies + assert "csrftoken" in self.api_client.cookies + self.api_client.set_default_header("Authorization", "Token " + auth.key) - def _has_credentials(self): + def has_credentials(self) -> bool: return ( - ("sessionid" in self.api.cookies) - or ("csrftoken" in self.api.cookies) - or (self.api.get_common_headers().get("Authorization", "")) + ("sessionid" in self.api_client.cookies) + or ("csrftoken" in self.api_client.cookies) + or bool(self.api_client.get_common_headers().get("Authorization", "")) ) - def logout(self): - if self._has_credentials(): - self.api.auth_api.create_logout() + def logout(self) -> None: + if self.has_credentials(): + self.api_client.auth_api.create_logout() + self.api_client.cookies.pop("sessionid", None) + self.api_client.cookies.pop("csrftoken", None) + self.api_client.default_headers.pop("Authorization", None) def wait_for_completion( self: Client, @@ -97,10 +140,10 @@ class Client: while True: sleep(status_check_period) - response = self.api.rest_client.request( + response = self.api_client.rest_client.request( method=method, url=url, - headers=self.api.get_common_headers(), + headers=self.api_client.get_common_headers(), query_params=query_params, post_params=post_params, ) @@ -156,21 +199,15 @@ class Client: class CVAT_API_V2: """Build parameterized API URLs""" - def __init__(self, host, https=False): - if host.startswith("https://"): - https = True - if host.startswith("http://") or host.startswith("https://"): - host = host.replace("http://", "") - host = host.replace("https://", "") - scheme = "https" if https else "http" - self.host = "{}://{}".format(scheme, host) + def __init__(self, host: str): + self.host = host self.base = self.host + "/api/" - self.git = f"{scheme}://{host}/git/repository/" + self.git = self.host + "/git/repository/" - def git_create(self, task_id): + def git_create(self, task_id: int) -> str: return self.git + f"create/{task_id}" - def git_check(self, rq_id): + def git_check(self, rq_id: int) -> str: return self.git + f"check/{rq_id}" def make_endpoint_url( @@ -190,9 +227,13 @@ class CVAT_API_V2: def make_client( - host: str, *, port: int = 8080, credentials: Optional[Tuple[int, int]] = None + host: str, *, port: Optional[int] = None, credentials: Optional[Tuple[int, int]] = None ) -> Client: - client = Client(url=f"{host}:{port}") + url = host + if port: + url = f"{url}:{port}" + + client = Client(url=url) if credentials is not None: client.login(credentials) return client diff --git a/cvat-sdk/cvat_sdk/core/downloading.py b/cvat-sdk/cvat_sdk/core/downloading.py index a270ffc4..7caf7138 100644 --- a/cvat-sdk/cvat_sdk/core/downloading.py +++ b/cvat-sdk/cvat_sdk/core/downloading.py @@ -45,10 +45,10 @@ class Downloader: if osp.exists(tmp_path): raise FileExistsError(f"Can't write temporary file '{tmp_path}' - file exists") - response = self._client.api.rest_client.GET( + response = self._client.api_client.rest_client.GET( url, _request_timeout=timeout, - headers=self._client.api.get_common_headers(), + headers=self._client.api_client.get_common_headers(), _parse_response=False, ) with closing(response): diff --git a/cvat-sdk/cvat_sdk/core/exceptions.py b/cvat-sdk/cvat_sdk/core/exceptions.py new file mode 100644 index 00000000..c458bf02 --- /dev/null +++ b/cvat-sdk/cvat_sdk/core/exceptions.py @@ -0,0 +1,11 @@ +# Copyright (C) 2022 CVAT.ai Corporation +# +# SPDX-License-Identifier: MIT + + +class CvatSdkException(Exception): + """Base class for SDK exceptions""" + + +class InvalidHostException(CvatSdkException): + """Indicates an invalid hostname error""" diff --git a/cvat-sdk/cvat_sdk/core/git.py b/cvat-sdk/cvat_sdk/core/git.py index 44e71ea9..bd7c1207 100644 --- a/cvat-sdk/cvat_sdk/core/git.py +++ b/cvat-sdk/cvat_sdk/core/git.py @@ -23,9 +23,9 @@ def create_git_repo( if status_check_period is None: status_check_period = client.config.status_check_period - common_headers = client.api.get_common_headers() + common_headers = client.api_client.get_common_headers() - response = client.api.rest_client.POST( + response = client.api_client.rest_client.POST( client.api_map.git_create(task_id), post_params={"path": repo_url, "lfs": use_lfs, "tid": task_id}, headers=common_headers, @@ -39,7 +39,7 @@ def create_git_repo( status = None while status != "finished": sleep(status_check_period) - response = client.api.rest_client.GET(check_url, headers=common_headers) + response = client.api_client.rest_client.GET(check_url, headers=common_headers) response_json = json.loads(response.data) status = response_json["status"] if status == "failed" or status == "unknown": diff --git a/cvat-sdk/cvat_sdk/core/proxies/jobs.py b/cvat-sdk/cvat_sdk/core/proxies/jobs.py index c1818870..26985a91 100644 --- a/cvat-sdk/cvat_sdk/core/proxies/jobs.py +++ b/cvat-sdk/cvat_sdk/core/proxies/jobs.py @@ -146,7 +146,7 @@ class Job( return self.get_meta().frames def remove_frames_by_ids(self, ids: Sequence[int]) -> None: - self._client.api.tasks_api.jobs_partial_update_data_meta( + self._client.api_client.tasks_api.jobs_partial_update_data_meta( self.id, patched_data_meta_write_request=models.PatchedDataMetaWriteRequest(deleted_frames=ids), ) diff --git a/cvat-sdk/cvat_sdk/core/proxies/model_proxy.py b/cvat-sdk/cvat_sdk/core/proxies/model_proxy.py index 04673481..61e149d6 100644 --- a/cvat-sdk/cvat_sdk/core/proxies/model_proxy.py +++ b/cvat-sdk/cvat_sdk/core/proxies/model_proxy.py @@ -47,7 +47,7 @@ class ModelProxy(ABC, Generic[ModelType, ApiType]): @classmethod def get_api(cls, client: Client) -> ApiType: - return getattr(client.api, cls._api_member_name) + return getattr(client.api_client, cls._api_member_name) @property def api(self) -> ApiType: diff --git a/cvat-sdk/cvat_sdk/core/uploading.py b/cvat-sdk/cvat_sdk/core/uploading.py index 93d4764e..3bfe630f 100644 --- a/cvat-sdk/cvat_sdk/core/uploading.py +++ b/cvat-sdk/cvat_sdk/core/uploading.py @@ -243,7 +243,7 @@ class Uploader: input_file = StreamWithProgress(input_file, pbar, length=file_size) tus_uploader = self._make_tus_uploader( - self._client.api, + self._client.api_client, url=url.rstrip("/") + "/", metadata=meta, file_stream=input_file, @@ -253,23 +253,23 @@ class Uploader: tus_uploader.upload() def _tus_start_upload(self, url, *, query_params=None): - response = self._client.api.rest_client.POST( + response = self._client.api_client.rest_client.POST( url, query_params=query_params, headers={ "Upload-Start": "", - **self._client.api.get_common_headers(), + **self._client.api_client.get_common_headers(), }, ) expect_status(202, response) return response def _tus_finish_upload(self, url, *, query_params=None, fields=None): - response = self._client.api.rest_client.POST( + response = self._client.api_client.rest_client.POST( url, headers={ "Upload-Finish": "", - **self._client.api.get_common_headers(), + **self._client.api_client.get_common_headers(), }, query_params=query_params, post_params=fields, @@ -356,13 +356,13 @@ class DataUploader(Uploader): filename, es.enter_context(closing(open(filename, "rb"))).read(), ) - response = self._client.api.rest_client.POST( + response = self._client.api_client.rest_client.POST( url, post_params=dict(**kwargs, **files), headers={ "Content-Type": "multipart/form-data", "Upload-Multiple": "", - **self._client.api.get_common_headers(), + **self._client.api_client.get_common_headers(), }, ) expect_status(200, response) diff --git a/cvat-sdk/cvat_sdk/exceptions.py b/cvat-sdk/cvat_sdk/exceptions.py index 2901582f..e28c7d38 100644 --- a/cvat-sdk/cvat_sdk/exceptions.py +++ b/cvat-sdk/cvat_sdk/exceptions.py @@ -11,3 +11,4 @@ from cvat_sdk.api_client.exceptions import ( ApiValueError, OpenApiException, ) +from cvat_sdk.core.exceptions import CvatSdkException diff --git a/site/content/en/docs/manual/advanced/cli.md b/site/content/en/docs/manual/advanced/cli.md index e5f9c9a7..047f7f61 100644 --- a/site/content/en/docs/manual/advanced/cli.md +++ b/site/content/en/docs/manual/advanced/cli.md @@ -28,7 +28,7 @@ To access the CLI, you need to have python in environment, as well as a clone of the CVAT repository and the necessary modules: ```bash -pip install 'git+https://github.com/cvat-ai/cvat#subdirectory=cvat-cli' +pip install cvat-cli ``` You can get help with `cvat-cli --help`. @@ -51,8 +51,6 @@ optional arguments: host (default: localhost) --server-port SERVER_PORT port (default: 8080) - --https - using https connection (default: False) --debug show debug output ``` diff --git a/tests/python/sdk/fixtures.py b/tests/python/sdk/fixtures.py index 40816f2f..6cb46995 100644 --- a/tests/python/sdk/fixtures.py +++ b/tests/python/sdk/fixtures.py @@ -19,7 +19,7 @@ def fxt_client(fxt_logger): logger, _ = fxt_logger client = Client(BASE_URL, logger=logger) - api_client = client.api + api_client = client.api_client for k in api_client.configuration.logger: api_client.configuration.logger[k] = logger client.config.status_check_period = 0.01 diff --git a/tests/python/sdk/test_client.py b/tests/python/sdk/test_client.py new file mode 100644 index 00000000..d36bf52b --- /dev/null +++ b/tests/python/sdk/test_client.py @@ -0,0 +1,71 @@ +# Copyright (C) 2022 CVAT.ai Corporation +# +# SPDX-License-Identifier: MIT + +import io +from logging import Logger +from typing import Tuple + +import pytest +from cvat_sdk import Client +from cvat_sdk.core.client import make_client +from cvat_sdk.core.exceptions import InvalidHostException +from cvat_sdk.exceptions import ApiException + +from shared.utils.config import BASE_URL, USER_PASS + + +class TestClientUsecases: + @pytest.fixture(autouse=True) + def setup( + self, + changedb, # force fixture call order to allow DB setup + fxt_logger: Tuple[Logger, io.StringIO], + fxt_client: Client, + fxt_stdout: io.StringIO, + admin_user: str, + ): + _, self.logger_stream = fxt_logger + self.client = fxt_client + self.stdout = fxt_stdout + self.user = admin_user + + yield + + def test_can_login_with_basic_auth(self): + self.client.login((self.user, USER_PASS)) + + assert self.client.has_credentials() + + def test_can_fail_to_login_with_basic_auth(self): + with pytest.raises(ApiException): + self.client.login((self.user, USER_PASS + "123")) + + def test_can_logout(self): + self.client.login((self.user, USER_PASS)) + + self.client.logout() + + assert not self.client.has_credentials() + + +def test_can_detect_server_schema_if_not_provided(): + host, port = BASE_URL.split("://", maxsplit=1)[1].rsplit(":", maxsplit=1) + client = make_client(host=host, port=int(port)) + assert client.api_map.host == "http://" + host + ":" + port + + +def test_can_fail_to_detect_server_schema_if_not_provided(): + host, port = BASE_URL.split("://", maxsplit=1)[1].rsplit(":", maxsplit=1) + with pytest.raises(InvalidHostException) as capture: + make_client(host=host, port=int(port) + 1) + + assert capture.match(r"Failed to detect host schema automatically") + + +def test_can_reject_invalid_server_schema(): + host, port = BASE_URL.split("://", maxsplit=1)[1].rsplit(":", maxsplit=1) + with pytest.raises(InvalidHostException) as capture: + make_client(host="ftp://" + host, port=int(port) + 1) + + assert capture.match(r"Invalid url schema 'ftp'")