Check server version in SDK (#4935)

main
Maxim Zhiltsov 3 years ago committed by GitHub
parent 55913f0096
commit 2e15025434
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -33,10 +33,15 @@ def configure_logger(level):
def build_client(parsed_args: SimpleNamespace, logger: logging.Logger) -> Client: def build_client(parsed_args: SimpleNamespace, logger: logging.Logger) -> Client:
config = Config(verify_ssl=not parsed_args.insecure) config = Config(verify_ssl=not parsed_args.insecure)
url = parsed_args.server_host
if parsed_args.server_port:
url += f":{parsed_args.server_port}"
return Client( return Client(
url="{host}:{port}".format(host=parsed_args.server_host, port=parsed_args.server_port), url=url,
logger=logger, logger=logger,
config=config, config=config,
check_server_version=False, # version is checked after auth to support versions < 2.3
) )

@ -23,6 +23,8 @@ class CLI:
self.client.login(credentials) self.client.login(credentials)
self.client.check_server_version(fail_if_unsupported=False)
def tasks_list(self, *, use_json_output: bool = False, **kwargs): def tasks_list(self, *, use_json_output: bool = False, **kwargs):
"""List all tasks in either basic or JSON format.""" """List all tasks in either basic or JSON format."""
results = self.client.tasks.list(return_json=use_json_output, **kwargs) results = self.client.tasks.list(return_json=use_json_output, **kwargs)

@ -70,10 +70,10 @@ schema/
.openapi-generator/ .openapi-generator/
# Generated code # Generated code
cvat_sdk/api_client/ /cvat_sdk/api_client/
cvat_sdk/version.py /cvat_sdk/version.py
requirements/ /requirements/
docs/ /docs/
setup.py /setup.py
README.md /README.md
MANIFEST.in /MANIFEST.in

@ -12,11 +12,12 @@ from time import sleep
from typing import Any, Dict, Optional, Sequence, Tuple from typing import Any, Dict, Optional, Sequence, Tuple
import attrs import attrs
import packaging.version as pv
import urllib3 import urllib3
import urllib3.exceptions import urllib3.exceptions
from cvat_sdk.api_client import ApiClient, Configuration, models from cvat_sdk.api_client import ApiClient, Configuration, exceptions, models
from cvat_sdk.core.exceptions import InvalidHostException from cvat_sdk.core.exceptions import IncompatibleVersionException, InvalidHostException
from cvat_sdk.core.helpers import expect_status from cvat_sdk.core.helpers import expect_status
from cvat_sdk.core.proxies.issues import CommentsRepo, IssuesRepo from cvat_sdk.core.proxies.issues import CommentsRepo, IssuesRepo
from cvat_sdk.core.proxies.jobs import JobsRepo from cvat_sdk.core.proxies.jobs import JobsRepo
@ -24,17 +25,19 @@ from cvat_sdk.core.proxies.model_proxy import Repo
from cvat_sdk.core.proxies.projects import ProjectsRepo from cvat_sdk.core.proxies.projects import ProjectsRepo
from cvat_sdk.core.proxies.tasks import TasksRepo from cvat_sdk.core.proxies.tasks import TasksRepo
from cvat_sdk.core.proxies.users import UsersRepo from cvat_sdk.core.proxies.users import UsersRepo
from cvat_sdk.version import VERSION
@attrs.define @attrs.define
class Config: class Config:
status_check_period: float = 5 status_check_period: float = 5
"""In seconds""" """Operation status check period, in seconds"""
allow_unsupported_server: bool = True
"""Allow to use SDK with an unsupported server version. If disabled, raise an exception"""
verify_ssl: Optional[bool] = None verify_ssl: Optional[bool] = None
""" """Whether to verify host SSL certificate or not"""
Whether to verify host SSL certificate or not.
"""
class Client: class Client:
@ -42,9 +45,21 @@ class Client:
Manages session and configuration. Manages session and configuration.
""" """
SUPPORTED_SERVER_VERSIONS = (
pv.Version("2.0"),
pv.Version("2.1"),
pv.Version("2.2"),
pv.Version("2.3"),
)
def __init__( def __init__(
self, url: str, *, logger: Optional[logging.Logger] = None, config: Optional[Config] = None self,
): url: str,
*,
logger: Optional[logging.Logger] = None,
config: Optional[Config] = None,
check_server_version: bool = True,
) -> None:
url = self._validate_and_prepare_url(url) url = self._validate_and_prepare_url(url)
self.logger = logger or logging.getLogger(__name__) self.logger = logger or logging.getLogger(__name__)
self.config = config or Config() self.config = config or Config()
@ -53,6 +68,9 @@ class Client:
Configuration(host=self.api_map.host, verify_ssl=self.config.verify_ssl) Configuration(host=self.api_map.host, verify_ssl=self.config.verify_ssl)
) )
if check_server_version:
self.check_server_version()
self._repos: Dict[str, Repo] = {} self._repos: Dict[str, Repo] = {}
ALLOWED_SCHEMAS = ("https", "http") ALLOWED_SCHEMAS = ("https", "http")
@ -87,12 +105,14 @@ class Client:
_request_timeout=5, _parse_response=False, _check_status=False _request_timeout=5, _parse_response=False, _check_status=False
) )
if response.status == 401: if response.status in [200, 401]:
# Server versions prior to 2.3.0 respond with unauthorized
# 2.3.0 allows unauthorized access
return schema return schema
raise InvalidHostException( raise InvalidHostException(
"Failed to detect host schema automatically, please check " "Failed to detect host schema automatically, please check "
"the server url and try to specify schema explicitly" "the server url and try to specify 'https://' or 'http://' explicitly"
) )
def __enter__(self): def __enter__(self):
@ -162,6 +182,44 @@ class Client:
return response return response
def check_server_version(self, fail_if_unsupported: Optional[bool] = None) -> None:
if fail_if_unsupported is None:
fail_if_unsupported = not self.config.allow_unsupported_server
try:
server_version = self.get_server_version()
except exceptions.ApiException as e:
msg = (
"Failed to retrieve server API version: %s. "
"Some SDK functions may not work properly with this server."
) % (e,)
self.logger.warning(msg)
if fail_if_unsupported:
raise IncompatibleVersionException(msg)
return
sdk_version = pv.Version(VERSION)
# We only check base version match. Micro releases and fixes do not affect
# API compatibility in general.
if all(
server_version.base_version != sv.base_version for sv in self.SUPPORTED_SERVER_VERSIONS
):
msg = (
"Server version '%s' is not compatible with SDK version '%s'. "
"Some SDK functions may not work properly with this server. "
"You can continue using this SDK, or you can "
"try to update with 'pip install cvat-sdk'."
) % (server_version, sdk_version)
self.logger.warning(msg)
if fail_if_unsupported:
raise IncompatibleVersionException(msg)
def get_server_version(self) -> pv.Version:
# TODO: allow to use this endpoint unauthorized
(about, _) = self.api_client.server_api.retrieve_about()
return pv.Version(about.version)
def _get_repo(self, key: str) -> Repo: def _get_repo(self, key: str) -> Repo:
_repo_map = { _repo_map = {
"tasks": TasksRepo, "tasks": TasksRepo,

@ -9,3 +9,7 @@ class CvatSdkException(Exception):
class InvalidHostException(CvatSdkException): class InvalidHostException(CvatSdkException):
"""Indicates an invalid hostname error""" """Indicates an invalid hostname error"""
class IncompatibleVersionException(CvatSdkException):
"""Indicates server and SDK version mismatch"""

@ -1,7 +1,8 @@
-r api_client.txt -r api_client.txt
attrs >= 21.4.0 attrs >= 21.4.0
packaging >= 21.3
Pillow >= 9.0.1 Pillow >= 9.0.1
tqdm >= 4.64.0 tqdm >= 4.64.0
tuspy == 0.2.5 # have it pinned, because SDK has lots of patched TUS code tuspy == 0.2.5 # have it pinned, because SDK has lots of patched TUS code
typing_extensions >= 4.2.0 typing_extensions >= 4.2.0

@ -504,7 +504,7 @@ class ServerAboutAPITestCase(APITestCase):
def test_api_v2_server_about_no_auth(self): def test_api_v2_server_about_no_auth(self):
response = self._run_api_v2_server_about(None) response = self._run_api_v2_server_about(None)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) self.assertEqual(response.status_code, status.HTTP_200_OK)
def test_api_server_about_versions_admin(self): def test_api_server_about_versions_admin(self):
for version in settings.REST_FRAMEWORK['ALLOWED_VERSIONS']: for version in settings.REST_FRAMEWORK['ALLOWED_VERSIONS']:

@ -98,7 +98,9 @@ class ServerViewSet(viewsets.ViewSet):
responses={ responses={
'200': AboutSerializer, '200': AboutSerializer,
}) })
@action(detail=False, methods=['GET'], serializer_class=AboutSerializer) @action(detail=False, methods=['GET'], serializer_class=AboutSerializer,
permission_classes=[] # This endpoint is available for everyone
)
def about(request): def about(request):
from cvat import __version__ as cvat_version from cvat import __version__ as cvat_version
about = { about = {

@ -23,6 +23,7 @@ import subprocess
import mimetypes import mimetypes
from corsheaders.defaults import default_headers from corsheaders.defaults import default_headers
from distutils.util import strtobool from distutils.util import strtobool
from cvat import __version__
mimetypes.add_type("application/wasm", ".wasm", True) mimetypes.add_type("application/wasm", ".wasm", True)
@ -517,7 +518,7 @@ SPECTACULAR_SETTINGS = {
# Statically set schema version. May also be an empty string. When used together with # Statically set schema version. May also be an empty string. When used together with
# view versioning, will become '0.0.0 (v2)' for 'v2' versioned requests. # view versioning, will become '0.0.0 (v2)' for 'v2' versioned requests.
# Set VERSION to None if only the request version should be rendered. # Set VERSION to None if only the request version should be rendered.
'VERSION': '2.1.0', 'VERSION': __version__,
'CONTACT': { 'CONTACT': {
'name': 'CVAT.ai team', 'name': 'CVAT.ai team',
'url': 'https://github.com/cvat-ai/cvat', 'url': 'https://github.com/cvat-ai/cvat',

@ -7,9 +7,10 @@ import json
import os import os
from pathlib import Path from pathlib import Path
import packaging.version as pv
import pytest import pytest
from cvat_cli.cli import CLI from cvat_cli.cli import CLI
from cvat_sdk import make_client from cvat_sdk import Client, make_client
from cvat_sdk.api_client import exceptions from cvat_sdk.api_client import exceptions
from cvat_sdk.core.proxies.tasks import ResourceType, Task from cvat_sdk.core.proxies.tasks import ResourceType, Task
from PIL import Image from PIL import Image
@ -190,6 +191,17 @@ class TestCLI:
assert task_id != fxt_new_task.id assert task_id != fxt_new_task.id
assert self.client.tasks.retrieve(task_id).size == fxt_new_task.size assert self.client.tasks.retrieve(task_id).size == fxt_new_task.size
def test_can_warn_on_mismatching_server_version(self, monkeypatch, caplog):
def mocked_version(_):
return pv.Version("0")
# We don't actually run a separate process in the tests here, so it works
monkeypatch.setattr(Client, "get_server_version", mocked_version)
self.run_cli("ls")
assert "Server version '0' is not compatible with SDK version" in caplog.text
@pytest.mark.parametrize("verify", [True, False]) @pytest.mark.parametrize("verify", [True, False])
def test_can_control_ssl_verification_with_arg(self, monkeypatch, verify: bool): def test_can_control_ssl_verification_with_arg(self, monkeypatch, verify: bool):
# TODO: Very hacky implementation, improve it, if possible # TODO: Very hacky implementation, improve it, if possible

@ -0,0 +1,36 @@
# Copyright (C) 2022 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
from http import HTTPStatus
import pytest
from shared.utils.config import make_api_client
@pytest.mark.usefixtures('dontchangedb')
class TestGetServer:
def test_can_retrieve_about_unauthorized(self):
with make_api_client(user=None, password=None) as api_client:
(data, response) = api_client.server_api.retrieve_about()
assert response.status == HTTPStatus.OK
assert data.version
def test_can_retrieve_formats(self, admin_user: str):
with make_api_client(admin_user) as api_client:
(data, response) = api_client.server_api.retrieve_annotation_formats()
assert response.status == HTTPStatus.OK
assert len(data.importers) != 0
assert len(data.exporters) != 0
@pytest.mark.usefixtures('dontchangedb')
class TestGetSchema:
def test_can_get_schema_unauthorized(self):
with make_api_client(user=None, password=None) as api_client:
(data, response) = api_client.schema_api.retrieve()
assert response.status == HTTPStatus.OK
assert data

@ -3,13 +3,15 @@
# SPDX-License-Identifier: MIT # SPDX-License-Identifier: MIT
import io import io
from contextlib import ExitStack
from logging import Logger from logging import Logger
from typing import Tuple from typing import Tuple
import packaging.version as pv
import pytest import pytest
from cvat_sdk import Client from cvat_sdk import Client
from cvat_sdk.core.client import Config, make_client from cvat_sdk.core.client import Config, make_client
from cvat_sdk.core.exceptions import InvalidHostException from cvat_sdk.core.exceptions import IncompatibleVersionException, InvalidHostException
from cvat_sdk.exceptions import ApiException from cvat_sdk.exceptions import ApiException
from shared.utils.config import BASE_URL, USER_PASS from shared.utils.config import BASE_URL, USER_PASS
@ -48,6 +50,13 @@ class TestClientUsecases:
assert not self.client.has_credentials() assert not self.client.has_credentials()
def test_can_get_server_version(self):
self.client.login((self.user, USER_PASS))
version = self.client.get_server_version()
assert (version.major, version.minor) >= (2, 0)
def test_can_detect_server_schema_if_not_provided(): def test_can_detect_server_schema_if_not_provided():
host, port = BASE_URL.split("://", maxsplit=1)[1].rsplit(":", maxsplit=1) host, port = BASE_URL.split("://", maxsplit=1)[1].rsplit(":", maxsplit=1)
@ -71,6 +80,72 @@ def test_can_reject_invalid_server_schema():
assert capture.match(r"Invalid url schema 'ftp'") assert capture.match(r"Invalid url schema 'ftp'")
@pytest.mark.parametrize("raise_exception", (True, False))
def test_can_warn_on_mismatching_server_version(
fxt_logger: Tuple[Logger, io.StringIO], monkeypatch, raise_exception: bool
):
logger, logger_stream = fxt_logger
def mocked_version(_):
return pv.Version("0")
monkeypatch.setattr(Client, "get_server_version", mocked_version)
config = Config()
with ExitStack() as es:
if raise_exception:
config.allow_unsupported_server = False
es.enter_context(pytest.raises(IncompatibleVersionException))
Client(url=BASE_URL, logger=logger, config=config)
assert "Server version '0' is not compatible with SDK version" in logger_stream.getvalue()
@pytest.mark.parametrize("do_check", (True, False))
def test_can_check_server_version_in_ctor(
fxt_logger: Tuple[Logger, io.StringIO], monkeypatch, do_check: bool
):
logger, logger_stream = fxt_logger
def mocked_version(_):
return pv.Version("0")
monkeypatch.setattr(Client, "get_server_version", mocked_version)
config = Config()
config.allow_unsupported_server = False
with ExitStack() as es:
if do_check:
es.enter_context(pytest.raises(IncompatibleVersionException))
Client(url=BASE_URL, logger=logger, config=config, check_server_version=do_check)
assert (
"Server version '0' is not compatible with SDK version" in logger_stream.getvalue()
) == do_check
def test_can_check_server_version_in_method(fxt_logger: Tuple[Logger, io.StringIO], monkeypatch):
logger, logger_stream = fxt_logger
def mocked_version(_):
return pv.Version("0")
monkeypatch.setattr(Client, "get_server_version", mocked_version)
config = Config()
config.allow_unsupported_server = False
client = Client(url=BASE_URL, logger=logger, config=config, check_server_version=False)
with client, pytest.raises(IncompatibleVersionException):
client.check_server_version()
assert "Server version '0' is not compatible with SDK version" in logger_stream.getvalue()
@pytest.mark.parametrize("verify", [True, False]) @pytest.mark.parametrize("verify", [True, False])
def test_can_control_ssl_verification_with_config(verify: bool): def test_can_control_ssl_verification_with_config(verify: bool):
config = Config(verify_ssl=verify) config = Config(verify_ssl=verify)

Loading…
Cancel
Save