SDK layer 2 - cover RC1 usecases (#4813)

main
Maxim Zhiltsov 4 years ago committed by GitHub
parent b60d3b481a
commit 53697ecac5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -6,3 +6,4 @@
# B406 : import_xml_sax # B406 : import_xml_sax
# B410 : import_lxml # B410 : import_lxml
skips: B101,B102,B320,B404,B406,B410 skips: B101,B102,B320,B404,B406,B410
exclude: **/tests/**,tests

@ -33,7 +33,7 @@ jobs:
echo "Bandit version: "$(bandit --version | head -1) echo "Bandit version: "$(bandit --version | head -1)
echo "The files will be checked: "$(echo $CHANGED_FILES) echo "The files will be checked: "$(echo $CHANGED_FILES)
bandit $CHANGED_FILES --exclude '**/tests/**' -a file --ini ./.bandit -f html -o ./bandit_report/bandit_checks.html bandit -a file --ini .bandit -f html -o ./bandit_report/bandit_checks.html $CHANGED_FILES
deactivate deactivate
else else
echo "No files with the \"py\" extension found" echo "No files with the \"py\" extension found"

@ -17,7 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Possibility to display tags on frame - Possibility to display tags on frame
- Support source and target storages (server part) - Support source and target storages (server part)
- Tests for import/export annotation, dataset, backup from/to cloud storage - Tests for import/export annotation, dataset, backup from/to cloud storage
- Added Python SDK package (`cvat-sdk`) - Added Python SDK package (`cvat-sdk`) (<https://github.com/opencv/cvat/pull/4813>)
- Previews for jobs - Previews for jobs
- Documentation for LDAP authentication (<https://github.com/cvat-ai/cvat/pull/39>) - Documentation for LDAP authentication (<https://github.com/cvat-ai/cvat/pull/39>)
- OpenCV.js caching and autoload (<https://github.com/cvat-ai/cvat/pull/30>) - OpenCV.js caching and autoload (<https://github.com/cvat-ai/cvat/pull/30>)
@ -28,7 +28,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed ### Changed
- Bumped nuclio version to 1.8.14 - Bumped nuclio version to 1.8.14
- Simplified running REST API tests. Extended CI-nightly workflow - Simplified running REST API tests. Extended CI-nightly workflow
- REST API tests are partially moved to Python SDK (`users`, `projects`, `tasks`) - REST API tests are partially moved to Python SDK (`users`, `projects`, `tasks`, `issues`)
- cvat-ui: Improve UI/UX on label, create task and create project forms (<https://github.com/cvat-ai/cvat/pull/7>) - cvat-ui: Improve UI/UX on label, create task and create project forms (<https://github.com/cvat-ai/cvat/pull/7>)
- Removed link to OpenVINO documentation (<https://github.com/cvat-ai/cvat/pull/35>) - Removed link to OpenVINO documentation (<https://github.com/cvat-ai/cvat/pull/35>)
- Clarified meaning of chunking for videos - Clarified meaning of chunking for videos
@ -51,6 +51,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Image search in cloud storage (<https://github.com/cvat-ai/cvat/pull/8>) - Image search in cloud storage (<https://github.com/cvat-ai/cvat/pull/8>)
- Reset password functionality (<https://github.com/cvat-ai/cvat/pull/52>) - Reset password functionality (<https://github.com/cvat-ai/cvat/pull/52>)
- Creating task with cloud storage data (<https://github.com/cvat-ai/cvat/pull/116>) - Creating task with cloud storage data (<https://github.com/cvat-ai/cvat/pull/116>)
- Show empty tasks (<https://github.com/cvat-ai/cvat/pull/100>)
### Security ### Security
- TDB - TDB

@ -1,4 +1,3 @@
# Copyright (C) 2020-2022 Intel Corporation
# Copyright (C) 2022 CVAT.ai Corporation # Copyright (C) 2022 CVAT.ai Corporation
# #
# SPDX-License-Identifier: MIT # SPDX-License-Identifier: MIT
@ -11,7 +10,7 @@ from typing import Dict, List, Sequence, Tuple
import tqdm import tqdm
from cvat_sdk import Client, models from cvat_sdk import Client, models
from cvat_sdk.core.helpers import TqdmProgressReporter from cvat_sdk.core.helpers import TqdmProgressReporter
from cvat_sdk.core.types import ResourceType from cvat_sdk.core.proxies.tasks import ResourceType
class CLI: class CLI:
@ -26,7 +25,7 @@ class CLI:
def tasks_list(self, *, use_json_output: bool = False, **kwargs): def tasks_list(self, *, use_json_output: bool = False, **kwargs):
"""List all tasks in either basic or JSON format.""" """List all tasks in either basic or JSON format."""
results = self.client.list_tasks(return_json=use_json_output, **kwargs) results = self.client.tasks.list(return_json=use_json_output, **kwargs)
if use_json_output: if use_json_output:
print(json.dumps(json.loads(results), indent=2)) print(json.dumps(json.loads(results), indent=2))
else: else:
@ -50,7 +49,7 @@ class CLI:
""" """
Create a new task with the given name and labels JSON and add the files to it. Create a new task with the given name and labels JSON and add the files to it.
""" """
task = self.client.create_task( task = self.client.tasks.create_from_data(
spec=models.TaskWriteRequest(name=name, labels=labels, **kwargs), spec=models.TaskWriteRequest(name=name, labels=labels, **kwargs),
resource_type=resource_type, resource_type=resource_type,
resources=resources, resources=resources,
@ -66,7 +65,7 @@ class CLI:
def tasks_delete(self, task_ids: Sequence[int]) -> None: def tasks_delete(self, task_ids: Sequence[int]) -> None:
"""Delete a list of tasks, ignoring those which don't exist.""" """Delete a list of tasks, ignoring those which don't exist."""
self.client.delete_tasks(task_ids=task_ids) self.client.tasks.remove_by_ids(task_ids=task_ids)
def tasks_frames( def tasks_frames(
self, self,
@ -80,11 +79,11 @@ class CLI:
Download the requested frame numbers for a task and save images as Download the requested frame numbers for a task and save images as
task_<ID>_frame_<FRAME>.jpg. task_<ID>_frame_<FRAME>.jpg.
""" """
self.client.retrieve_task(task_id=task_id).download_frames( self.client.tasks.retrieve(obj_id=task_id).download_frames(
frame_ids=frame_ids, frame_ids=frame_ids,
outdir=outdir, outdir=outdir,
quality=quality, quality=quality,
filename_pattern="task_{task_id}_frame_{frame_id:06d}{frame_ext}", filename_pattern=f"task_{task_id}" + "_frame_{frame_id:06d}{frame_ext}",
) )
def tasks_dump( def tasks_dump(
@ -99,7 +98,7 @@ class CLI:
""" """
Download annotations for a task in the specified format (e.g. 'YOLO ZIP 1.0'). Download annotations for a task in the specified format (e.g. 'YOLO ZIP 1.0').
""" """
self.client.retrieve_task(task_id=task_id).export_dataset( self.client.tasks.retrieve(obj_id=task_id).export_dataset(
format_name=fileformat, format_name=fileformat,
filename=filename, filename=filename,
pbar=self._make_pbar(), pbar=self._make_pbar(),
@ -112,7 +111,7 @@ class CLI:
) -> None: ) -> None:
"""Upload annotations for a task in the specified format """Upload annotations for a task in the specified format
(e.g. 'YOLO ZIP 1.0').""" (e.g. 'YOLO ZIP 1.0')."""
self.client.retrieve_task(task_id=task_id).import_annotations( self.client.tasks.retrieve(obj_id=task_id).import_annotations(
format_name=fileformat, format_name=fileformat,
filename=filename, filename=filename,
status_check_period=status_check_period, status_check_period=status_check_period,
@ -121,13 +120,13 @@ class CLI:
def tasks_export(self, task_id: str, filename: str, *, status_check_period: int = 2) -> None: def tasks_export(self, task_id: str, filename: str, *, status_check_period: int = 2) -> None:
"""Download a task backup""" """Download a task backup"""
self.client.retrieve_task(task_id=task_id).download_backup( self.client.tasks.retrieve(obj_id=task_id).download_backup(
filename=filename, status_check_period=status_check_period, pbar=self._make_pbar() filename=filename, status_check_period=status_check_period, pbar=self._make_pbar()
) )
def tasks_import(self, filename: str, *, status_check_period: int = 2) -> None: def tasks_import(self, filename: str, *, status_check_period: int = 2) -> None:
"""Import a task from a backup file""" """Import a task from a backup file"""
self.client.create_task_from_backup( self.client.tasks.create_from_backup(
filename=filename, status_check_period=status_check_period, pbar=self._make_pbar() filename=filename, status_check_period=status_check_period, pbar=self._make_pbar()
) )

@ -10,7 +10,7 @@ import logging
import os import os
from distutils.util import strtobool from distutils.util import strtobool
from cvat_sdk.core.types import ResourceType from cvat_sdk.core.proxies.tasks import ResourceType
from .version import VERSION from .version import VERSION

@ -74,4 +74,4 @@ cvat_sdk/api_client/
requirements/ requirements/
docs/ docs/
setup.py setup.py
README.md README.md

@ -1,4 +1,3 @@
# Copyright (C) 2020-2022 Intel Corporation
# Copyright (C) 2022 CVAT.ai Corporation # Copyright (C) 2022 CVAT.ai Corporation
# #
# SPDX-License-Identifier: MIT # SPDX-License-Identifier: MIT
@ -6,23 +5,22 @@
from __future__ import annotations from __future__ import annotations
import json
import logging import logging
import os.path as osp
import urllib.parse import urllib.parse
from time import sleep from time import sleep
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union from typing import Any, Dict, Optional, Sequence, Tuple
import attrs import attrs
import urllib3
from cvat_sdk.api_client import ApiClient, ApiException, ApiValueError, Configuration, models from cvat_sdk.api_client import ApiClient, Configuration, models
from cvat_sdk.core.git import create_git_repo from cvat_sdk.core.helpers import expect_status
from cvat_sdk.core.helpers import get_paginated_collection from cvat_sdk.core.proxies.issues import CommentsRepo, IssuesRepo
from cvat_sdk.core.progress import ProgressReporter from cvat_sdk.core.proxies.jobs import JobsRepo
from cvat_sdk.core.tasks import TaskProxy from cvat_sdk.core.proxies.model_proxy import Repo
from cvat_sdk.core.types import ResourceType from cvat_sdk.core.proxies.projects import ProjectsRepo
from cvat_sdk.core.uploading import Uploader from cvat_sdk.core.proxies.tasks import TasksRepo
from cvat_sdk.core.utils import assert_status from cvat_sdk.core.proxies.users import UsersRepo
@attrs.define @attrs.define
@ -43,11 +41,13 @@ class Client:
): ):
# TODO: use requests instead of urllib3 in ApiClient # TODO: use requests instead of urllib3 in ApiClient
# TODO: try to autodetect schema # TODO: try to autodetect schema
self._api_map = _CVAT_API_V2(url) self.api_map = CVAT_API_V2(url)
self.api = ApiClient(Configuration(host=url)) self.api = ApiClient(Configuration(host=url))
self.logger = logger or logging.getLogger(__name__) self.logger = logger or logging.getLogger(__name__)
self.config = config or Config() self.config = config or Config()
self._repos: Dict[str, Repo] = {}
def __enter__(self): def __enter__(self):
self.api.__enter__() self.api.__enter__()
return self return self
@ -67,150 +67,93 @@ class Client:
assert "csrftoken" in self.api.cookies assert "csrftoken" in self.api.cookies
self.api.set_default_header("Authorization", "Token " + auth.key) self.api.set_default_header("Authorization", "Token " + auth.key)
def create_task( def _has_credentials(self):
self, return (
spec: models.ITaskWriteRequest, ("sessionid" in self.api.cookies)
resource_type: ResourceType, or ("csrftoken" in self.api.cookies)
resources: Sequence[str], or (self.api.get_common_headers().get("Authorization", ""))
)
def logout(self):
if self._has_credentials():
self.api.auth_api.create_logout()
def wait_for_completion(
self: Client,
url: str,
*, *,
data_params: Optional[Dict[str, Any]] = None, success_status: int,
annotation_path: str = "", status_check_period: Optional[int] = None,
annotation_format: str = "CVAT XML 1.1", query_params: Optional[Dict[str, Any]] = None,
status_check_period: int = None, post_params: Optional[Dict[str, Any]] = None,
dataset_repository_url: str = "", method: str = "POST",
use_lfs: bool = False, positive_statuses: Optional[Sequence[int]] = None,
pbar: Optional[ProgressReporter] = None, ) -> urllib3.HTTPResponse:
) -> TaskProxy:
"""
Create a new task with the given name and labels JSON and
add the files to it.
Returns: id of the created task
"""
if status_check_period is None: if status_check_period is None:
status_check_period = self.config.status_check_period status_check_period = self.config.status_check_period
if getattr(spec, "project_id", None) and getattr(spec, "labels", None): positive_statuses = set(positive_statuses) | {success_status}
raise ApiValueError(
"Can't set labels to a task inside a project. "
"Tasks inside a project use project's labels.",
["labels"],
)
(task, _) = self.api.tasks_api.create(spec)
self.logger.info("Created task ID: %s NAME: %s", task.id, task.name)
task = TaskProxy(self, task) while True:
task.upload_data(resource_type, resources, pbar=pbar, params=data_params)
self.logger.info("Awaiting for task %s creation...", task.id)
status = None
while status != models.RqStatusStateEnum.allowed_values[("value",)]["FINISHED"]:
sleep(status_check_period) sleep(status_check_period)
(status, _) = self.api.tasks_api.retrieve_status(task.id)
self.logger.info(
"Task %s creation status=%s, message=%s",
task.id,
status.state.value,
status.message,
)
if status.state.value == models.RqStatusStateEnum.allowed_values[("value",)]["FAILED"]:
raise ApiException(status=status.state.value, reason=status.message)
status = status.state.value
if annotation_path: response = self.api.rest_client.request(
task.import_annotations(annotation_format, annotation_path, pbar=pbar) method=method,
url=url,
if dataset_repository_url: headers=self.api.get_common_headers(),
create_git_repo( query_params=query_params,
self, post_params=post_params,
task_id=task.id,
repo_url=dataset_repository_url,
status_check_period=status_check_period,
use_lfs=use_lfs,
) )
task.fetch() self.logger.debug("STATUS %s", response.status)
expect_status(positive_statuses, response)
if response.status == success_status:
break
return task return response
def list_tasks( def _get_repo(self, key: str) -> Repo:
self, *, return_json: bool = False, **kwargs _repo_map = {
) -> Union[List[TaskProxy], List[Dict[str, Any]]]: "tasks": TasksRepo,
"""List all tasks in either basic or JSON format.""" "projects": ProjectsRepo,
"jobs": JobsRepo,
"users": UsersRepo,
"issues": IssuesRepo,
"comments": CommentsRepo,
}
results = get_paginated_collection( repo = self._repos.get(key, None)
endpoint=self.api.tasks_api.list_endpoint, return_json=return_json, **kwargs if repo is None:
) repo = _repo_map[key](self)
self._repos[key] = repo
return repo
if return_json: @property
return json.dumps(results) def tasks(self) -> TasksRepo:
return self._get_repo("tasks")
return [TaskProxy(self, v) for v in results]
def retrieve_task(self, task_id: int) -> TaskProxy:
(task, _) = self.api.tasks_api.retrieve(task_id)
return TaskProxy(self, task)
def delete_tasks(self, task_ids: Sequence[int]):
"""
Delete a list of tasks, ignoring those which don't exist.
"""
for task_id in task_ids:
(_, response) = self.api.tasks_api.destroy(task_id, _check_status=False)
if 200 <= response.status <= 299:
self.logger.info(f"Task ID {task_id} deleted")
elif response.status == 404:
self.logger.info(f"Task ID {task_id} not found")
else:
self.logger.warning(
f"Failed to delete task ID {task_id}: "
f"{response.msg} (status {response.status})"
)
def create_task_from_backup(
self,
filename: str,
*,
status_check_period: int = None,
pbar: Optional[ProgressReporter] = None,
) -> TaskProxy:
"""
Import a task from a backup file
"""
if status_check_period is None:
status_check_period = self.config.status_check_period
params = {"filename": osp.basename(filename)} @property
url = self._api_map.make_endpoint_url(self.api.tasks_api.create_backup_endpoint.path) def projects(self) -> ProjectsRepo:
uploader = Uploader(self) return self._get_repo("projects")
response = uploader.upload_file(
url, filename, meta=params, query_params=params, pbar=pbar, logger=self.logger.debug
)
rq_id = json.loads(response.data)["rq_id"]
# check task status @property
while True: def jobs(self) -> JobsRepo:
sleep(status_check_period) return self._get_repo("jobs")
response = self.api.rest_client.POST( @property
url, post_params={"rq_id": rq_id}, headers=self.api.get_common_headers() def users(self) -> UsersRepo:
) return self._get_repo("users")
if response.status == 201:
break
assert_status(202, response)
task_id = json.loads(response.data)["id"] @property
self.logger.info(f"Task has been imported sucessfully. Task ID: {task_id}") def issues(self) -> IssuesRepo:
return self._get_repo("issues")
return self.retrieve_task(task_id) @property
def comments(self) -> CommentsRepo:
return self._get_repo("comments")
class _CVAT_API_V2: class CVAT_API_V2:
"""Build parameterized API URLs""" """Build parameterized API URLs"""
def __init__(self, host, https=False): def __init__(self, host, https=False):

@ -8,8 +8,9 @@ from __future__ import annotations
import os import os
import os.path as osp import os.path as osp
from contextlib import closing from contextlib import closing
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Any, Dict, Optional
from cvat_sdk.api_client.api_client import Endpoint
from cvat_sdk.core.progress import ProgressReporter from cvat_sdk.core.progress import ProgressReporter
if TYPE_CHECKING: if TYPE_CHECKING:
@ -17,8 +18,12 @@ if TYPE_CHECKING:
class Downloader: class Downloader:
"""
Implements common downloading protocols
"""
def __init__(self, client: Client): def __init__(self, client: Client):
self.client = client self._client = client
def download_file( def download_file(
self, self,
@ -29,8 +34,7 @@ class Downloader:
pbar: Optional[ProgressReporter] = None, pbar: Optional[ProgressReporter] = None,
) -> None: ) -> None:
""" """
Downloads the file from url into a temporary file, then renames it Downloads the file from url into a temporary file, then renames it to the requested name.
to the requested name.
""" """
CHUNK_SIZE = 10 * 2**20 CHUNK_SIZE = 10 * 2**20
@ -41,10 +45,10 @@ class Downloader:
if osp.exists(tmp_path): if osp.exists(tmp_path):
raise FileExistsError(f"Can't write temporary file '{tmp_path}' - file exists") raise FileExistsError(f"Can't write temporary file '{tmp_path}' - file exists")
response = self.client.api.rest_client.GET( response = self._client.api.rest_client.GET(
url, url,
_request_timeout=timeout, _request_timeout=timeout,
headers=self.client.api.get_common_headers(), headers=self._client.api.get_common_headers(),
_parse_response=False, _parse_response=False,
) )
with closing(response): with closing(response):
@ -72,3 +76,38 @@ class Downloader:
except: except:
os.unlink(tmp_path) os.unlink(tmp_path)
raise raise
def prepare_and_download_file_from_endpoint(
self,
endpoint: Endpoint,
filename: str,
*,
url_params: Optional[Dict[str, Any]] = None,
query_params: Optional[Dict[str, Any]] = None,
pbar: Optional[ProgressReporter] = None,
status_check_period: Optional[int] = None,
):
client = self._client
if status_check_period is None:
status_check_period = client.config.status_check_period
client.logger.info("Waiting for the server to prepare the file...")
url = client.api_map.make_endpoint_url(
endpoint.path, kwsub=url_params, query_params=query_params
)
client.wait_for_completion(
url,
method="GET",
positive_statuses=[202],
success_status=201,
status_check_period=status_check_period,
)
query_params = dict(query_params or {})
query_params["action"] = "download"
url = client.api_map.make_endpoint_url(
endpoint.path, kwsub=url_params, query_params=query_params
)
downloader = Downloader(client)
downloader.download_file(url, output_path=filename, pbar=pbar)

@ -1,4 +1,3 @@
# Copyright (C) 2020-2022 Intel Corporation
# Copyright (C) 2022 CVAT.ai Corporation # Copyright (C) 2022 CVAT.ai Corporation
# #
# SPDX-License-Identifier: MIT # SPDX-License-Identifier: MIT
@ -27,7 +26,7 @@ def create_git_repo(
common_headers = client.api.get_common_headers() common_headers = client.api.get_common_headers()
response = client.api.rest_client.POST( response = client.api.rest_client.POST(
client._api_map.git_create(task_id), client.api_map.git_create(task_id),
post_params={"path": repo_url, "lfs": use_lfs, "tid": task_id}, post_params={"path": repo_url, "lfs": use_lfs, "tid": task_id},
headers=common_headers, headers=common_headers,
) )
@ -36,7 +35,7 @@ def create_git_repo(
client.logger.info(f"Create RQ ID: {rq_id}") client.logger.info(f"Create RQ ID: {rq_id}")
client.logger.debug("Awaiting a dataset repository to be created for the task %s...", task_id) client.logger.debug("Awaiting a dataset repository to be created for the task %s...", task_id)
check_url = client._api_map.git_check(rq_id) check_url = client.api_map.git_check(rq_id)
status = None status = None
while status != "finished": while status != "finished":
sleep(status_check_period) sleep(status_check_period)

@ -6,13 +6,14 @@ from __future__ import annotations
import io import io
import json import json
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, Iterable, List, Optional, Union
import tqdm import tqdm
import urllib3
from cvat_sdk import exceptions
from cvat_sdk.api_client.api_client import Endpoint from cvat_sdk.api_client.api_client import Endpoint
from cvat_sdk.core.progress import ProgressReporter from cvat_sdk.core.progress import ProgressReporter
from cvat_sdk.core.utils import assert_status
def get_paginated_collection( def get_paginated_collection(
@ -26,7 +27,7 @@ def get_paginated_collection(
page = 1 page = 1
while True: while True:
(page_contents, response) = endpoint.call_with_http_info(**kwargs, page=page) (page_contents, response) = endpoint.call_with_http_info(**kwargs, page=page)
assert_status(200, response) expect_status(200, response)
if return_json: if return_json:
results.extend(json.loads(response.data).get("results", [])) results.extend(json.loads(response.data).get("results", []))
@ -86,3 +87,18 @@ class StreamWithProgress:
def tell(self): def tell(self):
return self.stream.tell() return self.stream.tell()
def expect_status(codes: Union[int, Iterable[int]], response: urllib3.HTTPResponse) -> None:
if not hasattr(codes, "__iter__"):
codes = [codes]
if response.status in codes:
return
if 300 <= response.status <= 500:
raise exceptions.ApiException(response.status, reason=response.msg, http_resp=response)
else:
raise exceptions.ApiException(
response.status, reason="Unexpected status code received", http_resp=response
)

@ -0,0 +1,66 @@
# Copyright (C) 2022 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
from abc import ABC
from enum import Enum
from typing import Optional, Sequence
from cvat_sdk import models
from cvat_sdk.core.proxies.model_proxy import _EntityT
class AnnotationUpdateAction(Enum):
CREATE = "create"
UPDATE = "update"
DELETE = "delete"
class AnnotationCrudMixin(ABC):
# TODO: refactor
@property
def _put_annotations_data_param(self) -> str:
...
def get_annotations(self: _EntityT) -> models.ILabeledData:
(annotations, _) = self.api.retrieve_annotations(getattr(self, self._model_id_field))
return annotations
def set_annotations(self: _EntityT, data: models.ILabeledDataRequest):
self.api.update_annotations(
getattr(self, self._model_id_field), **{self._put_annotations_data_param: data}
)
def update_annotations(
self: _EntityT,
data: models.IPatchedLabeledDataRequest,
*,
action: AnnotationUpdateAction = AnnotationUpdateAction.UPDATE,
):
self.api.partial_update_annotations(
action=action.value,
id=getattr(self, self._model_id_field),
patched_labeled_data_request=data,
)
def remove_annotations(self: _EntityT, *, ids: Optional[Sequence[int]] = None):
if ids:
anns = self.get_annotations()
if not isinstance(ids, set):
ids = set(ids)
anns_to_remove = models.PatchedLabeledDataRequest(
tags=[models.LabeledImageRequest(**a.to_dict()) for a in anns.tags if a.id in ids],
tracks=[
models.LabeledTrackRequest(**a.to_dict()) for a in anns.tracks if a.id in ids
],
shapes=[
models.LabeledShapeRequest(**a.to_dict()) for a in anns.shapes if a.id in ids
],
)
self.update_annotations(anns_to_remove, action=AnnotationUpdateAction.DELETE)
else:
self.api.destroy_annotations(getattr(self, self._model_id_field))

@ -0,0 +1,60 @@
# Copyright (C) 2022 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
from __future__ import annotations
from cvat_sdk.api_client import apis, models
from cvat_sdk.core.proxies.model_proxy import (
ModelCreateMixin,
ModelDeleteMixin,
ModelListMixin,
ModelRetrieveMixin,
ModelUpdateMixin,
build_model_bases,
)
_CommentEntityBase, _CommentRepoBase = build_model_bases(
models.CommentRead, apis.CommentsApi, api_member_name="comments_api"
)
class Comment(
models.ICommentRead,
_CommentEntityBase,
ModelUpdateMixin[models.IPatchedCommentWriteRequest],
ModelDeleteMixin,
):
_model_partial_update_arg = "patched_comment_write_request"
class CommentsRepo(
_CommentRepoBase,
ModelListMixin[Comment],
ModelCreateMixin[Comment, models.ICommentWriteRequest],
ModelRetrieveMixin[Comment],
):
_entity_type = Comment
_IssueEntityBase, _IssueRepoBase = build_model_bases(
models.IssueRead, apis.IssuesApi, api_member_name="issues_api"
)
class Issue(
models.IIssueRead,
_IssueEntityBase,
ModelUpdateMixin[models.IPatchedIssueWriteRequest],
ModelDeleteMixin,
):
_model_partial_update_arg = "patched_issue_write_request"
class IssuesRepo(
_IssueRepoBase,
ModelListMixin[Issue],
ModelCreateMixin[Issue, models.IIssueWriteRequest],
ModelRetrieveMixin[Issue],
):
_entity_type = Issue

@ -0,0 +1,166 @@
# Copyright (C) 2022 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
from __future__ import annotations
import io
import mimetypes
import os
import os.path as osp
from typing import List, Optional, Sequence
from PIL import Image
from cvat_sdk.api_client import apis, models
from cvat_sdk.core.downloading import Downloader
from cvat_sdk.core.helpers import get_paginated_collection
from cvat_sdk.core.progress import ProgressReporter
from cvat_sdk.core.proxies.annotations import AnnotationCrudMixin
from cvat_sdk.core.proxies.issues import Issue
from cvat_sdk.core.proxies.model_proxy import (
ModelListMixin,
ModelRetrieveMixin,
ModelUpdateMixin,
build_model_bases,
)
from cvat_sdk.core.uploading import AnnotationUploader
_JobEntityBase, _JobRepoBase = build_model_bases(
models.JobRead, apis.JobsApi, api_member_name="jobs_api"
)
class Job(
models.IJobRead,
_JobEntityBase,
ModelUpdateMixin[models.IPatchedJobWriteRequest],
AnnotationCrudMixin,
):
_model_partial_update_arg = "patched_job_write_request"
_put_annotations_data_param = "job_annotations_update_request"
def import_annotations(
self,
format_name: str,
filename: str,
*,
status_check_period: Optional[int] = None,
pbar: Optional[ProgressReporter] = None,
):
"""
Upload annotations for a job in the specified format (e.g. 'YOLO ZIP 1.0').
"""
AnnotationUploader(self._client).upload_file_and_wait(
self.api.create_annotations_endpoint,
filename,
format_name,
url_params={"id": self.id},
pbar=pbar,
status_check_period=status_check_period,
)
self._client.logger.info(f"Annotation file '{filename}' for job #{self.id} uploaded")
def export_dataset(
self,
format_name: str,
filename: str,
*,
pbar: Optional[ProgressReporter] = None,
status_check_period: Optional[int] = None,
include_images: bool = True,
) -> None:
"""
Download annotations for a job in the specified format (e.g. 'YOLO ZIP 1.0').
"""
if include_images:
endpoint = self.api.retrieve_dataset_endpoint
else:
endpoint = self.api.retrieve_annotations_endpoint
Downloader(self._client).prepare_and_download_file_from_endpoint(
endpoint=endpoint,
filename=filename,
url_params={"id": self.id},
query_params={"format": format_name},
pbar=pbar,
status_check_period=status_check_period,
)
self._client.logger.info(f"Dataset for job {self.id} has been downloaded to {filename}")
def get_frame(
self,
frame_id: int,
*,
quality: Optional[str] = None,
) -> io.RawIOBase:
(_, response) = self.api.retrieve_data(
self.id, number=frame_id, quality=quality, type="frame"
)
return io.BytesIO(response.data)
def get_preview(
self,
) -> io.RawIOBase:
(_, response) = self.api.retrieve_data(self.id, type="preview")
return io.BytesIO(response.data)
def download_frames(
self,
frame_ids: Sequence[int],
*,
outdir: str = "",
quality: str = "original",
filename_pattern: str = "frame_{frame_id:06d}{frame_ext}",
) -> Optional[List[Image.Image]]:
"""
Download the requested frame numbers for a job and save images as outdir/filename_pattern
"""
# TODO: add arg descriptions in schema
os.makedirs(outdir, exist_ok=True)
for frame_id in frame_ids:
frame_bytes = self.get_frame(frame_id, quality=quality)
im = Image.open(frame_bytes)
mime_type = im.get_format_mimetype() or "image/jpg"
im_ext = mimetypes.guess_extension(mime_type)
# FIXME It is better to use meta information from the server
# to determine the extension
# replace '.jpe' or '.jpeg' with a more used '.jpg'
if im_ext in (".jpe", ".jpeg", None):
im_ext = ".jpg"
outfile = filename_pattern.format(frame_id=frame_id, frame_ext=im_ext)
im.save(osp.join(outdir, outfile))
def get_meta(self) -> models.IDataMetaRead:
(meta, _) = self.api.retrieve_data_meta(self.id)
return meta
def get_frames_info(self) -> List[models.IFrameMeta]:
return self.get_meta().frames
def remove_frames_by_ids(self, ids: Sequence[int]) -> None:
self._client.api.tasks_api.jobs_partial_update_data_meta(
self.id,
patched_data_meta_write_request=models.PatchedDataMetaWriteRequest(deleted_frames=ids),
)
def get_issues(self) -> List[Issue]:
return [Issue(self._client, m) for m in self.api.list_issues(id=self.id)[0]]
def get_commits(self) -> List[models.IJobCommit]:
return get_paginated_collection(self.api.list_commits_endpoint, id=self.id)
class JobsRepo(
_JobRepoBase,
ModelListMixin[Job],
ModelRetrieveMixin[Job],
):
_entity_type = Job

@ -0,0 +1,213 @@
# Copyright (C) 2022 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
from __future__ import annotations
import json
from abc import ABC
from copy import deepcopy
from typing import (
TYPE_CHECKING,
Any,
Dict,
Generic,
List,
Literal,
Optional,
Tuple,
Type,
TypeVar,
Union,
overload,
)
from typing_extensions import Self
from cvat_sdk.api_client.model_utils import IModelData, ModelNormal, to_json
from cvat_sdk.core.helpers import get_paginated_collection
if TYPE_CHECKING:
from cvat_sdk.core.client import Client
IModel = TypeVar("IModel", bound=IModelData)
ModelType = TypeVar("ModelType", bound=ModelNormal)
ApiType = TypeVar("ApiType")
class ModelProxy(ABC, Generic[ModelType, ApiType]):
_client: Client
@property
def _api_member_name(self) -> str:
...
def __init__(self, client: Client) -> None:
self.__dict__["_client"] = client
@classmethod
def get_api(cls, client: Client) -> ApiType:
return getattr(client.api, cls._api_member_name)
@property
def api(self) -> ApiType:
return self.get_api(self._client)
class Entity(ModelProxy[ModelType, ApiType]):
"""
Represents a single object. Implements related operations and provides access to data members.
"""
_model: ModelType
def __init__(self, client: Client, model: ModelType) -> None:
super().__init__(client)
self.__dict__["_model"] = model
@property
def _model_id_field(self) -> str:
return "id"
def __getattr__(self, __name: str) -> Any:
# NOTE: be aware of potential problems with throwing AttributeError from @property
# in derived classes!
# https://medium.com/@ceshine/python-debugging-pitfall-mixed-use-of-property-and-getattr-f89e0ede13f1
return self._model[__name]
def __str__(self) -> str:
return str(self._model)
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: id={getattr(self, self._model_id_field)}>"
class Repo(ModelProxy[ModelType, ApiType]):
"""
Represents a collection of corresponding Entity objects.
Implements group and management operations for entities.
"""
_entity_type: Type[Entity[ModelType, ApiType]]
### Utilities
def build_model_bases(
mt: Type[ModelType], at: Type[ApiType], *, api_member_name: Optional[str] = None
) -> Tuple[Type[Entity[ModelType, ApiType]], Type[Repo[ModelType, ApiType]]]:
"""
Helps to remove code duplication in declarations of derived classes
"""
class _EntityBase(Entity[ModelType, ApiType]):
if api_member_name:
_api_member_name = api_member_name
class _RepoBase(Repo[ModelType, ApiType]):
if api_member_name:
_api_member_name = api_member_name
return _EntityBase, _RepoBase
### CRUD mixins
_EntityT = TypeVar("_EntityT", bound=Entity)
#### Repo mixins
class ModelCreateMixin(Generic[_EntityT, IModel]):
def create(self: Repo, spec: Union[Dict[str, Any], IModel]) -> _EntityT:
"""
Creates a new object on the server and returns corresponding local object
"""
(model, _) = self.api.create(spec)
return self._entity_type(self._client, model)
class ModelRetrieveMixin(Generic[_EntityT]):
def retrieve(self: Repo, obj_id: int) -> _EntityT:
"""
Retrieves an object from server by ID
"""
(model, _) = self.api.retrieve(id=obj_id)
return self._entity_type(self._client, model)
class ModelListMixin(Generic[_EntityT]):
@overload
def list(self: Repo, *, return_json: Literal[False] = False) -> List[_EntityT]:
...
@overload
def list(self: Repo, *, return_json: Literal[True] = False) -> List[Any]:
...
def list(self: Repo, *, return_json: bool = False) -> List[Union[_EntityT, Any]]:
"""
Retrieves all objects from the server and returns them in basic or JSON format.
"""
results = get_paginated_collection(endpoint=self.api.list_endpoint, return_json=return_json)
if return_json:
return json.dumps(results)
return [self._entity_type(self._client, model) for model in results]
#### Entity mixins
class ModelUpdateMixin(ABC, Generic[IModel]):
@property
def _model_partial_update_arg(self: Entity) -> str:
...
def _export_update_fields(
self: Entity, overrides: Optional[Union[Dict[str, Any], IModel]] = None
) -> Dict[str, Any]:
# TODO: support field conversion and assignment updating
# fields = to_json(self._model)
if isinstance(overrides, ModelNormal):
overrides = to_json(overrides)
fields = deepcopy(overrides)
return fields
def fetch(self: Entity) -> Self:
"""
Updates current object from the server
"""
# TODO: implement revision checking
(self._model, _) = self.api.retrieve(id=getattr(self, self._model_id_field))
return self
def update(self: Entity, values: Union[Dict[str, Any], IModel]) -> Self:
"""
Commits local model changes to the server
"""
# TODO: implement revision checking
self.api.partial_update(
id=getattr(self, self._model_id_field),
**{self._model_partial_update_arg: self._export_update_fields(values)},
)
# TODO: use the response model, once input and output models are same
return self.fetch()
class ModelDeleteMixin:
def remove(self: Entity) -> None:
"""
Removes current object on the server
"""
self.api.destroy(id=getattr(self, self._model_id_field))

@ -0,0 +1,187 @@
# Copyright (C) 2022 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
from __future__ import annotations
import json
import os.path as osp
from typing import Optional
from cvat_sdk.api_client import apis, models
from cvat_sdk.core.downloading import Downloader
from cvat_sdk.core.progress import ProgressReporter
from cvat_sdk.core.proxies.model_proxy import (
ModelCreateMixin,
ModelDeleteMixin,
ModelListMixin,
ModelRetrieveMixin,
ModelUpdateMixin,
build_model_bases,
)
from cvat_sdk.core.uploading import DatasetUploader, Uploader
_ProjectEntityBase, _ProjectRepoBase = build_model_bases(
models.ProjectRead, apis.ProjectsApi, api_member_name="projects_api"
)
class Project(
_ProjectEntityBase, models.IProjectRead, ModelUpdateMixin[models.IPatchedProjectWriteRequest]
):
_model_partial_update_arg = "patched_project_write_request"
def import_dataset(
self,
format_name: str,
filename: str,
*,
status_check_period: Optional[int] = None,
pbar: Optional[ProgressReporter] = None,
):
"""
Import dataset for a project in the specified format (e.g. 'YOLO ZIP 1.0').
"""
DatasetUploader(self._client).upload_file_and_wait(
self.api.create_dataset_endpoint,
filename,
format_name,
url_params={"id": self.id},
pbar=pbar,
status_check_period=status_check_period,
)
self._client.logger.info(f"Annotation file '{filename}' for project #{self.id} uploaded")
def export_dataset(
self,
format_name: str,
filename: str,
*,
pbar: Optional[ProgressReporter] = None,
status_check_period: Optional[int] = None,
include_images: bool = True,
) -> None:
"""
Download annotations for a project in the specified format (e.g. 'YOLO ZIP 1.0').
"""
if include_images:
endpoint = self.api.retrieve_dataset_endpoint
else:
endpoint = self.api.retrieve_annotations_endpoint
Downloader(self._client).prepare_and_download_file_from_endpoint(
endpoint=endpoint,
filename=filename,
url_params={"id": self.id},
query_params={"format": format_name},
pbar=pbar,
status_check_period=status_check_period,
)
self._client.logger.info(f"Dataset for project {self.id} has been downloaded to {filename}")
def download_backup(
self,
filename: str,
*,
status_check_period: int = None,
pbar: Optional[ProgressReporter] = None,
) -> None:
"""
Download a project backup
"""
Downloader(self._client).prepare_and_download_file_from_endpoint(
self.api.retrieve_backup_endpoint,
filename=filename,
pbar=pbar,
status_check_period=status_check_period,
url_params={"id": self.id},
)
self._client.logger.info(f"Backup for project {self.id} has been downloaded to {filename}")
def get_annotations(self) -> models.ILabeledData:
(annotations, _) = self.api.retrieve_annotations(self.id)
return annotations
class ProjectsRepo(
_ProjectRepoBase,
ModelCreateMixin[Project, models.IProjectWriteRequest],
ModelListMixin[Project],
ModelRetrieveMixin[Project],
ModelDeleteMixin,
):
_entity_type = Project
def create_from_dataset(
self,
spec: models.IProjectWriteRequest,
*,
dataset_path: str = "",
dataset_format: str = "CVAT XML 1.1",
status_check_period: int = None,
pbar: Optional[ProgressReporter] = None,
) -> Project:
"""
Create a new project with the given name and labels JSON and
add the files to it.
Returns: id of the created project
"""
project = self.create(spec=spec)
self._client.logger.info("Created project ID: %s NAME: %s", project.id, project.name)
if dataset_path:
project.import_dataset(
format_name=dataset_format,
filename=dataset_path,
pbar=pbar,
status_check_period=status_check_period,
)
project.fetch()
return project
def create_from_backup(
self,
filename: str,
*,
status_check_period: int = None,
pbar: Optional[ProgressReporter] = None,
) -> Project:
"""
Import a project from a backup file
"""
if status_check_period is None:
status_check_period = self.config.status_check_period
params = {"filename": osp.basename(filename)}
url = self.api_map.make_endpoint_url(self.api.create_backup_endpoint.path)
uploader = Uploader(self)
response = uploader.upload_file(
url,
filename,
meta=params,
query_params=params,
pbar=pbar,
logger=self._client.logger.debug,
)
rq_id = json.loads(response.data)["rq_id"]
response = self._client.wait_for_completion(
url,
success_status=201,
positive_statuses=[202],
post_params={"rq_id": rq_id},
status_check_period=status_check_period,
)
project_id = json.loads(response.data)["id"]
self._client.logger.info(f"Project has been imported sucessfully. Project ID: {project_id}")
return self.retrieve(project_id)

@ -0,0 +1,388 @@
# Copyright (C) 2022 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
from __future__ import annotations
import io
import json
import mimetypes
import os
import os.path as osp
from enum import Enum
from time import sleep
from typing import Any, Dict, List, Optional, Sequence
from PIL import Image
from cvat_sdk.api_client import apis, exceptions, models
from cvat_sdk.core import git
from cvat_sdk.core.downloading import Downloader
from cvat_sdk.core.progress import ProgressReporter
from cvat_sdk.core.proxies.annotations import AnnotationCrudMixin
from cvat_sdk.core.proxies.jobs import Job
from cvat_sdk.core.proxies.model_proxy import (
ModelCreateMixin,
ModelDeleteMixin,
ModelListMixin,
ModelRetrieveMixin,
ModelUpdateMixin,
build_model_bases,
)
from cvat_sdk.core.uploading import AnnotationUploader, DataUploader, Uploader
from cvat_sdk.core.utils import filter_dict
class ResourceType(Enum):
LOCAL = 0
SHARE = 1
REMOTE = 2
def __str__(self):
return self.name.lower()
def __repr__(self):
return str(self)
_TaskEntityBase, _TaskRepoBase = build_model_bases(
models.TaskRead, apis.TasksApi, api_member_name="tasks_api"
)
class Task(
_TaskEntityBase,
models.ITaskRead,
ModelUpdateMixin[models.IPatchedTaskWriteRequest],
ModelDeleteMixin,
AnnotationCrudMixin,
):
_model_partial_update_arg = "patched_task_write_request"
_put_annotations_data_param = "task_annotations_update_request"
def upload_data(
self,
resource_type: ResourceType,
resources: Sequence[str],
*,
pbar: Optional[ProgressReporter] = None,
params: Optional[Dict[str, Any]] = None,
) -> None:
"""
Add local, remote, or shared files to an existing task.
"""
params = params or {}
data = {}
if resource_type is ResourceType.LOCAL:
pass # handled later
elif resource_type is ResourceType.REMOTE:
data = {f"remote_files[{i}]": f for i, f in enumerate(resources)}
elif resource_type is ResourceType.SHARE:
data = {f"server_files[{i}]": f for i, f in enumerate(resources)}
data["image_quality"] = 70
data.update(
filter_dict(
params,
keep=[
"chunk_size",
"copy_data",
"image_quality",
"sorting_method",
"start_frame",
"stop_frame",
"use_cache",
"use_zip_chunks",
],
)
)
if params.get("frame_step") is not None:
data["frame_filter"] = f"step={params.get('frame_step')}"
if resource_type in [ResourceType.REMOTE, ResourceType.SHARE]:
self.api.create_data(
self.id,
data_request=models.DataRequest(**data),
_content_type="multipart/form-data",
)
elif resource_type == ResourceType.LOCAL:
url = self._client.api_map.make_endpoint_url(
self.api.create_data_endpoint.path, kwsub={"id": self.id}
)
DataUploader(self._client).upload_files(url, resources, pbar=pbar, **data)
def import_annotations(
self,
format_name: str,
filename: str,
*,
status_check_period: Optional[int] = None,
pbar: Optional[ProgressReporter] = None,
):
"""
Upload annotations for a task in the specified format (e.g. 'YOLO ZIP 1.0').
"""
AnnotationUploader(self._client).upload_file_and_wait(
self.api.create_annotations_endpoint,
filename,
format_name,
url_params={"id": self.id},
pbar=pbar,
status_check_period=status_check_period,
)
self._client.logger.info(f"Annotation file '{filename}' for task #{self.id} uploaded")
def get_frame(
self,
frame_id: int,
*,
quality: Optional[str] = None,
) -> io.RawIOBase:
params = {}
if quality:
params["quality"] = quality
(_, response) = self.api.retrieve_data(self.id, number=frame_id, **params, type="frame")
return io.BytesIO(response.data)
def get_preview(
self,
) -> io.RawIOBase:
(_, response) = self.api.retrieve_data(self.id, type="preview")
return io.BytesIO(response.data)
def download_frames(
self,
frame_ids: Sequence[int],
*,
outdir: str = "",
quality: str = "original",
filename_pattern: str = "frame_{frame_id:06d}{frame_ext}",
) -> Optional[List[Image.Image]]:
"""
Download the requested frame numbers for a task and save images as outdir/filename_pattern
"""
# TODO: add arg descriptions in schema
os.makedirs(outdir, exist_ok=True)
for frame_id in frame_ids:
frame_bytes = self.get_frame(frame_id, quality=quality)
im = Image.open(frame_bytes)
mime_type = im.get_format_mimetype() or "image/jpg"
im_ext = mimetypes.guess_extension(mime_type)
# FIXME It is better to use meta information from the server
# to determine the extension
# replace '.jpe' or '.jpeg' with a more used '.jpg'
if im_ext in (".jpe", ".jpeg", None):
im_ext = ".jpg"
outfile = filename_pattern.format(frame_id=frame_id, frame_ext=im_ext)
im.save(osp.join(outdir, outfile))
def export_dataset(
self,
format_name: str,
filename: str,
*,
pbar: Optional[ProgressReporter] = None,
status_check_period: Optional[int] = None,
include_images: bool = True,
) -> None:
"""
Download annotations for a task in the specified format (e.g. 'YOLO ZIP 1.0').
"""
if include_images:
endpoint = self.api.retrieve_dataset_endpoint
else:
endpoint = self.api.retrieve_annotations_endpoint
Downloader(self._client).prepare_and_download_file_from_endpoint(
endpoint=endpoint,
filename=filename,
url_params={"id": self.id},
query_params={"format": format_name},
pbar=pbar,
status_check_period=status_check_period,
)
self._client.logger.info(f"Dataset for task {self.id} has been downloaded to {filename}")
def download_backup(
self,
filename: str,
*,
status_check_period: int = None,
pbar: Optional[ProgressReporter] = None,
) -> None:
"""
Download a task backup
"""
Downloader(self._client).prepare_and_download_file_from_endpoint(
self.api.retrieve_backup_endpoint,
filename=filename,
pbar=pbar,
status_check_period=status_check_period,
url_params={"id": self.id},
)
self._client.logger.info(f"Backup for task {self.id} has been downloaded to {filename}")
def get_jobs(self) -> List[Job]:
return [Job(self._client, m) for m in self.api.list_jobs(id=self.id)[0]]
def get_meta(self) -> models.IDataMetaRead:
(meta, _) = self.api.retrieve_data_meta(self.id)
return meta
def get_frames_info(self) -> List[models.IFrameMeta]:
return self.get_meta().frames
def remove_frames_by_ids(self, ids: Sequence[int]) -> None:
self.api.partial_update_data_meta(
self.id,
patched_data_meta_write_request=models.PatchedDataMetaWriteRequest(deleted_frames=ids),
)
class TasksRepo(
_TaskRepoBase,
ModelCreateMixin[Task, models.ITaskWriteRequest],
ModelRetrieveMixin[Task],
ModelListMixin[Task],
ModelDeleteMixin,
):
_entity_type = Task
def create_from_data(
self,
spec: models.ITaskWriteRequest,
resource_type: ResourceType,
resources: Sequence[str],
*,
data_params: Optional[Dict[str, Any]] = None,
annotation_path: str = "",
annotation_format: str = "CVAT XML 1.1",
status_check_period: int = None,
dataset_repository_url: str = "",
use_lfs: bool = False,
pbar: Optional[ProgressReporter] = None,
) -> Task:
"""
Create a new task with the given name and labels JSON and
add the files to it.
Returns: id of the created task
"""
if status_check_period is None:
status_check_period = self._client.config.status_check_period
if getattr(spec, "project_id", None) and getattr(spec, "labels", None):
raise exceptions.ApiValueError(
"Can't set labels to a task inside a project. "
"Tasks inside a project use project's labels.",
["labels"],
)
task = self.create(spec=spec)
self._client.logger.info("Created task ID: %s NAME: %s", task.id, task.name)
task.upload_data(resource_type, resources, pbar=pbar, params=data_params)
self._client.logger.info("Awaiting for task %s creation...", task.id)
status: models.RqStatus = None
while status != models.RqStatusStateEnum.allowed_values[("value",)]["FINISHED"]:
sleep(status_check_period)
(status, response) = self.api.retrieve_status(task.id)
self._client.logger.info(
"Task %s creation status=%s, message=%s",
task.id,
status.state.value,
status.message,
)
if status.state.value == models.RqStatusStateEnum.allowed_values[("value",)]["FAILED"]:
raise exceptions.ApiException(
status=status.state.value, reason=status.message, http_resp=response
)
status = status.state.value
if annotation_path:
task.import_annotations(annotation_format, annotation_path, pbar=pbar)
if dataset_repository_url:
git.create_git_repo(
self,
task_id=task.id,
repo_url=dataset_repository_url,
status_check_period=status_check_period,
use_lfs=use_lfs,
)
task.fetch()
return task
def remove_by_ids(self, task_ids: Sequence[int]) -> None:
"""
Delete a list of tasks, ignoring those which don't exist.
"""
for task_id in task_ids:
(_, response) = self.api.destroy(task_id, _check_status=False)
if 200 <= response.status <= 299:
self._client.logger.info(f"Task ID {task_id} deleted")
elif response.status == 404:
self._client.logger.info(f"Task ID {task_id} not found")
else:
self._client.logger.warning(
f"Failed to delete task ID {task_id}: "
f"{response.msg} (status {response.status})"
)
def create_from_backup(
self,
filename: str,
*,
status_check_period: int = None,
pbar: Optional[ProgressReporter] = None,
) -> Task:
"""
Import a task from a backup file
"""
if status_check_period is None:
status_check_period = self._client.config.status_check_period
params = {"filename": osp.basename(filename)}
url = self._client.api_map.make_endpoint_url(self.api.create_backup_endpoint.path)
uploader = Uploader(self._client)
response = uploader.upload_file(
url,
filename,
meta=params,
query_params=params,
pbar=pbar,
logger=self._client.logger.debug,
)
rq_id = json.loads(response.data)["rq_id"]
response = self._client.wait_for_completion(
url,
success_status=201,
positive_statuses=[202],
post_params={"rq_id": rq_id},
status_check_period=status_check_period,
)
task_id = json.loads(response.data)["id"]
self._client.logger.info(f"Task has been imported sucessfully. Task ID: {task_id}")
return self.retrieve(task_id)

@ -0,0 +1,35 @@
# Copyright (C) 2022 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
from __future__ import annotations
from cvat_sdk.api_client import apis, models
from cvat_sdk.core.proxies.model_proxy import (
ModelDeleteMixin,
ModelListMixin,
ModelRetrieveMixin,
ModelUpdateMixin,
build_model_bases,
)
_UserEntityBase, _UserRepoBase = build_model_bases(
models.User, apis.UsersApi, api_member_name="users_api"
)
class User(
models.IUser, _UserEntityBase, ModelUpdateMixin[models.IPatchedUserRequest], ModelDeleteMixin
):
_model_partial_update_arg = "patched_user_request"
class UsersRepo(
_UserRepoBase,
ModelListMixin[User],
ModelRetrieveMixin[User],
):
_entity_type = User
def retrieve_current_user(self) -> User:
return User(self._client, self.api.retrieve_self()[0])

@ -1,308 +0,0 @@
# Copyright (C) 2020-2022 Intel Corporation
# Copyright (C) 2022 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
from __future__ import annotations
import io
import mimetypes
import os
import os.path as osp
from abc import ABC, abstractmethod
from io import BytesIO
from time import sleep
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence
from PIL import Image
from cvat_sdk import models
from cvat_sdk.api_client.model_utils import OpenApiModel
from cvat_sdk.core.downloading import Downloader
from cvat_sdk.core.progress import ProgressReporter
from cvat_sdk.core.types import ResourceType
from cvat_sdk.core.uploading import Uploader
from cvat_sdk.core.utils import filter_dict
if TYPE_CHECKING:
from cvat_sdk.core.client import Client
class ModelProxy(ABC):
_client: Client
_model: OpenApiModel
def __init__(self, client: Client, model: OpenApiModel) -> None:
self.__dict__["_client"] = client
self.__dict__["_model"] = model
def __getattr__(self, __name: str) -> Any:
return self._model[__name]
def __setattr__(self, __name: str, __value: Any) -> None:
if __name in self.__dict__:
self.__dict__[__name] = __value
else:
self._model[__name] = __value
@abstractmethod
def fetch(self, force: bool = False):
"""Fetches model data from the server"""
...
@abstractmethod
def commit(self, force: bool = False):
"""Commits local changes to the server"""
...
def sync(self):
"""Pulls server state and commits local model changes"""
raise NotImplementedError
@abstractmethod
def update(self, **kwargs):
"""Updates multiple fields at once"""
...
class TaskProxy(ModelProxy, models.ITaskRead):
def __init__(self, client: Client, task: models.TaskRead):
ModelProxy.__init__(self, client=client, model=task)
def remove(self):
self._client.api.tasks_api.destroy(self.id)
def upload_data(
self,
resource_type: ResourceType,
resources: Sequence[str],
*,
pbar: Optional[ProgressReporter] = None,
params: Optional[Dict[str, Any]] = None,
) -> None:
"""
Add local, remote, or shared files to an existing task.
"""
client = self._client
task_id = self.id
params = params or {}
data = {}
if resource_type is ResourceType.LOCAL:
pass # handled later
elif resource_type is ResourceType.REMOTE:
data = {f"remote_files[{i}]": f for i, f in enumerate(resources)}
elif resource_type is ResourceType.SHARE:
data = {f"server_files[{i}]": f for i, f in enumerate(resources)}
data["image_quality"] = 70
data.update(
filter_dict(
params,
keep=[
"chunk_size",
"copy_data",
"image_quality",
"sorting_method",
"start_frame",
"stop_frame",
"use_cache",
"use_zip_chunks",
],
)
)
if params.get("frame_step") is not None:
data["frame_filter"] = f"step={params.get('frame_step')}"
if resource_type in [ResourceType.REMOTE, ResourceType.SHARE]:
client.api.tasks_api.create_data(
task_id,
data_request=models.DataRequest(**data),
_content_type="multipart/form-data",
)
elif resource_type == ResourceType.LOCAL:
url = client._api_map.make_endpoint_url(
client.api.tasks_api.create_data_endpoint.path, kwsub={"id": task_id}
)
uploader = Uploader(client)
uploader.upload_files(url, resources, pbar=pbar, **data)
def import_annotations(
self,
format_name: str,
filename: str,
*,
status_check_period: int = None,
pbar: Optional[ProgressReporter] = None,
):
"""
Upload annotations for a task in the specified format
(e.g. 'YOLO ZIP 1.0').
"""
client = self._client
if status_check_period is None:
status_check_period = client.config.status_check_period
task_id = self.id
url = client._api_map.make_endpoint_url(
client.api.tasks_api.create_annotations_endpoint.path,
kwsub={"id": task_id},
)
params = {"format": format_name, "filename": osp.basename(filename)}
uploader = Uploader(client)
uploader.upload_file(
url, filename, pbar=pbar, query_params=params, meta={"filename": params["filename"]}
)
while True:
response = client.api.rest_client.POST(
url, headers=client.api.get_common_headers(), query_params=params
)
if response.status == 201:
break
sleep(status_check_period)
client.logger.info(
f"Upload job for Task ID {task_id} with annotation file {filename} finished"
)
def retrieve_frame(
self,
frame_id: int,
*,
quality: Optional[str] = None,
) -> io.RawIOBase:
client = self._client
task_id = self.id
(_, response) = client.api.tasks_api.retrieve_data(task_id, frame_id, quality, type="frame")
return BytesIO(response.data)
def download_frames(
self,
frame_ids: Sequence[int],
*,
outdir: str = "",
quality: str = "original",
filename_pattern: str = "task_{task_id}_frame_{frame_id:06d}{frame_ext}",
) -> Optional[List[Image.Image]]:
"""
Download the requested frame numbers for a task and save images as
outdir/filename_pattern
"""
# TODO: add arg descriptions in schema
task_id = self.id
os.makedirs(outdir, exist_ok=True)
for frame_id in frame_ids:
frame_bytes = self.retrieve_frame(frame_id, quality=quality)
im = Image.open(frame_bytes)
mime_type = im.get_format_mimetype() or "image/jpg"
im_ext = mimetypes.guess_extension(mime_type)
# FIXME It is better to use meta information from the server
# to determine the extension
# replace '.jpe' or '.jpeg' with a more used '.jpg'
if im_ext in (".jpe", ".jpeg", None):
im_ext = ".jpg"
outfile = filename_pattern.format(task_id=task_id, frame_id=frame_id, frame_ext=im_ext)
im.save(osp.join(outdir, outfile))
def export_dataset(
self,
format_name: str,
filename: str,
*,
pbar: Optional[ProgressReporter] = None,
status_check_period: int = None,
include_images: bool = True,
) -> None:
"""
Download annotations for a task in the specified format (e.g. 'YOLO ZIP 1.0').
"""
client = self._client
if status_check_period is None:
status_check_period = client.config.status_check_period
task_id = self.id
params = {"filename": self.name, "format": format_name}
if include_images:
endpoint = client.api.tasks_api.retrieve_dataset_endpoint
else:
endpoint = client.api.tasks_api.retrieve_annotations_endpoint
client.logger.info("Waiting for the server to prepare the file...")
while True:
(_, response) = endpoint.call_with_http_info(id=task_id, **params)
client.logger.debug("STATUS {}".format(response.status))
if response.status == 201:
break
sleep(status_check_period)
params["action"] = "download"
url = client._api_map.make_endpoint_url(
endpoint.path, kwsub={"id": task_id}, query_params=params
)
downloader = Downloader(client)
downloader.download_file(url, output_path=filename, pbar=pbar)
client.logger.info(f"Dataset has been exported to {filename}")
def download_backup(
self,
filename: str,
*,
status_check_period: int = None,
pbar: Optional[ProgressReporter] = None,
):
"""
Download a task backup
"""
client = self._client
if status_check_period is None:
status_check_period = client.config.status_check_period
task_id = self.id
endpoint = client.api.tasks_api.retrieve_backup_endpoint
client.logger.info("Waiting for the server to prepare the file...")
while True:
(_, response) = endpoint.call_with_http_info(id=task_id)
client.logger.debug("STATUS {}".format(response.status))
if response.status == 201:
break
sleep(status_check_period)
url = client._api_map.make_endpoint_url(
endpoint.path, kwsub={"id": task_id}, query_params={"action": "download"}
)
downloader = Downloader(client)
downloader.download_file(url, output_path=filename, pbar=pbar)
client.logger.info(
f"Task {task_id} has been exported sucessfully to {osp.abspath(filename)}"
)
def fetch(self, force: bool = False):
# TODO: implement revision checking
model, _ = self._client.api.tasks_api.retrieve(self.id)
self._model = model
def commit(self, force: bool = False):
return super().commit(force)
def update(self, **kwargs):
return super().update(**kwargs)
def __str__(self) -> str:
return str(self._model)

@ -1,18 +0,0 @@
# Copyright (C) 2022 Intel Corporation
# Copyright (C) 2022 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
from enum import Enum
class ResourceType(Enum):
LOCAL = 0
SHARE = 1
REMOTE = 2
def __str__(self):
return self.name.lower()
def __repr__(self):
return str(self)

@ -7,16 +7,15 @@ from __future__ import annotations
import os import os
import os.path as osp import os.path as osp
from contextlib import ExitStack, closing from contextlib import ExitStack, closing
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
import requests import requests
import urllib3 import urllib3
from cvat_sdk.api_client import ApiClient from cvat_sdk.api_client.api_client import ApiClient, Endpoint
from cvat_sdk.api_client.rest import RESTClientObject from cvat_sdk.api_client.rest import RESTClientObject
from cvat_sdk.core.helpers import StreamWithProgress from cvat_sdk.core.helpers import StreamWithProgress, expect_status
from cvat_sdk.core.progress import ProgressReporter from cvat_sdk.core.progress import ProgressReporter
from cvat_sdk.core.utils import assert_status
if TYPE_CHECKING: if TYPE_CHECKING:
from cvat_sdk.core.client import Client from cvat_sdk.core.client import Client
@ -25,57 +24,12 @@ MAX_REQUEST_SIZE = 100 * 2**20
class Uploader: class Uploader:
def __init__(self, client: Client): """
self.client = client Implements common uploading protocols
"""
def upload_files(
self,
url: str,
resources: List[str],
*,
pbar: Optional[ProgressReporter] = None,
**kwargs,
):
bulk_file_groups, separate_files, total_size = self._split_files_by_requests(resources)
if pbar is not None:
pbar.start(total_size, desc="Uploading data")
self._tus_start_upload(url)
for group, group_size in bulk_file_groups: def __init__(self, client: Client):
with ExitStack() as es: self._client = client
files = {}
for i, filename in enumerate(group):
files[f"client_files[{i}]"] = (
filename,
es.enter_context(closing(open(filename, "rb"))).read(),
)
response = self.client.api.rest_client.POST(
url,
post_params=dict(**kwargs, **files),
headers={
"Content-Type": "multipart/form-data",
"Upload-Multiple": "",
**self.client.api.get_common_headers(),
},
)
assert_status(200, response)
if pbar is not None:
pbar.advance(group_size)
for filename in separate_files:
# TODO: check if basename produces invalid paths here, can lead to overwriting
self._upload_file_data_with_tus(
url,
filename,
meta={"filename": osp.basename(filename)},
pbar=pbar,
logger=self.client.logger.debug,
)
self._tus_finish_upload(url, fields=kwargs)
def upload_file( def upload_file(
self, self,
@ -121,6 +75,27 @@ class Uploader:
) )
return self._tus_finish_upload(url, query_params=query_params, fields=fields) return self._tus_finish_upload(url, query_params=query_params, fields=fields)
def _wait_for_completion(
self,
url: str,
*,
success_status: int,
status_check_period: Optional[int] = None,
query_params: Optional[Dict[str, Any]] = None,
post_params: Optional[Dict[str, Any]] = None,
method: str = "POST",
positive_statuses: Optional[Sequence[int]] = None,
) -> urllib3.HTTPResponse:
return self._client.wait_for_completion(
url,
success_status=success_status,
status_check_period=status_check_period,
query_params=query_params,
post_params=post_params,
method=method,
positive_statuses=positive_statuses,
)
def _split_files_by_requests( def _split_files_by_requests(
self, filenames: List[str] self, filenames: List[str]
) -> Tuple[List[Tuple[List[str], int]], List[str], int]: ) -> Tuple[List[Tuple[List[str], int]], List[str], int]:
@ -268,7 +243,7 @@ class Uploader:
input_file = StreamWithProgress(input_file, pbar, length=file_size) input_file = StreamWithProgress(input_file, pbar, length=file_size)
tus_uploader = self._make_tus_uploader( tus_uploader = self._make_tus_uploader(
self.client.api, self._client.api,
url=url.rstrip("/") + "/", url=url.rstrip("/") + "/",
metadata=meta, metadata=meta,
file_stream=input_file, file_stream=input_file,
@ -278,26 +253,131 @@ class Uploader:
tus_uploader.upload() tus_uploader.upload()
def _tus_start_upload(self, url, *, query_params=None): def _tus_start_upload(self, url, *, query_params=None):
response = self.client.api.rest_client.POST( response = self._client.api.rest_client.POST(
url, url,
query_params=query_params, query_params=query_params,
headers={ headers={
"Upload-Start": "", "Upload-Start": "",
**self.client.api.get_common_headers(), **self._client.api.get_common_headers(),
}, },
) )
assert_status(202, response) expect_status(202, response)
return response return response
def _tus_finish_upload(self, url, *, query_params=None, fields=None): def _tus_finish_upload(self, url, *, query_params=None, fields=None):
response = self.client.api.rest_client.POST( response = self._client.api.rest_client.POST(
url, url,
headers={ headers={
"Upload-Finish": "", "Upload-Finish": "",
**self.client.api.get_common_headers(), **self._client.api.get_common_headers(),
}, },
query_params=query_params, query_params=query_params,
post_params=fields, post_params=fields,
) )
assert_status(202, response) expect_status(202, response)
return response return response
class AnnotationUploader(Uploader):
def upload_file_and_wait(
self,
endpoint: Endpoint,
filename: str,
format_name: str,
*,
url_params: Optional[Dict[str, Any]] = None,
pbar: Optional[ProgressReporter] = None,
status_check_period: Optional[int] = None,
):
url = self._client.api_map.make_endpoint_url(endpoint.path, kwsub=url_params)
params = {"format": format_name, "filename": osp.basename(filename)}
self.upload_file(
url, filename, pbar=pbar, query_params=params, meta={"filename": params["filename"]}
)
self._wait_for_completion(
url,
success_status=201,
positive_statuses=[202],
status_check_period=status_check_period,
query_params=params,
method="POST",
)
class DatasetUploader(Uploader):
def upload_file_and_wait(
self,
endpoint: Endpoint,
filename: str,
format_name: str,
*,
url_params: Optional[Dict[str, Any]] = None,
pbar: Optional[ProgressReporter] = None,
status_check_period: Optional[int] = None,
):
url = self._client.api_map.make_endpoint_url(endpoint.path, kwsub=url_params)
params = {"format": format_name, "filename": osp.basename(filename)}
self.upload_file(
url, filename, pbar=pbar, query_params=params, meta={"filename": params["filename"]}
)
self._wait_for_completion(
url,
success_status=201,
positive_statuses=[202],
status_check_period=status_check_period,
query_params=params,
method="GET",
)
class DataUploader(Uploader):
def upload_files(
self,
url: str,
resources: List[str],
*,
pbar: Optional[ProgressReporter] = None,
**kwargs,
):
bulk_file_groups, separate_files, total_size = self._split_files_by_requests(resources)
if pbar is not None:
pbar.start(total_size, desc="Uploading data")
self._tus_start_upload(url)
for group, group_size in bulk_file_groups:
with ExitStack() as es:
files = {}
for i, filename in enumerate(group):
files[f"client_files[{i}]"] = (
filename,
es.enter_context(closing(open(filename, "rb"))).read(),
)
response = self._client.api.rest_client.POST(
url,
post_params=dict(**kwargs, **files),
headers={
"Content-Type": "multipart/form-data",
"Upload-Multiple": "",
**self._client.api.get_common_headers(),
},
)
expect_status(200, response)
if pbar is not None:
pbar.advance(group_size)
for filename in separate_files:
# TODO: check if basename produces invalid paths here, can lead to overwriting
self._upload_file_data_with_tus(
url,
filename,
meta={"filename": osp.basename(filename)},
pbar=pbar,
logger=self._client.logger.debug,
)
self._tus_finish_upload(url, fields=kwargs)

@ -1,4 +1,3 @@
# Copyright (C) 2022 Intel Corporation
# Copyright (C) 2022 CVAT.ai Corporation # Copyright (C) 2022 CVAT.ai Corporation
# #
# SPDX-License-Identifier: MIT # SPDX-License-Identifier: MIT
@ -7,13 +6,6 @@ from __future__ import annotations
from typing import Any, Dict, Sequence from typing import Any, Dict, Sequence
import urllib3
def assert_status(code: int, response: urllib3.HTTPResponse) -> None:
if response.status != code:
raise Exception(f"Unexpected status code received {response.status}")
def filter_dict( def filter_dict(
d: Dict[str, Any], *, keep: Sequence[str] = None, drop: Sequence[str] = None d: Dict[str, Any], *, keep: Sequence[str] = None, drop: Sequence[str] = None

@ -48,7 +48,7 @@ class Processor:
tokenized_path = tokenized_path[2:] tokenized_path = tokenized_path[2:]
prefix = tokenized_path[0] + "_" prefix = tokenized_path[0] + "_"
if new_name.startswith(prefix): if new_name.startswith(prefix) and tokenized_path[0] in operation["tags"]:
new_name = new_name[len(prefix) :] new_name = new_name[len(prefix) :]
return new_name return new_name

@ -345,6 +345,9 @@ class ApiClient(object):
""" """
if response_schema == (file_type,): if response_schema == (file_type,):
# TODO: response schema can be "oneOf" with a file option,
# this implementation does not cover this.
# handle file downloading # handle file downloading
# save response body into a tmp file and return the instance # save response body into a tmp file and return the instance
content_disposition = response.getheader("Content-Disposition") content_disposition = response.getheader("Content-Disposition")

@ -9,6 +9,7 @@ import sys # noqa: F401
from {{packageName}}.model_utils import ( # noqa: F401 from {{packageName}}.model_utils import ( # noqa: F401
ApiTypeError, ApiTypeError,
IModelData,
ModelComposed, ModelComposed,
ModelNormal, ModelNormal,
ModelSimple, ModelSimple,

@ -1,5 +1,5 @@
class I{{classname}}: class I{{classname}}(IModelData):
""" """
NOTE: This class is auto generated by OpenAPI Generator. NOTE: This class is auto generated by OpenAPI Generator.
Ref: https://openapi-generator.tech Ref: https://openapi-generator.tech

@ -1,5 +1,5 @@
class I{{classname}}: class I{{classname}}(IModelData):
""" """
NOTE: This class is auto generated by OpenAPI Generator. NOTE: This class is auto generated by OpenAPI Generator.
Ref: https://openapi-generator.tech Ref: https://openapi-generator.tech

@ -113,6 +113,11 @@ def composed_model_input_classes(cls):
return [] return []
class IModelData:
"""
The base class for model data. Declares model fields and their types for better introspection
"""
class OpenApiModel(object): class OpenApiModel(object):
"""The base class for all OpenAPIModels""" """The base class for all OpenAPIModels"""

@ -3,3 +3,4 @@
attrs >= 21.4.0 attrs >= 21.4.0
tqdm >= 4.64.0 tqdm >= 4.64.0
tuspy == 0.2.5 # have it pinned, because SDK has lots of patched TUS code tuspy == 0.2.5 # have it pinned, because SDK has lots of patched TUS code
typing_extensions >= 4.2.0

@ -1,6 +1,6 @@
{ {
"name": "cvat-ui", "name": "cvat-ui",
"version": "1.41.0", "version": "1.41.1",
"description": "CVAT single-page application", "description": "CVAT single-page application",
"main": "src/index.tsx", "main": "src/index.tsx",
"scripts": { "scripts": {

@ -70,7 +70,7 @@ function TasksPageComponent(props: Props): JSX.Element {
<Button <Button
type='link' type='link'
onClick={(): void => { onClick={(): void => {
dispatch(hideEmptyTasks(true)); dispatch(hideEmptyTasks(false));
message.destroy(); message.destroy();
}} }}
> >

@ -109,10 +109,8 @@ def update_git_repo(request, tid):
status=http.HTTPStatus.OK, status=http.HTTPStatus.OK,
) )
except Exception as ex: except Exception as ex:
try: with contextlib.suppress(Exception):
slogger.task[tid].error("error occurred during changing repository request", exc_info=True) slogger.task[tid].error("error occurred during changing repository request", exc_info=True)
except Exception:
pass
return HttpResponseBadRequest(str(ex)) return HttpResponseBadRequest(str(ex))

@ -15,7 +15,7 @@ from rest_framework.exceptions import ValidationError
class SearchFilter(filters.SearchFilter): class SearchFilter(filters.SearchFilter):
def get_search_fields(self, view, request): def get_search_fields(self, view, request):
search_fields = getattr(view, 'search_fields', []) search_fields = getattr(view, 'search_fields') or []
lookup_fields = {field:field for field in search_fields} lookup_fields = {field:field for field in search_fields}
view_lookup_fields = getattr(view, 'lookup_fields', {}) view_lookup_fields = getattr(view, 'lookup_fields', {})
keys_to_update = set(search_fields) & set(view_lookup_fields.keys()) keys_to_update = set(search_fields) & set(view_lookup_fields.keys())

@ -9,7 +9,7 @@ import uuid
from django.conf import settings from django.conf import settings
from django.core.cache import cache from django.core.cache import cache
from distutils.util import strtobool from distutils.util import strtobool
from rest_framework import status from rest_framework import status, mixins
from rest_framework.response import Response from rest_framework.response import Response
from cvat.apps.engine.models import Location from cvat.apps.engine.models import Location
@ -315,3 +315,17 @@ class SerializeMixin:
file_name = request.query_params.get("filename", "") file_name = request.query_params.get("filename", "")
return import_func(request, filename=file_name) return import_func(request, filename=file_name)
return self.upload_data(request) return self.upload_data(request)
class PartialUpdateModelMixin:
"""
Update fields of a model instance.
Almost the same as UpdateModelMixin, but has no public PUT / update() method.
"""
def perform_update(self, serializer):
mixins.UpdateModelMixin.perform_update(self, serializer=serializer)
def partial_update(self, request, *args, **kwargs):
kwargs['partial'] = True
return mixins.UpdateModelMixin.update(self, request=request, *args, **kwargs)

@ -2,12 +2,26 @@
# #
# SPDX-License-Identifier: MIT # SPDX-License-Identifier: MIT
from typing import Type
from rest_framework import serializers from rest_framework import serializers
from drf_spectacular.extensions import OpenApiSerializerExtension from drf_spectacular.extensions import OpenApiSerializerExtension
from drf_spectacular.plumbing import force_instance from drf_spectacular.plumbing import force_instance, build_basic_type
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.serializers import PolymorphicProxySerializerExtension from drf_spectacular.serializers import PolymorphicProxySerializerExtension
def _copy_serializer(
instance: serializers.Serializer,
*,
_new_type: Type[serializers.Serializer] = None,
**kwargs
) -> serializers.Serializer:
_new_type = _new_type or type(instance)
instance_kwargs = instance._kwargs
instance_kwargs['partial'] = instance.partial # this can be set separately
instance_kwargs.update(kwargs)
return _new_type(*instance._args, **instance._kwargs)
class DataSerializerExtension(OpenApiSerializerExtension): class DataSerializerExtension(OpenApiSerializerExtension):
# *FileSerializer mimics a FileField # *FileSerializer mimics a FileField
# but it is mapped as an object with a file field, which # but it is mapped as an object with a file field, which
@ -23,40 +37,106 @@ class DataSerializerExtension(OpenApiSerializerExtension):
target_class = 'cvat.apps.engine.serializers.DataSerializer' target_class = 'cvat.apps.engine.serializers.DataSerializer'
def map_serializer(self, auto_schema, direction): def map_serializer(self, auto_schema, direction):
assert isinstance(self.target_class, type) assert issubclass(self.target_class, serializers.ModelSerializer)
instance = force_instance(self.target_class) instance = self.target
assert isinstance(instance, serializers.ModelSerializer) assert isinstance(instance, serializers.ModelSerializer)
def _get_field(instance, source_name, field_name): def _get_field(
instance: serializers.ModelSerializer,
source_name: str,
field_name: str
) -> serializers.ModelField:
child_instance = force_instance(instance.fields[source_name].child) child_instance = force_instance(instance.fields[source_name].child)
assert isinstance(child_instance, serializers.ModelSerializer)
child_fields = child_instance.fields child_fields = child_instance.fields
assert child_fields.keys() == {'file'} # protect from changes assert child_fields.keys() == {'file'} # protection from implementation changes
return child_fields[field_name] return child_fields[field_name]
def _sanitize_field(field): def _sanitize_field(field: serializers.ModelField) -> serializers.ModelField:
field.source = None field.source = None
field.source_attrs = [] field.source_attrs = []
return field return field
def _make_field(source_name, field_name): def _make_field(source_name: str, field_name: str) -> serializers.ModelField:
return _sanitize_field(_get_field(instance, source_name, field_name)) return _sanitize_field(_get_field(instance, source_name, field_name))
class _Override(self.target_class): # pylint: disable=inherit-non-class class _Override(self.target_class): # pylint: disable=inherit-non-class
client_files = serializers.ListField(child=_make_field('client_files', 'file'), default=[]) client_files = serializers.ListField(
server_files = serializers.ListField(child=_make_field('server_files', 'file'), default=[]) child=_make_field('client_files', 'file'), default=[])
remote_files = serializers.ListField(child=_make_field('remote_files', 'file'), default=[]) server_files = serializers.ListField(
child=_make_field('server_files', 'file'), default=[])
remote_files = serializers.ListField(
child=_make_field('remote_files', 'file'), default=[])
return auto_schema._map_serializer(
_copy_serializer(instance, _new_type=_Override, context={'view': auto_schema.view}),
direction, bypass_extensions=False)
class WriteOnceSerializerExtension(OpenApiSerializerExtension):
"""
Enables support for cvat.apps.engine.serializers.WriteOnceMixin in drf-spectacular.
Doesn't block other extensions on the target serializer.
"""
return auto_schema._map_serializer(_Override(), direction, bypass_extensions=False) match_subclasses = True
target_class = 'cvat.apps.engine.serializers.WriteOnceMixin'
_PROCESSED_INDICATOR_NAME = 'write_once_serializer_extension_processed'
class CustomProxySerializerExtension(PolymorphicProxySerializerExtension): @classmethod
""" def _matches(cls, target) -> bool:
Allows to patch PolymorphicProxySerializer-based schema. if super()._matches(target):
# protect from recursive invocations
assert isinstance(target, serializers.Serializer)
processed = target.context.get(cls._PROCESSED_INDICATOR_NAME, False)
return not processed
return False
Override "target_component" in children classes. def map_serializer(self, auto_schema, direction):
return auto_schema._map_serializer(
_copy_serializer(self.target, context={
'view': auto_schema.view,
self._PROCESSED_INDICATOR_NAME: True
}),
direction, bypass_extensions=False)
class OpenApiTypeProxySerializerExtension(PolymorphicProxySerializerExtension):
"""
Provides support for OpenApiTypes in the PolymorphicProxySerializer list
""" """
priority = 0 # restore normal priority priority = 0 # restore normal priority
def _process_serializer(self, auto_schema, serializer, direction):
if isinstance(serializer, OpenApiTypes):
schema = build_basic_type(serializer)
return (None, schema)
else:
return super()._process_serializer(auto_schema=auto_schema,
serializer=serializer, direction=direction)
def map_serializer(self, auto_schema, direction):
""" custom handling for @extend_schema's injection of PolymorphicProxySerializer """
result = super().map_serializer(auto_schema=auto_schema, direction=direction)
if isinstance(self.target.serializers, dict):
required = OpenApiTypes.NONE not in self.target.serializers.values()
else:
required = OpenApiTypes.NONE not in self.target.serializers
if not required:
result['nullable'] = True
return result
class ComponentProxySerializerExtension(OpenApiTypeProxySerializerExtension):
"""
Allows to patch PolymorphicProxySerializer-based component schema.
Override the "target_component" field in children classes.
"""
priority = 1 # higher than in the parent class
target_component: str = '' target_component: str = ''
@classmethod @classmethod
@ -69,7 +149,7 @@ class CustomProxySerializerExtension(PolymorphicProxySerializerExtension):
return target.component_name == cls.target_component return target.component_name == cls.target_component
class AnyOfProxySerializerExtension(CustomProxySerializerExtension): class AnyOfProxySerializerExtension(ComponentProxySerializerExtension):
""" """
Replaces oneOf with anyOf in the generated schema. Useful when Replaces oneOf with anyOf in the generated schema. Useful when
no disciminator field is available, and the options are no disciminator field is available, and the options are

@ -198,7 +198,9 @@ class JobReadSerializer(serializers.ModelSerializer):
class JobWriteSerializer(serializers.ModelSerializer): class JobWriteSerializer(serializers.ModelSerializer):
assignee = serializers.IntegerField(allow_null=True, required=False) assignee = serializers.IntegerField(allow_null=True, required=False)
def to_representation(self, instance): def to_representation(self, instance):
# FIXME: deal with resquest/response separation
serializer = JobReadSerializer(instance, context=self.context) serializer = JobReadSerializer(instance, context=self.context)
return serializer.data return serializer.data
@ -307,8 +309,8 @@ class RqStatusSerializer(serializers.Serializer):
progress = serializers.FloatField(max_value=100, default=0) progress = serializers.FloatField(max_value=100, default=0)
class WriteOnceMixin: class WriteOnceMixin:
"""
"""Adds support for write once fields to serializers. Adds support for write once fields to serializers.
To use it, specify a list of fields as `write_once_fields` on the To use it, specify a list of fields as `write_once_fields` on the
serializer's Meta: serializer's Meta:
@ -329,12 +331,15 @@ class WriteOnceMixin:
# We're only interested in PATCH/PUT. # We're only interested in PATCH/PUT.
if 'update' in getattr(self.context.get('view'), 'action', ''): if 'update' in getattr(self.context.get('view'), 'action', ''):
return self._set_write_once_fields(extra_kwargs) extra_kwargs = self._set_write_once_fields(extra_kwargs)
return extra_kwargs return extra_kwargs
def _set_write_once_fields(self, extra_kwargs): def _set_write_once_fields(self, extra_kwargs):
"""Set all fields in `Meta.write_once_fields` to read_only.""" """
Set all fields in `Meta.write_once_fields` to read_only.
"""
write_once_fields = getattr(self.Meta, 'write_once_fields', None) write_once_fields = getattr(self.Meta, 'write_once_fields', None)
if not write_once_fields: if not write_once_fields:
return extra_kwargs return extra_kwargs
@ -352,7 +357,7 @@ class WriteOnceMixin:
return extra_kwargs return extra_kwargs
class DataSerializer(serializers.ModelSerializer): class DataSerializer(WriteOnceMixin, serializers.ModelSerializer):
image_quality = serializers.IntegerField(min_value=0, max_value=100) image_quality = serializers.IntegerField(min_value=0, max_value=100)
use_zip_chunks = serializers.BooleanField(default=False) use_zip_chunks = serializers.BooleanField(default=False)
client_files = ClientFileSerializer(many=True, default=[]) client_files = ClientFileSerializer(many=True, default=[])
@ -876,16 +881,16 @@ class AnnotationSerializer(serializers.Serializer):
id = serializers.IntegerField(default=None, allow_null=True) id = serializers.IntegerField(default=None, allow_null=True)
frame = serializers.IntegerField(min_value=0) frame = serializers.IntegerField(min_value=0)
label_id = serializers.IntegerField(min_value=0) label_id = serializers.IntegerField(min_value=0)
group = serializers.IntegerField(min_value=0, allow_null=True) group = serializers.IntegerField(min_value=0, allow_null=True, default=None)
source = serializers.CharField(default = 'manual') source = serializers.CharField(default='manual')
class LabeledImageSerializer(AnnotationSerializer): class LabeledImageSerializer(AnnotationSerializer):
attributes = AttributeValSerializer(many=True, attributes = AttributeValSerializer(many=True,
source="labeledimageattributeval_set") source="labeledimageattributeval_set", default=[])
class ShapeSerializer(serializers.Serializer): class ShapeSerializer(serializers.Serializer):
type = serializers.ChoiceField(choices=models.ShapeType.choices()) type = serializers.ChoiceField(choices=models.ShapeType.choices())
occluded = serializers.BooleanField() occluded = serializers.BooleanField(default=False)
outside = serializers.BooleanField(default=False, required=False) outside = serializers.BooleanField(default=False, required=False)
z_order = serializers.IntegerField(default=0) z_order = serializers.IntegerField(default=0)
rotation = serializers.FloatField(default=0, min_value=0, max_value=360) rotation = serializers.FloatField(default=0, min_value=0, max_value=360)
@ -896,7 +901,7 @@ class ShapeSerializer(serializers.Serializer):
class SubLabeledShapeSerializer(ShapeSerializer, AnnotationSerializer): class SubLabeledShapeSerializer(ShapeSerializer, AnnotationSerializer):
attributes = AttributeValSerializer(many=True, attributes = AttributeValSerializer(many=True,
source="labeledshapeattributeval_set") source="labeledshapeattributeval_set", default=[])
class LabeledShapeSerializer(SubLabeledShapeSerializer): class LabeledShapeSerializer(SubLabeledShapeSerializer):
elements = SubLabeledShapeSerializer(many=True, required=False) elements = SubLabeledShapeSerializer(many=True, required=False)
@ -905,22 +910,22 @@ class TrackedShapeSerializer(ShapeSerializer):
id = serializers.IntegerField(default=None, allow_null=True) id = serializers.IntegerField(default=None, allow_null=True)
frame = serializers.IntegerField(min_value=0) frame = serializers.IntegerField(min_value=0)
attributes = AttributeValSerializer(many=True, attributes = AttributeValSerializer(many=True,
source="trackedshapeattributeval_set") source="trackedshapeattributeval_set", default=[])
class SubLabeledTrackSerializer(AnnotationSerializer): class SubLabeledTrackSerializer(AnnotationSerializer):
shapes = TrackedShapeSerializer(many=True, allow_empty=True, shapes = TrackedShapeSerializer(many=True, allow_empty=True,
source="trackedshape_set") source="trackedshape_set")
attributes = AttributeValSerializer(many=True, attributes = AttributeValSerializer(many=True,
source="labeledtrackattributeval_set") source="labeledtrackattributeval_set", default=[])
class LabeledTrackSerializer(SubLabeledTrackSerializer): class LabeledTrackSerializer(SubLabeledTrackSerializer):
elements = SubLabeledTrackSerializer(many=True, required=False) elements = SubLabeledTrackSerializer(many=True, required=False)
class LabeledDataSerializer(serializers.Serializer): class LabeledDataSerializer(serializers.Serializer):
version = serializers.IntegerField() version = serializers.IntegerField(default=0) # TODO: remove
tags = LabeledImageSerializer(many=True) tags = LabeledImageSerializer(many=True, default=[])
shapes = LabeledShapeSerializer(many=True) shapes = LabeledShapeSerializer(many=True, default=[])
tracks = LabeledTrackSerializer(many=True) tracks = LabeledTrackSerializer(many=True, default=[])
class FileInfoSerializer(serializers.Serializer): class FileInfoSerializer(serializers.Serializer):
name = serializers.CharField(max_length=1024) name = serializers.CharField(max_length=1024)
@ -991,6 +996,10 @@ class IssueReadSerializer(serializers.ModelSerializer):
fields = ('id', 'frame', 'position', 'job', 'owner', 'assignee', fields = ('id', 'frame', 'position', 'job', 'owner', 'assignee',
'created_date', 'updated_date', 'comments', 'resolved') 'created_date', 'updated_date', 'comments', 'resolved')
read_only_fields = fields read_only_fields = fields
extra_kwargs = {
'created_date': { 'allow_null': True },
'updated_date': { 'allow_null': True },
}
class IssueWriteSerializer(WriteOnceMixin, serializers.ModelSerializer): class IssueWriteSerializer(WriteOnceMixin, serializers.ModelSerializer):
@ -1010,6 +1019,12 @@ class IssueWriteSerializer(WriteOnceMixin, serializers.ModelSerializer):
message=message, owner=db_issue.owner) message=message, owner=db_issue.owner)
return db_issue return db_issue
def update(self, instance, validated_data):
message = validated_data.pop('message', None)
if message:
raise NotImplementedError('Check https://github.com/cvat-ai/cvat/issues/122')
return super().update(instance, validated_data)
class Meta: class Meta:
model = models.Issue model = models.Issue
fields = ('id', 'frame', 'position', 'job', 'owner', 'assignee', fields = ('id', 'frame', 'position', 'job', 'owner', 'assignee',

@ -313,7 +313,7 @@ class JobGetAPITestCase(APITestCase):
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
class JobUpdateAPITestCase(APITestCase): class JobPartialUpdateAPITestCase(APITestCase):
def setUp(self): def setUp(self):
self.client = APIClient() self.client = APIClient()
self.task = create_dummy_db_tasks(self)[0] self.task = create_dummy_db_tasks(self)[0]
@ -327,7 +327,7 @@ class JobUpdateAPITestCase(APITestCase):
def _run_api_v2_jobs_id(self, jid, user, data): def _run_api_v2_jobs_id(self, jid, user, data):
with ForceLogin(user, self.client): with ForceLogin(user, self.client):
response = self.client.put('/api/jobs/{}'.format(jid), data=data, format='json') response = self.client.patch('/api/jobs/{}'.format(jid), data=data, format='json')
return response return response
@ -382,22 +382,43 @@ class JobUpdateAPITestCase(APITestCase):
response = self._run_api_v2_jobs_id(self.job.id + 10, None, data) response = self._run_api_v2_jobs_id(self.job.id + 10, None, data)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
class JobPartialUpdateAPITestCase(JobUpdateAPITestCase): def test_api_v2_jobs_id_annotator_partial(self):
data = {"stage": StageChoice.ANNOTATION}
response = self._run_api_v2_jobs_id(self.job.id, self.annotator, data)
self.assertEquals(response.status_code, status.HTTP_403_FORBIDDEN, response)
def test_api_v2_jobs_id_admin_partial(self):
data = {"assignee_id": self.user.id}
response = self._run_api_v2_jobs_id(self.job.id, self.owner, data)
self._check_request(response, data)
class JobUpdateAPITestCase(APITestCase):
def setUp(self):
self.client = APIClient()
self.task = create_dummy_db_tasks(self)[0]
self.job = Job.objects.filter(segment__task_id=self.task.id).first()
self.job.assignee = self.annotator
self.job.save()
@classmethod
def setUpTestData(cls):
create_db_users(cls)
def _run_api_v2_jobs_id(self, jid, user, data): def _run_api_v2_jobs_id(self, jid, user, data):
with ForceLogin(user, self.client): with ForceLogin(user, self.client):
response = self.client.patch('/api/jobs/{}'.format(jid), data=data, format='json') response = self.client.put('/api/jobs/{}'.format(jid), data=data, format='json')
return response return response
def test_api_v2_jobs_id_annotator_partial(self): def test_api_v2_jobs_id_annotator(self):
data = {"stage": StageChoice.ANNOTATION} data = {"stage": StageChoice.ANNOTATION}
response = self._run_api_v2_jobs_id(self.job.id, self.annotator, data) response = self._run_api_v2_jobs_id(self.job.id, self.annotator, data)
self.assertEquals(response.status_code, status.HTTP_403_FORBIDDEN, response) self.assertEquals(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED, response)
def test_api_v2_jobs_id_admin_partial(self): def test_api_v2_jobs_id_admin(self):
data = {"assignee_id": self.user.id} data = {"assignee_id": self.user.id}
response = self._run_api_v2_jobs_id(self.job.id, self.owner, data) response = self._run_api_v2_jobs_id(self.job.id, self.owner, data)
self._check_request(response, data) self.assertEquals(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED, response)
class JobDataMetaPartialUpdateAPITestCase(APITestCase): class JobDataMetaPartialUpdateAPITestCase(APITestCase):
def setUp(self): def setUp(self):
@ -1987,7 +2008,6 @@ class TaskDeleteAPITestCase(APITestCase):
self.assertFalse(os.path.exists(task_dir)) self.assertFalse(os.path.exists(task_dir))
class TaskUpdateAPITestCase(APITestCase): class TaskUpdateAPITestCase(APITestCase):
def setUp(self): def setUp(self):
self.client = APIClient() self.client = APIClient()
@ -2003,6 +2023,39 @@ class TaskUpdateAPITestCase(APITestCase):
return response return response
def _check_api_v2_tasks_id(self, user, data):
for db_task in self.tasks:
response = self._run_api_v2_tasks_id(db_task.id, user, data)
if user is None:
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
else:
self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
def test_api_v2_tasks_id_admin(self):
data = { "name": "new name for the task" }
self._check_api_v2_tasks_id(self.admin, data)
def test_api_v2_tasks_id_user(self):
data = { "name": "new name for the task" }
self._check_api_v2_tasks_id(self.user, data)
def test_api_v2_tasks_id_somebody(self):
data = { "name": "new name for the task" }
self._check_api_v2_tasks_id(self.somebody, data)
def test_api_v2_tasks_id_no_auth(self):
data = { "name": "new name for the task" }
self._check_api_v2_tasks_id(None, data)
class TaskPartialUpdateAPITestCase(APITestCase):
def setUp(self):
self.client = APIClient()
@classmethod
def setUpTestData(cls):
create_db_users(cls)
cls.tasks = create_dummy_db_tasks(cls)
def _check_response(self, response, db_task, data): def _check_response(self, response, db_task, data):
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
name = data.get("name", db_task.name) name = data.get("name", db_task.name)
@ -2034,6 +2087,13 @@ class TaskUpdateAPITestCase(APITestCase):
[label["name"] for label in response.data["labels"]] [label["name"] for label in response.data["labels"]]
) )
def _run_api_v2_tasks_id(self, tid, user, data):
with ForceLogin(user, self.client):
response = self.client.patch('/api/tasks/{}'.format(tid),
data=data, format="json")
return response
def _check_api_v2_tasks_id(self, user, data): def _check_api_v2_tasks_id(self, user, data):
for db_task in self.tasks: for db_task in self.tasks:
response = self._run_api_v2_tasks_id(db_task.id, user, data) response = self._run_api_v2_tasks_id(db_task.id, user, data)
@ -2077,32 +2137,6 @@ class TaskUpdateAPITestCase(APITestCase):
} }
self._check_api_v2_tasks_id(self.user, data) self._check_api_v2_tasks_id(self.user, data)
def test_api_v2_tasks_id_somebody(self):
data = {
"name": "new name for the task",
"labels": [{
"name": "test",
}]
}
self._check_api_v2_tasks_id(self.somebody, data)
def test_api_v2_tasks_id_no_auth(self):
data = {
"name": "new name for the task",
"labels": [{
"name": "test",
}]
}
self._check_api_v2_tasks_id(None, data)
class TaskPartialUpdateAPITestCase(TaskUpdateAPITestCase):
def _run_api_v2_tasks_id(self, tid, user, data):
with ForceLogin(user, self.client):
response = self.client.patch('/api/tasks/{}'.format(tid),
data=data, format="json")
return response
def test_api_v2_tasks_id_admin_partial(self): def test_api_v2_tasks_id_admin_partial(self):
data = { data = {
"name": "new name for the task #2", "name": "new name for the task #2",

@ -69,7 +69,7 @@ from cvat.apps.engine.serializers import (
from utils.dataset_manifest import ImageManifestManager from utils.dataset_manifest import ImageManifestManager
from cvat.apps.engine.utils import av_scan_paths from cvat.apps.engine.utils import av_scan_paths
from cvat.apps.engine import backup from cvat.apps.engine import backup
from cvat.apps.engine.mixins import UploadMixin, AnnotationMixin, SerializeMixin from cvat.apps.engine.mixins import PartialUpdateModelMixin, UploadMixin, AnnotationMixin, SerializeMixin
from . import models, task from . import models, task
from .log import clogger, slogger from .log import clogger, slogger
@ -237,9 +237,9 @@ class ServerViewSet(viewsets.ViewSet):
}), }),
create=extend_schema( create=extend_schema(
summary='Method creates a new project', summary='Method creates a new project',
request=ProjectWriteSerializer, # request=ProjectWriteSerializer,
responses={ responses={
'201': ProjectWriteSerializer, '201': ProjectReadSerializer, # check ProjectWriteSerializer.to_representation
}), }),
retrieve=extend_schema( retrieve=extend_schema(
summary='Method returns details of a specific project', summary='Method returns details of a specific project',
@ -253,12 +253,15 @@ class ServerViewSet(viewsets.ViewSet):
}), }),
partial_update=extend_schema( partial_update=extend_schema(
summary='Methods does a partial update of chosen fields in a project', summary='Methods does a partial update of chosen fields in a project',
request=ProjectWriteSerializer, # request=ProjectWriteSerializer,
responses={ responses={
'200': ProjectWriteSerializer, '200': ProjectReadSerializer, # check ProjectWriteSerializer.to_representation
}) })
) )
class ProjectViewSet(viewsets.ModelViewSet, UploadMixin, AnnotationMixin, SerializeMixin): class ProjectViewSet(viewsets.GenericViewSet, mixins.ListModelMixin,
mixins.RetrieveModelMixin, mixins.CreateModelMixin, mixins.DestroyModelMixin,
PartialUpdateModelMixin, UploadMixin, AnnotationMixin, SerializeMixin
):
queryset = models.Project.objects.prefetch_related(Prefetch('label_set', queryset = models.Project.objects.prefetch_related(Prefetch('label_set',
queryset=models.Label.objects.order_by('id') queryset=models.Label.objects.order_by('id')
)) ))
@ -270,7 +273,6 @@ class ProjectViewSet(viewsets.ModelViewSet, UploadMixin, AnnotationMixin, Serial
ordering_fields = filter_fields ordering_fields = filter_fields
ordering = "-id" ordering = "-id"
lookup_fields = {'owner': 'owner__username', 'assignee': 'assignee__username'} lookup_fields = {'owner': 'owner__username', 'assignee': 'assignee__username'}
http_method_names = ('get', 'post', 'head', 'patch', 'delete', 'options')
iam_organization_field = 'organization' iam_organization_field = 'organization'
def get_serializer_class(self): def get_serializer_class(self):
@ -335,7 +337,7 @@ class ProjectViewSet(viewsets.ModelViewSet, UploadMixin, AnnotationMixin, Serial
default=True), default=True),
], ],
responses={ responses={
'200': OpenApiResponse(description='Download of file started'), '200': OpenApiResponse(OpenApiTypes.BINARY, description='Download of file started'),
'201': OpenApiResponse(description='Output file is ready for downloading'), '201': OpenApiResponse(description='Output file is ready for downloading'),
'202': OpenApiResponse(description='Exporting has been started'), '202': OpenApiResponse(description='Exporting has been started'),
'405': OpenApiResponse(description='Format is not available'), '405': OpenApiResponse(description='Format is not available'),
@ -356,7 +358,10 @@ class ProjectViewSet(viewsets.ModelViewSet, UploadMixin, AnnotationMixin, Serial
OpenApiParameter('filename', description='Dataset file name', OpenApiParameter('filename', description='Dataset file name',
location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False), location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False),
], ],
request=DatasetFileSerializer(required=False), request=PolymorphicProxySerializer('DatasetWrite',
serializers=[DatasetFileSerializer, OpenApiTypes.NONE],
resource_type_field_name=None
),
responses={ responses={
'202': OpenApiResponse(description='Exporting has been started'), '202': OpenApiResponse(description='Exporting has been started'),
'400': OpenApiResponse(description='Failed to import dataset'), '400': OpenApiResponse(description='Failed to import dataset'),
@ -481,7 +486,11 @@ class ProjectViewSet(viewsets.ModelViewSet, UploadMixin, AnnotationMixin, Serial
default=True), default=True),
], ],
responses={ responses={
'200': OpenApiResponse(description='Download of file started'), '200': OpenApiResponse(PolymorphicProxySerializer(
component_name='AnnotationsRead',
serializers=[LabeledDataSerializer, OpenApiTypes.BINARY],
resource_type_field_name=None
), description='Download of file started'),
'201': OpenApiResponse(description='Annotations file is ready to download'), '201': OpenApiResponse(description='Annotations file is ready to download'),
'202': OpenApiResponse(description='Dump of annotations has been started'), '202': OpenApiResponse(description='Dump of annotations has been started'),
'401': OpenApiResponse(description='Format is not specified'), '401': OpenApiResponse(description='Format is not specified'),
@ -535,12 +544,16 @@ class ProjectViewSet(viewsets.ModelViewSet, UploadMixin, AnnotationMixin, Serial
OpenApiParameter('filename', description='Backup file name', OpenApiParameter('filename', description='Backup file name',
location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False), location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False),
], ],
request=ProjectFileSerializer(required=False), request=PolymorphicProxySerializer('BackupWrite',
serializers=[ProjectFileSerializer, OpenApiTypes.NONE],
resource_type_field_name=None
),
responses={ responses={
'201': OpenApiResponse(description='The project has been imported'), # or better specify {id: project_id} '201': OpenApiResponse(description='The project has been imported'), # or better specify {id: project_id}
'202': OpenApiResponse(description='Importing a backup file has been started'), '202': OpenApiResponse(description='Importing a backup file has been started'),
}) })
@action(detail=False, methods=['OPTIONS', 'POST'], url_path=r'backup/?$', serializer_class=ProjectFileSerializer(required=False)) @action(detail=False, methods=['OPTIONS', 'POST'], url_path=r'backup/?$',
serializer_class=ProjectFileSerializer(required=False))
def import_backup(self, request, pk=None): def import_backup(self, request, pk=None):
return self.deserialize(request, backup.import_project) return self.deserialize(request, backup.import_project)
@ -554,7 +567,8 @@ class ProjectViewSet(viewsets.ModelViewSet, UploadMixin, AnnotationMixin, Serial
@extend_schema(methods=['HEAD'], @extend_schema(methods=['HEAD'],
summary="Implements TUS file uploading protocol." summary="Implements TUS file uploading protocol."
) )
@action(detail=False, methods=['HEAD', 'PATCH'], url_path='backup/'+UploadMixin.file_id_regex) @action(detail=False, methods=['HEAD', 'PATCH'], url_path='backup/'+UploadMixin.file_id_regex,
serializer_class=None)
def append_backup_chunk(self, request, file_id): def append_backup_chunk(self, request, file_id):
return self.append_tus_chunk(request, file_id) return self.append_tus_chunk(request, file_id)
@ -661,16 +675,14 @@ class DataChunkGetter:
}), }),
create=extend_schema( create=extend_schema(
summary='Method creates a new task in a database without any attached images and videos', summary='Method creates a new task in a database without any attached images and videos',
request=TaskWriteSerializer,
responses={ responses={
'201': TaskWriteSerializer, '201': TaskReadSerializer, # check TaskWriteSerializer.to_representation
}), }),
retrieve=extend_schema( retrieve=extend_schema(
summary='Method returns details of a specific task', summary='Method returns details of a specific task',
responses=TaskReadSerializer),
update=extend_schema(
summary='Method updates a task by id',
responses={ responses={
'200': TaskWriteSerializer, '200': TaskReadSerializer
}), }),
destroy=extend_schema( destroy=extend_schema(
summary='Method deletes a specific task, all attached jobs, annotations, and data', summary='Method deletes a specific task, all attached jobs, annotations, and data',
@ -679,16 +691,19 @@ class DataChunkGetter:
}), }),
partial_update=extend_schema( partial_update=extend_schema(
summary='Methods does a partial update of chosen fields in a task', summary='Methods does a partial update of chosen fields in a task',
request=TaskWriteSerializer(partial=True),
responses={ responses={
'200': TaskWriteSerializer, '200': TaskReadSerializer, # check TaskWriteSerializer.to_representation
}) })
) )
class TaskViewSet(UploadMixin, AnnotationMixin, viewsets.ModelViewSet, SerializeMixin): class TaskViewSet(viewsets.GenericViewSet, mixins.ListModelMixin,
mixins.RetrieveModelMixin, mixins.CreateModelMixin, mixins.DestroyModelMixin,
PartialUpdateModelMixin, UploadMixin, AnnotationMixin, SerializeMixin
):
queryset = Task.objects.prefetch_related( queryset = Task.objects.prefetch_related(
Prefetch('label_set', queryset=models.Label.objects.order_by('id')), Prefetch('label_set', queryset=models.Label.objects.order_by('id')),
"label_set__attributespec_set", "label_set__attributespec_set",
"segment_set__job_set") "segment_set__job_set")
http_method_names = ('get', 'post', 'head', 'patch', 'delete', 'options', 'put')
lookup_fields = {'project_name': 'project__name', 'owner': 'owner__username', 'assignee': 'assignee__username'} lookup_fields = {'project_name': 'project__name', 'owner': 'owner__username', 'assignee': 'assignee__username'}
search_fields = ('project_name', 'name', 'owner', 'status', 'assignee', 'subset', 'mode', 'dimension') search_fields = ('project_name', 'name', 'owner', 'status', 'assignee', 'subset', 'mode', 'dimension')
filter_fields = list(search_fields) + ['id', 'project_id', 'updated_date'] filter_fields = list(search_fields) + ['id', 'project_id', 'updated_date']
@ -796,10 +811,11 @@ class TaskViewSet(UploadMixin, AnnotationMixin, viewsets.ModelViewSet, Serialize
db_project.save() db_project.save()
@extend_schema(summary='Method returns a list of jobs for a specific task', @extend_schema(summary='Method returns a list of jobs for a specific task',
responses={ responses=JobReadSerializer(many=True)) # Duplicate to still get 'list' op. name
'200': JobReadSerializer(many=True), @action(detail=True, methods=['GET'], serializer_class=JobReadSerializer(many=True),
}) # Remove regular list() parameters from swagger schema
@action(detail=True, methods=['GET'], serializer_class=JobReadSerializer) # https://drf-spectacular.readthedocs.io/en/latest/faq.html#my-action-is-erroneously-paginated-or-has-filter-parameters-that-i-do-not-want
pagination_class=None, filter_fields=None, search_fields=None, ordering_fields=None)
def jobs(self, request, pk): def jobs(self, request, pk):
self.get_object() # force to call check_object_permissions self.get_object() # force to call check_object_permissions
queryset = Job.objects.filter(segment__task_id=pk) queryset = Job.objects.filter(segment__task_id=pk)
@ -899,13 +915,13 @@ class TaskViewSet(UploadMixin, AnnotationMixin, viewsets.ModelViewSet, Serialize
}) })
@extend_schema(methods=['GET'], summary='Method returns data for a specific task', @extend_schema(methods=['GET'], summary='Method returns data for a specific task',
parameters=[ parameters=[
OpenApiParameter('type', location=OpenApiParameter.QUERY, required=True, OpenApiParameter('type', location=OpenApiParameter.QUERY, required=False,
type=OpenApiTypes.STR, enum=['chunk', 'frame', 'preview', 'context_image'], type=OpenApiTypes.STR, enum=['chunk', 'frame', 'preview', 'context_image'],
description='Specifies the type of the requested data'), description='Specifies the type of the requested data'),
OpenApiParameter('quality', location=OpenApiParameter.QUERY, required=True, OpenApiParameter('quality', location=OpenApiParameter.QUERY, required=False,
type=OpenApiTypes.STR, enum=['compressed', 'original'], type=OpenApiTypes.STR, enum=['compressed', 'original'],
description="Specifies the quality level of the requested data, doesn't matter for 'preview' type"), description="Specifies the quality level of the requested data, doesn't matter for 'preview' type"),
OpenApiParameter('number', location=OpenApiParameter.QUERY, required=True, type=OpenApiTypes.INT, OpenApiParameter('number', location=OpenApiParameter.QUERY, required=False, type=OpenApiTypes.INT,
description="A unique number value identifying chunk or frame, doesn't matter for 'preview' type"), description="A unique number value identifying chunk or frame, doesn't matter for 'preview' type"),
], ],
responses={ responses={
@ -971,7 +987,11 @@ class TaskViewSet(UploadMixin, AnnotationMixin, viewsets.ModelViewSet, Serialize
default=True), default=True),
], ],
responses={ responses={
'200': OpenApiResponse(description='Download of file started'), '200': OpenApiResponse(PolymorphicProxySerializer(
component_name='AnnotationsRead',
serializers=[LabeledDataSerializer, OpenApiTypes.BINARY],
resource_type_field_name=None
), description='Download of file started'),
'201': OpenApiResponse(description='Annotations file is ready to download'), '201': OpenApiResponse(description='Annotations file is ready to download'),
'202': OpenApiResponse(description='Dump of annotations has been started'), '202': OpenApiResponse(description='Dump of annotations has been started'),
'405': OpenApiResponse(description='Format is not available'), '405': OpenApiResponse(description='Format is not available'),
@ -981,12 +1001,17 @@ class TaskViewSet(UploadMixin, AnnotationMixin, viewsets.ModelViewSet, Serialize
OpenApiParameter('format', location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False, OpenApiParameter('format', location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False,
description='Input format name\nYou can get the list of supported formats at:\n/server/annotation/formats'), description='Input format name\nYou can get the list of supported formats at:\n/server/annotation/formats'),
], ],
request=PolymorphicProxySerializer('TaskAnnotationsUpdate',
serializers=[LabeledDataSerializer, AnnotationFileSerializer, OpenApiTypes.NONE],
resource_type_field_name=None
),
responses={ responses={
'201': OpenApiResponse(description='Uploading has finished'), '201': OpenApiResponse(description='Uploading has finished'),
'202': OpenApiResponse(description='Uploading has been started'), '202': OpenApiResponse(description='Uploading has been started'),
'405': OpenApiResponse(description='Format is not available'), '405': OpenApiResponse(description='Format is not available'),
}) })
@extend_schema(methods=['POST'], summary='Method allows to upload task annotations from storage', @extend_schema(methods=['POST'],
summary="Method allows to upload task annotations from a local file or a cloud storage",
parameters=[ parameters=[
OpenApiParameter('format', location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False, OpenApiParameter('format', location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False,
description='Input format name\nYou can get the list of supported formats at:\n/server/annotation/formats'), description='Input format name\nYou can get the list of supported formats at:\n/server/annotation/formats'),
@ -1001,6 +1026,10 @@ class TaskViewSet(UploadMixin, AnnotationMixin, viewsets.ModelViewSet, Serialize
OpenApiParameter('filename', description='Annotation file name', OpenApiParameter('filename', description='Annotation file name',
location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False), location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False),
], ],
request=PolymorphicProxySerializer('TaskAnnotationsWrite',
serializers=[AnnotationFileSerializer, OpenApiTypes.NONE],
resource_type_field_name=None
),
responses={ responses={
'201': OpenApiResponse(description='Uploading has finished'), '201': OpenApiResponse(description='Uploading has finished'),
'202': OpenApiResponse(description='Uploading has been started'), '202': OpenApiResponse(description='Uploading has been started'),
@ -1010,13 +1039,17 @@ class TaskViewSet(UploadMixin, AnnotationMixin, viewsets.ModelViewSet, Serialize
parameters=[ parameters=[
OpenApiParameter('action', location=OpenApiParameter.QUERY, required=True, OpenApiParameter('action', location=OpenApiParameter.QUERY, required=True,
type=OpenApiTypes.STR, enum=['create', 'update', 'delete']), type=OpenApiTypes.STR, enum=['create', 'update', 'delete']),
]) ],
request=LabeledDataSerializer,
responses={
'200': LabeledDataSerializer,
})
@extend_schema(methods=['DELETE'], summary='Method deletes all annotations for a specific task', @extend_schema(methods=['DELETE'], summary='Method deletes all annotations for a specific task',
responses={ responses={
'204': OpenApiResponse(description='The annotation has been deleted'), '204': OpenApiResponse(description='The annotation has been deleted'),
}) })
@action(detail=True, methods=['GET', 'DELETE', 'PUT', 'PATCH', 'POST', 'OPTIONS'], url_path=r'annotations/?$', @action(detail=True, methods=['GET', 'DELETE', 'PUT', 'PATCH', 'POST', 'OPTIONS'], url_path=r'annotations/?$',
serializer_class=LabeledDataSerializer(required=False)) serializer_class=None)
def annotations(self, request, pk): def annotations(self, request, pk):
self._object = self.get_object() # force to call check_object_permissions self._object = self.get_object() # force to call check_object_permissions
if request.method == 'GET': if request.method == 'GET':
@ -1122,6 +1155,7 @@ class TaskViewSet(UploadMixin, AnnotationMixin, viewsets.ModelViewSet, Serialize
'200': DataMetaReadSerializer, '200': DataMetaReadSerializer,
}) })
@extend_schema(methods=['PATCH'], summary='Method performs an update of data meta fields (deleted frames)', @extend_schema(methods=['PATCH'], summary='Method performs an update of data meta fields (deleted frames)',
request=DataMetaWriteSerializer,
responses={ responses={
'200': DataMetaReadSerializer, '200': DataMetaReadSerializer,
}) })
@ -1178,7 +1212,7 @@ class TaskViewSet(UploadMixin, AnnotationMixin, viewsets.ModelViewSet, Serialize
location=OpenApiParameter.QUERY, type=OpenApiTypes.NUMBER, required=False), location=OpenApiParameter.QUERY, type=OpenApiTypes.NUMBER, required=False),
], ],
responses={ responses={
'200': OpenApiResponse(description='Download of file started'), '200': OpenApiResponse(OpenApiTypes.BINARY, description='Download of file started'),
'201': OpenApiResponse(description='Output file is ready for downloading'), '201': OpenApiResponse(description='Output file is ready for downloading'),
'202': OpenApiResponse(description='Exporting has been started'), '202': OpenApiResponse(description='Exporting has been started'),
'405': OpenApiResponse(description='Format is not available'), '405': OpenApiResponse(description='Format is not available'),
@ -1208,19 +1242,16 @@ class TaskViewSet(UploadMixin, AnnotationMixin, viewsets.ModelViewSet, Serialize
responses={ responses={
'200': JobReadSerializer(many=True), '200': JobReadSerializer(many=True),
}), }),
update=extend_schema(
summary='Method updates a job by id',
responses={
'200': JobWriteSerializer,
}),
partial_update=extend_schema( partial_update=extend_schema(
summary='Methods does a partial update of chosen fields in a job', summary='Methods does a partial update of chosen fields in a job',
request=JobWriteSerializer,
responses={ responses={
'200': JobWriteSerializer, '200': JobReadSerializer, # check JobWriteSerializer.to_representation
}) })
) )
class JobViewSet(viewsets.GenericViewSet, mixins.ListModelMixin, class JobViewSet(viewsets.GenericViewSet, mixins.ListModelMixin,
mixins.RetrieveModelMixin, mixins.UpdateModelMixin, UploadMixin, AnnotationMixin): mixins.RetrieveModelMixin, PartialUpdateModelMixin, UploadMixin, AnnotationMixin
):
queryset = Job.objects.all() queryset = Job.objects.all()
iam_organization_field = 'segment__task__organization' iam_organization_field = 'segment__task__organization'
search_fields = ('task_name', 'project_name', 'assignee', 'state', 'stage') search_fields = ('task_name', 'project_name', 'assignee', 'state', 'stage')
@ -1279,7 +1310,9 @@ class JobViewSet(viewsets.GenericViewSet, mixins.ListModelMixin,
return Response(data='Unknown upload was finished', return Response(data='Unknown upload was finished',
status=status.HTTP_400_BAD_REQUEST) status=status.HTTP_400_BAD_REQUEST)
@extend_schema(methods=['GET'], summary='Method returns annotations for a specific job', @extend_schema(methods=['GET'],
summary="Method returns annotations for a specific job as a JSON document. "
"If format is specified, a zip archive is returned.",
parameters=[ parameters=[
OpenApiParameter('format', location=OpenApiParameter.QUERY, OpenApiParameter('format', location=OpenApiParameter.QUERY,
description='Desired output format name\nYou can get the list of supported formats at:\n/server/annotation/formats', description='Desired output format name\nYou can get the list of supported formats at:\n/server/annotation/formats',
@ -1299,7 +1332,11 @@ class JobViewSet(viewsets.GenericViewSet, mixins.ListModelMixin,
default=True), default=True),
], ],
responses={ responses={
'200': LabeledDataSerializer, '200': OpenApiResponse(PolymorphicProxySerializer(
component_name='AnnotationsRead',
serializers=[LabeledDataSerializer, OpenApiTypes.BINARY],
resource_type_field_name=None
), description='Download of file started'),
'201': OpenApiResponse(description='Output file is ready for downloading'), '201': OpenApiResponse(description='Output file is ready for downloading'),
'202': OpenApiResponse(description='Exporting has been started'), '202': OpenApiResponse(description='Exporting has been started'),
'405': OpenApiResponse(description='Format is not available'), '405': OpenApiResponse(description='Format is not available'),
@ -1319,13 +1356,23 @@ class JobViewSet(viewsets.GenericViewSet, mixins.ListModelMixin,
OpenApiParameter('filename', description='Annotation file name', OpenApiParameter('filename', description='Annotation file name',
location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False), location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False),
], ],
request=AnnotationFileSerializer,
responses={ responses={
'201': OpenApiResponse(description='Uploading has finished'), '201': OpenApiResponse(description='Uploading has finished'),
'202': OpenApiResponse(description='Uploading has been started'), '202': OpenApiResponse(description='Uploading has been started'),
'405': OpenApiResponse(description='Format is not available'), '405': OpenApiResponse(description='Format is not available'),
}) })
@extend_schema(methods=['PUT'], summary='Method performs an update of all annotations in a specific job', @extend_schema(methods=['PUT'], summary='Method performs an update of all annotations in a specific job',
request=AnnotationFileSerializer, responses={ parameters=[
OpenApiParameter('format', location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False,
description='Input format name\nYou can get the list of supported formats at:\n/server/annotation/formats'),
],
request=PolymorphicProxySerializer(
component_name='JobAnnotationsUpdate',
serializers=[LabeledDataSerializer, AnnotationFileSerializer],
resource_type_field_name=None
),
responses={
'201': OpenApiResponse(description='Uploading has finished'), '201': OpenApiResponse(description='Uploading has finished'),
'202': OpenApiResponse(description='Uploading has been started'), '202': OpenApiResponse(description='Uploading has been started'),
'405': OpenApiResponse(description='Format is not available'), '405': OpenApiResponse(description='Format is not available'),
@ -1335,9 +1382,9 @@ class JobViewSet(viewsets.GenericViewSet, mixins.ListModelMixin,
OpenApiParameter('action', location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, OpenApiParameter('action', location=OpenApiParameter.QUERY, type=OpenApiTypes.STR,
required=True, enum=['create', 'update', 'delete']) required=True, enum=['create', 'update', 'delete'])
], ],
request=LabeledDataSerializer,
responses={ responses={
#TODO '200': OpenApiResponse(description='Annotations successfully uploaded'),
'200': OpenApiResponse(description=''),
}) })
@extend_schema(methods=['DELETE'], summary='Method deletes all annotations for a specific job', @extend_schema(methods=['DELETE'], summary='Method deletes all annotations for a specific job',
responses={ responses={
@ -1436,7 +1483,7 @@ class JobViewSet(viewsets.GenericViewSet, mixins.ListModelMixin,
location=OpenApiParameter.QUERY, type=OpenApiTypes.NUMBER, required=False), location=OpenApiParameter.QUERY, type=OpenApiTypes.NUMBER, required=False),
], ],
responses={ responses={
'200': OpenApiResponse(description='Download of file started'), '200': OpenApiResponse(OpenApiTypes.BINARY, description='Download of file started'),
'201': OpenApiResponse(description='Output file is ready for downloading'), '201': OpenApiResponse(description='Output file is ready for downloading'),
'202': OpenApiResponse(description='Exporting has been started'), '202': OpenApiResponse(description='Exporting has been started'),
'405': OpenApiResponse(description='Format is not available'), '405': OpenApiResponse(description='Format is not available'),
@ -1454,12 +1501,12 @@ class JobViewSet(viewsets.GenericViewSet, mixins.ListModelMixin,
callback=dm.views.export_job_as_dataset callback=dm.views.export_job_as_dataset
) )
@extend_schema( @extend_schema(summary='Method returns list of issues for the job',
summary='Method returns list of issues for the job', responses=IssueReadSerializer(many=True)) # Duplicate to still get 'list' op. name
responses={ @action(detail=True, methods=['GET'], serializer_class=IssueReadSerializer(many=True),
'200': IssueReadSerializer(many=True) # Remove regular list() parameters from swagger schema
}) # https://drf-spectacular.readthedocs.io/en/latest/faq.html#my-action-is-erroneously-paginated-or-has-filter-parameters-that-i-do-not-want
@action(detail=True, methods=['GET'], serializer_class=IssueReadSerializer) pagination_class=None, filter_fields=None, search_fields=None, ordering_fields=None)
def issues(self, request, pk): def issues(self, request, pk):
db_job = self.get_object() db_job = self.get_object()
queryset = db_job.issues queryset = db_job.issues
@ -1471,16 +1518,16 @@ class JobViewSet(viewsets.GenericViewSet, mixins.ListModelMixin,
@extend_schema(summary='Method returns data for a specific job', @extend_schema(summary='Method returns data for a specific job',
parameters=[ parameters=[
OpenApiParameter('type', description='Specifies the type of the requested data', OpenApiParameter('type', description='Specifies the type of the requested data',
location=OpenApiParameter.QUERY, required=True, type=OpenApiTypes.STR, location=OpenApiParameter.QUERY, required=False, type=OpenApiTypes.STR,
enum=['chunk', 'frame', 'preview', 'context_image']), enum=['chunk', 'frame', 'preview', 'context_image']),
OpenApiParameter('quality', location=OpenApiParameter.QUERY, required=True, OpenApiParameter('quality', location=OpenApiParameter.QUERY, required=False,
type=OpenApiTypes.STR, enum=['compressed', 'original'], type=OpenApiTypes.STR, enum=['compressed', 'original'],
description="Specifies the quality level of the requested data, doesn't matter for 'preview' type"), description="Specifies the quality level of the requested data, doesn't matter for 'preview' type"),
OpenApiParameter('number', location=OpenApiParameter.QUERY, required=True, type=OpenApiTypes.NUMBER, OpenApiParameter('number', location=OpenApiParameter.QUERY, required=False, type=OpenApiTypes.INT,
description="A unique number value identifying chunk or frame, doesn't matter for 'preview' type"), description="A unique number value identifying chunk or frame, doesn't matter for 'preview' type"),
], ],
responses={ responses={
'200': OpenApiResponse(description='Data of a specific type'), '200': OpenApiResponse(OpenApiTypes.BINARY, description='Data of a specific type'),
}) })
@action(detail=True, methods=['GET']) @action(detail=True, methods=['GET'])
def data(self, request, pk): def data(self, request, pk):
@ -1500,6 +1547,7 @@ class JobViewSet(viewsets.GenericViewSet, mixins.ListModelMixin,
'200': DataMetaReadSerializer, '200': DataMetaReadSerializer,
}) })
@extend_schema(methods=['PATCH'], summary='Method performs an update of data meta fields (deleted frames)', @extend_schema(methods=['PATCH'], summary='Method performs an update of data meta fields (deleted frames)',
request=DataMetaWriteSerializer,
responses={ responses={
'200': DataMetaReadSerializer, '200': DataMetaReadSerializer,
}, tags=['tasks'], versions=['2.0']) }, tags=['tasks'], versions=['2.0'])
@ -1565,11 +1613,11 @@ class JobViewSet(viewsets.GenericViewSet, mixins.ListModelMixin,
serializer = DataMetaReadSerializer(db_data) serializer = DataMetaReadSerializer(db_data)
return Response(serializer.data) return Response(serializer.data)
@extend_schema(summary='The action returns the list of tracked ' @extend_schema(summary='The action returns the list of tracked changes for the job',
'changes for the job', responses={ responses={
'200': JobCommitSerializer(many=True), '200': JobCommitSerializer(many=True),
}) })
@action(detail=True, methods=['GET'], serializer_class=JobCommitSerializer) @action(detail=True, methods=['GET'], serializer_class=None)
def commits(self, request, pk): def commits(self, request, pk):
db_job = self.get_object() db_job = self.get_object()
queryset = db_job.commits.order_by('-id') queryset = db_job.commits.order_by('-id')
@ -1594,20 +1642,17 @@ class JobViewSet(viewsets.GenericViewSet, mixins.ListModelMixin,
responses={ responses={
'200': IssueReadSerializer(many=True), '200': IssueReadSerializer(many=True),
}), }),
update=extend_schema(
summary='Method updates an issue by id',
responses={
'200': IssueWriteSerializer,
}),
partial_update=extend_schema( partial_update=extend_schema(
summary='Methods does a partial update of chosen fields in an issue', summary='Methods does a partial update of chosen fields in an issue',
request=IssueWriteSerializer,
responses={ responses={
'200': IssueWriteSerializer, '200': IssueReadSerializer, # check IssueWriteSerializer.to_representation
}), }),
create=extend_schema( create=extend_schema(
summary='Method creates an issue', summary='Method creates an issue',
request=IssueWriteSerializer,
responses={ responses={
'201': IssueWriteSerializer, '201': IssueReadSerializer, # check IssueWriteSerializer.to_representation
}), }),
destroy=extend_schema( destroy=extend_schema(
summary='Method deletes an issue', summary='Method deletes an issue',
@ -1615,9 +1660,11 @@ class JobViewSet(viewsets.GenericViewSet, mixins.ListModelMixin,
'204': OpenApiResponse(description='The issue has been deleted'), '204': OpenApiResponse(description='The issue has been deleted'),
}) })
) )
class IssueViewSet(viewsets.ModelViewSet): class IssueViewSet(viewsets.GenericViewSet, mixins.ListModelMixin,
mixins.RetrieveModelMixin, mixins.CreateModelMixin, mixins.DestroyModelMixin,
PartialUpdateModelMixin
):
queryset = Issue.objects.all().order_by('-id') queryset = Issue.objects.all().order_by('-id')
http_method_names = ['get', 'post', 'patch', 'delete', 'options']
iam_organization_field = 'job__segment__task__organization' iam_organization_field = 'job__segment__task__organization'
search_fields = ('owner', 'assignee') search_fields = ('owner', 'assignee')
filter_fields = list(search_fields) + ['id', 'job_id', 'task_id', 'resolved'] filter_fields = list(search_fields) + ['id', 'job_id', 'task_id', 'resolved']
@ -1648,11 +1695,14 @@ class IssueViewSet(viewsets.ModelViewSet):
serializer.save(owner=self.request.user) serializer.save(owner=self.request.user)
@extend_schema(summary='The action returns all comments of a specific issue', @extend_schema(summary='The action returns all comments of a specific issue',
responses={ responses=CommentReadSerializer(many=True)) # Duplicate to still get 'list' op. name
'200': CommentReadSerializer(many=True), @action(detail=True, methods=['GET'], serializer_class=CommentReadSerializer(many=True),
}) # Remove regular list() parameters from swagger schema
@action(detail=True, methods=['GET'], serializer_class=CommentReadSerializer) # https://drf-spectacular.readthedocs.io/en/latest/faq.html#my-action-is-erroneously-paginated-or-has-filter-parameters-that-i-do-not-want
pagination_class=None, filter_fields=None, search_fields=None, ordering_fields=None)
def comments(self, request, pk): def comments(self, request, pk):
# TODO: remove this endpoint? It is totally covered by issue body.
db_issue = self.get_object() db_issue = self.get_object()
queryset = db_issue.comments queryset = db_issue.comments
serializer = CommentReadSerializer(queryset, serializer = CommentReadSerializer(queryset,
@ -1672,20 +1722,17 @@ class IssueViewSet(viewsets.ModelViewSet):
responses={ responses={
'200':CommentReadSerializer(many=True), '200':CommentReadSerializer(many=True),
}), }),
update=extend_schema(
summary='Method updates a comment by id',
responses={
'200': CommentWriteSerializer,
}),
partial_update=extend_schema( partial_update=extend_schema(
summary='Methods does a partial update of chosen fields in a comment', summary='Methods does a partial update of chosen fields in a comment',
request=CommentWriteSerializer,
responses={ responses={
'200': CommentWriteSerializer, '200': CommentReadSerializer, # check CommentWriteSerializer.to_representation
}), }),
create=extend_schema( create=extend_schema(
summary='Method creates a comment', summary='Method creates a comment',
request=CommentWriteSerializer,
responses={ responses={
'201': CommentWriteSerializer, '201': CommentReadSerializer, # check CommentWriteSerializer.to_representation
}), }),
destroy=extend_schema( destroy=extend_schema(
summary='Method deletes a comment', summary='Method deletes a comment',
@ -1693,9 +1740,11 @@ class IssueViewSet(viewsets.ModelViewSet):
'204': OpenApiResponse(description='The comment has been deleted'), '204': OpenApiResponse(description='The comment has been deleted'),
}) })
) )
class CommentViewSet(viewsets.ModelViewSet): class CommentViewSet(viewsets.GenericViewSet, mixins.ListModelMixin,
mixins.RetrieveModelMixin, mixins.CreateModelMixin, mixins.DestroyModelMixin,
PartialUpdateModelMixin
):
queryset = Comment.objects.all().order_by('-id') queryset = Comment.objects.all().order_by('-id')
http_method_names = ['get', 'post', 'patch', 'delete', 'options']
iam_organization_field = 'issue__job__segment__task__organization' iam_organization_field = 'issue__job__segment__task__organization'
search_fields = ('owner',) search_fields = ('owner',)
filter_fields = list(search_fields) + ['id', 'issue_id'] filter_fields = list(search_fields) + ['id', 'issue_id']
@ -1753,9 +1802,8 @@ class CommentViewSet(viewsets.ModelViewSet):
}) })
) )
class UserViewSet(viewsets.GenericViewSet, mixins.ListModelMixin, class UserViewSet(viewsets.GenericViewSet, mixins.ListModelMixin,
mixins.RetrieveModelMixin, mixins.UpdateModelMixin, mixins.DestroyModelMixin): mixins.RetrieveModelMixin, PartialUpdateModelMixin, mixins.DestroyModelMixin):
queryset = User.objects.prefetch_related('groups').all() queryset = User.objects.prefetch_related('groups').all()
http_method_names = ['get', 'post', 'head', 'patch', 'delete', 'options']
search_fields = ('username', 'first_name', 'last_name') search_fields = ('username', 'first_name', 'last_name')
iam_organization_field = 'memberships__organization' iam_organization_field = 'memberships__organization'
@ -1822,17 +1870,21 @@ class UserViewSet(viewsets.GenericViewSet, mixins.ListModelMixin,
}), }),
partial_update=extend_schema( partial_update=extend_schema(
summary='Methods does a partial update of chosen fields in a cloud storage instance', summary='Methods does a partial update of chosen fields in a cloud storage instance',
request=CloudStorageWriteSerializer,
responses={ responses={
'200': CloudStorageWriteSerializer, '200': CloudStorageReadSerializer, # check CloudStorageWriteSerializer.to_representation
}), }),
create=extend_schema( create=extend_schema(
summary='Method creates a cloud storage with a specified characteristics', summary='Method creates a cloud storage with a specified characteristics',
request=CloudStorageWriteSerializer,
responses={ responses={
'201': CloudStorageWriteSerializer, '201': CloudStorageReadSerializer, # check CloudStorageWriteSerializer.to_representation
}) })
) )
class CloudStorageViewSet(viewsets.ModelViewSet): class CloudStorageViewSet(viewsets.GenericViewSet, mixins.ListModelMixin,
http_method_names = ['get', 'post', 'patch', 'delete', 'options'] mixins.RetrieveModelMixin, mixins.CreateModelMixin, mixins.DestroyModelMixin,
PartialUpdateModelMixin
):
queryset = CloudStorageModel.objects.all().prefetch_related('data') queryset = CloudStorageModel.objects.all().prefetch_related('data')
search_fields = ('provider_type', 'display_name', 'resource', search_fields = ('provider_type', 'display_name', 'resource',

@ -45,10 +45,10 @@ def get_git_changeset():
so it's sufficient for generating the development version numbers. so it's sufficient for generating the development version numbers.
""" """
repo_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) repo_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
git_log = subprocess.Popen( git_log = subprocess.Popen( # nosec: B603, B607
'git log --pretty=format:%ct --quiet -1 HEAD', ['git', 'log', '--pretty=format:%ct', '--quiet', '-1', 'HEAD'],
stdout=subprocess.PIPE, stderr=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
shell=True, cwd=repo_dir, universal_newlines=True, cwd=repo_dir, universal_newlines=True,
) )
timestamp = git_log.communicate()[0] timestamp = git_log.communicate()[0]
try: try:
@ -56,4 +56,3 @@ def get_git_changeset():
except ValueError: except ValueError:
return None return None
return timestamp.strftime('%Y%m%d%H%M%S') return timestamp.strftime('%Y%m%d%H%M%S')

@ -8,9 +8,9 @@ import os
from pathlib import Path from pathlib import Path
import pytest import pytest
from cvat_sdk import exceptions, make_client from cvat_sdk import make_client
from cvat_sdk.core.tasks import TaskProxy from cvat_sdk.api_client import exceptions
from cvat_sdk.core.types import ResourceType from cvat_sdk.core.proxies.tasks import ResourceType, Task
from PIL import Image from PIL import Image
from sdk.util import generate_coco_json from sdk.util import generate_coco_json
@ -41,8 +41,6 @@ class TestCLI:
yield yield
self.tmp_path = None
@pytest.fixture @pytest.fixture
def fxt_image_file(self): def fxt_image_file(self):
img_path = self.tmp_path / "img_0.png" img_path = self.tmp_path / "img_0.png"
@ -61,7 +59,7 @@ class TestCLI:
yield ann_filename yield ann_filename
@pytest.fixture @pytest.fixture
def fxt_backup_file(self, fxt_new_task: TaskProxy, fxt_coco_file: str): def fxt_backup_file(self, fxt_new_task: Task, fxt_coco_file: str):
backup_path = self.tmp_path / "backup.zip" backup_path = self.tmp_path / "backup.zip"
fxt_new_task.import_annotations("COCO 1.0", filename=fxt_coco_file) fxt_new_task.import_annotations("COCO 1.0", filename=fxt_coco_file)
@ -73,7 +71,7 @@ class TestCLI:
def fxt_new_task(self): def fxt_new_task(self):
files = generate_images(str(self.tmp_path), 5) files = generate_images(str(self.tmp_path), 5)
task = self.client.create_task( task = self.client.tasks.create_from_data(
spec={ spec={
"name": "test_task", "name": "test_task",
"labels": [{"name": "car"}, {"name": "person"}], "labels": [{"name": "car"}, {"name": "person"}],
@ -114,30 +112,28 @@ class TestCLI:
) )
task_id = int(stdout.split()[-1]) task_id = int(stdout.split()[-1])
assert self.client.retrieve_task(task_id).size == 5 assert self.client.tasks.retrieve(task_id).size == 5
def test_can_list_tasks_in_simple_format(self, fxt_new_task: TaskProxy): def test_can_list_tasks_in_simple_format(self, fxt_new_task: Task):
output = self.run_cli("ls") output = self.run_cli("ls")
results = output.split("\n") results = output.split("\n")
assert any(str(fxt_new_task.id) in r for r in results) assert any(str(fxt_new_task.id) in r for r in results)
def test_can_list_tasks_in_json_format(self, fxt_new_task: TaskProxy): def test_can_list_tasks_in_json_format(self, fxt_new_task: Task):
output = self.run_cli("ls", "--json") output = self.run_cli("ls", "--json")
results = json.loads(output) results = json.loads(output)
assert any(r["id"] == fxt_new_task.id for r in results) assert any(r["id"] == fxt_new_task.id for r in results)
def test_can_delete_task(self, fxt_new_task: TaskProxy): def test_can_delete_task(self, fxt_new_task: Task):
self.run_cli("delete", str(fxt_new_task.id)) self.run_cli("delete", str(fxt_new_task.id))
with pytest.raises(exceptions.ApiException) as capture: with pytest.raises(exceptions.NotFoundException):
fxt_new_task.fetch() fxt_new_task.fetch()
assert capture.value.status == 404 def test_can_download_task_annotations(self, fxt_new_task: Task):
filename = self.tmp_path / "task_{fxt_new_task.id}-cvat.zip"
def test_can_download_task_annotations(self, fxt_new_task: TaskProxy):
filename: Path = self.tmp_path / "task_{fxt_new_task.id}-cvat.zip"
self.run_cli( self.run_cli(
"dump", "dump",
str(fxt_new_task.id), str(fxt_new_task.id),
@ -152,8 +148,8 @@ class TestCLI:
assert 0 < filename.stat().st_size assert 0 < filename.stat().st_size
def test_can_download_task_backup(self, fxt_new_task: TaskProxy): def test_can_download_task_backup(self, fxt_new_task: Task):
filename: Path = self.tmp_path / "task_{fxt_new_task.id}-cvat.zip" filename = self.tmp_path / "task_{fxt_new_task.id}-cvat.zip"
self.run_cli( self.run_cli(
"export", "export",
str(fxt_new_task.id), str(fxt_new_task.id),
@ -165,7 +161,7 @@ class TestCLI:
assert 0 < filename.stat().st_size assert 0 < filename.stat().st_size
@pytest.mark.parametrize("quality", ("compressed", "original")) @pytest.mark.parametrize("quality", ("compressed", "original"))
def test_can_download_task_frames(self, fxt_new_task: TaskProxy, quality: str): def test_can_download_task_frames(self, fxt_new_task: Task, quality: str):
out_dir = str(self.tmp_path / "downloads") out_dir = str(self.tmp_path / "downloads")
self.run_cli( self.run_cli(
"frames", "frames",
@ -182,13 +178,13 @@ class TestCLI:
"task_{}_frame_{:06d}.jpg".format(fxt_new_task.id, i) for i in range(2) "task_{}_frame_{:06d}.jpg".format(fxt_new_task.id, i) for i in range(2)
} }
def test_can_upload_annotations(self, fxt_new_task: TaskProxy, fxt_coco_file: Path): def test_can_upload_annotations(self, fxt_new_task: Task, fxt_coco_file: Path):
self.run_cli("upload", str(fxt_new_task.id), str(fxt_coco_file), "--format", "COCO 1.0") self.run_cli("upload", str(fxt_new_task.id), str(fxt_coco_file), "--format", "COCO 1.0")
def test_can_create_from_backup(self, fxt_new_task: TaskProxy, fxt_backup_file: Path): def test_can_create_from_backup(self, fxt_new_task: Task, fxt_backup_file: Path):
stdout = self.run_cli("import", str(fxt_backup_file)) stdout = self.run_cli("import", str(fxt_backup_file))
task_id = int(stdout.split()[-1]) task_id = int(stdout.split()[-1])
assert task_id assert task_id
assert task_id != fxt_new_task.id assert task_id != fxt_new_task.id
assert self.client.retrieve_task(task_id).size == fxt_new_task.size assert self.client.tasks.retrieve(task_id).size == fxt_new_task.size

@ -0,0 +1,138 @@
# Copyright (C) 2022 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
import json
from http import HTTPStatus
import pytest
from cvat_sdk.api_client import ApiClient, Configuration, models
from shared.utils.config import BASE_URL, USER_PASS, make_api_client
@pytest.mark.usefixtures("dontchangedb")
class TestBasicAuth:
def test_can_do_basic_auth(self, admin_user: str):
username = admin_user
config = Configuration(host=BASE_URL, username=username, password=USER_PASS)
with ApiClient(config) as client:
(user, response) = client.users_api.retrieve_self()
assert response.status == HTTPStatus.OK
assert user.username == username
@pytest.mark.usefixtures("changedb")
class TestTokenAuth:
@staticmethod
def login(client: ApiClient, username: str) -> models.Token:
(auth, _) = client.auth_api.create_login(
models.LoginRequest(username=username, password=USER_PASS)
)
client.set_default_header("Authorization", "Token " + auth.key)
return auth
@classmethod
def make_client(cls, username: str) -> ApiClient:
with ApiClient(Configuration(host=BASE_URL)) as client:
cls.login(client, username)
return client
def test_can_do_token_auth_and_manage_cookies(self, admin_user: str):
username = admin_user
with ApiClient(Configuration(host=BASE_URL)) as client:
auth = self.login(client, username=username)
assert "sessionid" in client.cookies
assert "csrftoken" in client.cookies
assert auth.key
(user, response) = client.users_api.retrieve_self()
assert response.status == HTTPStatus.OK
assert user.username == username
def test_can_do_logout(self, admin_user: str):
username = admin_user
with self.make_client(username) as client:
(_, response) = client.auth_api.create_logout()
assert response.status == HTTPStatus.OK
(_, response) = client.users_api.retrieve_self(
_parse_response=False, _check_status=False
)
assert response.status == HTTPStatus.UNAUTHORIZED
@pytest.mark.usefixtures("changedb")
class TestCredentialsManagement:
def test_can_register(self):
username = "newuser"
email = "123@456.com"
with ApiClient(Configuration(host=BASE_URL)) as client:
(user, response) = client.auth_api.create_register(
models.RestrictedRegisterRequest(
username=username, password1=USER_PASS, password2=USER_PASS, email=email
)
)
assert response.status == HTTPStatus.CREATED
assert user.username == username
with make_api_client(username) as client:
(user, response) = client.users_api.retrieve_self()
assert response.status == HTTPStatus.OK
assert user.username == username
assert user.email == email
def test_can_change_password(self, admin_user: str):
username = admin_user
new_pass = "5w4knrqaW#$@gewa"
with make_api_client(username) as client:
(info, response) = client.auth_api.create_password_change(
models.PasswordChangeRequest(
old_password=USER_PASS, new_password1=new_pass, new_password2=new_pass
)
)
assert response.status == HTTPStatus.OK
assert info.detail == "New password has been saved."
(_, response) = client.users_api.retrieve_self(
_parse_response=False, _check_status=False
)
assert response.status == HTTPStatus.UNAUTHORIZED
client.configuration.password = new_pass
(user, response) = client.users_api.retrieve_self()
assert response.status == HTTPStatus.OK
assert user.username == username
def test_can_report_weak_password(self, admin_user: str):
username = admin_user
new_pass = "pass"
with make_api_client(username) as client:
(_, response) = client.auth_api.create_password_change(
models.PasswordChangeRequest(
old_password=USER_PASS, new_password1=new_pass, new_password2=new_pass
),
_parse_response=False,
_check_status=False,
)
assert response.status == HTTPStatus.BAD_REQUEST
assert json.loads(response.data) == {
"new_password2": [
"This password is too short. It must contain at least 8 characters.",
"This password is too common.",
]
}
def test_can_report_mismatching_passwords(self, admin_user: str):
username = admin_user
with make_api_client(username) as client:
(_, response) = client.auth_api.create_password_change(
models.PasswordChangeRequest(
old_password=USER_PASS, new_password1="3j4tb13/T$#", new_password2="q#@$n34g5"
),
_parse_response=False,
_check_status=False,
)
assert response.status == HTTPStatus.BAD_REQUEST
assert json.loads(response.data) == {
"new_password2": ["The two password fields didnt match."]
}

@ -3,48 +3,78 @@
# #
# SPDX-License-Identifier: MIT # SPDX-License-Identifier: MIT
import pytest import json
from copy import deepcopy
from http import HTTPStatus from http import HTTPStatus
import pytest
from cvat_sdk import models
from deepdiff import DeepDiff from deepdiff import DeepDiff
from copy import deepcopy
from shared.utils.config import post_method, patch_method from cvat_sdk.api_client import exceptions
from shared.utils.config import make_api_client
@pytest.mark.usefixtures('changedb')
@pytest.mark.usefixtures("changedb")
class TestPostIssues: class TestPostIssues:
def _test_check_response(self, user, data, is_allow, **kwargs): def _test_check_response(self, user, data, is_allow, **kwargs):
response = post_method(user, 'issues', data, **kwargs) with make_api_client(user) as client:
(_, response) = client.issues_api.create(
models.IssueWriteRequest(**data),
**kwargs,
_parse_response=False,
_check_status=False,
)
if is_allow: if is_allow:
assert response.status_code == HTTPStatus.CREATED assert response.status == HTTPStatus.CREATED
assert user == response.json()['owner']['username'] response_json = json.loads(response.data)
assert data['message'] == response.json()['comments'][0]['message'] assert user == response_json["owner"]["username"]
assert DeepDiff(data, response.json(), assert data["message"] == response_json["comments"][0]["message"]
exclude_regex_paths=r"root\['created_date|updated_date|comments|id|owner|message'\]") == {} assert (
DeepDiff(
data,
response_json,
exclude_regex_paths=r"root\['created_date|updated_date|comments|id|owner|message'\]",
)
== {}
)
else: else:
assert response.status_code == HTTPStatus.FORBIDDEN assert response.status == HTTPStatus.FORBIDDEN
@pytest.mark.parametrize('org', ['']) @pytest.mark.parametrize("org", [""])
@pytest.mark.parametrize('privilege, job_staff, is_allow', [ @pytest.mark.parametrize(
('admin', True, True), ('admin', False, True), "privilege, job_staff, is_allow",
('business', True, True), ('business', False, False), [
('worker', True, True), ('worker', False, False), ("admin", True, True),
('user', True, True), ('user', False, False) ("admin", False, True),
]) ("business", True, True),
def test_user_create_issue(self, org, privilege, job_staff, is_allow, ("business", False, False),
find_job_staff_user, find_users, jobs_by_org): ("worker", True, True),
("worker", False, False),
("user", True, True),
("user", False, False),
],
)
def test_user_create_issue(
self, org, privilege, job_staff, is_allow, find_job_staff_user, find_users, jobs_by_org
):
users = find_users(privilege=privilege) users = find_users(privilege=privilege)
jobs = jobs_by_org[org] jobs = jobs_by_org[org]
username, jid = find_job_staff_user(jobs, users, job_staff) username, jid = find_job_staff_user(jobs, users, job_staff)
job, = filter(lambda job: job['id'] == jid, jobs) (job,) = filter(lambda job: job["id"] == jid, jobs)
data = { data = {
"assignee": None, "assignee": None,
"comments": [], "comments": [],
"job": jid, "job": jid,
"frame": job['start_frame'], "frame": job["start_frame"],
"position": [ "position": [
0., 0., 1., 1., 0.0,
0.0,
1.0,
1.0,
], ],
"resolved": False, "resolved": False,
"message": "lorem ipsum", "message": "lorem ipsum",
@ -52,16 +82,23 @@ class TestPostIssues:
self._test_check_response(username, data, is_allow) self._test_check_response(username, data, is_allow)
@pytest.mark.parametrize("org", [2])
@pytest.mark.parametrize('org', [2]) @pytest.mark.parametrize(
@pytest.mark.parametrize('role, job_staff, is_allow', [ "role, job_staff, is_allow",
('maintainer', False, True), ('owner', False, True), [
('supervisor', False, False), ('worker', False, False), ("maintainer", False, True),
('maintainer', True, True), ('owner', True, True), ("owner", False, True),
('supervisor', True, True), ('worker', True, True) ("supervisor", False, False),
]) ("worker", False, False),
def test_member_create_issue(self, org, role, job_staff, is_allow, ("maintainer", True, True),
find_job_staff_user, find_users, jobs_by_org, jobs): ("owner", True, True),
("supervisor", True, True),
("worker", True, True),
],
)
def test_member_create_issue(
self, org, role, job_staff, is_allow, find_job_staff_user, find_users, jobs_by_org, jobs
):
users = find_users(role=role, org=org) users = find_users(role=role, org=org)
username, jid = find_job_staff_user(jobs_by_org[org], users, job_staff) username, jid = find_job_staff_user(jobs_by_org[org], users, job_staff)
job = jobs[jid] job = jobs[jid]
@ -70,50 +107,85 @@ class TestPostIssues:
"assignee": None, "assignee": None,
"comments": [], "comments": [],
"job": jid, "job": jid,
"frame": job['start_frame'], "frame": job["start_frame"],
"position": [ "position": [
0., 0., 1., 1., 0.0,
0.0,
1.0,
1.0,
], ],
"resolved": False, "resolved": False,
"message": "lorem ipsum", "message": "lorem ipsum",
} }
self._test_check_response(username, data, is_allow, org_id=org) self._test_check_response(username, data, is_allow, org_id=org)
@pytest.mark.usefixtures('changedb')
@pytest.mark.usefixtures("changedb")
class TestPatchIssues: class TestPatchIssues:
def _test_check_response(self, user, issue_id, data, is_allow, **kwargs): def _test_check_response(self, user, issue_id, data, is_allow, **kwargs):
response = patch_method(user, f'issues/{issue_id}', data, with make_api_client(user) as client:
action='update', **kwargs) (_, response) = client.issues_api.partial_update(
issue_id,
patched_issue_write_request=models.PatchedIssueWriteRequest(**data),
**kwargs,
_parse_response=False,
_check_status=False,
)
if is_allow: if is_allow:
assert response.status_code == HTTPStatus.OK assert response.status == HTTPStatus.OK
assert DeepDiff(data, response.json(), assert (
exclude_regex_paths=r"root\['created_date|updated_date|comments|id|owner'\]") == {} DeepDiff(
data,
json.loads(response.data),
exclude_regex_paths=r"root\['created_date|updated_date|comments|id|owner'\]",
)
== {}
)
else: else:
assert response.status_code == HTTPStatus.FORBIDDEN assert response.status == HTTPStatus.FORBIDDEN
@pytest.fixture(scope='class') @pytest.fixture(scope="class")
def request_data(self, issues): def request_data(self, issues):
def get_data(issue_id): def get_data(issue_id):
data = deepcopy(issues[issue_id]) data = deepcopy(issues[issue_id])
data['resolved'] = not data['resolved'] data["resolved"] = not data["resolved"]
data.pop('comments') data.pop("comments")
data.pop('updated_date') data.pop("updated_date")
data.pop('id') data.pop("id")
data.pop('owner') data.pop("owner")
return data return data
return get_data return get_data
@pytest.mark.parametrize('org', ['']) @pytest.mark.parametrize("org", [""])
@pytest.mark.parametrize('privilege, issue_staff, issue_admin, is_allow', [ @pytest.mark.parametrize(
('admin', True, None, True), ('admin', False, None, True), "privilege, issue_staff, issue_admin, is_allow",
('business', True, None, True), ('business', False, None, False), [
('user', True, None, True), ('user', False, None, False), ("admin", True, None, True),
('worker', False, True, True), ('worker', True, False, False), ("admin", False, None, True),
('worker', False, False, False) ("business", True, None, True),
]) ("business", False, None, False),
def test_user_update_issue(self, org, privilege, issue_staff, issue_admin, is_allow, ("user", True, None, True),
find_issue_staff_user, find_users, issues_by_org, request_data): ("user", False, None, False),
("worker", False, True, True),
("worker", True, False, False),
("worker", False, False, False),
],
)
def test_user_update_issue(
self,
org,
privilege,
issue_staff,
issue_admin,
is_allow,
find_issue_staff_user,
find_users,
issues_by_org,
request_data,
):
users = find_users(privilege=privilege) users = find_users(privilege=privilege)
issues = issues_by_org[org] issues = issues_by_org[org]
username, issue_id = find_issue_staff_user(issues, users, issue_staff, issue_admin) username, issue_id = find_issue_staff_user(issues, users, issue_staff, issue_admin)
@ -121,19 +193,135 @@ class TestPatchIssues:
data = request_data(issue_id) data = request_data(issue_id)
self._test_check_response(username, issue_id, data, is_allow) self._test_check_response(username, issue_id, data, is_allow)
@pytest.mark.parametrize('org', [2]) @pytest.mark.parametrize("org", [2])
@pytest.mark.parametrize('role, issue_staff, issue_admin, is_allow', [ @pytest.mark.parametrize(
('maintainer', True, None, True), ('maintainer', False, None, True), "role, issue_staff, issue_admin, is_allow",
('supervisor', True, None, True), ('supervisor', False, None, False), [
('owner', True, None, True), ('owner', False, None, True), ("maintainer", True, None, True),
('worker', False, True, True), ('worker', True, False, False), ("maintainer", False, None, True),
('worker', False, False, False) ("supervisor", True, None, True),
]) ("supervisor", False, None, False),
def test_member_update_issue(self, org, role, issue_staff, issue_admin, is_allow, ("owner", True, None, True),
find_issue_staff_user, find_users, issues_by_org, request_data): ("owner", False, None, True),
("worker", False, True, True),
("worker", True, False, False),
("worker", False, False, False),
],
)
def test_member_update_issue(
self,
org,
role,
issue_staff,
issue_admin,
is_allow,
find_issue_staff_user,
find_users,
issues_by_org,
request_data,
):
users = find_users(role=role, org=org) users = find_users(role=role, org=org)
issues = issues_by_org[org] issues = issues_by_org[org]
username, issue_id = find_issue_staff_user(issues, users, issue_staff, issue_admin) username, issue_id = find_issue_staff_user(issues, users, issue_staff, issue_admin)
data = request_data(issue_id) data = request_data(issue_id)
self._test_check_response(username, issue_id, data, is_allow, org_id=org) self._test_check_response(username, issue_id, data, is_allow, org_id=org)
@pytest.mark.xfail(raises=exceptions.ServiceException,
reason="server bug, https://github.com/cvat-ai/cvat/issues/122")
def test_cant_update_message(self, admin_user: str, issues_by_org):
org = 2
issue_id = issues_by_org[org][0]['id']
with make_api_client(admin_user) as client:
client.issues_api.partial_update(
issue_id,
patched_issue_write_request=models.PatchedIssueWriteRequest(message="foo"),
org_id=org,
)
@pytest.mark.usefixtures("changedb")
class TestDeleteIssues:
def _test_check_response(self, user, issue_id, expect_success, **kwargs):
with make_api_client(user) as client:
(_, response) = client.issues_api.destroy(
issue_id,
**kwargs,
_parse_response=False,
_check_status=False,
)
if expect_success:
assert response.status == HTTPStatus.NO_CONTENT
(_, response) = client.issues_api.retrieve(
issue_id, _parse_response=False, _check_status=False
)
assert response.status == HTTPStatus.NOT_FOUND
else:
assert response.status == HTTPStatus.FORBIDDEN
@pytest.mark.parametrize("org", [""])
@pytest.mark.parametrize(
"privilege, issue_staff, issue_admin, expect_success",
[
("admin", True, None, True),
("admin", False, None, True),
("business", True, None, True),
("business", False, None, False),
("user", True, None, True),
("user", False, None, False),
("worker", False, True, True),
("worker", True, False, False),
("worker", False, False, False),
],
)
def test_user_delete_issue(
self,
org,
privilege,
issue_staff,
issue_admin,
expect_success,
find_issue_staff_user,
find_users,
issues_by_org,
):
users = find_users(privilege=privilege)
issues = issues_by_org[org]
username, issue_id = find_issue_staff_user(issues, users, issue_staff, issue_admin)
self._test_check_response(username, issue_id, expect_success)
@pytest.mark.parametrize("org", [2])
@pytest.mark.parametrize(
"role, issue_staff, issue_admin, expect_success",
[
("maintainer", True, None, True),
("maintainer", False, None, True),
("supervisor", True, None, True),
("supervisor", False, None, False),
("owner", True, None, True),
("owner", False, None, True),
("worker", False, True, True),
("worker", True, False, False),
("worker", False, False, False),
],
)
def test_org_member_delete_issue(
self,
org,
role,
issue_staff,
issue_admin,
expect_success,
find_issue_staff_user,
find_users,
issues_by_org,
):
users = find_users(role=role, org=org)
issues = issues_by_org[org]
username, issue_id = find_issue_staff_user(issues, users, issue_staff, issue_admin)
self._test_check_response(username, issue_id, expect_success, org_id=org)

@ -4,10 +4,14 @@
# SPDX-License-Identifier: MIT # SPDX-License-Identifier: MIT
from http import HTTPStatus from http import HTTPStatus
import json
from typing import List
from cvat_sdk.core.helpers import get_paginated_collection
from deepdiff import DeepDiff from deepdiff import DeepDiff
import pytest import pytest
from copy import deepcopy from copy import deepcopy
from shared.utils.config import get_method, patch_method from shared.utils.config import make_api_client
from .utils import export_dataset
def get_job_staff(job, tasks, projects): def get_job_staff(job, tasks, projects):
job_staff = [] job_staff = []
@ -42,15 +46,17 @@ def filter_jobs(jobs, tasks, org):
@pytest.mark.usefixtures('dontchangedb') @pytest.mark.usefixtures('dontchangedb')
class TestGetJobs: class TestGetJobs:
def _test_get_job_200(self, user, jid, data, **kwargs): def _test_get_job_200(self, user, jid, data, **kwargs):
response = get_method(user, f'jobs/{jid}', **kwargs) with make_api_client(user) as client:
(_, response) = client.jobs_api.retrieve(jid, **kwargs)
assert response.status_code == HTTPStatus.OK assert response.status == HTTPStatus.OK
assert DeepDiff(data, response.json(), exclude_paths="root['updated_date']", assert DeepDiff(data, json.loads(response.data), exclude_paths="root['updated_date']",
ignore_order=True) == {} ignore_order=True) == {}
def _test_get_job_403(self, user, jid, **kwargs): def _test_get_job_403(self, user, jid, **kwargs):
response = get_method(user, f'jobs/{jid}', **kwargs) with make_api_client(user) as client:
assert response.status_code == HTTPStatus.FORBIDDEN (_, response) = client.jobs_api.retrieve(jid, **kwargs,
_check_status=False, _parse_response=False)
assert response.status == HTTPStatus.FORBIDDEN
@pytest.mark.parametrize('org', [None, '', 1, 2]) @pytest.mark.parametrize('org', [None, '', 1, 2])
def test_admin_get_job(self, jobs, tasks, org): def test_admin_get_job(self, jobs, tasks, org):
@ -82,15 +88,17 @@ class TestGetJobs:
@pytest.mark.usefixtures('dontchangedb') @pytest.mark.usefixtures('dontchangedb')
class TestListJobs: class TestListJobs:
def _test_list_jobs_200(self, user, data, **kwargs): def _test_list_jobs_200(self, user, data, **kwargs):
response = get_method(user, 'jobs', **kwargs, page_size='all') with make_api_client(user) as client:
results = get_paginated_collection(client.jobs_api.list_endpoint,
assert response.status_code == HTTPStatus.OK return_json=True, **kwargs)
assert DeepDiff(data, response.json()['results'], exclude_paths="root['updated_date']", assert DeepDiff(data, results, exclude_paths="root['updated_date']",
ignore_order=True) == {} ignore_order=True) == {}
def _test_list_jobs_403(self, user, **kwargs): def _test_list_jobs_403(self, user, **kwargs):
response = get_method(user, 'jobs', **kwargs) with make_api_client(user) as client:
assert response.status_code == HTTPStatus.FORBIDDEN (_, response) = client.jobs_api.list(**kwargs,
_check_status=False, _parse_response=False)
assert response.status == HTTPStatus.FORBIDDEN
@pytest.mark.parametrize('org', [None, '', 1, 2]) @pytest.mark.parametrize('org', [None, '', 1, 2])
def test_admin_list_jobs(self, jobs, tasks, org): def test_admin_list_jobs(self, jobs, tasks, org):
@ -119,52 +127,54 @@ class TestListJobs:
@pytest.mark.usefixtures('dontchangedb') @pytest.mark.usefixtures('dontchangedb')
class TestGetAnnotations: class TestGetAnnotations:
def _test_get_job_annotations_200(self, user, jid, data, **kwargs): def _test_get_job_annotations_200(self, user, jid, data, **kwargs):
response = get_method(user, f'jobs/{jid}/annotations', **kwargs) with make_api_client(user) as client:
(_, response) = client.jobs_api.retrieve_annotations(jid, **kwargs)
response_data = response.json() assert response.status == HTTPStatus.OK
response_data['shapes'] = sorted(response_data['shapes'], key=lambda a: a['id'])
assert response.status_code == HTTPStatus.OK response_data = json.loads(response.data)
assert DeepDiff(data, response_data, response_data['shapes'] = sorted(response_data['shapes'], key=lambda a: a['id'])
exclude_regex_paths=r"root\['version|updated_date'\]") == {} assert DeepDiff(data, response_data,
exclude_regex_paths=r"root\['version|updated_date'\]") == {}
def _test_get_job_annotations_403(self, user, jid, **kwargs): def _test_get_job_annotations_403(self, user, jid, **kwargs):
response = get_method(user, f'jobs/{jid}/annotations', **kwargs) with make_api_client(user) as client:
assert response.status_code == HTTPStatus.FORBIDDEN (_, response) = client.jobs_api.retrieve_annotations(jid, **kwargs,
_check_status=False, _parse_response=False)
assert response.status == HTTPStatus.FORBIDDEN
@pytest.mark.parametrize('org', ['']) @pytest.mark.parametrize('org', [''])
@pytest.mark.parametrize('groups, job_staff, is_allow', [ @pytest.mark.parametrize('groups, job_staff, expect_success', [
(['admin'], True, True), (['admin'], False, True), (['admin'], True, True), (['admin'], False, True),
(['business'], True, True), (['business'], False, False), (['business'], True, True), (['business'], False, False),
(['worker'], True, True), (['worker'], False, False), (['worker'], True, True), (['worker'], False, False),
(['user'], True, True), (['user'], False, False) (['user'], True, True), (['user'], False, False)
]) ])
def test_user_get_job_annotations(self, org, groups, job_staff, def test_user_get_job_annotations(self, org, groups, job_staff,
is_allow, users, jobs, tasks, annotations, find_job_staff_user): expect_success, users, jobs, tasks, annotations, find_job_staff_user):
users = [u for u in users if u['groups'] == groups] users = [u for u in users if u['groups'] == groups]
jobs, kwargs = filter_jobs(jobs, tasks, org) jobs, kwargs = filter_jobs(jobs, tasks, org)
username, job_id = find_job_staff_user(jobs, users, job_staff) username, job_id = find_job_staff_user(jobs, users, job_staff)
if is_allow: if expect_success:
self._test_get_job_annotations_200(username, self._test_get_job_annotations_200(username,
job_id, annotations['job'][str(job_id)], **kwargs) job_id, annotations['job'][str(job_id)], **kwargs)
else: else:
self._test_get_job_annotations_403(username, job_id, **kwargs) self._test_get_job_annotations_403(username, job_id, **kwargs)
@pytest.mark.parametrize('org', [2]) @pytest.mark.parametrize('org', [2])
@pytest.mark.parametrize('role, job_staff, is_allow', [ @pytest.mark.parametrize('role, job_staff, expect_success', [
('owner', True, True), ('owner', False, True), ('owner', True, True), ('owner', False, True),
('maintainer', True, True), ('maintainer', False, True), ('maintainer', True, True), ('maintainer', False, True),
('supervisor', True, True), ('supervisor', False, False), ('supervisor', True, True), ('supervisor', False, False),
('worker', True, True), ('worker', False, False), ('worker', True, True), ('worker', False, False),
]) ])
def test_member_get_job_annotations(self, org, role, job_staff, is_allow, def test_member_get_job_annotations(self, org, role, job_staff, expect_success,
jobs, tasks, find_job_staff_user, annotations, find_users): jobs, tasks, find_job_staff_user, annotations, find_users):
users = find_users(org=org, role=role) users = find_users(org=org, role=role)
jobs, kwargs = filter_jobs(jobs, tasks, org) jobs, kwargs = filter_jobs(jobs, tasks, org)
username, jid = find_job_staff_user(jobs, users, job_staff) username, jid = find_job_staff_user(jobs, users, job_staff)
if is_allow: if expect_success:
data = annotations['job'][str(jid)] data = annotations['job'][str(jid)]
data['shapes'] = sorted(data['shapes'], key=lambda a: a['id']) data['shapes'] = sorted(data['shapes'], key=lambda a: a['id'])
self._test_get_job_annotations_200(username, jid, data, **kwargs) self._test_get_job_annotations_200(username, jid, data, **kwargs)
@ -172,17 +182,17 @@ class TestGetAnnotations:
self._test_get_job_annotations_403(username, jid, **kwargs) self._test_get_job_annotations_403(username, jid, **kwargs)
@pytest.mark.parametrize('org', [1]) @pytest.mark.parametrize('org', [1])
@pytest.mark.parametrize('privilege, is_allow', [ @pytest.mark.parametrize('privilege, expect_success', [
('admin', True), ('business', False), ('worker', False), ('user', False) ('admin', True), ('business', False), ('worker', False), ('user', False)
]) ])
def test_non_member_get_job_annotations(self, org, privilege, is_allow, def test_non_member_get_job_annotations(self, org, privilege, expect_success,
jobs, tasks, find_job_staff_user, annotations, find_users): jobs, tasks, find_job_staff_user, annotations, find_users):
users = find_users(privilege=privilege, exclude_org=org) users = find_users(privilege=privilege, exclude_org=org)
jobs, kwargs = filter_jobs(jobs, tasks, org) jobs, kwargs = filter_jobs(jobs, tasks, org)
username, job_id = find_job_staff_user(jobs, users, False) username, job_id = find_job_staff_user(jobs, users, False)
kwargs = {'org_id': org} kwargs = {'org_id': org}
if is_allow: if expect_success:
self._test_get_job_annotations_200(username, self._test_get_job_annotations_200(username,
job_id, annotations['job'][str(job_id)], **kwargs) job_id, annotations['job'][str(job_id)], **kwargs)
else: else:
@ -190,15 +200,25 @@ class TestGetAnnotations:
@pytest.mark.usefixtures('changedb') @pytest.mark.usefixtures('changedb')
class TestPatchJobAnnotations: class TestPatchJobAnnotations:
_ORG = 2 def _check_respone(self, username, jid, expect_success, data=None, org=None):
kwargs = {}
if org is not None:
if isinstance(org, str):
kwargs['org'] = org
else:
kwargs['org_id'] = org
def _test_check_respone(self, is_allow, response, data=None): with make_api_client(username) as client:
if is_allow: (_, response) = client.jobs_api.partial_update_annotations(id=jid,
assert response.status_code == HTTPStatus.OK patched_labeled_data_request=deepcopy(data), action='update', **kwargs,
assert DeepDiff(data, response.json(), _parse_response=expect_success, _check_status=expect_success)
exclude_regex_paths=r"root\['version|updated_date'\]") == {}
else: if expect_success:
assert response.status_code == HTTPStatus.FORBIDDEN assert response.status == HTTPStatus.OK
assert DeepDiff(data, json.loads(response.data),
exclude_regex_paths=r"root\['version|updated_date'\]") == {}
else:
assert response.status == HTTPStatus.FORBIDDEN
@pytest.fixture(scope='class') @pytest.fixture(scope='class')
def request_data(self, annotations): def request_data(self, annotations):
@ -210,13 +230,13 @@ class TestPatchJobAnnotations:
return get_data return get_data
@pytest.mark.parametrize('org', [2]) @pytest.mark.parametrize('org', [2])
@pytest.mark.parametrize('role, job_staff, is_allow', [ @pytest.mark.parametrize('role, job_staff, expect_success', [
('maintainer', False, True), ('owner', False, True), ('maintainer', False, True), ('owner', False, True),
('supervisor', False, False), ('worker', False, False), ('supervisor', False, False), ('worker', False, False),
('maintainer', True, True), ('owner', True, True), ('maintainer', True, True), ('owner', True, True),
('supervisor', True, True), ('worker', True, True) ('supervisor', True, True), ('worker', True, True)
]) ])
def test_member_update_job_annotations(self, org, role, job_staff, is_allow, def test_member_update_job_annotations(self, org, role, job_staff, expect_success,
find_job_staff_user, find_users, request_data, jobs_by_org, filter_jobs_with_shapes): find_job_staff_user, find_users, request_data, jobs_by_org, filter_jobs_with_shapes):
users = find_users(role=role, org=org) users = find_users(role=role, org=org)
jobs = jobs_by_org[org] jobs = jobs_by_org[org]
@ -224,17 +244,13 @@ class TestPatchJobAnnotations:
username, jid = find_job_staff_user(filtered_jobs, users, job_staff) username, jid = find_job_staff_user(filtered_jobs, users, job_staff)
data = request_data(jid) data = request_data(jid)
response = patch_method(username, f'jobs/{jid}/annotations', self._check_respone(username, jid, expect_success, data, org=org)
data, org_id=org, action='update')
self._test_check_respone(is_allow, response, data)
@pytest.mark.parametrize('org', [2]) @pytest.mark.parametrize('org', [2])
@pytest.mark.parametrize('privilege, is_allow', [ @pytest.mark.parametrize('privilege, expect_success', [
('admin', True), ('business', False), ('worker', False), ('user', False) ('admin', True), ('business', False), ('worker', False), ('user', False)
]) ])
def test_non_member_update_job_annotations(self, org, privilege, is_allow, def test_non_member_update_job_annotations(self, org, privilege, expect_success,
find_job_staff_user, find_users, request_data, jobs_by_org, filter_jobs_with_shapes): find_job_staff_user, find_users, request_data, jobs_by_org, filter_jobs_with_shapes):
users = find_users(privilege=privilege, exclude_org=org) users = find_users(privilege=privilege, exclude_org=org)
jobs = jobs_by_org[org] jobs = jobs_by_org[org]
@ -242,19 +258,16 @@ class TestPatchJobAnnotations:
username, jid = find_job_staff_user(filtered_jobs, users, False) username, jid = find_job_staff_user(filtered_jobs, users, False)
data = request_data(jid) data = request_data(jid)
response = patch_method(username, f'jobs/{jid}/annotations', data, self._check_respone(username, jid, expect_success, data, org=org)
org_id=org, action='update')
self._test_check_respone(is_allow, response, data)
@pytest.mark.parametrize('org', ['']) @pytest.mark.parametrize('org', [''])
@pytest.mark.parametrize('privilege, job_staff, is_allow', [ @pytest.mark.parametrize('privilege, job_staff, expect_success', [
('admin', True, True), ('admin', False, True), ('admin', True, True), ('admin', False, True),
('business', True, True), ('business', False, False), ('business', True, True), ('business', False, False),
('worker', True, True), ('worker', False, False), ('worker', True, True), ('worker', False, False),
('user', True, True), ('user', False, False) ('user', True, True), ('user', False, False)
]) ])
def test_user_update_job_annotations(self, org, privilege, job_staff, is_allow, def test_user_update_job_annotations(self, org, privilege, job_staff, expect_success,
find_job_staff_user, find_users, request_data, jobs_by_org, filter_jobs_with_shapes): find_job_staff_user, find_users, request_data, jobs_by_org, filter_jobs_with_shapes):
users = find_users(privilege=privilege) users = find_users(privilege=privilege)
jobs = jobs_by_org[org] jobs = jobs_by_org[org]
@ -262,15 +275,10 @@ class TestPatchJobAnnotations:
username, jid = find_job_staff_user(filtered_jobs, users, job_staff) username, jid = find_job_staff_user(filtered_jobs, users, job_staff)
data = request_data(jid) data = request_data(jid)
response = patch_method(username, f'jobs/{jid}/annotations', data, self._check_respone(username, jid, expect_success, data, org=org)
org_id=org, action='update')
self._test_check_respone(is_allow, response, data)
@pytest.mark.usefixtures('changedb') @pytest.mark.usefixtures('changedb')
class TestPatchJob: class TestPatchJob:
_ORG = 2
@pytest.fixture(scope='class') @pytest.fixture(scope='class')
def find_task_staff_user(self, is_task_staff): def find_task_staff_user(self, is_task_staff):
def find(jobs, users, is_staff): def find(jobs, users, is_staff):
@ -300,24 +308,47 @@ class TestPatchJob:
return find_new_assignee return find_new_assignee
@pytest.mark.parametrize('org', [2]) @pytest.mark.parametrize('org', [2])
@pytest.mark.parametrize('role, task_staff, is_allow', [ @pytest.mark.parametrize('role, task_staff, expect_success', [
('maintainer', False, True), ('owner', False, True), ('maintainer', False, True), ('owner', False, True),
('supervisor', False, False), ('worker', False, False), ('supervisor', False, False), ('worker', False, False),
('maintainer', True, True), ('owner', True, True), ('maintainer', True, True), ('owner', True, True),
('supervisor', True, True), ('worker', True, True) ('supervisor', True, True), ('worker', True, True)
]) ])
def test_member_update_job_assignee(self, org, role, task_staff, is_allow, def test_member_update_job_assignee(self, org, role, task_staff, expect_success,
find_task_staff_user, find_users, jobs_by_org, new_assignee, expected_data): find_task_staff_user, find_users, jobs_by_org, new_assignee, expected_data):
users, jobs = find_users(role=role, org=org), jobs_by_org[org] users, jobs = find_users(role=role, org=org), jobs_by_org[org]
user, jid = find_task_staff_user(jobs, users, task_staff) user, jid = find_task_staff_user(jobs, users, task_staff)
assignee = new_assignee(jid, user['id']) assignee = new_assignee(jid, user['id'])
response = patch_method(user['username'], f'jobs/{jid}', with make_api_client(user['username']) as client:
{'assignee': assignee}, org_id=self._ORG) (_, response) = client.jobs_api.partial_update(id=jid,
patched_job_write_request={'assignee': assignee}, org_id=org,
_parse_response=expect_success, _check_status=expect_success)
if expect_success:
assert response.status == HTTPStatus.OK
assert DeepDiff(expected_data(jid, assignee), json.loads(response.data),
exclude_paths="root['updated_date']", ignore_order=True) == {}
else:
assert response.status == HTTPStatus.FORBIDDEN
if is_allow: @pytest.mark.usefixtures('dontchangedb')
assert response.status_code == HTTPStatus.OK class TestJobDataset:
assert DeepDiff(expected_data(jid, assignee), response.json(), def _export_dataset(self, username, jid, **kwargs):
exclude_paths="root['updated_date']", ignore_order=True) == {} with make_api_client(username) as api_client:
else: return export_dataset(api_client.jobs_api.retrieve_dataset_endpoint, id=jid, **kwargs)
assert response.status_code == HTTPStatus.FORBIDDEN
def _export_annotations(self, username, jid, **kwargs):
with make_api_client(username) as api_client:
return export_dataset(api_client.jobs_api.retrieve_annotations_endpoint,
id=jid, **kwargs)
def test_can_export_dataset(self, admin_user: str, jobs_with_shapes: List):
job = jobs_with_shapes[0]
response = self._export_dataset(admin_user, job['id'], format='CVAT for images 1.1')
assert response.data
def test_can_export_annotations(self, admin_user: str, jobs_with_shapes: List):
job = jobs_with_shapes[0]
response = self._export_annotations(admin_user, job['id'], format='CVAT for images 1.1')
assert response.data

@ -13,9 +13,8 @@ import pytest
from copy import deepcopy from copy import deepcopy
from deepdiff import DeepDiff from deepdiff import DeepDiff
from cvat_sdk.models import DatasetFileRequest, ProjectWriteRequest
from shared.utils.config import get_method, patch_method, make_api_client from shared.utils.config import get_method, patch_method, make_api_client
from .utils import export_dataset
@pytest.mark.usefixtures('dontchangedb') @pytest.mark.usefixtures('dontchangedb')
@ -229,12 +228,12 @@ class TestGetProjectBackup:
class TestPostProjects: class TestPostProjects:
def _test_create_project_201(self, user, spec, **kwargs): def _test_create_project_201(self, user, spec, **kwargs):
with make_api_client(user) as api_client: with make_api_client(user) as api_client:
(_, response) = api_client.projects_api.create(ProjectWriteRequest(**spec), **kwargs) (_, response) = api_client.projects_api.create(spec, **kwargs)
assert response.status == HTTPStatus.CREATED assert response.status == HTTPStatus.CREATED
def _test_create_project_403(self, user, spec, **kwargs): def _test_create_project_403(self, user, spec, **kwargs):
with make_api_client(user) as api_client: with make_api_client(user) as api_client:
(_, response) = api_client.projects_api.create(ProjectWriteRequest(**spec), **kwargs, (_, response) = api_client.projects_api.create(spec, **kwargs,
_parse_response=False, _check_status=False) _parse_response=False, _check_status=False)
assert response.status == HTTPStatus.FORBIDDEN assert response.status == HTTPStatus.FORBIDDEN
@ -316,43 +315,30 @@ class TestPostProjects:
self._test_create_project_201(user['username'], spec, org_id=user['org']) self._test_create_project_201(user['username'], spec, org_id=user['org'])
@pytest.mark.usefixtures("changedb") @pytest.mark.usefixtures("changedb")
@pytest.mark.usefixtures("restore_cvat_data")
class TestImportExportDatasetProject: class TestImportExportDatasetProject:
def _test_export_project(self, username, project_id, format_name): def _test_export_project(self, username, pid, format_name):
with make_api_client(username) as api_client: with make_api_client(username) as api_client:
while True: return export_dataset(api_client.projects_api.retrieve_dataset_endpoint,
(_, response) = api_client.projects_api.retrieve_dataset(id=project_id, id=pid, format=format_name)
format=format_name)
if response.status == HTTPStatus.CREATED:
break
(_, response) = api_client.projects_api.retrieve_dataset(id=project_id,
format=format_name, action='download')
assert response.status == HTTPStatus.OK
return response
def _test_import_project(self, username, project_id, format_name, data): def _test_import_project(self, username, project_id, format_name, data):
with make_api_client(username) as api_client: with make_api_client(username) as api_client:
(_, response) = api_client.projects_api.create_dataset(id=project_id, (_, response) = api_client.projects_api.create_dataset(id=project_id,
format=format_name, dataset_file_request=DatasetFileRequest(**data), format=format_name, dataset_write_request=deepcopy(data),
_content_type="multipart/form-data") _content_type="multipart/form-data")
assert response.status == HTTPStatus.ACCEPTED assert response.status == HTTPStatus.ACCEPTED
while True: while True:
# TODO: Request schema doesn't describe this capability. # TODO: It's better be refactored to a separate endpoint to get request status
# It's better be refactored to a separate endpoint to get request status (_, response) = api_client.projects_api.retrieve_dataset(project_id,
response = get_method(username, f'projects/{project_id}/dataset',
action='import_status') action='import_status')
response.raise_for_status() if response.status == HTTPStatus.CREATED:
if response.status_code == HTTPStatus.CREATED:
break break
def test_can_import_dataset_in_org(self): def test_can_import_dataset_in_org(self, admin_user):
username = 'admin1'
project_id = 4 project_id = 4
response = self._test_export_project(username, project_id, 'CVAT for images 1.1') response = self._test_export_project(admin_user, project_id, 'CVAT for images 1.1')
tmp_file = io.BytesIO(response.data) tmp_file = io.BytesIO(response.data)
tmp_file.name = 'dataset.zip' tmp_file.name = 'dataset.zip'
@ -361,7 +347,7 @@ class TestImportExportDatasetProject:
'dataset_file': tmp_file, 'dataset_file': tmp_file,
} }
self._test_import_project(username, project_id, 'CVAT 1.1', import_data) self._test_import_project(admin_user, project_id, 'CVAT 1.1', import_data)
@pytest.mark.usefixtures('changedb') @pytest.mark.usefixtures('changedb')
class TestPatchProjectLabel: class TestPatchProjectLabel:

@ -7,14 +7,15 @@ import json
from copy import deepcopy from copy import deepcopy
from http import HTTPStatus from http import HTTPStatus
from time import sleep from time import sleep
from cvat_sdk.api_client.apis import TasksApi from cvat_sdk.api_client import models, apis
from cvat_sdk.api_client import models from cvat_sdk.core.helpers import get_paginated_collection
import pytest import pytest
from deepdiff import DeepDiff from deepdiff import DeepDiff
from shared.utils.config import make_api_client from shared.utils.config import make_api_client
from shared.utils.helpers import generate_image_files from shared.utils.helpers import generate_image_files
from .utils import export_dataset
def get_cloud_storage_content(username, cloud_storage_id, manifest): def get_cloud_storage_content(username, cloud_storage_id, manifest):
with make_api_client(username) as api_client: with make_api_client(username) as api_client:
@ -27,12 +28,9 @@ def get_cloud_storage_content(username, cloud_storage_id, manifest):
class TestGetTasks: class TestGetTasks:
def _test_task_list_200(self, user, project_id, data, exclude_paths = '', **kwargs): def _test_task_list_200(self, user, project_id, data, exclude_paths = '', **kwargs):
with make_api_client(user) as api_client: with make_api_client(user) as api_client:
(_, response) = api_client.projects_api.list_tasks(project_id, **kwargs, results = get_paginated_collection(api_client.projects_api.list_tasks_endpoint,
_parse_response=False) return_json=True, id=project_id, **kwargs)
assert response.status == HTTPStatus.OK assert DeepDiff(data, results, ignore_order=True, exclude_paths=exclude_paths) == {}
response_data = json.loads(response.data)
assert DeepDiff(data, response_data['results'], ignore_order=True, exclude_paths=exclude_paths) == {}
def _test_task_list_403(self, user, project_id, **kwargs): def _test_task_list_403(self, user, project_id, **kwargs):
with make_api_client(user) as api_client: with make_api_client(user) as api_client:
@ -60,7 +58,7 @@ class TestGetTasks:
for user in staff_users: for user in staff_users:
with make_api_client(user['username']) as api_client: with make_api_client(user['username']) as api_client:
(_, response) = api_client.tasks_api.list(**kwargs, _parse_response=False) (_, response) = api_client.tasks_api.list(**kwargs)
assert response.status == HTTPStatus.OK assert response.status == HTTPStatus.OK
response_data = json.loads(response.data) response_data = json.loads(response.data)
@ -113,12 +111,12 @@ class TestGetTasks:
class TestPostTasks: class TestPostTasks:
def _test_create_task_201(self, user, spec, **kwargs): def _test_create_task_201(self, user, spec, **kwargs):
with make_api_client(user) as api_client: with make_api_client(user) as api_client:
(_, response) = api_client.tasks_api.create(models.TaskWriteRequest(**spec), **kwargs) (_, response) = api_client.tasks_api.create(spec, **kwargs)
assert response.status == HTTPStatus.CREATED assert response.status == HTTPStatus.CREATED
def _test_create_task_403(self, user, spec, **kwargs): def _test_create_task_403(self, user, spec, **kwargs):
with make_api_client(user) as api_client: with make_api_client(user) as api_client:
(_, response) = api_client.tasks_api.create(models.TaskWriteRequest(**spec), **kwargs, (_, response) = api_client.tasks_api.create(spec, **kwargs,
_parse_response=False, _check_status=False) _parse_response=False, _check_status=False)
assert response.status == HTTPStatus.FORBIDDEN assert response.status == HTTPStatus.FORBIDDEN
@ -210,10 +208,9 @@ class TestPatchTaskAnnotations:
data = request_data(tid) data = request_data(tid)
with make_api_client(username) as api_client: with make_api_client(username) as api_client:
patched_data = models.PatchedTaskWriteRequest(**deepcopy(data))
(_, response) = api_client.tasks_api.partial_update_annotations( (_, response) = api_client.tasks_api.partial_update_annotations(
id=tid, action='update', org=org, id=tid, action='update', org=org,
patched_task_write_request=patched_data, patched_labeled_data_request=deepcopy(data),
_parse_response=False, _check_status=False) _parse_response=False, _check_status=False)
self._test_check_response(is_allow, response, data) self._test_check_response(is_allow, response, data)
@ -233,30 +230,23 @@ class TestPatchTaskAnnotations:
data = request_data(tid) data = request_data(tid)
with make_api_client(username) as api_client: with make_api_client(username) as api_client:
patched_data = models.PatchedTaskWriteRequest(**deepcopy(data))
(_, response) = api_client.tasks_api.partial_update_annotations( (_, response) = api_client.tasks_api.partial_update_annotations(
id=tid, org_id=org, action='update', id=tid, org_id=org, action='update',
patched_task_write_request=patched_data, patched_labeled_data_request=deepcopy(data),
_parse_response=False, _check_status=False) _parse_response=False, _check_status=False)
self._test_check_response(is_allow, response, data) self._test_check_response(is_allow, response, data)
@pytest.mark.usefixtures('dontchangedb') @pytest.mark.usefixtures('dontchangedb')
class TestGetTaskDataset: class TestGetTaskDataset:
def _test_export_project(self, username, tid, **kwargs): def _test_export_task(self, username, tid, **kwargs):
with make_api_client(username) as api_client: with make_api_client(username) as api_client:
(_, response) = api_client.tasks_api.retrieve_dataset(id=tid, **kwargs) return export_dataset(api_client.tasks_api.retrieve_dataset_endpoint, id=tid, **kwargs)
assert response.status == HTTPStatus.ACCEPTED
(_, response) = api_client.tasks_api.retrieve_dataset(id=tid, **kwargs)
assert response.status == HTTPStatus.CREATED
(_, response) = api_client.tasks_api.retrieve_dataset(id=tid, **kwargs, action='download')
assert response.status == HTTPStatus.OK
def test_admin_can_export_task_dataset(self, tasks_with_shapes): def test_can_export_task_dataset(self, admin_user, tasks_with_shapes):
task = tasks_with_shapes[0] task = tasks_with_shapes[0]
self._test_export_project('admin1', task['id'], format='CVAT for images 1.1') response = self._test_export_task(admin_user, task['id'], format='CVAT for images 1.1')
assert response.data
@pytest.mark.usefixtures("changedb") @pytest.mark.usefixtures("changedb")
@pytest.mark.usefixtures("restore_cvat_data") @pytest.mark.usefixtures("restore_cvat_data")
@ -264,7 +254,7 @@ class TestPostTaskData:
_USERNAME = 'admin1' _USERNAME = 'admin1'
@staticmethod @staticmethod
def _wait_until_task_is_created(api: TasksApi, task_id: int) -> models.RqStatus: def _wait_until_task_is_created(api: apis.TasksApi, task_id: int) -> models.RqStatus:
for _ in range(100): for _ in range(100):
(status, _) = api.retrieve_status(task_id) (status, _) = api.retrieve_status(task_id)
if status.state.value in ['Finished', 'Failed']: if status.state.value in ['Finished', 'Failed']:
@ -274,11 +264,10 @@ class TestPostTaskData:
def _test_create_task(self, username, spec, data, content_type, **kwargs): def _test_create_task(self, username, spec, data, content_type, **kwargs):
with make_api_client(username) as api_client: with make_api_client(username) as api_client:
(task, response) = api_client.tasks_api.create(models.TaskWriteRequest(**spec), **kwargs) (task, response) = api_client.tasks_api.create(spec, **kwargs)
assert response.status == HTTPStatus.CREATED assert response.status == HTTPStatus.CREATED
task_data = models.DataRequest(**data) (_, response) = api_client.tasks_api.create_data(task.id, data_request=deepcopy(data),
(_, response) = api_client.tasks_api.create_data(task.id, task_data,
_content_type=content_type, **kwargs) _content_type=content_type, **kwargs)
assert response.status == HTTPStatus.ACCEPTED assert response.status == HTTPStatus.ACCEPTED

@ -0,0 +1,26 @@
# Copyright (C) 2022 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
from http import HTTPStatus
from time import sleep
from cvat_sdk.api_client.api_client import Endpoint
from urllib3 import HTTPResponse
def export_dataset(
endpoint: Endpoint, *, max_retries: int = 20, interval: float = 0.1, **kwargs
) -> HTTPResponse:
for _ in range(max_retries):
(_, response) = endpoint.call_with_http_info(**kwargs, _parse_response=False)
if response.status == HTTPStatus.CREATED:
break
assert response.status == HTTPStatus.ACCEPTED
sleep(interval)
assert response.status == HTTPStatus.CREATED
(_, response) = endpoint.call_with_http_info(**kwargs, action="download", _parse_response=False)
assert response.status == HTTPStatus.OK
return response

@ -2,10 +2,16 @@
# #
# SPDX-License-Identifier: MIT # SPDX-License-Identifier: MIT
from pathlib import Path
import pytest import pytest
from cvat_sdk import Client from cvat_sdk import Client
from PIL import Image
from shared.utils.config import BASE_URL from shared.utils.config import BASE_URL
from shared.utils.helpers import generate_image_file
from .util import generate_coco_json
@pytest.fixture @pytest.fixture
@ -20,3 +26,22 @@ def fxt_client(fxt_logger):
with client: with client:
yield client yield client
@pytest.fixture
def fxt_image_file(tmp_path: Path):
img_path = tmp_path / "img.png"
with img_path.open("wb") as f:
f.write(generate_image_file(filename=str(img_path), size=(5, 10)).getvalue())
return img_path
@pytest.fixture
def fxt_coco_file(tmp_path: Path, fxt_image_file: Path):
img_filename = fxt_image_file
img_size = Image.open(img_filename).size
ann_filename = tmp_path / "coco.json"
generate_coco_json(ann_filename, img_info=(img_filename, *img_size))
yield ann_filename

@ -0,0 +1,236 @@
# Copyright (C) 2022 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
import io
from logging import Logger
from pathlib import Path
from typing import Tuple
import pytest
from cvat_sdk import Client
from cvat_sdk.api_client import exceptions, models
from cvat_sdk.core.proxies.tasks import ResourceType, Task
from shared.utils.config import USER_PASS
class TestIssuesUsecases:
@pytest.fixture(autouse=True)
def setup(
self,
changedb, # force fixture call order to allow DB setup
tmp_path: Path,
fxt_logger: Tuple[Logger, io.StringIO],
fxt_client: Client,
fxt_stdout: io.StringIO,
admin_user: str,
):
self.tmp_path = tmp_path
_, self.logger_stream = fxt_logger
self.client = fxt_client
self.stdout = fxt_stdout
self.user = admin_user
self.client.login((self.user, USER_PASS))
yield
@pytest.fixture
def fxt_new_task(self, fxt_image_file: Path):
task = self.client.tasks.create_from_data(
spec={
"name": "test_task",
"labels": [{"name": "car"}, {"name": "person"}],
},
resource_type=ResourceType.LOCAL,
resources=[str(fxt_image_file)],
data_params={"image_quality": 80},
)
return task
def test_can_retrieve_issue(self, fxt_new_task: Task):
issue = self.client.issues.create(
models.IssueWriteRequest(
frame=0,
position=[2.0, 4.0],
job=fxt_new_task.get_jobs()[0].id,
message="hello",
)
)
retrieved_issue = self.client.issues.retrieve(issue.id)
assert issue.id == retrieved_issue.id
assert self.stdout.getvalue() == ""
def test_can_list_issues(self, fxt_new_task: Task):
issue = self.client.issues.create(
models.IssueWriteRequest(
frame=0,
position=[2.0, 4.0],
job=fxt_new_task.get_jobs()[0].id,
message="hello",
assignee=self.client.users.list()[0].id,
)
)
issues = self.client.issues.list()
assert any(issue.id == j.id for j in issues)
assert self.stdout.getvalue() == ""
def test_can_list_comments(self, fxt_new_task: Task):
issue = self.client.issues.create(
models.IssueWriteRequest(
frame=0,
position=[2.0, 4.0],
job=fxt_new_task.get_jobs()[0].id,
message="hello",
)
)
comment = self.client.comments.create(models.CommentWriteRequest(issue.id, message="hi!"))
issue.fetch()
comment_ids = {c.id for c in issue.comments}
assert len(comment_ids) == 2
assert comment.id in comment_ids
assert self.stdout.getvalue() == ""
def test_can_modify_issue(self, fxt_new_task: Task):
issue = self.client.issues.create(
models.IssueWriteRequest(
frame=0,
position=[2.0, 4.0],
job=fxt_new_task.get_jobs()[0].id,
message="hello",
)
)
issue.update(models.PatchedIssueWriteRequest(resolved=True))
retrieved_issue = self.client.issues.retrieve(issue.id)
assert retrieved_issue.resolved is True
assert issue.resolved == retrieved_issue.resolved
assert self.stdout.getvalue() == ""
def test_can_remove_issue(self, fxt_new_task: Task):
issue = self.client.issues.create(
models.IssueWriteRequest(
frame=0,
position=[2.0, 4.0],
job=fxt_new_task.get_jobs()[0].id,
message="hello",
)
)
issue.remove()
with pytest.raises(exceptions.NotFoundException):
issue.fetch()
with pytest.raises(exceptions.NotFoundException):
self.client.comments.retrieve(issue.comments[0].id)
assert self.stdout.getvalue() == ""
class TestCommentsUsecases:
@pytest.fixture(autouse=True)
def setup(
self,
changedb, # force fixture call order to allow DB setup
tmp_path: Path,
fxt_logger: Tuple[Logger, io.StringIO],
fxt_client: Client,
fxt_stdout: io.StringIO,
admin_user: str,
):
self.tmp_path = tmp_path
_, self.logger_stream = fxt_logger
self.client = fxt_client
self.stdout = fxt_stdout
self.user = admin_user
self.client.login((self.user, USER_PASS))
yield
@pytest.fixture
def fxt_new_task(self, fxt_image_file: Path):
task = self.client.tasks.create_from_data(
spec={
"name": "test_task",
"labels": [{"name": "car"}, {"name": "person"}],
},
resource_type=ResourceType.LOCAL,
resources=[str(fxt_image_file)],
data_params={"image_quality": 80},
)
return task
def test_can_retrieve_comment(self, fxt_new_task: Task):
issue = self.client.issues.create(
models.IssueWriteRequest(
frame=0,
position=[2.0, 4.0],
job=fxt_new_task.get_jobs()[0].id,
message="hello",
)
)
comment = self.client.comments.create(models.CommentWriteRequest(issue.id, message="hi!"))
retrieved_comment = self.client.comments.retrieve(comment.id)
assert comment.id == retrieved_comment.id
assert self.stdout.getvalue() == ""
def test_can_list_comments(self, fxt_new_task: Task):
issue = self.client.issues.create(
models.IssueWriteRequest(
frame=0,
position=[2.0, 4.0],
job=fxt_new_task.get_jobs()[0].id,
message="hello",
)
)
comment = self.client.comments.create(models.CommentWriteRequest(issue.id, message="hi!"))
comments = self.client.comments.list()
assert any(comment.id == c.id for c in comments)
assert self.stdout.getvalue() == ""
def test_can_modify_comment(self, fxt_new_task: Task):
issue = self.client.issues.create(
models.IssueWriteRequest(
frame=0,
position=[2.0, 4.0],
job=fxt_new_task.get_jobs()[0].id,
message="hello",
)
)
comment = self.client.comments.create(models.CommentWriteRequest(issue.id, message="hi!"))
comment.update(models.PatchedCommentWriteRequest(message="bar"))
retrieved_comment = self.client.comments.retrieve(comment.id)
assert retrieved_comment.message == "bar"
assert comment.message == retrieved_comment.message
assert self.stdout.getvalue() == ""
def test_can_remove_comment(self, fxt_new_task: Task):
issue = self.client.issues.create(
models.IssueWriteRequest(
frame=0,
position=[2.0, 4.0],
job=fxt_new_task.get_jobs()[0].id,
message="hello",
)
)
comment = self.client.comments.create(models.CommentWriteRequest(issue.id, message="hi!"))
comment.remove()
with pytest.raises(exceptions.NotFoundException):
comment.fetch()
assert self.stdout.getvalue() == ""

@ -0,0 +1,280 @@
# Copyright (C) 2022 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
import io
import os.path as osp
from logging import Logger
from pathlib import Path
from typing import Tuple
import pytest
from cvat_sdk import Client
from cvat_sdk.api_client import models
from cvat_sdk.core.proxies.tasks import ResourceType, Task
from PIL import Image
from shared.utils.config import USER_PASS
from .util import make_pbar
class TestJobUsecases:
@pytest.fixture(autouse=True)
def setup(
self,
changedb, # force fixture call order to allow DB setup
tmp_path: Path,
fxt_logger: Tuple[Logger, io.StringIO],
fxt_client: Client,
fxt_stdout: io.StringIO,
admin_user: str,
):
self.tmp_path = tmp_path
_, self.logger_stream = fxt_logger
self.client = fxt_client
self.stdout = fxt_stdout
self.user = admin_user
self.client.login((self.user, USER_PASS))
yield
@pytest.fixture
def fxt_new_task(self, fxt_image_file: Path):
task = self.client.tasks.create_from_data(
spec={
"name": "test_task",
"labels": [{"name": "car"}, {"name": "person"}],
},
resource_type=ResourceType.LOCAL,
resources=[str(fxt_image_file)],
data_params={"image_quality": 80},
)
return task
@pytest.fixture
def fxt_task_with_shapes(self, fxt_new_task: Task):
fxt_new_task.set_annotations(
models.LabeledDataRequest(
shapes=[
models.LabeledShapeRequest(
frame=0,
label_id=fxt_new_task.labels[0].id,
type="rectangle",
points=[1, 1, 2, 2],
),
],
)
)
return fxt_new_task
def test_can_retrieve_job(self, fxt_new_task: Task):
job_id = fxt_new_task.get_jobs()[0].id
job = self.client.jobs.retrieve(job_id)
assert job.id == job_id
assert self.stdout.getvalue() == ""
def test_can_list_jobs(self, fxt_new_task: Task):
task_job_ids = set(j.id for j in fxt_new_task.get_jobs())
jobs = self.client.jobs.list()
assert len(task_job_ids) != 0
assert task_job_ids.issubset(j.id for j in jobs)
assert self.stdout.getvalue() == ""
def test_can_update_job_field_directly(self, fxt_new_task: Task):
job = self.client.jobs.list()[0]
assert not job.assignee
new_assignee = self.client.users.list()[0]
job.update({"assignee": new_assignee.id})
updated_job = self.client.jobs.retrieve(job.id)
assert updated_job.assignee.id == new_assignee.id
assert self.stdout.getvalue() == ""
@pytest.mark.parametrize("include_images", (True, False))
def test_can_download_dataset(self, fxt_new_task: Task, include_images: bool):
pbar_out = io.StringIO()
pbar = make_pbar(file=pbar_out)
task_id = fxt_new_task.id
path = str(self.tmp_path / f"task_{task_id}-cvat.zip")
job = self.client.jobs.retrieve(task_id)
job.export_dataset(
format_name="CVAT for images 1.1",
filename=path,
pbar=pbar,
include_images=include_images,
)
assert "100%" in pbar_out.getvalue().strip("\r").split("\r")[-1]
assert osp.isfile(path)
assert self.stdout.getvalue() == ""
def test_can_download_preview(self, fxt_new_task: Task):
frame_encoded = fxt_new_task.get_jobs()[0].get_preview()
assert Image.open(frame_encoded).size != 0
assert self.stdout.getvalue() == ""
@pytest.mark.parametrize("quality", ("compressed", "original"))
def test_can_download_frame(self, fxt_new_task: Task, quality: str):
frame_encoded = fxt_new_task.get_jobs()[0].get_frame(0, quality=quality)
assert Image.open(frame_encoded).size != 0
assert self.stdout.getvalue() == ""
@pytest.mark.parametrize("quality", ("compressed", "original"))
def test_can_download_frames(self, fxt_new_task: Task, quality: str):
fxt_new_task.get_jobs()[0].download_frames(
[0],
quality=quality,
outdir=str(self.tmp_path),
filename_pattern="frame-{frame_id}{frame_ext}",
)
assert osp.isfile(self.tmp_path / "frame-0.jpg")
assert self.stdout.getvalue() == ""
def test_can_upload_annotations(self, fxt_new_task: Task, fxt_coco_file: Path):
pbar_out = io.StringIO()
pbar = make_pbar(file=pbar_out)
fxt_new_task.get_jobs()[0].import_annotations(
format_name="COCO 1.0", filename=str(fxt_coco_file), pbar=pbar
)
assert "uploaded" in self.logger_stream.getvalue()
assert "100%" in pbar_out.getvalue().strip("\r").split("\r")[-1]
assert self.stdout.getvalue() == ""
def test_can_get_meta(self, fxt_new_task: Task):
meta = fxt_new_task.get_jobs()[0].get_meta()
assert meta.image_quality == 80
assert meta.size == 1
assert len(meta.frames) == meta.size
assert meta.frames[0].name == "img.png"
assert meta.frames[0].width == 5
assert meta.frames[0].height == 10
assert not meta.deleted_frames
assert self.stdout.getvalue() == ""
def test_can_remove_frames(self, fxt_new_task: Task):
fxt_new_task.get_jobs()[0].remove_frames_by_ids([0])
meta = fxt_new_task.get_jobs()[0].get_meta()
assert meta.deleted_frames == [0]
assert self.stdout.getvalue() == ""
def test_can_get_issues(self, fxt_new_task: Task):
issue = self.client.issues.create(
models.IssueWriteRequest(
frame=0,
position=[2.0, 4.0],
job=fxt_new_task.get_jobs()[0].id,
message="hello",
)
)
job_issue_ids = set(j.id for j in fxt_new_task.get_jobs()[0].get_issues())
assert {issue.id} == job_issue_ids
assert self.stdout.getvalue() == ""
def test_can_get_annotations(self, fxt_task_with_shapes: Task):
anns = fxt_task_with_shapes.get_jobs()[0].get_annotations()
assert len(anns.shapes) == 1
assert anns.shapes[0].type.value == "rectangle"
assert self.stdout.getvalue() == ""
def test_can_set_annotations(self, fxt_new_task: Task):
fxt_new_task.get_jobs()[0].set_annotations(
models.LabeledDataRequest(
tags=[models.LabeledImageRequest(frame=0, label_id=fxt_new_task.labels[0].id)],
)
)
anns = fxt_new_task.get_jobs()[0].get_annotations()
assert len(anns.tags) == 1
assert self.stdout.getvalue() == ""
def test_can_clear_annotations(self, fxt_task_with_shapes: Task):
fxt_task_with_shapes.get_jobs()[0].remove_annotations()
anns = fxt_task_with_shapes.get_jobs()[0].get_annotations()
assert len(anns.tags) == 0
assert len(anns.tracks) == 0
assert len(anns.shapes) == 0
assert self.stdout.getvalue() == ""
def test_can_remove_annotations(self, fxt_new_task: Task):
fxt_new_task.get_jobs()[0].set_annotations(
models.LabeledDataRequest(
shapes=[
models.LabeledShapeRequest(
frame=0,
label_id=fxt_new_task.labels[0].id,
type="rectangle",
points=[1, 1, 2, 2],
),
models.LabeledShapeRequest(
frame=0,
label_id=fxt_new_task.labels[0].id,
type="rectangle",
points=[2, 2, 3, 3],
),
],
)
)
anns = fxt_new_task.get_jobs()[0].get_annotations()
fxt_new_task.get_jobs()[0].remove_annotations(ids=[anns.shapes[0].id])
anns = fxt_new_task.get_jobs()[0].get_annotations()
assert len(anns.tags) == 0
assert len(anns.tracks) == 0
assert len(anns.shapes) == 1
assert self.stdout.getvalue() == ""
def test_can_update_annotations(self, fxt_task_with_shapes: Task):
fxt_task_with_shapes.get_jobs()[0].update_annotations(
models.PatchedLabeledDataRequest(
shapes=[
models.LabeledShapeRequest(
frame=0,
label_id=fxt_task_with_shapes.labels[0].id,
type="rectangle",
points=[0, 1, 2, 3],
),
],
tracks=[
models.LabeledTrackRequest(
frame=0,
label_id=fxt_task_with_shapes.labels[0].id,
shapes=[
models.TrackedShapeRequest(
frame=0, type="polygon", points=[3, 2, 2, 3, 3, 4]
),
],
)
],
tags=[
models.LabeledImageRequest(frame=0, label_id=fxt_task_with_shapes.labels[0].id)
],
)
)
anns = fxt_task_with_shapes.get_jobs()[0].get_annotations()
assert len(anns.shapes) == 2
assert len(anns.tracks) == 1
assert len(anns.tags) == 1
assert self.stdout.getvalue() == ""

@ -9,15 +9,15 @@ from pathlib import Path
from typing import Tuple from typing import Tuple
import pytest import pytest
from cvat_sdk import Client, exceptions from cvat_sdk import Client, models
from cvat_sdk.core.tasks import TaskProxy from cvat_sdk.api_client import exceptions
from cvat_sdk.core.types import ResourceType from cvat_sdk.core.proxies.tasks import ResourceType, Task
from PIL import Image from PIL import Image
from shared.utils.config import USER_PASS from shared.utils.config import USER_PASS
from shared.utils.helpers import generate_image_file, generate_image_files from shared.utils.helpers import generate_image_files
from .util import generate_coco_json, make_pbar from .util import make_pbar
class TestTaskUsecases: class TestTaskUsecases:
@ -40,29 +40,8 @@ class TestTaskUsecases:
yield yield
self.tmp_path = None
self.client = None
self.stdout = None
@pytest.fixture
def fxt_image_file(self):
img_path = self.tmp_path / "img.png"
with img_path.open("wb") as f:
f.write(generate_image_file(filename=str(img_path)).getvalue())
return img_path
@pytest.fixture
def fxt_coco_file(self, fxt_image_file):
img_filename = fxt_image_file
img_size = Image.open(img_filename).size
ann_filename = self.tmp_path / "coco.json"
generate_coco_json(ann_filename, img_info=(img_filename, *img_size))
yield ann_filename
@pytest.fixture @pytest.fixture
def fxt_backup_file(self, fxt_new_task: TaskProxy, fxt_coco_file: str): def fxt_backup_file(self, fxt_new_task: Task, fxt_coco_file: str):
backup_path = self.tmp_path / "backup.zip" backup_path = self.tmp_path / "backup.zip"
fxt_new_task.import_annotations("COCO 1.0", filename=fxt_coco_file) fxt_new_task.import_annotations("COCO 1.0", filename=fxt_coco_file)
@ -71,18 +50,36 @@ class TestTaskUsecases:
yield backup_path yield backup_path
@pytest.fixture @pytest.fixture
def fxt_new_task(self, fxt_image_file): def fxt_new_task(self, fxt_image_file: Path):
task = self.client.create_task( task = self.client.tasks.create_from_data(
spec={ spec={
"name": "test_task", "name": "test_task",
"labels": [{"name": "car"}, {"name": "person"}], "labels": [{"name": "car"}, {"name": "person"}],
}, },
resource_type=ResourceType.LOCAL, resource_type=ResourceType.LOCAL,
resources=[fxt_image_file], resources=[str(fxt_image_file)],
data_params={"image_quality": 80},
) )
return task return task
@pytest.fixture
def fxt_task_with_shapes(self, fxt_new_task: Task):
fxt_new_task.set_annotations(
models.LabeledDataRequest(
shapes=[
models.LabeledShapeRequest(
frame=0,
label_id=fxt_new_task.labels[0].id,
type="rectangle",
points=[1, 1, 2, 2],
),
],
)
)
return fxt_new_task
def test_can_create_task_with_local_data(self): def test_can_create_task_with_local_data(self):
pbar_out = io.StringIO() pbar_out = io.StringIO()
pbar = make_pbar(file=pbar_out) pbar = make_pbar(file=pbar_out)
@ -117,7 +114,7 @@ class TestTaskUsecases:
fd.write(f.getvalue()) fd.write(f.getvalue())
task_files[i] = str(fname) task_files[i] = str(fname)
task = self.client.create_task( task = self.client.tasks.create_from_data(
spec=task_spec, spec=task_spec,
data_params=data_params, data_params=data_params,
resource_type=ResourceType.LOCAL, resource_type=ResourceType.LOCAL,
@ -143,7 +140,7 @@ class TestTaskUsecases:
} }
with pytest.raises(exceptions.ApiException) as capture: with pytest.raises(exceptions.ApiException) as capture:
self.client.create_task( self.client.tasks.create_from_data(
spec=task_spec, spec=task_spec,
resource_type=ResourceType.LOCAL, resource_type=ResourceType.LOCAL,
resources=[], resources=[],
@ -153,53 +150,57 @@ class TestTaskUsecases:
assert capture.match("No media data found") assert capture.match("No media data found")
assert self.stdout.getvalue() == "" assert self.stdout.getvalue() == ""
def test_can_retrieve_task(self, fxt_new_task): def test_can_retrieve_task(self, fxt_new_task: Task):
task_id = fxt_new_task.id task_id = fxt_new_task.id
task = self.client.retrieve_task(task_id) task = self.client.tasks.retrieve(task_id)
assert task.id == task_id assert task.id == task_id
assert self.stdout.getvalue() == ""
def test_can_list_tasks(self, fxt_new_task): def test_can_list_tasks(self, fxt_new_task: Task):
task_id = fxt_new_task.id task_id = fxt_new_task.id
tasks = self.client.list_tasks() tasks = self.client.tasks.list()
assert any(t.id == task_id for t in tasks) assert any(t.id == task_id for t in tasks)
assert self.stdout.getvalue() == "" assert self.stdout.getvalue() == ""
def test_can_delete_tasks_by_ids(self, fxt_new_task): def test_can_update_task(self, fxt_new_task: Task):
task_id = fxt_new_task.id fxt_new_task.update(models.PatchedTaskWriteRequest(name="foo"))
old_tasks = self.client.list_tasks()
self.client.delete_tasks([task_id]) retrieved_task = self.client.tasks.retrieve(fxt_new_task.id)
assert retrieved_task.name == "foo"
assert fxt_new_task.name == retrieved_task.name
assert self.stdout.getvalue() == ""
new_tasks = self.client.list_tasks() def test_can_delete_task(self, fxt_new_task: Task):
assert any(t.id == task_id for t in old_tasks) fxt_new_task.remove()
assert all(t.id != task_id for t in new_tasks)
assert self.logger_stream.getvalue(), f".*Task ID {task_id} deleted.*" with pytest.raises(exceptions.NotFoundException):
fxt_new_task.fetch()
assert self.stdout.getvalue() == "" assert self.stdout.getvalue() == ""
def test_can_delete_task(self, fxt_new_task): def test_can_delete_tasks_by_ids(self, fxt_new_task: Task):
task_id = fxt_new_task.id task_id = fxt_new_task.id
task = self.client.retrieve_task(task_id) old_tasks = self.client.tasks.list()
old_tasks = self.client.list_tasks()
task.remove() self.client.tasks.remove_by_ids([task_id])
new_tasks = self.client.list_tasks() new_tasks = self.client.tasks.list()
assert any(t.id == task_id for t in old_tasks) assert any(t.id == task_id for t in old_tasks)
assert all(t.id != task_id for t in new_tasks) assert all(t.id != task_id for t in new_tasks)
assert self.logger_stream.getvalue(), f".*Task ID {task_id} deleted.*"
assert self.stdout.getvalue() == "" assert self.stdout.getvalue() == ""
@pytest.mark.parametrize("include_images", (True, False)) @pytest.mark.parametrize("include_images", (True, False))
def test_can_download_dataset(self, fxt_new_task: TaskProxy, include_images: bool): def test_can_download_dataset(self, fxt_new_task: Task, include_images: bool):
pbar_out = io.StringIO() pbar_out = io.StringIO()
pbar = make_pbar(file=pbar_out) pbar = make_pbar(file=pbar_out)
task_id = fxt_new_task.id task_id = fxt_new_task.id
path = str(self.tmp_path / f"task_{task_id}-cvat.zip") path = str(self.tmp_path / f"task_{task_id}-cvat.zip")
task = self.client.retrieve_task(task_id) task = self.client.tasks.retrieve(task_id)
task.export_dataset( task.export_dataset(
format_name="CVAT for images 1.1", format_name="CVAT for images 1.1",
filename=path, filename=path,
@ -211,28 +212,34 @@ class TestTaskUsecases:
assert osp.isfile(path) assert osp.isfile(path)
assert self.stdout.getvalue() == "" assert self.stdout.getvalue() == ""
def test_can_download_backup(self, fxt_new_task): def test_can_download_backup(self, fxt_new_task: Task):
pbar_out = io.StringIO() pbar_out = io.StringIO()
pbar = make_pbar(file=pbar_out) pbar = make_pbar(file=pbar_out)
task_id = fxt_new_task.id task_id = fxt_new_task.id
path = str(self.tmp_path / f"task_{task_id}-backup.zip") path = str(self.tmp_path / f"task_{task_id}-backup.zip")
task = self.client.retrieve_task(task_id) task = self.client.tasks.retrieve(task_id)
task.download_backup(filename=path, pbar=pbar) task.download_backup(filename=path, pbar=pbar)
assert "100%" in pbar_out.getvalue().strip("\r").split("\r")[-1] assert "100%" in pbar_out.getvalue().strip("\r").split("\r")[-1]
assert osp.isfile(path) assert osp.isfile(path)
assert self.stdout.getvalue() == "" assert self.stdout.getvalue() == ""
def test_can_download_preview(self, fxt_new_task: Task):
frame_encoded = fxt_new_task.get_preview()
assert Image.open(frame_encoded).size != 0
assert self.stdout.getvalue() == ""
@pytest.mark.parametrize("quality", ("compressed", "original")) @pytest.mark.parametrize("quality", ("compressed", "original"))
def test_can_download_frame(self, fxt_new_task: TaskProxy, quality: str): def test_can_download_frame(self, fxt_new_task: Task, quality: str):
frame_encoded = fxt_new_task.retrieve_frame(0, quality=quality) frame_encoded = fxt_new_task.get_frame(0, quality=quality)
assert Image.open(frame_encoded).size != 0 assert Image.open(frame_encoded).size != 0
assert self.stdout.getvalue() == "" assert self.stdout.getvalue() == ""
@pytest.mark.parametrize("quality", ("compressed", "original")) @pytest.mark.parametrize("quality", ("compressed", "original"))
def test_can_download_frames(self, fxt_new_task: TaskProxy, quality: str): def test_can_download_frames(self, fxt_new_task: Task, quality: str):
fxt_new_task.download_frames( fxt_new_task.download_frames(
[0], [0],
quality=quality, quality=quality,
@ -243,7 +250,7 @@ class TestTaskUsecases:
assert osp.isfile(self.tmp_path / "frame-0.jpg") assert osp.isfile(self.tmp_path / "frame-0.jpg")
assert self.stdout.getvalue() == "" assert self.stdout.getvalue() == ""
def test_can_upload_annotations(self, fxt_new_task: TaskProxy, fxt_coco_file: Path): def test_can_upload_annotations(self, fxt_new_task: Task, fxt_coco_file: Path):
pbar_out = io.StringIO() pbar_out = io.StringIO()
pbar = make_pbar(file=pbar_out) pbar = make_pbar(file=pbar_out)
@ -251,19 +258,135 @@ class TestTaskUsecases:
format_name="COCO 1.0", filename=str(fxt_coco_file), pbar=pbar format_name="COCO 1.0", filename=str(fxt_coco_file), pbar=pbar
) )
assert str(fxt_coco_file) in self.logger_stream.getvalue() assert "uploaded" in self.logger_stream.getvalue()
assert "100%" in pbar_out.getvalue().strip("\r").split("\r")[-1] assert "100%" in pbar_out.getvalue().strip("\r").split("\r")[-1]
assert self.stdout.getvalue() == "" assert self.stdout.getvalue() == ""
def test_can_create_from_backup(self, fxt_new_task: TaskProxy, fxt_backup_file: Path): def test_can_create_from_backup(self, fxt_new_task: Task, fxt_backup_file: Path):
pbar_out = io.StringIO() pbar_out = io.StringIO()
pbar = make_pbar(file=pbar_out) pbar = make_pbar(file=pbar_out)
task = self.client.create_task_from_backup(str(fxt_backup_file), pbar=pbar) task = self.client.tasks.create_from_backup(str(fxt_backup_file), pbar=pbar)
assert task.id assert task.id
assert task.id != fxt_new_task.id assert task.id != fxt_new_task.id
assert task.size == fxt_new_task.size assert task.size == fxt_new_task.size
assert "exported sucessfully" in self.logger_stream.getvalue() assert "imported sucessfully" in self.logger_stream.getvalue()
assert "100%" in pbar_out.getvalue().strip("\r").split("\r")[-1] assert "100%" in pbar_out.getvalue().strip("\r").split("\r")[-1]
assert self.stdout.getvalue() == "" assert self.stdout.getvalue() == ""
def test_can_get_jobs(self, fxt_new_task: Task):
jobs = fxt_new_task.get_jobs()
assert len(jobs) != 0
assert self.stdout.getvalue() == ""
def test_can_get_meta(self, fxt_new_task: Task):
meta = fxt_new_task.get_meta()
assert meta.image_quality == 80
assert meta.size == 1
assert len(meta.frames) == meta.size
assert meta.frames[0].name == "img.png"
assert meta.frames[0].width == 5
assert meta.frames[0].height == 10
assert not meta.deleted_frames
assert self.stdout.getvalue() == ""
def test_can_remove_frames(self, fxt_new_task: Task):
fxt_new_task.remove_frames_by_ids([0])
meta = fxt_new_task.get_meta()
assert meta.deleted_frames == [0]
assert self.stdout.getvalue() == ""
def test_can_get_annotations(self, fxt_task_with_shapes: Task):
anns = fxt_task_with_shapes.get_annotations()
assert len(anns.shapes) == 1
assert anns.shapes[0].type.value == "rectangle"
assert self.stdout.getvalue() == ""
def test_can_set_annotations(self, fxt_new_task: Task):
fxt_new_task.set_annotations(
models.LabeledDataRequest(
tags=[models.LabeledImageRequest(frame=0, label_id=fxt_new_task.labels[0].id)],
)
)
anns = fxt_new_task.get_annotations()
assert len(anns.tags) == 1
assert self.stdout.getvalue() == ""
def test_can_clear_annotations(self, fxt_task_with_shapes: Task):
fxt_task_with_shapes.remove_annotations()
anns = fxt_task_with_shapes.get_annotations()
assert len(anns.tags) == 0
assert len(anns.tracks) == 0
assert len(anns.shapes) == 0
assert self.stdout.getvalue() == ""
def test_can_remove_annotations(self, fxt_new_task: Task):
fxt_new_task.set_annotations(
models.LabeledDataRequest(
shapes=[
models.LabeledShapeRequest(
frame=0,
label_id=fxt_new_task.labels[0].id,
type="rectangle",
points=[1, 1, 2, 2],
),
models.LabeledShapeRequest(
frame=0,
label_id=fxt_new_task.labels[0].id,
type="rectangle",
points=[2, 2, 3, 3],
),
],
)
)
anns = fxt_new_task.get_annotations()
fxt_new_task.remove_annotations(ids=[anns.shapes[0].id])
anns = fxt_new_task.get_annotations()
assert len(anns.tags) == 0
assert len(anns.tracks) == 0
assert len(anns.shapes) == 1
assert self.stdout.getvalue() == ""
def test_can_update_annotations(self, fxt_task_with_shapes: Task):
fxt_task_with_shapes.update_annotations(
models.PatchedLabeledDataRequest(
shapes=[
models.LabeledShapeRequest(
frame=0,
label_id=fxt_task_with_shapes.labels[0].id,
type="rectangle",
points=[0, 1, 2, 3],
),
],
tracks=[
models.LabeledTrackRequest(
frame=0,
label_id=fxt_task_with_shapes.labels[0].id,
shapes=[
models.TrackedShapeRequest(
frame=0, type="polygon", points=[3, 2, 2, 3, 3, 4]
),
],
)
],
tags=[
models.LabeledImageRequest(frame=0, label_id=fxt_task_with_shapes.labels[0].id)
],
)
)
anns = fxt_task_with_shapes.get_annotations()
assert len(anns.shapes) == 2
assert len(anns.tracks) == 1
assert len(anns.tags) == 1
assert self.stdout.getvalue() == ""

@ -0,0 +1,73 @@
# Copyright (C) 2022 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
import io
from logging import Logger
from pathlib import Path
from typing import Tuple
import pytest
from cvat_sdk import Client, models
from cvat_sdk.api_client import exceptions
from shared.utils.config import USER_PASS
class TestUserUsecases:
@pytest.fixture(autouse=True)
def setup(
self,
changedb, # force fixture call order to allow DB setup
tmp_path: Path,
fxt_logger: Tuple[Logger, io.StringIO],
fxt_client: Client,
fxt_stdout: io.StringIO,
admin_user: str,
):
self.tmp_path = tmp_path
_, self.logger_stream = fxt_logger
self.client = fxt_client
self.stdout = fxt_stdout
self.user = admin_user
self.client.login((self.user, USER_PASS))
yield
def test_can_retrieve_user(self):
me = self.client.users.retrieve_current_user()
user = self.client.users.retrieve(me.id)
assert user.id == me.id
assert user.username == self.user
assert self.stdout.getvalue() == ""
def test_can_list_users(self):
users = self.client.users.list()
assert self.user in set(u.username for u in users)
assert self.stdout.getvalue() == ""
def test_can_update_user(self):
user = self.client.users.retrieve_current_user()
user.update(models.PatchedUserRequest(first_name="foo", last_name="bar"))
retrieved_user = self.client.users.retrieve(user.id)
assert retrieved_user.first_name == "foo"
assert retrieved_user.last_name == "bar"
assert user.first_name == retrieved_user.first_name
assert user.last_name == retrieved_user.last_name
assert self.stdout.getvalue() == ""
def test_can_remove_user(self):
users = self.client.users.list()
removed_user = next(u for u in users if u.username != self.user)
removed_user.remove()
with pytest.raises(exceptions.NotFoundException):
removed_user.fetch()
assert self.stdout.getvalue() == ""

@ -33,44 +33,44 @@ def generate_coco_anno(image_path: str, image_width: int, image_height: int) ->
{ {
"categories": [ "categories": [
{ {
"id": 1, "id": 1,
"name": "car", "name": "car",
"supercategory": "" "supercategory": ""
}, },
{ {
"id": 2, "id": 2,
"name": "person", "name": "person",
"supercategory": "" "supercategory": ""
} }
], ],
"images": [ "images": [
{ {
"coco_url": "", "coco_url": "",
"date_captured": "", "date_captured": "",
"flickr_url": "", "flickr_url": "",
"license": 0, "license": 0,
"id": 0, "id": 0,
"file_name": "%(image_path)s", "file_name": "%(image_path)s",
"height": %(image_height)d, "height": %(image_height)d,
"width": %(image_width)d "width": %(image_width)d
} }
], ],
"annotations": [ "annotations": [
{ {
"category_id": 1, "category_id": 1,
"id": 1, "id": 1,
"image_id": 0, "image_id": 0,
"iscrowd": 0, "iscrowd": 0,
"segmentation": [ "segmentation": [
[] []
], ],
"area": 17702.0, "area": 17702.0,
"bbox": [ "bbox": [
574.0, 574.0,
407.0, 407.0,
167.0, 167.0,
106.0 106.0
] ]
} }
] ]
} }

@ -279,6 +279,10 @@ def filter_tasks_with_shapes(annotations):
return list(filter(lambda t: annotations['task'][str(t['id'])]['shapes'], tasks)) return list(filter(lambda t: annotations['task'][str(t['id'])]['shapes'], tasks))
return find return find
@pytest.fixture(scope='session')
def jobs_with_shapes(jobs, filter_jobs_with_shapes):
return filter_jobs_with_shapes(jobs)
@pytest.fixture(scope='session') @pytest.fixture(scope='session')
def tasks_with_shapes(tasks, filter_tasks_with_shapes): def tasks_with_shapes(tasks, filter_tasks_with_shapes):
return filter_tasks_with_shapes(tasks) return filter_tasks_with_shapes(tasks)

@ -48,5 +48,6 @@ def post_files_method(username, endpoint, data, files, **kwargs):
def server_get(username, endpoint, **kwargs): def server_get(username, endpoint, **kwargs):
return requests.get(get_server_url(endpoint, **kwargs), auth=(username, USER_PASS)) return requests.get(get_server_url(endpoint, **kwargs), auth=(username, USER_PASS))
def make_api_client(user: str) -> ApiClient: def make_api_client(user: str, *, password: str = None) -> ApiClient:
return ApiClient(configuration=Configuration(host=BASE_URL, username=user, password=USER_PASS)) return ApiClient(configuration=Configuration(host=BASE_URL,
username=user, password=password or USER_PASS))

Loading…
Cancel
Save