Check server version in SDK (#4935)

main
Maxim Zhiltsov 3 years ago committed by GitHub
parent 55913f0096
commit 2e15025434
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -33,10 +33,15 @@ def configure_logger(level):
def build_client(parsed_args: SimpleNamespace, logger: logging.Logger) -> Client:
config = Config(verify_ssl=not parsed_args.insecure)
url = parsed_args.server_host
if parsed_args.server_port:
url += f":{parsed_args.server_port}"
return Client(
url="{host}:{port}".format(host=parsed_args.server_host, port=parsed_args.server_port),
url=url,
logger=logger,
config=config,
check_server_version=False, # version is checked after auth to support versions < 2.3
)

@ -23,6 +23,8 @@ class CLI:
self.client.login(credentials)
self.client.check_server_version(fail_if_unsupported=False)
def tasks_list(self, *, use_json_output: bool = False, **kwargs):
"""List all tasks in either basic or JSON format."""
results = self.client.tasks.list(return_json=use_json_output, **kwargs)

@ -70,10 +70,10 @@ schema/
.openapi-generator/
# Generated code
cvat_sdk/api_client/
cvat_sdk/version.py
requirements/
docs/
setup.py
README.md
MANIFEST.in
/cvat_sdk/api_client/
/cvat_sdk/version.py
/requirements/
/docs/
/setup.py
/README.md
/MANIFEST.in

@ -12,11 +12,12 @@ from time import sleep
from typing import Any, Dict, Optional, Sequence, Tuple
import attrs
import packaging.version as pv
import urllib3
import urllib3.exceptions
from cvat_sdk.api_client import ApiClient, Configuration, models
from cvat_sdk.core.exceptions import InvalidHostException
from cvat_sdk.api_client import ApiClient, Configuration, exceptions, models
from cvat_sdk.core.exceptions import IncompatibleVersionException, InvalidHostException
from cvat_sdk.core.helpers import expect_status
from cvat_sdk.core.proxies.issues import CommentsRepo, IssuesRepo
from cvat_sdk.core.proxies.jobs import JobsRepo
@ -24,17 +25,19 @@ from cvat_sdk.core.proxies.model_proxy import Repo
from cvat_sdk.core.proxies.projects import ProjectsRepo
from cvat_sdk.core.proxies.tasks import TasksRepo
from cvat_sdk.core.proxies.users import UsersRepo
from cvat_sdk.version import VERSION
@attrs.define
class Config:
status_check_period: float = 5
"""In seconds"""
"""Operation status check period, in seconds"""
allow_unsupported_server: bool = True
"""Allow to use SDK with an unsupported server version. If disabled, raise an exception"""
verify_ssl: Optional[bool] = None
"""
Whether to verify host SSL certificate or not.
"""
"""Whether to verify host SSL certificate or not"""
class Client:
@ -42,9 +45,21 @@ class Client:
Manages session and configuration.
"""
SUPPORTED_SERVER_VERSIONS = (
pv.Version("2.0"),
pv.Version("2.1"),
pv.Version("2.2"),
pv.Version("2.3"),
)
def __init__(
self, url: str, *, logger: Optional[logging.Logger] = None, config: Optional[Config] = None
):
self,
url: str,
*,
logger: Optional[logging.Logger] = None,
config: Optional[Config] = None,
check_server_version: bool = True,
) -> None:
url = self._validate_and_prepare_url(url)
self.logger = logger or logging.getLogger(__name__)
self.config = config or Config()
@ -53,6 +68,9 @@ class Client:
Configuration(host=self.api_map.host, verify_ssl=self.config.verify_ssl)
)
if check_server_version:
self.check_server_version()
self._repos: Dict[str, Repo] = {}
ALLOWED_SCHEMAS = ("https", "http")
@ -87,12 +105,14 @@ class Client:
_request_timeout=5, _parse_response=False, _check_status=False
)
if response.status == 401:
if response.status in [200, 401]:
# Server versions prior to 2.3.0 respond with unauthorized
# 2.3.0 allows unauthorized access
return schema
raise InvalidHostException(
"Failed to detect host schema automatically, please check "
"the server url and try to specify schema explicitly"
"the server url and try to specify 'https://' or 'http://' explicitly"
)
def __enter__(self):
@ -162,6 +182,44 @@ class Client:
return response
def check_server_version(self, fail_if_unsupported: Optional[bool] = None) -> None:
if fail_if_unsupported is None:
fail_if_unsupported = not self.config.allow_unsupported_server
try:
server_version = self.get_server_version()
except exceptions.ApiException as e:
msg = (
"Failed to retrieve server API version: %s. "
"Some SDK functions may not work properly with this server."
) % (e,)
self.logger.warning(msg)
if fail_if_unsupported:
raise IncompatibleVersionException(msg)
return
sdk_version = pv.Version(VERSION)
# We only check base version match. Micro releases and fixes do not affect
# API compatibility in general.
if all(
server_version.base_version != sv.base_version for sv in self.SUPPORTED_SERVER_VERSIONS
):
msg = (
"Server version '%s' is not compatible with SDK version '%s'. "
"Some SDK functions may not work properly with this server. "
"You can continue using this SDK, or you can "
"try to update with 'pip install cvat-sdk'."
) % (server_version, sdk_version)
self.logger.warning(msg)
if fail_if_unsupported:
raise IncompatibleVersionException(msg)
def get_server_version(self) -> pv.Version:
# TODO: allow to use this endpoint unauthorized
(about, _) = self.api_client.server_api.retrieve_about()
return pv.Version(about.version)
def _get_repo(self, key: str) -> Repo:
_repo_map = {
"tasks": TasksRepo,

@ -9,3 +9,7 @@ class CvatSdkException(Exception):
class InvalidHostException(CvatSdkException):
"""Indicates an invalid hostname error"""
class IncompatibleVersionException(CvatSdkException):
"""Indicates server and SDK version mismatch"""

@ -1,7 +1,8 @@
-r api_client.txt
attrs >= 21.4.0
packaging >= 21.3
Pillow >= 9.0.1
tqdm >= 4.64.0
tuspy == 0.2.5 # have it pinned, because SDK has lots of patched TUS code
typing_extensions >= 4.2.0
typing_extensions >= 4.2.0

@ -504,7 +504,7 @@ class ServerAboutAPITestCase(APITestCase):
def test_api_v2_server_about_no_auth(self):
response = self._run_api_v2_server_about(None)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
self.assertEqual(response.status_code, status.HTTP_200_OK)
def test_api_server_about_versions_admin(self):
for version in settings.REST_FRAMEWORK['ALLOWED_VERSIONS']:

@ -98,7 +98,9 @@ class ServerViewSet(viewsets.ViewSet):
responses={
'200': AboutSerializer,
})
@action(detail=False, methods=['GET'], serializer_class=AboutSerializer)
@action(detail=False, methods=['GET'], serializer_class=AboutSerializer,
permission_classes=[] # This endpoint is available for everyone
)
def about(request):
from cvat import __version__ as cvat_version
about = {

@ -23,6 +23,7 @@ import subprocess
import mimetypes
from corsheaders.defaults import default_headers
from distutils.util import strtobool
from cvat import __version__
mimetypes.add_type("application/wasm", ".wasm", True)
@ -517,7 +518,7 @@ SPECTACULAR_SETTINGS = {
# Statically set schema version. May also be an empty string. When used together with
# view versioning, will become '0.0.0 (v2)' for 'v2' versioned requests.
# Set VERSION to None if only the request version should be rendered.
'VERSION': '2.1.0',
'VERSION': __version__,
'CONTACT': {
'name': 'CVAT.ai team',
'url': 'https://github.com/cvat-ai/cvat',

@ -7,9 +7,10 @@ import json
import os
from pathlib import Path
import packaging.version as pv
import pytest
from cvat_cli.cli import CLI
from cvat_sdk import make_client
from cvat_sdk import Client, make_client
from cvat_sdk.api_client import exceptions
from cvat_sdk.core.proxies.tasks import ResourceType, Task
from PIL import Image
@ -190,6 +191,17 @@ class TestCLI:
assert task_id != fxt_new_task.id
assert self.client.tasks.retrieve(task_id).size == fxt_new_task.size
def test_can_warn_on_mismatching_server_version(self, monkeypatch, caplog):
def mocked_version(_):
return pv.Version("0")
# We don't actually run a separate process in the tests here, so it works
monkeypatch.setattr(Client, "get_server_version", mocked_version)
self.run_cli("ls")
assert "Server version '0' is not compatible with SDK version" in caplog.text
@pytest.mark.parametrize("verify", [True, False])
def test_can_control_ssl_verification_with_arg(self, monkeypatch, verify: bool):
# TODO: Very hacky implementation, improve it, if possible

@ -0,0 +1,36 @@
# Copyright (C) 2022 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
from http import HTTPStatus
import pytest
from shared.utils.config import make_api_client
@pytest.mark.usefixtures('dontchangedb')
class TestGetServer:
def test_can_retrieve_about_unauthorized(self):
with make_api_client(user=None, password=None) as api_client:
(data, response) = api_client.server_api.retrieve_about()
assert response.status == HTTPStatus.OK
assert data.version
def test_can_retrieve_formats(self, admin_user: str):
with make_api_client(admin_user) as api_client:
(data, response) = api_client.server_api.retrieve_annotation_formats()
assert response.status == HTTPStatus.OK
assert len(data.importers) != 0
assert len(data.exporters) != 0
@pytest.mark.usefixtures('dontchangedb')
class TestGetSchema:
def test_can_get_schema_unauthorized(self):
with make_api_client(user=None, password=None) as api_client:
(data, response) = api_client.schema_api.retrieve()
assert response.status == HTTPStatus.OK
assert data

@ -3,13 +3,15 @@
# SPDX-License-Identifier: MIT
import io
from contextlib import ExitStack
from logging import Logger
from typing import Tuple
import packaging.version as pv
import pytest
from cvat_sdk import Client
from cvat_sdk.core.client import Config, make_client
from cvat_sdk.core.exceptions import InvalidHostException
from cvat_sdk.core.exceptions import IncompatibleVersionException, InvalidHostException
from cvat_sdk.exceptions import ApiException
from shared.utils.config import BASE_URL, USER_PASS
@ -48,6 +50,13 @@ class TestClientUsecases:
assert not self.client.has_credentials()
def test_can_get_server_version(self):
self.client.login((self.user, USER_PASS))
version = self.client.get_server_version()
assert (version.major, version.minor) >= (2, 0)
def test_can_detect_server_schema_if_not_provided():
host, port = BASE_URL.split("://", maxsplit=1)[1].rsplit(":", maxsplit=1)
@ -71,6 +80,72 @@ def test_can_reject_invalid_server_schema():
assert capture.match(r"Invalid url schema 'ftp'")
@pytest.mark.parametrize("raise_exception", (True, False))
def test_can_warn_on_mismatching_server_version(
fxt_logger: Tuple[Logger, io.StringIO], monkeypatch, raise_exception: bool
):
logger, logger_stream = fxt_logger
def mocked_version(_):
return pv.Version("0")
monkeypatch.setattr(Client, "get_server_version", mocked_version)
config = Config()
with ExitStack() as es:
if raise_exception:
config.allow_unsupported_server = False
es.enter_context(pytest.raises(IncompatibleVersionException))
Client(url=BASE_URL, logger=logger, config=config)
assert "Server version '0' is not compatible with SDK version" in logger_stream.getvalue()
@pytest.mark.parametrize("do_check", (True, False))
def test_can_check_server_version_in_ctor(
fxt_logger: Tuple[Logger, io.StringIO], monkeypatch, do_check: bool
):
logger, logger_stream = fxt_logger
def mocked_version(_):
return pv.Version("0")
monkeypatch.setattr(Client, "get_server_version", mocked_version)
config = Config()
config.allow_unsupported_server = False
with ExitStack() as es:
if do_check:
es.enter_context(pytest.raises(IncompatibleVersionException))
Client(url=BASE_URL, logger=logger, config=config, check_server_version=do_check)
assert (
"Server version '0' is not compatible with SDK version" in logger_stream.getvalue()
) == do_check
def test_can_check_server_version_in_method(fxt_logger: Tuple[Logger, io.StringIO], monkeypatch):
logger, logger_stream = fxt_logger
def mocked_version(_):
return pv.Version("0")
monkeypatch.setattr(Client, "get_server_version", mocked_version)
config = Config()
config.allow_unsupported_server = False
client = Client(url=BASE_URL, logger=logger, config=config, check_server_version=False)
with client, pytest.raises(IncompatibleVersionException):
client.check_server_version()
assert "Server version '0' is not compatible with SDK version" in logger_stream.getvalue()
@pytest.mark.parametrize("verify", [True, False])
def test_can_control_ssl_verification_with_config(verify: bool):
config = Config(verify_ssl=verify)

Loading…
Cancel
Save