Add support for source & target storages (#28)

main
Maria Khrustaleva 4 years ago committed by GitHub
parent cf1f6faed9
commit 54c15830a9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -8,7 +8,7 @@ from typing import Any, Callable, List, Mapping, Tuple
from django.db import transaction from django.db import transaction
from cvat.apps.engine import models from cvat.apps.engine import models
from cvat.apps.engine.serializers import DataSerializer, TaskSerializer from cvat.apps.engine.serializers import DataSerializer, TaskWriteSerializer
from cvat.apps.engine.task import _create_thread as create_task from cvat.apps.engine.task import _create_thread as create_task
from cvat.apps.dataset_manager.task import TaskAnnotation from cvat.apps.dataset_manager.task import TaskAnnotation
@ -80,7 +80,7 @@ class ProjectAnnotationAndData:
}) })
data_serializer.is_valid(raise_exception=True) data_serializer.is_valid(raise_exception=True)
db_data = data_serializer.save() db_data = data_serializer.save()
db_task = TaskSerializer.create(None, { db_task = TaskWriteSerializer.create(None, {
**task_fields, **task_fields,
'data_id': db_data.id, 'data_id': db_data.id,
'project_id': self.db_project.id 'project_id': self.db_project.id

@ -15,7 +15,7 @@ from django.utils import timezone
import cvat.apps.dataset_manager.task as task import cvat.apps.dataset_manager.task as task
import cvat.apps.dataset_manager.project as project import cvat.apps.dataset_manager.project as project
from cvat.apps.engine.log import slogger from cvat.apps.engine.log import slogger
from cvat.apps.engine.models import Project, Task from cvat.apps.engine.models import Project, Task, Job
from .formats.registry import EXPORT_FORMATS, IMPORT_FORMATS from .formats.registry import EXPORT_FORMATS, IMPORT_FORMATS
from .util import current_function_name from .util import current_function_name
@ -30,29 +30,36 @@ def log_exception(logger=None, exc_info=True):
def get_export_cache_dir(db_instance): def get_export_cache_dir(db_instance):
base_dir = osp.abspath(db_instance.get_project_dirname() if isinstance(db_instance, Project) else db_instance.get_task_dirname()) base_dir = osp.abspath(db_instance.get_dirname())
if osp.isdir(base_dir): if osp.isdir(base_dir):
return osp.join(base_dir, 'export_cache') return osp.join(base_dir, 'export_cache')
else: else:
raise Exception('{} dir {} does not exist'.format("Project" if isinstance(db_instance, Project) else "Task", base_dir)) raise FileNotFoundError('{} dir {} does not exist'.format(db_instance.__class__.__name__, base_dir))
DEFAULT_CACHE_TTL = timedelta(hours=10) DEFAULT_CACHE_TTL = timedelta(hours=10)
TASK_CACHE_TTL = DEFAULT_CACHE_TTL TASK_CACHE_TTL = DEFAULT_CACHE_TTL
PROJECT_CACHE_TTL = DEFAULT_CACHE_TTL / 3 PROJECT_CACHE_TTL = DEFAULT_CACHE_TTL / 3
JOB_CACHE_TTL = DEFAULT_CACHE_TTL
def export(dst_format, task_id=None, project_id=None, server_url=None, save_images=False): def export(dst_format, project_id=None, task_id=None, job_id=None, server_url=None, save_images=False):
try: try:
if task_id is not None: if task_id is not None:
db_instance = Task.objects.get(pk=task_id) db_instance = Task.objects.get(pk=task_id)
logger = slogger.task[task_id] logger = slogger.task[task_id]
cache_ttl = TASK_CACHE_TTL cache_ttl = TASK_CACHE_TTL
export_fn = task.export_task export_fn = task.export_task
else: elif project_id is not None:
db_instance = Project.objects.get(pk=project_id) db_instance = Project.objects.get(pk=project_id)
logger = slogger.project[project_id] logger = slogger.project[project_id]
cache_ttl = PROJECT_CACHE_TTL cache_ttl = PROJECT_CACHE_TTL
export_fn = project.export_project export_fn = project.export_project
else:
db_instance = Job.objects.get(pk=job_id)
logger = slogger.job[job_id]
cache_ttl = JOB_CACHE_TTL
export_fn = task.export_job
cache_dir = get_export_cache_dir(db_instance) cache_dir = get_export_cache_dir(db_instance)
@ -86,8 +93,9 @@ def export(dst_format, task_id=None, project_id=None, server_url=None, save_imag
"The {} '{}' is exported as '{}' at '{}' " "The {} '{}' is exported as '{}' at '{}' "
"and available for downloading for the next {}. " "and available for downloading for the next {}. "
"Export cache cleaning job is enqueued, id '{}'".format( "Export cache cleaning job is enqueued, id '{}'".format(
"project" if isinstance(db_instance, Project) else 'task', db_instance.__class__.__name__.lower(),
db_instance.name, dst_format, output_path, cache_ttl, db_instance.name if isinstance(db_instance, (Project, Task)) else db_instance.id,
dst_format, output_path, cache_ttl,
cleaning_job.id cleaning_job.id
)) ))
@ -96,6 +104,12 @@ def export(dst_format, task_id=None, project_id=None, server_url=None, save_imag
log_exception(logger) log_exception(logger)
raise raise
def export_job_annotations(job_id, dst_format=None, server_url=None):
return export(dst_format,job_id=job_id, server_url=server_url, save_images=False)
def export_job_as_dataset(job_id, dst_format=None, server_url=None):
return export(dst_format, job_id=job_id, server_url=server_url, save_images=True)
def export_task_as_dataset(task_id, dst_format=None, server_url=None): def export_task_as_dataset(task_id, dst_format=None, server_url=None):
return export(dst_format, task_id=task_id, server_url=server_url, save_images=True) return export(dst_format, task_id=task_id, server_url=server_url, save_images=True)

@ -26,7 +26,7 @@ from cvat.apps.engine.plugins import add_plugin
def _have_no_access_exception(ex): def _have_no_access_exception(ex):
if 'Permission denied' in ex.stderr or 'Could not read from remote repository' in ex.stderr: if 'Permission denied' in ex.stderr or 'Could not read from remote repository' in ex.stderr:
keys = subprocess.run(['ssh-add -L'], shell = True, keys = subprocess.run(['ssh-add', '-L'], #nosec
stdout = subprocess.PIPE).stdout.decode('utf-8').split('\n') stdout = subprocess.PIPE).stdout.decode('utf-8').split('\n')
keys = list(filter(len, list(map(lambda x: x.strip(), keys)))) keys = list(filter(len, list(map(lambda x: x.strip(), keys))))
raise Exception( raise Exception(
@ -268,7 +268,7 @@ class Git:
# Dump an annotation # Dump an annotation
timestamp = datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S") timestamp = datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
dump_name = os.path.join(db_task.get_task_dirname(), dump_name = os.path.join(db_task.get_dirname(),
"git_annotation_{}_{}.zip".format(self._format, timestamp)) "git_annotation_{}_{}.zip".format(self._format, timestamp))
export_task( export_task(
@ -303,7 +303,7 @@ class Git:
} }
old_diffs_dir = os.path.join(os.path.dirname(self._diffs_dir), 'repos_diffs') old_diffs_dir = os.path.join(os.path.dirname(self._diffs_dir), 'repos_diffs')
if (os.path.isdir(old_diffs_dir)): if os.path.isdir(old_diffs_dir):
_read_old_diffs(old_diffs_dir, summary_diff) _read_old_diffs(old_diffs_dir, summary_diff)
for diff_name in list(map(lambda x: os.path.join(self._diffs_dir, x), os.listdir(self._diffs_dir))): for diff_name in list(map(lambda x: os.path.join(self._diffs_dir, x), os.listdir(self._diffs_dir))):

@ -17,27 +17,35 @@ import django_rq
from django.conf import settings from django.conf import settings
from django.db import transaction from django.db import transaction
from django.utils import timezone from django.utils import timezone
from django.shortcuts import get_object_or_404
from rest_framework import serializers, status from rest_framework import serializers, status
from rest_framework.parsers import JSONParser from rest_framework.parsers import JSONParser
from rest_framework.renderers import JSONRenderer from rest_framework.renderers import JSONRenderer
from rest_framework.response import Response from rest_framework.response import Response
from django_sendfile import sendfile from django_sendfile import sendfile
from distutils.util import strtobool
import cvat.apps.dataset_manager as dm import cvat.apps.dataset_manager as dm
from cvat.apps.engine import models from cvat.apps.engine import models
from cvat.apps.engine.log import slogger from cvat.apps.engine.log import slogger
from cvat.apps.engine.serializers import (AttributeSerializer, DataSerializer, from cvat.apps.engine.serializers import (AttributeSerializer, DataSerializer,
LabeledDataSerializer, SegmentSerializer, SimpleJobSerializer, TaskSerializer, LabeledDataSerializer, SegmentSerializer, SimpleJobSerializer, TaskReadSerializer,
ProjectSerializer, ProjectFileSerializer, TaskFileSerializer) ProjectReadSerializer, ProjectFileSerializer, TaskFileSerializer)
from cvat.apps.engine.utils import av_scan_paths from cvat.apps.engine.utils import av_scan_paths
from cvat.apps.engine.models import StorageChoice, StorageMethodChoice, DataChoice, Task, Project from cvat.apps.engine.models import (
StorageChoice, StorageMethodChoice, DataChoice, Task, Project, Location,
CloudStorage as CloudStorageModel)
from cvat.apps.engine.task import _create_thread from cvat.apps.engine.task import _create_thread
from cvat.apps.dataset_manager.views import TASK_CACHE_TTL, PROJECT_CACHE_TTL, get_export_cache_dir, clear_export_cache, log_exception from cvat.apps.dataset_manager.views import TASK_CACHE_TTL, PROJECT_CACHE_TTL, get_export_cache_dir, clear_export_cache, log_exception
from cvat.apps.dataset_manager.bindings import CvatImportError from cvat.apps.dataset_manager.bindings import CvatImportError
from cvat.apps.engine.cloud_provider import (
db_storage_to_storage_instance, validate_bucket_status
)
from cvat.apps.engine.location import StorageType, get_location_configuration
class Version(Enum): class Version(Enum):
V1 = '1.0' V1 = '1.0'
def _get_label_mapping(db_labels): def _get_label_mapping(db_labels):
@ -266,7 +274,7 @@ class TaskExporter(_ExporterBase, _TaskBackupBase):
raise NotImplementedError() raise NotImplementedError()
def _write_task(self, zip_object, target_dir=None): def _write_task(self, zip_object, target_dir=None):
task_dir = self._db_task.get_task_dirname() task_dir = self._db_task.get_dirname()
target_task_dir = os.path.join(target_dir, self.TASK_DIRNAME) if target_dir else self.TASK_DIRNAME target_task_dir = os.path.join(target_dir, self.TASK_DIRNAME) if target_dir else self.TASK_DIRNAME
self._write_directory( self._write_directory(
source_dir=task_dir, source_dir=task_dir,
@ -277,7 +285,7 @@ class TaskExporter(_ExporterBase, _TaskBackupBase):
def _write_manifest(self, zip_object, target_dir=None): def _write_manifest(self, zip_object, target_dir=None):
def serialize_task(): def serialize_task():
task_serializer = TaskSerializer(self._db_task) task_serializer = TaskReadSerializer(self._db_task)
for field in ('url', 'owner', 'assignee', 'segments'): for field in ('url', 'owner', 'assignee', 'segments'):
task_serializer.fields.pop(field) task_serializer.fields.pop(field)
@ -348,8 +356,8 @@ class TaskExporter(_ExporterBase, _TaskBackupBase):
def export_to(self, file, target_dir=None): def export_to(self, file, target_dir=None):
if self._db_task.data.storage_method == StorageMethodChoice.FILE_SYSTEM and \ if self._db_task.data.storage_method == StorageMethodChoice.FILE_SYSTEM and \
self._db_task.data.storage == StorageChoice.SHARE: self._db_task.data.storage == StorageChoice.SHARE:
raise Exception('The task cannot be exported because it does not contain any raw data') raise Exception('The task cannot be exported because it does not contain any raw data')
if isinstance(file, str): if isinstance(file, str):
with ZipFile(file, 'w') as zf: with ZipFile(file, 'w') as zf:
@ -484,7 +492,7 @@ class TaskImporter(_ImporterBase, _TaskBackupBase):
self._manifest['project_id'] = self._project_id self._manifest['project_id'] = self._project_id
self._db_task = models.Task.objects.create(**self._manifest, organization_id=self._org_id) self._db_task = models.Task.objects.create(**self._manifest, organization_id=self._org_id)
task_path = self._db_task.get_task_dirname() task_path = self._db_task.get_dirname()
if os.path.isdir(task_path): if os.path.isdir(task_path):
shutil.rmtree(task_path) shutil.rmtree(task_path)
@ -569,7 +577,7 @@ class ProjectExporter(_ExporterBase, _ProjectBackupBase):
def _write_manifest(self, zip_object): def _write_manifest(self, zip_object):
def serialize_project(): def serialize_project():
project_serializer = ProjectSerializer(self._db_project) project_serializer = ProjectReadSerializer(self._db_project)
for field in ('assignee', 'owner', 'tasks', 'url'): for field in ('assignee', 'owner', 'tasks', 'url'):
project_serializer.fields.pop(field) project_serializer.fields.pop(field)
@ -591,7 +599,7 @@ class ProjectExporter(_ExporterBase, _ProjectBackupBase):
self._write_manifest(output_file) self._write_manifest(output_file)
class ProjectImporter(_ImporterBase, _ProjectBackupBase): class ProjectImporter(_ImporterBase, _ProjectBackupBase):
TASKNAME_RE = 'task_(\d+)/' TASKNAME_RE = r'task_(\d+)/'
def __init__(self, filename, user_id, org_id=None): def __init__(self, filename, user_id, org_id=None):
super().__init__(logger=slogger.glob) super().__init__(logger=slogger.glob)
@ -616,7 +624,7 @@ class ProjectImporter(_ImporterBase, _ProjectBackupBase):
self._manifest["owner_id"] = self._user_id self._manifest["owner_id"] = self._user_id
self._db_project = models.Project.objects.create(**self._manifest, organization_id=self._org_id) self._db_project = models.Project.objects.create(**self._manifest, organization_id=self._org_id)
project_path = self._db_project.get_project_dirname() project_path = self._db_project.get_dirname()
if os.path.isdir(project_path): if os.path.isdir(project_path):
shutil.rmtree(project_path) shutil.rmtree(project_path)
os.makedirs(self._db_project.get_project_logs_dirname()) os.makedirs(self._db_project.get_project_logs_dirname())
@ -702,14 +710,23 @@ def export(db_instance, request):
logger = slogger.task[db_instance.pk] logger = slogger.task[db_instance.pk]
Exporter = TaskExporter Exporter = TaskExporter
cache_ttl = TASK_CACHE_TTL cache_ttl = TASK_CACHE_TTL
use_target_storage_conf = request.query_params.get('use_default_location', True)
elif isinstance(db_instance, Project): elif isinstance(db_instance, Project):
filename_prefix = 'project' filename_prefix = 'project'
logger = slogger.project[db_instance.pk] logger = slogger.project[db_instance.pk]
Exporter = ProjectExporter Exporter = ProjectExporter
cache_ttl = PROJECT_CACHE_TTL cache_ttl = PROJECT_CACHE_TTL
use_target_storage_conf = request.query_params.get('use_default_location', True)
else: else:
raise Exception( raise Exception(
"Unexpected type of db_isntance: {}".format(type(db_instance))) "Unexpected type of db_isntance: {}".format(type(db_instance)))
use_settings = strtobool(str(use_target_storage_conf))
obj = db_instance if use_settings else request.query_params
location_conf = get_location_configuration(
obj=obj,
use_settings=use_settings,
field_name=StorageType.TARGET
)
queue = django_rq.get_queue("default") queue = django_rq.get_queue("default")
rq_id = "/api/{}s/{}/backup".format(filename_prefix, db_instance.pk) rq_id = "/api/{}s/{}/backup".format(filename_prefix, db_instance.pk)
@ -731,8 +748,30 @@ def export(db_instance, request):
filename = "{}_{}_backup_{}{}".format( filename = "{}_{}_backup_{}{}".format(
filename_prefix, db_instance.name, timestamp, filename_prefix, db_instance.name, timestamp,
os.path.splitext(file_path)[1]) os.path.splitext(file_path)[1])
return sendfile(request, file_path, attachment=True,
attachment_filename=filename.lower()) location = location_conf.get('location')
if location == Location.LOCAL:
return sendfile(request, file_path, attachment=True,
attachment_filename=filename.lower())
elif location == Location.CLOUD_STORAGE:
@validate_bucket_status
def _export_to_cloud_storage(storage, file_path, file_name):
storage.upload_file(file_path, file_name)
try:
storage_id = location_conf['storage_id']
except KeyError:
raise serializers.ValidationError(
'Cloud storage location was selected for destination'
' but cloud storage id was not specified')
db_storage = get_object_or_404(CloudStorageModel, pk=storage_id)
storage = db_storage_to_storage_instance(db_storage)
_export_to_cloud_storage(storage, file_path, filename)
return Response(status=status.HTTP_200_OK)
else:
raise NotImplementedError()
else: else:
if os.path.exists(file_path): if os.path.exists(file_path):
return Response(status=status.HTTP_201_CREATED) return Response(status=status.HTTP_201_CREATED)
@ -753,21 +792,47 @@ def export(db_instance, request):
result_ttl=ttl, failure_ttl=ttl) result_ttl=ttl, failure_ttl=ttl)
return Response(status=status.HTTP_202_ACCEPTED) return Response(status=status.HTTP_202_ACCEPTED)
def _import(importer, request, rq_id, Serializer, file_field_name, filename=None): def _import(importer, request, rq_id, Serializer, file_field_name, location_conf, filename=None):
queue = django_rq.get_queue("default") queue = django_rq.get_queue("default")
rq_job = queue.fetch_job(rq_id) rq_job = queue.fetch_job(rq_id)
if not rq_job: if not rq_job:
org_id = getattr(request.iam_context['organization'], 'id', None) org_id = getattr(request.iam_context['organization'], 'id', None)
fd = None fd = None
if not filename:
serializer = Serializer(data=request.data) location = location_conf.get('location')
serializer.is_valid(raise_exception=True) if location == Location.LOCAL:
payload_file = serializer.validated_data[file_field_name] if not filename:
serializer = Serializer(data=request.data)
serializer.is_valid(raise_exception=True)
payload_file = serializer.validated_data[file_field_name]
fd, filename = mkstemp(prefix='cvat_')
with open(filename, 'wb+') as f:
for chunk in payload_file.chunks():
f.write(chunk)
else:
@validate_bucket_status
def _import_from_cloud_storage(storage, file_name):
return storage.download_fileobj(file_name)
file_name = request.query_params.get('filename')
assert file_name
# download file from cloud storage
try:
storage_id = location_conf['storage_id']
except KeyError:
raise serializers.ValidationError(
'Cloud storage location was selected for destination'
' but cloud storage id was not specified')
db_storage = get_object_or_404(CloudStorageModel, pk=storage_id)
storage = db_storage_to_storage_instance(db_storage)
data = _import_from_cloud_storage(storage, file_name)
fd, filename = mkstemp(prefix='cvat_') fd, filename = mkstemp(prefix='cvat_')
with open(filename, 'wb+') as f: with open(filename, 'wb+') as f:
for chunk in payload_file.chunks(): f.write(data.getbuffer())
f.write(chunk)
rq_job = queue.enqueue_call( rq_job = queue.enqueue_call(
func=importer, func=importer,
args=(filename, request.user.id, org_id), args=(filename, request.user.id, org_id),
@ -814,12 +879,18 @@ def import_project(request, filename=None):
Serializer = ProjectFileSerializer Serializer = ProjectFileSerializer
file_field_name = 'project_file' file_field_name = 'project_file'
location_conf = get_location_configuration(
obj=request.query_params,
field_name=StorageType.SOURCE,
)
return _import( return _import(
importer=_import_project, importer=_import_project,
request=request, request=request,
rq_id=rq_id, rq_id=rq_id,
Serializer=Serializer, Serializer=Serializer,
file_field_name=file_field_name, file_field_name=file_field_name,
location_conf=location_conf,
filename=filename filename=filename
) )
@ -831,11 +902,17 @@ def import_task(request, filename=None):
Serializer = TaskFileSerializer Serializer = TaskFileSerializer
file_field_name = 'task_file' file_field_name = 'task_file'
location_conf = get_location_configuration(
obj=request.query_params,
field_name=StorageType.SOURCE
)
return _import( return _import(
importer=_import_task, importer=_import_task,
request=request, request=request,
rq_id=rq_id, rq_id=rq_id,
Serializer=Serializer, Serializer=Serializer,
file_field_name=file_field_name, file_field_name=file_field_name,
location_conf=location_conf,
filename=filename filename=filename
) )

@ -4,10 +4,13 @@
import os import os
import boto3 import boto3
import functools
import json
from abc import ABC, abstractmethod, abstractproperty from abc import ABC, abstractmethod, abstractproperty
from enum import Enum from enum import Enum
from io import BytesIO from io import BytesIO
from rest_framework import serializers
from boto3.s3.transfer import TransferConfig from boto3.s3.transfer import TransferConfig
from botocore.exceptions import ClientError from botocore.exceptions import ClientError
@ -35,6 +38,14 @@ class Status(str, Enum):
def __str__(self): def __str__(self):
return self.value return self.value
class Permissions(str, Enum):
READ = 'read'
WRITE = 'write'
@classmethod
def all(cls):
return {i.value for i in cls}
class _CloudStorage(ABC): class _CloudStorage(ABC):
def __init__(self): def __init__(self):
@ -86,7 +97,11 @@ class _CloudStorage(ABC):
raise NotImplementedError("Unsupported type {} was found".format(type(file_obj))) raise NotImplementedError("Unsupported type {} was found".format(type(file_obj)))
@abstractmethod @abstractmethod
def upload_file(self, file_obj, file_name): def upload_fileobj(self, file_obj, file_name):
pass
@abstractmethod
def upload_file(self, file_path, file_name=None):
pass pass
def __contains__(self, file_name): def __contains__(self, file_name):
@ -99,6 +114,18 @@ class _CloudStorage(ABC):
def content(self): def content(self):
return list(map(lambda x: x['name'] , self._files)) return list(map(lambda x: x['name'] , self._files))
@abstractproperty
def supported_actions(self):
pass
@property
def read_access(self):
return Permissions.READ in self.access
@property
def write_access(self):
return Permissions.WRITE in self.access
def get_cloud_storage_instance(cloud_provider, resource, credentials, specific_attributes=None, endpoint=None): def get_cloud_storage_instance(cloud_provider, resource, credentials, specific_attributes=None, endpoint=None):
instance = None instance = None
if cloud_provider == CloudProviderChoice.AWS_S3: if cloud_provider == CloudProviderChoice.AWS_S3:
@ -133,6 +160,12 @@ class AWS_S3(_CloudStorage):
transfer_config = { transfer_config = {
'max_io_queue': 10, 'max_io_queue': 10,
} }
class Effect(str, Enum):
ALLOW = 'Allow'
DENY = 'Deny'
def __init__(self, def __init__(self,
bucket, bucket,
region, region,
@ -209,13 +242,27 @@ class AWS_S3(_CloudStorage):
def get_file_last_modified(self, key): def get_file_last_modified(self, key):
return self._head_file(key).get('LastModified') return self._head_file(key).get('LastModified')
def upload_file(self, file_obj, file_name): def upload_fileobj(self, file_obj, file_name):
self._bucket.upload_fileobj( self._bucket.upload_fileobj(
Fileobj=file_obj, Fileobj=file_obj,
Key=file_name, Key=file_name,
Config=TransferConfig(max_io_queue=self.transfer_config['max_io_queue']) Config=TransferConfig(max_io_queue=self.transfer_config['max_io_queue'])
) )
def upload_file(self, file_path, file_name=None):
if not file_name:
file_name = os.path.basename(file_path)
try:
self._bucket.upload_file(
file_path,
file_name,
Config=TransferConfig(max_io_queue=self.transfer_config['max_io_queue'])
)
except ClientError as ex:
msg = str(ex)
slogger.glob.error(msg)
raise Exception(msg)
def initialize_content(self): def initialize_content(self):
files = self._bucket.objects.all() files = self._bucket.objects.all()
self._files = [{ self._files = [{
@ -251,8 +298,45 @@ class AWS_S3(_CloudStorage):
slogger.glob.info(msg) slogger.glob.info(msg)
raise Exception(msg) raise Exception(msg)
def delete_file(self, file_name: str):
try:
self._client_s3.delete_object(Bucket=self.name, Key=file_name)
except Exception as ex:
msg = str(ex)
slogger.glob.info(msg)
raise
@property
def supported_actions(self):
allowed_actions = set()
try:
bucket_policy = self._bucket.Policy().policy
except ClientError as ex:
if 'NoSuchBucketPolicy' in str(ex):
return Permissions.all()
else:
raise Exception(str(ex))
bucket_policy = json.loads(bucket_policy) if isinstance(bucket_policy, str) else bucket_policy
for statement in bucket_policy['Statement']:
effect = statement.get('Effect') # Allow | Deny
actions = statement.get('Action', set())
if effect == self.Effect.ALLOW:
allowed_actions.update(actions)
access = {
's3:GetObject': Permissions.READ,
's3:PutObject': Permissions.WRITE,
}
allowed_actions = Permissions.all() & {access.get(i) for i in allowed_actions}
return allowed_actions
class AzureBlobContainer(_CloudStorage): class AzureBlobContainer(_CloudStorage):
MAX_CONCURRENCY = 3 MAX_CONCURRENCY = 3
class Effect:
pass
def __init__(self, container, account_name, sas_token=None): def __init__(self, container, account_name, sas_token=None):
super().__init__() super().__init__()
self._account_name = account_name self._account_name = account_name
@ -317,9 +401,18 @@ class AzureBlobContainer(_CloudStorage):
else: else:
return Status.NOT_FOUND return Status.NOT_FOUND
def upload_file(self, file_obj, file_name): def upload_fileobj(self, file_obj, file_name):
self._container_client.upload_blob(name=file_name, data=file_obj) self._container_client.upload_blob(name=file_name, data=file_obj)
def upload_file(self, file_path, file_name=None):
if not file_name:
file_name = os.path.basename(file_path)
try:
with open(file_path, 'r') as f:
self.upload_fileobj(f, file_name)
except Exception as ex:
slogger.glob.error(str(ex))
raise
# TODO: # TODO:
# def multipart_upload(self, file_obj): # def multipart_upload(self, file_obj):
@ -342,6 +435,10 @@ class AzureBlobContainer(_CloudStorage):
buf.seek(0) buf.seek(0)
return buf return buf
@property
def supported_actions(self):
pass
class GOOGLE_DRIVE(_CloudStorage): class GOOGLE_DRIVE(_CloudStorage):
pass pass
@ -361,6 +458,9 @@ def _define_gcs_status(func):
class GoogleCloudStorage(_CloudStorage): class GoogleCloudStorage(_CloudStorage):
class Effect:
pass
def __init__(self, bucket_name, prefix=None, service_account_json=None, anonymous_access=False, project=None, location=None): def __init__(self, bucket_name, prefix=None, service_account_json=None, anonymous_access=False, project=None, location=None):
super().__init__() super().__init__()
if service_account_json: if service_account_json:
@ -416,9 +516,18 @@ class GoogleCloudStorage(_CloudStorage):
buf.seek(0) buf.seek(0)
return buf return buf
def upload_file(self, file_obj, file_name): def upload_fileobj(self, file_obj, file_name):
self.bucket.blob(file_name).upload_from_file(file_obj) self.bucket.blob(file_name).upload_from_file(file_obj)
def upload_file(self, file_path, file_name=None):
if not file_name:
file_name = os.path.basename(file_path)
try:
self.bucket.blob(file_name).upload_from_filename(file_path)
except Exception as ex:
slogger.glob.info(str(ex))
raise
def create(self): def create(self):
try: try:
self._bucket = self._storage_client.create_bucket( self._bucket = self._storage_client.create_bucket(
@ -441,6 +550,10 @@ class GoogleCloudStorage(_CloudStorage):
blob.reload() blob.reload()
return blob.updated return blob.updated
@property
def supported_actions(self):
pass
class Credentials: class Credentials:
__slots__ = ('key', 'secret_key', 'session_token', 'account_name', 'key_file_path', 'credentials_type') __slots__ = ('key', 'secret_key', 'session_token', 'account_name', 'key_file_path', 'credentials_type')
@ -502,3 +615,36 @@ class Credentials:
def values(self): def values(self):
return [self.key, self.secret_key, self.session_token, self.account_name, self.key_file_path] return [self.key, self.secret_key, self.session_token, self.account_name, self.key_file_path]
def validate_bucket_status(func):
@functools.wraps(func)
def wrapper(storage, *args, **kwargs):
try:
res = func(storage, *args, **kwargs)
except Exception as ex:
# check that cloud storage exists
storage_status = storage.get_status() if storage is not None else None
if storage_status == Status.FORBIDDEN:
msg = 'The resource {} is no longer available. Access forbidden.'.format(storage.name)
elif storage_status == Status.NOT_FOUND:
msg = 'The resource {} not found. It may have been deleted.'.format(storage.name)
else:
msg = str(ex)
raise serializers.ValidationError(msg)
return res
return wrapper
def db_storage_to_storage_instance(db_storage):
credentials = Credentials()
credentials.convert_from_db({
'type': db_storage.credentials_type,
'value': db_storage.credentials,
})
details = {
'resource': db_storage.resource,
'credentials': credentials,
'specific_attributes': db_storage.get_specific_attributes()
}
return get_cloud_storage_instance(cloud_provider=db_storage.provider_type, **details)

@ -0,0 +1,36 @@
from enum import Enum
from cvat.apps.engine.models import Location
class StorageType(str, Enum):
TARGET = 'target_storage'
SOURCE = 'source_storage'
def __str__(self):
return self.value
def get_location_configuration(obj, field_name, use_settings=False):
location_conf = dict()
if use_settings:
storage = getattr(obj, field_name)
if storage is None:
location_conf['location'] = Location.LOCAL
else:
location_conf['location'] = storage.location
sid = storage.cloud_storage_id
if sid:
location_conf['storage_id'] = sid
else:
# obj is query_params
# FIXME when ui part will be done
location_conf['location'] = obj.get('location', Location.LOCAL)
# try:
# location_conf['location'] = obj['location']
# except KeyError:
# raise ValidationError("Custom settings were selected but no location was specified")
sid = obj.get('cloud_storage_id')
if sid:
location_conf['storage_id'] = int(sid)
return location_conf

@ -0,0 +1,51 @@
# Generated by Django 3.2.12 on 2022-06-10 18:29
import cvat.apps.engine.models
from django.db import migrations, models
import django.db.models.deletion
class Migration(migrations.Migration):
dependencies = [
('engine', '0053_data_deleted_frames'),
]
operations = [
migrations.CreateModel(
name='Storage',
fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('location', models.CharField(choices=[('cloud_storage', 'CLOUD_STORAGE'), ('local', 'LOCAL')], default=cvat.apps.engine.models.Location['LOCAL'], max_length=16)),
('cloud_storage_id', models.IntegerField(blank=True, default=None, null=True)),
],
options={
'default_permissions': (),
},
),
migrations.AddField(
model_name='job',
name='updated_date',
field=models.DateTimeField(auto_now=True),
),
migrations.AddField(
model_name='project',
name='source_storage',
field=models.ForeignKey(blank=True, default=None, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='+', to='engine.storage'),
),
migrations.AddField(
model_name='project',
name='target_storage',
field=models.ForeignKey(blank=True, default=None, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='+', to='engine.storage'),
),
migrations.AddField(
model_name='task',
name='source_storage',
field=models.ForeignKey(blank=True, default=None, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='+', to='engine.storage'),
),
migrations.AddField(
model_name='task',
name='target_storage',
field=models.ForeignKey(blank=True, default=None, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='+', to='engine.storage'),
),
]

@ -0,0 +1,53 @@
# Generated by Django 3.2.12 on 2022-06-10 18:29
import os
import shutil
from django.db import migrations
from django.conf import settings
from cvat.apps.engine.log import get_logger
MIGRATION_NAME = os.path.splitext(os.path.basename(__file__))[0]
MIGRATION_LOG = os.path.join(settings.MIGRATIONS_LOGS_ROOT, f"{MIGRATION_NAME}.log")
def _get_query_set(apps):
Job = apps.get_model("engine", "Job")
query_set = Job.objects.all()
return query_set
def _get_job_dir_path(jid):
return os.path.join(settings.JOBS_ROOT, str(jid))
def create_directories(apps, schema_editor):
logger = get_logger(MIGRATION_NAME, MIGRATION_LOG)
query_set = _get_query_set(apps)
logger.info(f'Migration has been started. Need to create {query_set.count()} directories.')
for db_job in query_set:
jid = db_job.id
os.makedirs(_get_job_dir_path(jid), exist_ok=True)
logger.info(f'Migration has been finished successfully.')
def delete_directories(apps, schema_editor):
logger = get_logger(MIGRATION_NAME, MIGRATION_LOG)
query_set = _get_query_set(apps)
logger.info(f'Reverse migration has been started. Need to delete {query_set.count()} directories.')
for db_job in query_set:
jid = db_job.id
job_dir = _get_job_dir_path(jid)
if os.path.isdir(job_dir):
shutil.rmtree(job_dir)
logger.info(f'Migration has been reversed successfully.')
class Migration(migrations.Migration):
dependencies = [
('engine', '0054_auto_20220610_1829'),
]
operations = [
migrations.RunPython(
code=create_directories,
reverse_code=delete_directories
)
]

@ -8,10 +8,13 @@ import uuid
from django.conf import settings from django.conf import settings
from django.core.cache import cache from django.core.cache import cache
from distutils.util import strtobool
from rest_framework import status from rest_framework import status
from rest_framework.response import Response from rest_framework.response import Response
from cvat.apps.engine.serializers import DataSerializer from cvat.apps.engine.models import Location
from cvat.apps.engine.location import StorageType, get_location_configuration
from cvat.apps.engine.serializers import DataSerializer, LabeledDataSerializer
class TusFile: class TusFile:
_tus_cache_timeout = 3600 _tus_cache_timeout = 3600
@ -90,7 +93,7 @@ class TusChunk:
# This upload mixin is implemented using tus # This upload mixin is implemented using tus
# tus is open protocol for file uploads (see more https://tus.io/) # tus is open protocol for file uploads (see more https://tus.io/)
class UploadMixin(object): class UploadMixin:
_tus_api_version = '1.0.0' _tus_api_version = '1.0.0'
_tus_api_version_supported = ['1.0.0'] _tus_api_version_supported = ['1.0.0']
_tus_api_extensions = [] _tus_api_extensions = []
@ -238,3 +241,80 @@ class UploadMixin(object):
# override this to do stuff after upload # override this to do stuff after upload
def upload_finished(self, request): def upload_finished(self, request):
raise NotImplementedError('You need to implement upload_finished in UploadMixin') raise NotImplementedError('You need to implement upload_finished in UploadMixin')
class AnnotationMixin:
def export_annotations(self, request, pk, db_obj, export_func, callback, get_data=None):
format_name = request.query_params.get("format")
action = request.query_params.get("action", "").lower()
filename = request.query_params.get("filename", "")
use_default_location = request.query_params.get("use_default_location", True)
use_settings = strtobool(str(use_default_location))
obj = db_obj if use_settings else request.query_params
location_conf = get_location_configuration(
obj=obj,
use_settings=use_settings,
field_name=StorageType.TARGET,
)
rq_id = "/api/{}/{}/annotations/{}".format(self._object.__class__.__name__.lower(), pk, format_name)
if format_name:
return export_func(db_instance=self._object,
rq_id=rq_id,
request=request,
action=action,
callback=callback,
format_name=format_name,
filename=filename,
location_conf=location_conf,
)
if not get_data:
return Response("Format is not specified",status=status.HTTP_400_BAD_REQUEST)
data = get_data(pk)
serializer = LabeledDataSerializer(data=data)
if serializer.is_valid(raise_exception=True):
return Response(serializer.data)
def import_annotations(self, request, pk, db_obj, import_func, rq_func):
use_default_location = request.query_params.get('use_default_location', True)
use_settings = strtobool(str(use_default_location))
obj = db_obj if use_settings else request.query_params
location_conf = get_location_configuration(
obj=obj,
use_settings=use_settings,
field_name=StorageType.SOURCE,
)
if location_conf['location'] == Location.CLOUD_STORAGE:
format_name = request.query_params.get('format')
file_name = request.query_params.get('filename')
rq_id = "{}@/api/{}/{}/annotations/upload".format(
self._object.__class__.__name__.lower(), request.user, pk
)
return import_func(
request=request,
rq_id=rq_id,
rq_func=rq_func,
pk=pk,
format_name=format_name,
location_conf=location_conf,
filename=file_name,
)
return self.upload_data(request)
class SerializeMixin:
def serialize(self, request, export_func):
db_object = self.get_object() # force to call check_object_permissions
return export_func(db_object, request)
def deserialize(self, request, import_func):
location = request.query_params.get("location", Location.LOCAL)
if location == Location.CLOUD_STORAGE:
file_name = request.query_params.get("filename", "")
return import_func(request, filename=file_name)
return self.upload_data(request)

@ -181,7 +181,7 @@ class Data(models.Model):
default_permissions = () default_permissions = ()
def get_frame_step(self): def get_frame_step(self):
match = re.search("step\s*=\s*([1-9]\d*)", self.frame_filter) match = re.search(r"step\s*=\s*([1-9]\d*)", self.frame_filter)
return int(match.group(1)) if match else 1 return int(match.group(1)) if match else 1
def get_data_dirname(self): def get_data_dirname(self):
@ -265,7 +265,6 @@ class Image(models.Model):
default_permissions = () default_permissions = ()
class Project(models.Model): class Project(models.Model):
name = SafeCharField(max_length=256) name = SafeCharField(max_length=256)
owner = models.ForeignKey(User, null=True, blank=True, owner = models.ForeignKey(User, null=True, blank=True,
on_delete=models.SET_NULL, related_name="+") on_delete=models.SET_NULL, related_name="+")
@ -278,15 +277,19 @@ class Project(models.Model):
default=StatusChoice.ANNOTATION) default=StatusChoice.ANNOTATION)
organization = models.ForeignKey(Organization, null=True, default=None, organization = models.ForeignKey(Organization, null=True, default=None,
blank=True, on_delete=models.SET_NULL, related_name="projects") blank=True, on_delete=models.SET_NULL, related_name="projects")
source_storage = models.ForeignKey('Storage', null=True, default=None,
blank=True, on_delete=models.SET_NULL, related_name='+')
target_storage = models.ForeignKey('Storage', null=True, default=None,
blank=True, on_delete=models.SET_NULL, related_name='+')
def get_project_dirname(self): def get_dirname(self):
return os.path.join(settings.PROJECTS_ROOT, str(self.id)) return os.path.join(settings.PROJECTS_ROOT, str(self.id))
def get_project_logs_dirname(self): def get_project_logs_dirname(self):
return os.path.join(self.get_project_dirname(), 'logs') return os.path.join(self.get_dirname(), 'logs')
def get_tmp_dirname(self): def get_tmp_dirname(self):
return os.path.join(self.get_project_dirname(), "tmp") return os.path.join(self.get_dirname(), "tmp")
def get_client_log_path(self): def get_client_log_path(self):
return os.path.join(self.get_project_logs_dirname(), "client.log") return os.path.join(self.get_project_logs_dirname(), "client.log")
@ -324,17 +327,20 @@ class Task(models.Model):
subset = models.CharField(max_length=64, blank=True, default="") subset = models.CharField(max_length=64, blank=True, default="")
organization = models.ForeignKey(Organization, null=True, default=None, organization = models.ForeignKey(Organization, null=True, default=None,
blank=True, on_delete=models.SET_NULL, related_name="tasks") blank=True, on_delete=models.SET_NULL, related_name="tasks")
source_storage = models.ForeignKey('Storage', null=True, default=None,
blank=True, on_delete=models.SET_NULL, related_name='+')
target_storage = models.ForeignKey('Storage', null=True, default=None,
blank=True, on_delete=models.SET_NULL, related_name='+')
# Extend default permission model # Extend default permission model
class Meta: class Meta:
default_permissions = () default_permissions = ()
def get_task_dirname(self): def get_dirname(self):
return os.path.join(settings.TASKS_ROOT, str(self.id)) return os.path.join(settings.TASKS_ROOT, str(self.id))
def get_task_logs_dirname(self): def get_task_logs_dirname(self):
return os.path.join(self.get_task_dirname(), 'logs') return os.path.join(self.get_dirname(), 'logs')
def get_client_log_path(self): def get_client_log_path(self):
return os.path.join(self.get_task_logs_dirname(), "client.log") return os.path.join(self.get_task_logs_dirname(), "client.log")
@ -343,10 +349,10 @@ class Task(models.Model):
return os.path.join(self.get_task_logs_dirname(), "task.log") return os.path.join(self.get_task_logs_dirname(), "task.log")
def get_task_artifacts_dirname(self): def get_task_artifacts_dirname(self):
return os.path.join(self.get_task_dirname(), 'artifacts') return os.path.join(self.get_dirname(), 'artifacts')
def get_tmp_dirname(self): def get_tmp_dirname(self):
return os.path.join(self.get_task_dirname(), "tmp") return os.path.join(self.get_dirname(), "tmp")
def __str__(self): def __str__(self):
return self.name return self.name
@ -414,6 +420,7 @@ class Segment(models.Model):
class Job(models.Model): class Job(models.Model):
segment = models.ForeignKey(Segment, on_delete=models.CASCADE) segment = models.ForeignKey(Segment, on_delete=models.CASCADE)
assignee = models.ForeignKey(User, null=True, blank=True, on_delete=models.SET_NULL) assignee = models.ForeignKey(User, null=True, blank=True, on_delete=models.SET_NULL)
updated_date = models.DateTimeField(auto_now=True)
# TODO: it has to be deleted in Job, Task, Project and replaced by (stage, state) # TODO: it has to be deleted in Job, Task, Project and replaced by (stage, state)
# The stage field cannot be changed by an assignee, but state field can be. For # The stage field cannot be changed by an assignee, but state field can be. For
# now status is read only and it will be updated by (stage, state). Thus we don't # now status is read only and it will be updated by (stage, state). Thus we don't
@ -425,6 +432,9 @@ class Job(models.Model):
state = models.CharField(max_length=32, choices=StateChoice.choices(), state = models.CharField(max_length=32, choices=StateChoice.choices(),
default=StateChoice.NEW) default=StateChoice.NEW)
def get_dirname(self):
return os.path.join(settings.JOBS_ROOT, str(self.id))
def get_project_id(self): def get_project_id(self):
project = self.segment.task.project project = self.segment.task.project
return project.id if project else None return project.id if project else None
@ -524,8 +534,8 @@ class SourceType(str, Enum):
MANUAL = 'manual' MANUAL = 'manual'
@classmethod @classmethod
def choices(self): def choices(cls):
return tuple((x.value, x.name) for x in self) return tuple((x.value, x.name) for x in cls)
def __str__(self): def __str__(self):
return self.value return self.value
@ -669,6 +679,21 @@ class Manifest(models.Model):
def __str__(self): def __str__(self):
return '{}'.format(self.filename) return '{}'.format(self.filename)
class Location(str, Enum):
CLOUD_STORAGE = 'cloud_storage'
LOCAL = 'local'
@classmethod
def choices(cls):
return tuple((x.value, x.name) for x in cls)
def __str__(self):
return self.value
@classmethod
def list(cls):
return [i.value for i in cls]
class CloudStorage(models.Model): class CloudStorage(models.Model):
# restrictions: # restrictions:
# AWS bucket name, Azure container name - 63, Google bucket name - 63 without dots and 222 with dots # AWS bucket name, Azure container name - 63, Google bucket name - 63 without dots and 222 with dots
@ -696,7 +721,6 @@ class CloudStorage(models.Model):
organization = models.ForeignKey(Organization, null=True, default=None, organization = models.ForeignKey(Organization, null=True, default=None,
blank=True, on_delete=models.SET_NULL, related_name="cloudstorages") blank=True, on_delete=models.SET_NULL, related_name="cloudstorages")
class Meta: class Meta:
default_permissions = () default_permissions = ()
unique_together = ('provider_type', 'resource', 'credentials') unique_together = ('provider_type', 'resource', 'credentials')
@ -721,3 +745,10 @@ class CloudStorage(models.Model):
def get_key_file_path(self): def get_key_file_path(self):
return os.path.join(self.get_storage_dirname(), 'key.json') return os.path.join(self.get_storage_dirname(), 'key.json')
class Storage(models.Model):
location = models.CharField(max_length=16, choices=Location.choices(), default=Location.LOCAL)
cloud_storage_id = models.IntegerField(null=True, blank=True, default=None)
class Meta:
default_permissions = ()

@ -20,7 +20,7 @@ from cvat.apps.engine.utils import parse_specific_attributes
from drf_spectacular.utils import OpenApiExample, extend_schema_serializer from drf_spectacular.utils import OpenApiExample, extend_schema_serializer
class BasicUserSerializer(serializers.ModelSerializer): class BasicUserSerializer(serializers.ModelSerializer):
def validate(self, data): def validate(self, attrs):
if hasattr(self, 'initial_data'): if hasattr(self, 'initial_data'):
unknown_keys = set(self.initial_data.keys()) - set(self.fields.keys()) unknown_keys = set(self.initial_data.keys()) - set(self.fields.keys())
if unknown_keys: if unknown_keys:
@ -30,7 +30,7 @@ class BasicUserSerializer(serializers.ModelSerializer):
else: else:
message = 'Got unknown fields: {}'.format(unknown_keys) message = 'Got unknown fields: {}'.format(unknown_keys)
raise serializers.ValidationError(message) raise serializers.ValidationError(message)
return data return attrs
class Meta: class Meta:
model = User model = User
@ -81,7 +81,7 @@ class LabelSerializer(serializers.ModelSerializer):
fields = ('id', 'name', 'color', 'attributes', 'deleted') fields = ('id', 'name', 'color', 'attributes', 'deleted')
def validate(self, attrs): def validate(self, attrs):
if attrs.get('deleted') == True and attrs.get('id') is None: if attrs.get('deleted') and attrs.get('id') is None:
raise serializers.ValidationError('Deleted label must have an ID') raise serializers.ValidationError('Deleted label must have an ID')
return attrs return attrs
@ -107,7 +107,7 @@ class LabelSerializer(serializers.ModelSerializer):
else: else:
db_label = models.Label.objects.create(name=validated_data.get('name'), **instance) db_label = models.Label.objects.create(name=validated_data.get('name'), **instance)
logger.info("New {} label was created".format(db_label.name)) logger.info("New {} label was created".format(db_label.name))
if validated_data.get('deleted') == True: if validated_data.get('deleted'):
db_label.delete() db_label.delete()
return return
if not validated_data.get('color', None): if not validated_data.get('color', None):
@ -159,7 +159,8 @@ class JobReadSerializer(serializers.ModelSerializer):
model = models.Job model = models.Job
fields = ('url', 'id', 'task_id', 'project_id', 'assignee', fields = ('url', 'id', 'task_id', 'project_id', 'assignee',
'dimension', 'labels', 'bug_tracker', 'status', 'stage', 'state', 'mode', 'dimension', 'labels', 'bug_tracker', 'status', 'stage', 'state', 'mode',
'start_frame', 'stop_frame', 'data_chunk_size', 'data_compressed_chunk_type') 'start_frame', 'stop_frame', 'data_chunk_size', 'data_compressed_chunk_type',
'updated_date',)
read_only_fields = fields read_only_fields = fields
class JobWriteSerializer(serializers.ModelSerializer): class JobWriteSerializer(serializers.ModelSerializer):
@ -336,7 +337,7 @@ class DataSerializer(serializers.ModelSerializer):
# pylint: disable=no-self-use # pylint: disable=no-self-use
def validate_frame_filter(self, value): def validate_frame_filter(self, value):
match = re.search("step\s*=\s*([1-9]\d*)", value) match = re.search(r"step\s*=\s*([1-9]\d*)", value)
if not match: if not match:
raise serializers.ValidationError("Invalid frame filter expression") raise serializers.ValidationError("Invalid frame filter expression")
return value return value
@ -348,11 +349,11 @@ class DataSerializer(serializers.ModelSerializer):
return value return value
# pylint: disable=no-self-use # pylint: disable=no-self-use
def validate(self, data): def validate(self, attrs):
if 'start_frame' in data and 'stop_frame' in data \ if 'start_frame' in attrs and 'stop_frame' in attrs \
and data['start_frame'] > data['stop_frame']: and attrs['start_frame'] > attrs['stop_frame']:
raise serializers.ValidationError('Stop frame must be more or equal start frame') raise serializers.ValidationError('Stop frame must be more or equal start frame')
return data return attrs
def create(self, validated_data): def create(self, validated_data):
files = self._pop_data(validated_data) files = self._pop_data(validated_data)
@ -404,8 +405,12 @@ class DataSerializer(serializers.ModelSerializer):
remote_file = models.RemoteFile(data=instance, **f) remote_file = models.RemoteFile(data=instance, **f)
remote_file.save() remote_file.save()
class StorageSerializer(serializers.ModelSerializer):
class Meta:
model = models.Storage
fields = ('id', 'location', 'cloud_storage_id')
class TaskSerializer(WriteOnceMixin, serializers.ModelSerializer): class TaskReadSerializer(serializers.ModelSerializer):
labels = LabelSerializer(many=True, source='label_set', partial=True, required=False) labels = LabelSerializer(many=True, source='label_set', partial=True, required=False)
segments = SegmentSerializer(many=True, source='segment_set', read_only=True) segments = SegmentSerializer(many=True, source='segment_set', read_only=True)
data_chunk_size = serializers.ReadOnlyField(source='data.chunk_size') data_chunk_size = serializers.ReadOnlyField(source='data.chunk_size')
@ -415,24 +420,47 @@ class TaskSerializer(WriteOnceMixin, serializers.ModelSerializer):
image_quality = serializers.ReadOnlyField(source='data.image_quality') image_quality = serializers.ReadOnlyField(source='data.image_quality')
data = serializers.ReadOnlyField(source='data.id') data = serializers.ReadOnlyField(source='data.id')
owner = BasicUserSerializer(required=False) owner = BasicUserSerializer(required=False)
owner_id = serializers.IntegerField(write_only=True, allow_null=True, required=False)
assignee = BasicUserSerializer(allow_null=True, required=False) assignee = BasicUserSerializer(allow_null=True, required=False)
assignee_id = serializers.IntegerField(write_only=True, allow_null=True, required=False)
project_id = serializers.IntegerField(required=False, allow_null=True) project_id = serializers.IntegerField(required=False, allow_null=True)
dimension = serializers.CharField(allow_blank=True, required=False) dimension = serializers.CharField(allow_blank=True, required=False)
target_storage = StorageSerializer(required=False)
source_storage = StorageSerializer(required=False)
class Meta: class Meta:
model = models.Task model = models.Task
fields = ('url', 'id', 'name', 'project_id', 'mode', 'owner', 'assignee', fields = ('url', 'id', 'name', 'project_id', 'mode', 'owner', 'assignee',
'owner_id', 'assignee_id', 'bug_tracker', 'created_date', 'updated_date', 'bug_tracker', 'created_date', 'updated_date', 'overlap', 'segment_size',
'overlap', 'segment_size', 'status', 'labels', 'segments', 'status', 'labels', 'segments', 'data_chunk_size', 'data_compressed_chunk_type',
'data_chunk_size', 'data_compressed_chunk_type', 'data_original_chunk_type', 'data_original_chunk_type', 'size', 'image_quality', 'data', 'dimension',
'size', 'image_quality', 'data', 'dimension', 'subset', 'organization') 'subset', 'organization', 'target_storage', 'source_storage',
read_only_fields = ('mode', 'created_date', 'updated_date', 'status', )
'data_chunk_size', 'owner', 'assignee', 'data_compressed_chunk_type', read_only_fields = fields
'data_original_chunk_type', 'size', 'image_quality', 'data',
'organization') def to_representation(self, instance):
write_once_fields = ('overlap', 'segment_size', 'project_id') response = super().to_representation(instance)
if instance.project_id:
response["labels"] = LabelSerializer(many=True).to_representation(instance.project.label_set)
return response
class TaskWriteSerializer(WriteOnceMixin, serializers.ModelSerializer):
labels = LabelSerializer(many=True, source='label_set', partial=True, required=False)
owner_id = serializers.IntegerField(write_only=True, allow_null=True, required=False)
assignee_id = serializers.IntegerField(write_only=True, allow_null=True, required=False)
project_id = serializers.IntegerField(required=False, allow_null=True)
target_storage = StorageSerializer(required=False)
source_storage = StorageSerializer(required=False)
class Meta:
model = models.Task
fields = ('url', 'id', 'name', 'project_id', 'owner_id', 'assignee_id',
'bug_tracker', 'overlap', 'segment_size', 'labels', 'subset',
'target_storage', 'source_storage',
)
write_once_fields = ('overlap', 'segment_size', 'project_id', 'owner_id', 'labels')
def to_representation(self, instance):
serializer = TaskReadSerializer(instance, context=self.context)
return serializer.data
# pylint: disable=no-self-use # pylint: disable=no-self-use
def create(self, validated_data): def create(self, validated_data):
@ -453,7 +481,17 @@ class TaskSerializer(WriteOnceMixin, serializers.ModelSerializer):
raise serializers.ValidationError(f'The task and its project should be in the same organization.') raise serializers.ValidationError(f'The task and its project should be in the same organization.')
labels = validated_data.pop('label_set', []) labels = validated_data.pop('label_set', [])
db_task = models.Task.objects.create(**validated_data)
# configure source/target storages for import/export
storages = _configure_related_storages({
'source_storage': validated_data.pop('source_storage', None),
'target_storage': validated_data.pop('target_storage', None),
})
db_task = models.Task.objects.create(
**storages,
**validated_data)
label_colors = list() label_colors = list()
for label in labels: for label in labels:
attributes = label.pop('attributespec_set') attributes = label.pop('attributespec_set')
@ -468,7 +506,7 @@ class TaskSerializer(WriteOnceMixin, serializers.ModelSerializer):
del attr['id'] del attr['id']
models.AttributeSpec.objects.create(label=db_label, **attr) models.AttributeSpec.objects.create(label=db_label, **attr)
task_path = db_task.get_task_dirname() task_path = db_task.get_dirname()
if os.path.isdir(task_path): if os.path.isdir(task_path):
shutil.rmtree(task_path) shutil.rmtree(task_path)
@ -478,12 +516,6 @@ class TaskSerializer(WriteOnceMixin, serializers.ModelSerializer):
db_task.save() db_task.save()
return db_task return db_task
def to_representation(self, instance):
response = super().to_representation(instance)
if instance.project_id:
response["labels"] = LabelSerializer(many=True).to_representation(instance.project.label_set)
return response
# pylint: disable=no-self-use # pylint: disable=no-self-use
def update(self, instance, validated_data): def update(self, instance, validated_data):
instance.name = validated_data.get('name', instance.name) instance.name = validated_data.get('name', instance.name)
@ -500,7 +532,7 @@ class TaskSerializer(WriteOnceMixin, serializers.ModelSerializer):
if validated_project_id is not None and validated_project_id != instance.project_id: if validated_project_id is not None and validated_project_id != instance.project_id:
project = models.Project.objects.get(id=validated_project_id) project = models.Project.objects.get(id=validated_project_id)
if project.tasks.count() and project.tasks.first().dimension != instance.dimension: if project.tasks.count() and project.tasks.first().dimension != instance.dimension:
raise serializers.ValidationError(f'Dimension ({instance.dimension}) of the task must be the same as other tasks in project ({project.tasks.first().dimension})') raise serializers.ValidationError(f'Dimension ({instance.dimension}) of the task must be the same as other tasks in project ({project.tasks.first().dimension})')
if instance.project_id is None: if instance.project_id is None:
for old_label in instance.label_set.all(): for old_label in instance.label_set.all():
try: try:
@ -536,6 +568,9 @@ class TaskSerializer(WriteOnceMixin, serializers.ModelSerializer):
) )
instance.project = project instance.project = project
# update source and target storages
_update_related_storages(instance, validated_data)
instance.save() instance.save()
return instance return instance
@ -547,6 +582,7 @@ class TaskSerializer(WriteOnceMixin, serializers.ModelSerializer):
project = models.Project.objects.filter(id=project_id).first() project = models.Project.objects.filter(id=project_id).first()
if project is None: if project is None:
raise serializers.ValidationError(f'Cannot find project with ID {project_id}') raise serializers.ValidationError(f'Cannot find project with ID {project_id}')
# Check that all labels can be mapped # Check that all labels can be mapped
new_label_names = set() new_label_names = set()
old_labels = self.instance.project.label_set.all() if self.instance.project_id else self.instance.label_set.all() old_labels = self.instance.project.label_set.all() if self.instance.project_id else self.instance.label_set.all()
@ -577,22 +613,26 @@ class ProjectSearchSerializer(serializers.ModelSerializer):
fields = ('id', 'name') fields = ('id', 'name')
read_only_fields = ('name',) read_only_fields = ('name',)
class ProjectSerializer(serializers.ModelSerializer): class ProjectReadSerializer(serializers.ModelSerializer):
labels = LabelSerializer(many=True, source='label_set', partial=True, default=[]) labels = LabelSerializer(many=True, source='label_set', partial=True, default=[])
owner = BasicUserSerializer(required=False, read_only=True) owner = BasicUserSerializer(required=False, read_only=True)
owner_id = serializers.IntegerField(write_only=True, allow_null=True, required=False) assignee = BasicUserSerializer(allow_null=True, required=False, read_only=True)
assignee = BasicUserSerializer(allow_null=True, required=False)
assignee_id = serializers.IntegerField(write_only=True, allow_null=True, required=False)
task_subsets = serializers.ListField(child=serializers.CharField(), required=False) task_subsets = serializers.ListField(child=serializers.CharField(), required=False)
dimension = serializers.CharField(max_length=16, required=False, read_only=True) dimension = serializers.CharField(max_length=16, required=False, read_only=True)
target_storage = StorageSerializer(required=False)
source_storage = StorageSerializer(required=False)
class Meta: class Meta:
model = models.Project model = models.Project
fields = ('url', 'id', 'name', 'labels', 'tasks', 'owner', 'assignee', fields = ('url', 'id', 'name', 'labels', 'tasks', 'owner', 'assignee',
'owner_id', 'assignee_id', 'bug_tracker', 'task_subsets', 'bug_tracker', 'task_subsets', # 'owner_id', 'assignee_id',
'created_date', 'updated_date', 'status', 'dimension', 'organization') 'created_date', 'updated_date', 'status', 'dimension', 'organization',
'target_storage', 'source_storage',
)
read_only_fields = ('created_date', 'updated_date', 'status', 'owner', read_only_fields = ('created_date', 'updated_date', 'status', 'owner',
'assignee', 'task_subsets', 'dimension', 'organization', 'tasks') 'assignee', 'task_subsets', 'dimension', 'organization', 'tasks',
'target_storage', 'source_storage',
)
def to_representation(self, instance): def to_representation(self, instance):
response = super().to_representation(instance) response = super().to_representation(instance)
@ -602,10 +642,38 @@ class ProjectSerializer(serializers.ModelSerializer):
response['dimension'] = instance.tasks.first().dimension if instance.tasks.count() else None response['dimension'] = instance.tasks.first().dimension if instance.tasks.count() else None
return response return response
class ProjectWriteSerializer(serializers.ModelSerializer):
labels = LabelSerializer(many=True, source='label_set', partial=True, default=[])
owner_id = serializers.IntegerField(write_only=True, allow_null=True, required=False)
assignee_id = serializers.IntegerField(write_only=True, allow_null=True, required=False)
task_subsets = serializers.ListField(child=serializers.CharField(), required=False)
target_storage = StorageSerializer(required=False)
source_storage = StorageSerializer(required=False)
class Meta:
model = models.Project
fields = ('name', 'labels', 'owner_id', 'assignee_id', 'bug_tracker',
'target_storage', 'source_storage', 'task_subsets',
)
def to_representation(self, instance):
serializer = ProjectReadSerializer(instance, context=self.context)
return serializer.data
# pylint: disable=no-self-use # pylint: disable=no-self-use
def create(self, validated_data): def create(self, validated_data):
labels = validated_data.pop('label_set') labels = validated_data.pop('label_set')
db_project = models.Project.objects.create(**validated_data)
# configure source/target storages for import/export
storages = _configure_related_storages({
'source_storage': validated_data.pop('source_storage', None),
'target_storage': validated_data.pop('target_storage', None),
})
db_project = models.Project.objects.create(
**storages,
**validated_data)
label_colors = list() label_colors = list()
for label in labels: for label in labels:
if label.get('id', None): if label.get('id', None):
@ -620,7 +688,7 @@ class ProjectSerializer(serializers.ModelSerializer):
del attr['id'] del attr['id']
models.AttributeSpec.objects.create(label=db_label, **attr) models.AttributeSpec.objects.create(label=db_label, **attr)
project_path = db_project.get_project_dirname() project_path = db_project.get_dirname()
if os.path.isdir(project_path): if os.path.isdir(project_path):
shutil.rmtree(project_path) shutil.rmtree(project_path)
os.makedirs(db_project.get_project_logs_dirname()) os.makedirs(db_project.get_project_logs_dirname())
@ -637,6 +705,9 @@ class ProjectSerializer(serializers.ModelSerializer):
for label in labels: for label in labels:
LabelSerializer.update_instance(label, instance) LabelSerializer.update_instance(label, instance)
# update source and target storages
_update_related_storages(instance, validated_data)
instance.save() instance.save()
return instance return instance
@ -976,19 +1047,19 @@ class CloudStorageWriteSerializer(serializers.ModelSerializer):
@staticmethod @staticmethod
def _manifests_validation(storage, manifests): def _manifests_validation(storage, manifests):
# check manifest files availability # check manifest files availability
for manifest in manifests: for manifest in manifests:
file_status = storage.get_file_status(manifest) file_status = storage.get_file_status(manifest)
if file_status == Status.NOT_FOUND: if file_status == Status.NOT_FOUND:
raise serializers.ValidationError({ raise serializers.ValidationError({
'manifests': "The '{}' file does not exist on '{}' cloud storage" \ 'manifests': "The '{}' file does not exist on '{}' cloud storage" \
.format(manifest, storage.name) .format(manifest, storage.name)
}) })
elif file_status == Status.FORBIDDEN: elif file_status == Status.FORBIDDEN:
raise serializers.ValidationError({ raise serializers.ValidationError({
'manifests': "The '{}' file does not available on '{}' cloud storage. Access denied" \ 'manifests': "The '{}' file does not available on '{}' cloud storage. Access denied" \
.format(manifest, storage.name) .format(manifest, storage.name)
}) })
def create(self, validated_data): def create(self, validated_data):
provider_type = validated_data.get('provider_type') provider_type = validated_data.get('provider_type')
@ -1140,3 +1211,61 @@ class RelatedFileSerializer(serializers.ModelSerializer):
model = models.RelatedFile model = models.RelatedFile
fields = '__all__' fields = '__all__'
read_only_fields = ('path',) read_only_fields = ('path',)
def _update_related_storages(instance, validated_data):
for storage in ('source_storage', 'target_storage'):
new_conf = validated_data.pop(storage, None)
if not new_conf:
continue
cloud_storage_id = new_conf.get('cloud_storage_id')
if cloud_storage_id:
_validate_existence_of_cloud_storage(cloud_storage_id)
# storage_instance maybe None
storage_instance = getattr(instance, storage)
if not storage_instance:
storage_instance = models.Storage(**new_conf)
storage_instance.save()
setattr(instance, storage, storage_instance)
continue
new_location = new_conf.get('location')
storage_instance.location = new_location or storage_instance.location
storage_instance.cloud_storage_id = new_conf.get('cloud_storage_id', \
storage_instance.cloud_storage_id if not new_location else None)
cloud_storage_id = storage_instance.cloud_storage_id
if cloud_storage_id:
try:
_ = models.CloudStorage.objects.get(id=cloud_storage_id)
except models.CloudStorage.DoesNotExist:
raise serializers.ValidationError(f'The specified cloud storage {cloud_storage_id} does not exist.')
storage_instance.save()
def _configure_related_storages(validated_data):
storages = {
'source_storage': None,
'target_storage': None,
}
for i in storages:
storage_conf = validated_data.get(i)
if storage_conf:
cloud_storage_id = storage_conf.get('cloud_storage_id')
if cloud_storage_id:
_validate_existence_of_cloud_storage(cloud_storage_id)
storage_instance = models.Storage(**storage_conf)
storage_instance.save()
storages[i] = storage_instance
return storages
def _validate_existence_of_cloud_storage(cloud_storage_id):
try:
_ = models.CloudStorage.objects.get(id=cloud_storage_id)
except models.CloudStorage.DoesNotExist:
raise serializers.ValidationError(f'The specified cloud storage {cloud_storage_id} does not exist.')

@ -39,7 +39,7 @@ def create_profile(instance, **kwargs):
@receiver(post_delete, sender=Task, dispatch_uid="delete_task_files_on_delete_task") @receiver(post_delete, sender=Task, dispatch_uid="delete_task_files_on_delete_task")
def delete_task_files_on_delete_task(instance, **kwargs): def delete_task_files_on_delete_task(instance, **kwargs):
shutil.rmtree(instance.get_task_dirname(), ignore_errors=True) shutil.rmtree(instance.get_dirname(), ignore_errors=True)
@receiver(post_delete, sender=Data, dispatch_uid="delete_data_files_on_delete_data") @receiver(post_delete, sender=Data, dispatch_uid="delete_data_files_on_delete_data")

@ -112,6 +112,12 @@ def _save_task_to_db(db_task):
db_job = models.Job(segment=db_segment) db_job = models.Job(segment=db_segment)
db_job.save() db_job.save()
# create job directory
job_path = db_job.get_dirname()
if os.path.isdir(job_path):
shutil.rmtree(job_path)
os.makedirs(job_path)
db_task.data.save() db_task.data.save()
db_task.save() db_task.save()
@ -489,7 +495,7 @@ def _create_thread(db_task, data, isBackupRestore=False, isDatasetImport=False):
# calculate chunk size if it isn't specified # calculate chunk size if it isn't specified
if db_data.chunk_size is None: if db_data.chunk_size is None:
if isinstance(compressed_chunk_writer, ZipCompressedChunkWriter): if isinstance(compressed_chunk_writer, ZipCompressedChunkWriter):
if not (db_data.storage == models.StorageChoice.CLOUD_STORAGE): if not db_data.storage == models.StorageChoice.CLOUD_STORAGE:
w, h = extractor.get_image_size(0) w, h = extractor.get_image_size(0)
else: else:
img_properties = manifest[0] img_properties = manifest[0]
@ -507,7 +513,7 @@ def _create_thread(db_task, data, isBackupRestore=False, isDatasetImport=False):
job.save_meta() job.save_meta()
if settings.USE_CACHE and db_data.storage_method == models.StorageMethodChoice.CACHE: if settings.USE_CACHE and db_data.storage_method == models.StorageMethodChoice.CACHE:
for media_type, media_files in media.items(): for media_type, media_files in media.items():
if not media_files: if not media_files:
continue continue

@ -80,8 +80,8 @@ def create_db_task(data):
labels = data.pop('labels', None) labels = data.pop('labels', None)
db_task = Task.objects.create(**data) db_task = Task.objects.create(**data)
shutil.rmtree(db_task.get_task_dirname(), ignore_errors=True) shutil.rmtree(db_task.get_dirname(), ignore_errors=True)
os.makedirs(db_task.get_task_dirname()) os.makedirs(db_task.get_dirname())
os.makedirs(db_task.get_task_logs_dirname()) os.makedirs(db_task.get_task_logs_dirname())
os.makedirs(db_task.get_task_artifacts_dirname()) os.makedirs(db_task.get_task_artifacts_dirname())
db_task.data = db_data db_task.data = db_data
@ -117,8 +117,8 @@ def create_db_task(data):
def create_db_project(data): def create_db_project(data):
labels = data.pop('labels', None) labels = data.pop('labels', None)
db_project = Project.objects.create(**data) db_project = Project.objects.create(**data)
shutil.rmtree(db_project.get_project_dirname(), ignore_errors=True) shutil.rmtree(db_project.get_dirname(), ignore_errors=True)
os.makedirs(db_project.get_project_dirname()) os.makedirs(db_project.get_dirname())
os.makedirs(db_project.get_project_logs_dirname()) os.makedirs(db_project.get_project_logs_dirname())
if not labels is None: if not labels is None:
@ -1979,11 +1979,11 @@ class TaskDeleteAPITestCase(APITestCase):
def test_api_v2_tasks_delete_task_data_after_delete_task(self): def test_api_v2_tasks_delete_task_data_after_delete_task(self):
for task in self.tasks: for task in self.tasks:
task_dir = task.get_task_dirname() task_dir = task.get_dirname()
self.assertTrue(os.path.exists(task_dir)) self.assertTrue(os.path.exists(task_dir))
self._check_api_v2_tasks_id(self.admin) self._check_api_v2_tasks_id(self.admin)
for task in self.tasks: for task in self.tasks:
task_dir = task.get_task_dirname() task_dir = task.get_dirname()
self.assertFalse(os.path.exists(task_dir)) self.assertFalse(os.path.exists(task_dir))
class TaskUpdateAPITestCase(APITestCase): class TaskUpdateAPITestCase(APITestCase):
@ -2418,7 +2418,7 @@ class TaskMoveAPITestCase(APITestCase):
def _check_api_v2_tasks(self, tid, data, expected_status=status.HTTP_200_OK): def _check_api_v2_tasks(self, tid, data, expected_status=status.HTTP_200_OK):
response = self._run_api_v2_tasks_id(tid, data) response = self._run_api_v2_tasks_id(tid, data)
self.assertEqual(response.status_code, expected_status) self.assertEqual(response.status_code, expected_status)
if (expected_status == status.HTTP_200_OK): if expected_status == status.HTTP_200_OK:
self._check_response(response, data) self._check_response(response, data)
def test_move_task_bad_request(self): def test_move_task_bad_request(self):
@ -2936,6 +2936,8 @@ class TaskImportExportAPITestCase(APITestCase):
"created_date", "created_date",
"updated_date", "updated_date",
"data", "data",
"source_storage",
"target_storage",
), ),
) )

@ -15,6 +15,7 @@ from tempfile import mkstemp, NamedTemporaryFile
import cv2 import cv2
from django.db.models.query import Prefetch from django.db.models.query import Prefetch
from django.shortcuts import get_object_or_404
import django_rq import django_rq
from django.apps import apps from django.apps import apps
from django.conf import settings from django.conf import settings
@ -40,7 +41,8 @@ from django_sendfile import sendfile
import cvat.apps.dataset_manager as dm import cvat.apps.dataset_manager as dm
import cvat.apps.dataset_manager.views # pylint: disable=unused-import import cvat.apps.dataset_manager.views # pylint: disable=unused-import
from cvat.apps.engine.cloud_provider import get_cloud_storage_instance, Credentials, Status as CloudStorageStatus from cvat.apps.engine.cloud_provider import (
db_storage_to_storage_instance, validate_bucket_status, Status as CloudStorageStatus)
from cvat.apps.dataset_manager.bindings import CvatImportError from cvat.apps.dataset_manager.bindings import CvatImportError
from cvat.apps.dataset_manager.serializers import DatasetFormatsSerializer from cvat.apps.dataset_manager.serializers import DatasetFormatsSerializer
from cvat.apps.engine.frame_provider import FrameProvider from cvat.apps.engine.frame_provider import FrameProvider
@ -49,22 +51,23 @@ from cvat.apps.engine.mime_types import mimetypes
from cvat.apps.engine.models import ( from cvat.apps.engine.models import (
Job, Task, Project, Issue, Data, Job, Task, Project, Issue, Data,
Comment, StorageMethodChoice, StorageChoice, Image, Comment, StorageMethodChoice, StorageChoice, Image,
CloudProviderChoice CloudProviderChoice, Location
) )
from cvat.apps.engine.models import CloudStorage as CloudStorageModel from cvat.apps.engine.models import CloudStorage as CloudStorageModel
from cvat.apps.engine.serializers import ( from cvat.apps.engine.serializers import (
AboutSerializer, AnnotationFileSerializer, BasicUserSerializer, AboutSerializer, AnnotationFileSerializer, BasicUserSerializer,
DataMetaReadSerializer, DataMetaWriteSerializer, DataSerializer, ExceptionSerializer, DataMetaReadSerializer, DataMetaWriteSerializer, DataSerializer, ExceptionSerializer,
FileInfoSerializer, JobReadSerializer, JobWriteSerializer, LabeledDataSerializer, FileInfoSerializer, JobReadSerializer, JobWriteSerializer, LabeledDataSerializer,
LogEventSerializer, ProjectSerializer, ProjectSearchSerializer, LogEventSerializer, ProjectReadSerializer, ProjectWriteSerializer, ProjectSearchSerializer,
RqStatusSerializer, TaskSerializer, UserSerializer, PluginsSerializer, IssueReadSerializer, RqStatusSerializer, TaskReadSerializer, TaskWriteSerializer, UserSerializer, PluginsSerializer, IssueReadSerializer,
IssueWriteSerializer, CommentReadSerializer, CommentWriteSerializer, CloudStorageWriteSerializer, IssueWriteSerializer, CommentReadSerializer, CommentWriteSerializer, CloudStorageWriteSerializer,
CloudStorageReadSerializer, DatasetFileSerializer, JobCommitSerializer) CloudStorageReadSerializer, DatasetFileSerializer, JobCommitSerializer,
ProjectFileSerializer, TaskFileSerializer)
from utils.dataset_manifest import ImageManifestManager from utils.dataset_manifest import ImageManifestManager
from cvat.apps.engine.utils import av_scan_paths from cvat.apps.engine.utils import av_scan_paths
from cvat.apps.engine import backup from cvat.apps.engine import backup
from cvat.apps.engine.mixins import UploadMixin from cvat.apps.engine.mixins import UploadMixin, AnnotationMixin, SerializeMixin
from . import models, task from . import models, task
from .log import clogger, slogger from .log import clogger, slogger
@ -72,6 +75,7 @@ from cvat.apps.iam.permissions import (CloudStoragePermission,
CommentPermission, IssuePermission, JobPermission, ProjectPermission, CommentPermission, IssuePermission, JobPermission, ProjectPermission,
TaskPermission, UserPermission) TaskPermission, UserPermission)
@extend_schema(tags=['server']) @extend_schema(tags=['server'])
class ServerViewSet(viewsets.ViewSet): class ServerViewSet(viewsets.ViewSet):
serializer_class = None serializer_class = None
@ -226,18 +230,18 @@ class ServerViewSet(viewsets.ViewSet):
responses={ responses={
'200': PolymorphicProxySerializer(component_name='PolymorphicProject', '200': PolymorphicProxySerializer(component_name='PolymorphicProject',
serializers=[ serializers=[
ProjectSerializer, ProjectSearchSerializer, ProjectReadSerializer, ProjectSearchSerializer,
], resource_type_field_name='name', many=True), ], resource_type_field_name='name', many=True),
}), }),
create=extend_schema( create=extend_schema(
summary='Method creates a new project', summary='Method creates a new project',
responses={ responses={
'201': ProjectSerializer, '201': ProjectWriteSerializer,
}), }),
retrieve=extend_schema( retrieve=extend_schema(
summary='Method returns details of a specific project', summary='Method returns details of a specific project',
responses={ responses={
'200': ProjectSerializer, '200': ProjectReadSerializer,
}), }),
destroy=extend_schema( destroy=extend_schema(
summary='Method deletes a specific project', summary='Method deletes a specific project',
@ -247,10 +251,10 @@ class ServerViewSet(viewsets.ViewSet):
partial_update=extend_schema( partial_update=extend_schema(
summary='Methods does a partial update of chosen fields in a project', summary='Methods does a partial update of chosen fields in a project',
responses={ responses={
'200': ProjectSerializer, '200': ProjectWriteSerializer,
}) })
) )
class ProjectViewSet(viewsets.ModelViewSet, UploadMixin): class ProjectViewSet(viewsets.ModelViewSet, UploadMixin, AnnotationMixin, SerializeMixin):
queryset = models.Project.objects.prefetch_related(Prefetch('label_set', queryset = models.Project.objects.prefetch_related(Prefetch('label_set',
queryset=models.Label.objects.order_by('id') queryset=models.Label.objects.order_by('id')
)) ))
@ -267,9 +271,12 @@ class ProjectViewSet(viewsets.ModelViewSet, UploadMixin):
def get_serializer_class(self): def get_serializer_class(self):
if self.request.path.endswith('tasks'): if self.request.path.endswith('tasks'):
return TaskSerializer return TaskReadSerializer
else: else:
return ProjectSerializer if self.request.method in SAFE_METHODS:
return ProjectReadSerializer
else:
return ProjectWriteSerializer
def get_queryset(self): def get_queryset(self):
queryset = super().get_queryset() queryset = super().get_queryset()
@ -285,9 +292,9 @@ class ProjectViewSet(viewsets.ModelViewSet, UploadMixin):
@extend_schema( @extend_schema(
summary='Method returns information of the tasks of the project with the selected id', summary='Method returns information of the tasks of the project with the selected id',
responses={ responses={
'200': TaskSerializer(many=True), '200': TaskReadSerializer(many=True),
}) })
@action(detail=True, methods=['GET'], serializer_class=TaskSerializer) @action(detail=True, methods=['GET'], serializer_class=TaskReadSerializer)
def tasks(self, request, pk): def tasks(self, request, pk):
self.get_object() # force to call check_object_permissions self.get_object() # force to call check_object_permissions
queryset = Task.objects.filter(project_id=pk).order_by('-id') queryset = Task.objects.filter(project_id=pk).order_by('-id')
@ -311,7 +318,15 @@ class ProjectViewSet(viewsets.ModelViewSet, UploadMixin):
OpenApiParameter('filename', description='Desired output file name', OpenApiParameter('filename', description='Desired output file name',
location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False), location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False),
OpenApiParameter('action', description='Used to start downloading process after annotation file had been created', OpenApiParameter('action', description='Used to start downloading process after annotation file had been created',
location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False, enum=['download', 'import_status']) location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False, enum=['download', 'import_status']),
OpenApiParameter('location', description='Where need to save downloaded dataset',
location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False,
enum=Location.list()),
OpenApiParameter('cloud_storage_id', description='Storage id',
location=OpenApiParameter.QUERY, type=OpenApiTypes.NUMBER, required=False),
OpenApiParameter('use_default_location', description='Use the location that was configured in project to import dataset',
location=OpenApiParameter.QUERY, type=OpenApiTypes.BOOL, required=False,
default=True),
], ],
responses={ responses={
'200': OpenApiResponse(description='Download of file started'), '200': OpenApiResponse(description='Download of file started'),
@ -323,7 +338,17 @@ class ProjectViewSet(viewsets.ModelViewSet, UploadMixin):
parameters=[ parameters=[
OpenApiParameter('format', description='Desired dataset format name\n' OpenApiParameter('format', description='Desired dataset format name\n'
'You can get the list of supported formats at:\n/server/annotation/formats', 'You can get the list of supported formats at:\n/server/annotation/formats',
location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=True) location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=True),
OpenApiParameter('location', description='Where to import the dataset from',
location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False,
enum=Location.list()),
OpenApiParameter('cloud_storage_id', description='Storage id',
location=OpenApiParameter.QUERY, type=OpenApiTypes.NUMBER, required=False),
OpenApiParameter('use_default_location', description='Use the location that was configured in the project to import annotations',
location=OpenApiParameter.QUERY, type=OpenApiTypes.BOOL, required=False,
default=True),
OpenApiParameter('filename', description='Dataset file name',
location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False),
], ],
responses={ responses={
'202': OpenApiResponse(description='Exporting has been started'), '202': OpenApiResponse(description='Exporting has been started'),
@ -335,8 +360,15 @@ class ProjectViewSet(viewsets.ModelViewSet, UploadMixin):
def dataset(self, request, pk): def dataset(self, request, pk):
self._object = self.get_object() # force to call check_object_permissions self._object = self.get_object() # force to call check_object_permissions
if request.method == 'POST' or request.method == 'OPTIONS': if request.method in {'POST', 'OPTIONS'}:
return self.upload_data(request)
return self.import_annotations(
request=request,
pk=pk,
db_obj=self._object,
import_func=_import_project_dataset,
rq_func=dm.project.import_dataset_as_project
)
else: else:
action = request.query_params.get("action", "").lower() action = request.query_params.get("action", "").lower()
if action in ("import_status",): if action in ("import_status",):
@ -363,15 +395,12 @@ class ProjectViewSet(viewsets.ModelViewSet, UploadMixin):
status=status.HTTP_202_ACCEPTED status=status.HTTP_202_ACCEPTED
) )
else: else:
format_name = request.query_params.get("format", "") return self.export_annotations(
return _export_annotations(
db_instance=self._object,
rq_id="/api/project/{}/dataset/{}".format(pk, format_name),
request=request, request=request,
action=action, pk=pk,
callback=dm.views.export_project_as_dataset, db_obj=self._object,
format_name=format_name, export_func=_export_annotations,
filename=request.query_params.get("filename", "").lower(), callback=dm.views.export_project_as_dataset
) )
@action(detail=True, methods=['HEAD', 'PATCH'], url_path='dataset/'+UploadMixin.file_id_regex) @action(detail=True, methods=['HEAD', 'PATCH'], url_path='dataset/'+UploadMixin.file_id_regex)
@ -423,7 +452,15 @@ class ProjectViewSet(viewsets.ModelViewSet, UploadMixin):
OpenApiParameter('filename', description='Desired output file name', OpenApiParameter('filename', description='Desired output file name',
location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False), location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False),
OpenApiParameter('action', description='Used to start downloading process after annotation file had been created', OpenApiParameter('action', description='Used to start downloading process after annotation file had been created',
location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False, enum=['download']) location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False, enum=['download']),
OpenApiParameter('location', description='Where need to save downloaded dataset',
location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False,
enum=Location.list()),
OpenApiParameter('cloud_storage_id', description='Storage id',
location=OpenApiParameter.QUERY, type=OpenApiTypes.NUMBER, required=False),
OpenApiParameter('use_default_location', description='Use the location that was configured in project to export annotation',
location=OpenApiParameter.QUERY, type=OpenApiTypes.BOOL, required=False,
default=True),
], ],
responses={ responses={
'200': OpenApiResponse(description='Download of file started'), '200': OpenApiResponse(description='Download of file started'),
@ -435,21 +472,30 @@ class ProjectViewSet(viewsets.ModelViewSet, UploadMixin):
@action(detail=True, methods=['GET'], @action(detail=True, methods=['GET'],
serializer_class=LabeledDataSerializer) serializer_class=LabeledDataSerializer)
def annotations(self, request, pk): def annotations(self, request, pk):
db_project = self.get_object() # force to call check_object_permissions self._object = self.get_object() # force to call check_object_permissions
format_name = request.query_params.get('format') return self.export_annotations(
if format_name: request=request,
return _export_annotations(db_instance=db_project, pk=pk,
rq_id="/api/projects/{}/annotations/{}".format(pk, format_name), db_obj=self._object,
request=request, export_func=_export_annotations,
action=request.query_params.get("action", "").lower(), callback=dm.views.export_project_annotations,
callback=dm.views.export_project_annotations, get_data=dm.task.get_job_data,
format_name=format_name, )
filename=request.query_params.get("filename", "").lower(),
)
else:
return Response("Format is not specified",status=status.HTTP_400_BAD_REQUEST)
@extend_schema(summary='Methods creates a backup copy of a project', @extend_schema(summary='Methods creates a backup copy of a project',
parameters=[
OpenApiParameter('action', location=OpenApiParameter.QUERY,
description='Used to start downloading process after backup file had been created',
type=OpenApiTypes.STR, required=False, enum=['download']),
OpenApiParameter('location', description='Where need to save downloaded backup',
location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False,
enum=Location.list()),
OpenApiParameter('cloud_storage_id', description='Storage id',
location=OpenApiParameter.QUERY, type=OpenApiTypes.NUMBER, required=False),
OpenApiParameter('use_default_location', description='Use the location that was configured in project to export backup',
location=OpenApiParameter.QUERY, type=OpenApiTypes.BOOL, required=False,
default=True),
],
responses={ responses={
'200': OpenApiResponse(description='Download of file started'), '200': OpenApiResponse(description='Download of file started'),
'201': OpenApiResponse(description='Output backup file is ready for downloading'), '201': OpenApiResponse(description='Output backup file is ready for downloading'),
@ -457,17 +503,26 @@ class ProjectViewSet(viewsets.ModelViewSet, UploadMixin):
}) })
@action(methods=['GET'], detail=True, url_path='backup') @action(methods=['GET'], detail=True, url_path='backup')
def export_backup(self, request, pk=None): def export_backup(self, request, pk=None):
db_project = self.get_object() # force to call check_object_permissions return self.serialize(request, backup.export)
return backup.export(db_project, request)
@extend_schema(summary='Methods create a project from a backup', @extend_schema(summary='Methods create a project from a backup',
parameters=[
OpenApiParameter('location', description='Where to import the backup file from',
location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False,
enum=Location.list(), default=Location.LOCAL),
OpenApiParameter('cloud_storage_id', description='Storage id',
location=OpenApiParameter.QUERY, type=OpenApiTypes.NUMBER, required=False),
OpenApiParameter('filename', description='Backup file name',
location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False),
],
request=ProjectFileSerializer(required=False),
responses={ responses={
'201': OpenApiResponse(description='The project has been imported'), # or better specify {id: project_id} '201': OpenApiResponse(description='The project has been imported'), # or better specify {id: project_id}
'202': OpenApiResponse(description='Importing a backup file has been started'), '202': OpenApiResponse(description='Importing a backup file has been started'),
}) })
@action(detail=False, methods=['OPTIONS', 'POST'], url_path=r'backup/?$') @action(detail=False, methods=['OPTIONS', 'POST'], url_path=r'backup/?$', serializer_class=ProjectFileSerializer(required=False))
def import_backup(self, request, pk=None): def import_backup(self, request, pk=None):
return self.upload_data(request) return self.deserialize(request, backup.import_project)
@action(detail=False, methods=['HEAD', 'PATCH'], url_path='backup/'+UploadMixin.file_id_regex) @action(detail=False, methods=['HEAD', 'PATCH'], url_path='backup/'+UploadMixin.file_id_regex)
def append_backup_chunk(self, request, file_id): def append_backup_chunk(self, request, file_id):
@ -522,6 +577,7 @@ class DataChunkGetter:
if self.type == 'chunk': if self.type == 'chunk':
start_chunk = frame_provider.get_chunk_number(start) start_chunk = frame_provider.get_chunk_number(start)
stop_chunk = frame_provider.get_chunk_number(stop) stop_chunk = frame_provider.get_chunk_number(stop)
# pylint: disable=superfluous-parens
if not (start_chunk <= self.number <= stop_chunk): if not (start_chunk <= self.number <= stop_chunk):
raise ValidationError('The chunk number should be in ' + raise ValidationError('The chunk number should be in ' +
f'[{start_chunk}, {stop_chunk}] range') f'[{start_chunk}, {stop_chunk}] range')
@ -571,20 +627,20 @@ class DataChunkGetter:
list=extend_schema( list=extend_schema(
summary='Returns a paginated list of tasks according to query parameters (10 tasks per page)', summary='Returns a paginated list of tasks according to query parameters (10 tasks per page)',
responses={ responses={
'200': TaskSerializer(many=True), '200': TaskReadSerializer(many=True),
}), }),
create=extend_schema( create=extend_schema(
summary='Method creates a new task in a database without any attached images and videos', summary='Method creates a new task in a database without any attached images and videos',
responses={ responses={
'201': TaskSerializer, '201': TaskWriteSerializer,
}), }),
retrieve=extend_schema( retrieve=extend_schema(
summary='Method returns details of a specific task', summary='Method returns details of a specific task',
responses=TaskSerializer), responses=TaskReadSerializer),
update=extend_schema( update=extend_schema(
summary='Method updates a task by id', summary='Method updates a task by id',
responses={ responses={
'200': TaskSerializer, '200': TaskWriteSerializer,
}), }),
destroy=extend_schema( destroy=extend_schema(
summary='Method deletes a specific task, all attached jobs, annotations, and data', summary='Method deletes a specific task, all attached jobs, annotations, and data',
@ -594,15 +650,14 @@ class DataChunkGetter:
partial_update=extend_schema( partial_update=extend_schema(
summary='Methods does a partial update of chosen fields in a task', summary='Methods does a partial update of chosen fields in a task',
responses={ responses={
'200': TaskSerializer, '200': TaskWriteSerializer,
}) })
) )
class TaskViewSet(UploadMixin, viewsets.ModelViewSet): class TaskViewSet(UploadMixin, AnnotationMixin, viewsets.ModelViewSet, SerializeMixin):
queryset = Task.objects.prefetch_related( queryset = Task.objects.prefetch_related(
Prefetch('label_set', queryset=models.Label.objects.order_by('id')), Prefetch('label_set', queryset=models.Label.objects.order_by('id')),
"label_set__attributespec_set", "label_set__attributespec_set",
"segment_set__job_set") "segment_set__job_set")
serializer_class = TaskSerializer
lookup_fields = {'project_name': 'project__name', 'owner': 'owner__username', 'assignee': 'assignee__username'} lookup_fields = {'project_name': 'project__name', 'owner': 'owner__username', 'assignee': 'assignee__username'}
search_fields = ('project_name', 'name', 'owner', 'status', 'assignee', 'subset', 'mode', 'dimension') search_fields = ('project_name', 'name', 'owner', 'status', 'assignee', 'subset', 'mode', 'dimension')
filter_fields = list(search_fields) + ['id', 'project_id', 'updated_date'] filter_fields = list(search_fields) + ['id', 'project_id', 'updated_date']
@ -610,6 +665,12 @@ class TaskViewSet(UploadMixin, viewsets.ModelViewSet):
ordering = "-id" ordering = "-id"
iam_organization_field = 'organization' iam_organization_field = 'organization'
def get_serializer_class(self):
if self.request.method in SAFE_METHODS:
return TaskReadSerializer
else:
return TaskWriteSerializer
def get_queryset(self): def get_queryset(self):
queryset = super().get_queryset() queryset = super().get_queryset()
if self.action == 'list': if self.action == 'list':
@ -619,19 +680,42 @@ class TaskViewSet(UploadMixin, viewsets.ModelViewSet):
return queryset return queryset
@extend_schema(summary='Method recreates a task from an attached task backup file', @extend_schema(summary='Method recreates a task from an attached task backup file',
parameters=[
OpenApiParameter('location', description='Where to import the backup file from',
location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False,
enum=Location.list(), default=Location.LOCAL),
OpenApiParameter('cloud_storage_id', description='Storage id',
location=OpenApiParameter.QUERY, type=OpenApiTypes.NUMBER, required=False),
OpenApiParameter('filename', description='Backup file name',
location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False),
],
request=TaskFileSerializer(required=False),
responses={ responses={
'201': OpenApiResponse(description='The task has been imported'), # or better specify {id: task_id} '201': OpenApiResponse(description='The task has been imported'), # or better specify {id: task_id}
'202': OpenApiResponse(description='Importing a backup file has been started'), '202': OpenApiResponse(description='Importing a backup file has been started'),
}) })
@action(detail=False, methods=['OPTIONS', 'POST'], url_path=r'backup/?$') @action(detail=False, methods=['OPTIONS', 'POST'], url_path=r'backup/?$', serializer_class=TaskFileSerializer(required=False))
def import_backup(self, request, pk=None): def import_backup(self, request, pk=None):
return self.upload_data(request) return self.deserialize(request, backup.import_task)
@action(detail=False, methods=['HEAD', 'PATCH'], url_path='backup/'+UploadMixin.file_id_regex) @action(detail=False, methods=['HEAD', 'PATCH'], url_path='backup/'+UploadMixin.file_id_regex)
def append_backup_chunk(self, request, file_id): def append_backup_chunk(self, request, file_id):
return self.append_tus_chunk(request, file_id) return self.append_tus_chunk(request, file_id)
@extend_schema(summary='Method backup a specified task', @extend_schema(summary='Method backup a specified task',
parameters=[
OpenApiParameter('action', location=OpenApiParameter.QUERY,
description='Used to start downloading process after backup file had been created',
type=OpenApiTypes.STR, required=False, enum=['download']),
OpenApiParameter('location', description='Where need to save downloaded backup',
location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False,
enum=Location.list()),
OpenApiParameter('cloud_storage_id', description='Storage id',
location=OpenApiParameter.QUERY, type=OpenApiTypes.NUMBER, required=False),
OpenApiParameter('use_default_location', description='Use the location that was configured in the task to export backup',
location=OpenApiParameter.QUERY, type=OpenApiTypes.BOOL, required=False,
default=True),
],
responses={ responses={
'200': OpenApiResponse(description='Download of file started'), '200': OpenApiResponse(description='Download of file started'),
'201': OpenApiResponse(description='Output backup file is ready for downloading'), '201': OpenApiResponse(description='Output backup file is ready for downloading'),
@ -639,8 +723,7 @@ class TaskViewSet(UploadMixin, viewsets.ModelViewSet):
}) })
@action(methods=['GET'], detail=True, url_path='backup') @action(methods=['GET'], detail=True, url_path='backup')
def export_backup(self, request, pk=None): def export_backup(self, request, pk=None):
db_task = self.get_object() # force to call check_object_permissions return self.serialize(request, backup.export)
return backup.export(db_task, request)
def perform_update(self, serializer): def perform_update(self, serializer):
instance = serializer.instance instance = serializer.instance
@ -659,7 +742,7 @@ class TaskViewSet(UploadMixin, viewsets.ModelViewSet):
assert instance.organization == db_project.organization assert instance.organization == db_project.organization
def perform_destroy(self, instance): def perform_destroy(self, instance):
task_dirname = instance.get_task_dirname() task_dirname = instance.get_dirname()
super().perform_destroy(instance) super().perform_destroy(instance)
shutil.rmtree(task_dirname, ignore_errors=True) shutil.rmtree(task_dirname, ignore_errors=True)
if instance.data and not instance.data.tasks.all(): if instance.data and not instance.data.tasks.all():
@ -695,6 +778,7 @@ class TaskViewSet(UploadMixin, viewsets.ModelViewSet):
# UploadMixin method # UploadMixin method
def upload_finished(self, request): def upload_finished(self, request):
if self.action == 'annotations': if self.action == 'annotations':
# db_task = self.get_object()
format_name = request.query_params.get("format", "") format_name = request.query_params.get("format", "")
filename = request.query_params.get("filename", "") filename = request.query_params.get("filename", "")
tmp_dir = self._object.get_tmp_dirname() tmp_dir = self._object.get_tmp_dirname()
@ -823,7 +907,15 @@ class TaskViewSet(UploadMixin, viewsets.ModelViewSet):
location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False), location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False),
OpenApiParameter('action', location=OpenApiParameter.QUERY, OpenApiParameter('action', location=OpenApiParameter.QUERY,
description='Used to start downloading process after annotation file had been created', description='Used to start downloading process after annotation file had been created',
type=OpenApiTypes.STR, required=False, enum=['download']) type=OpenApiTypes.STR, required=False, enum=['download']),
OpenApiParameter('location', description='Where need to save downloaded dataset',
location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False,
enum=Location.list()),
OpenApiParameter('cloud_storage_id', description='Storage id',
location=OpenApiParameter.QUERY, type=OpenApiTypes.NUMBER, required=False),
OpenApiParameter('use_default_location', description='Use the location that was configured in the task to export annotation',
location=OpenApiParameter.QUERY, type=OpenApiTypes.BOOL, required=False,
default=True),
], ],
responses={ responses={
'200': OpenApiResponse(description='Download of file started'), '200': OpenApiResponse(description='Download of file started'),
@ -841,6 +933,26 @@ class TaskViewSet(UploadMixin, viewsets.ModelViewSet):
'202': OpenApiResponse(description='Uploading has been started'), '202': OpenApiResponse(description='Uploading has been started'),
'405': OpenApiResponse(description='Format is not available'), '405': OpenApiResponse(description='Format is not available'),
}) })
@extend_schema(methods=['POST'], summary='Method allows to upload task annotations from storage',
parameters=[
OpenApiParameter('format', location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False,
description='Input format name\nYou can get the list of supported formats at:\n/server/annotation/formats'),
OpenApiParameter('location', description='where to import the annotation from',
location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False,
enum=Location.list()),
OpenApiParameter('cloud_storage_id', description='Storage id',
location=OpenApiParameter.QUERY, type=OpenApiTypes.NUMBER, required=False),
OpenApiParameter('use_default_location', description='Use the location that was configured in task to import annotations',
location=OpenApiParameter.QUERY, type=OpenApiTypes.BOOL, required=False,
default=True),
OpenApiParameter('filename', description='Annotation file name',
location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False),
],
responses={
'201': OpenApiResponse(description='Uploading has finished'),
'202': OpenApiResponse(description='Uploading has been started'),
'405': OpenApiResponse(description='Format is not available'),
})
@extend_schema(methods=['PATCH'], summary='Method performs a partial update of annotations in a specific task', @extend_schema(methods=['PATCH'], summary='Method performs a partial update of annotations in a specific task',
parameters=[ parameters=[
OpenApiParameter('action', location=OpenApiParameter.QUERY, required=True, OpenApiParameter('action', location=OpenApiParameter.QUERY, required=True,
@ -851,27 +963,26 @@ class TaskViewSet(UploadMixin, viewsets.ModelViewSet):
'204': OpenApiResponse(description='The annotation has been deleted'), '204': OpenApiResponse(description='The annotation has been deleted'),
}) })
@action(detail=True, methods=['GET', 'DELETE', 'PUT', 'PATCH', 'POST', 'OPTIONS'], url_path=r'annotations/?$', @action(detail=True, methods=['GET', 'DELETE', 'PUT', 'PATCH', 'POST', 'OPTIONS'], url_path=r'annotations/?$',
serializer_class=LabeledDataSerializer) serializer_class=LabeledDataSerializer(required=False))
def annotations(self, request, pk): def annotations(self, request, pk):
self._object = self.get_object() # force to call check_object_permissions self._object = self.get_object() # force to call check_object_permissions
if request.method == 'GET': if request.method == 'GET':
format_name = request.query_params.get('format') return self.export_annotations(
if format_name: request=request,
return _export_annotations(db_instance=self._object, pk=pk,
rq_id="/api/tasks/{}/annotations/{}".format(pk, format_name), db_obj=self._object,
request=request, export_func=_export_annotations,
action=request.query_params.get("action", "").lower(), callback=dm.views.export_task_annotations,
callback=dm.views.export_task_annotations, get_data=dm.task.get_task_data,
format_name=format_name, )
filename=request.query_params.get("filename", "").lower(),
)
else:
data = dm.task.get_task_data(pk)
serializer = LabeledDataSerializer(data=data)
if serializer.is_valid(raise_exception=True):
return Response(serializer.data)
elif request.method == 'POST' or request.method == 'OPTIONS': elif request.method == 'POST' or request.method == 'OPTIONS':
return self.upload_data(request) return self.import_annotations(
request=request,
pk=pk,
db_obj=self._object,
import_func=_import_annotations,
rq_func=dm.task.import_task_annotations,
)
elif request.method == 'PUT': elif request.method == 'PUT':
format_name = request.query_params.get('format') format_name = request.query_params.get('format')
if format_name: if format_name:
@ -991,7 +1102,15 @@ class TaskViewSet(UploadMixin, viewsets.ModelViewSet):
location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False), location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False),
OpenApiParameter('action', location=OpenApiParameter.QUERY, OpenApiParameter('action', location=OpenApiParameter.QUERY,
description='Used to start downloading process after annotation file had been created', description='Used to start downloading process after annotation file had been created',
type=OpenApiTypes.STR, required=False, enum=['download']) type=OpenApiTypes.STR, required=False, enum=['download']),
OpenApiParameter('use_default_location', description='Use the location that was configured in task to export annotations',
location=OpenApiParameter.QUERY, type=OpenApiTypes.BOOL, required=False,
default=True),
OpenApiParameter('location', description='Where need to save downloaded dataset',
location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False,
enum=Location.list()),
OpenApiParameter('cloud_storage_id', description='Storage id',
location=OpenApiParameter.QUERY, type=OpenApiTypes.NUMBER, required=False),
], ],
responses={ responses={
'200': OpenApiResponse(description='Download of file started'), '200': OpenApiResponse(description='Download of file started'),
@ -1002,16 +1121,14 @@ class TaskViewSet(UploadMixin, viewsets.ModelViewSet):
@action(detail=True, methods=['GET'], serializer_class=None, @action(detail=True, methods=['GET'], serializer_class=None,
url_path='dataset') url_path='dataset')
def dataset_export(self, request, pk): def dataset_export(self, request, pk):
db_task = self.get_object() # force to call check_object_permissions self._object = self.get_object() # force to call check_object_permissions
format_name = request.query_params.get("format", "") return self.export_annotations(
return _export_annotations(db_instance=db_task,
rq_id="/api/tasks/{}/dataset/{}".format(pk, format_name),
request=request, request=request,
action=request.query_params.get("action", "").lower(), pk=pk,
callback=dm.views.export_task_as_dataset, db_obj=self._object,
format_name=format_name, export_func=_export_annotations,
filename=request.query_params.get("filename", "").lower(), callback=dm.views.export_task_as_dataset
) )
@extend_schema(tags=['jobs']) @extend_schema(tags=['jobs'])
@ -1038,7 +1155,7 @@ class TaskViewSet(UploadMixin, viewsets.ModelViewSet):
}) })
) )
class JobViewSet(viewsets.GenericViewSet, mixins.ListModelMixin, class JobViewSet(viewsets.GenericViewSet, mixins.ListModelMixin,
mixins.RetrieveModelMixin, mixins.UpdateModelMixin, UploadMixin): mixins.RetrieveModelMixin, mixins.UpdateModelMixin, UploadMixin, AnnotationMixin):
queryset = Job.objects.all() queryset = Job.objects.all()
iam_organization_field = 'segment__task__organization' iam_organization_field = 'segment__task__organization'
search_fields = ('task_name', 'project_name', 'assignee', 'state', 'stage') search_fields = ('task_name', 'project_name', 'assignee', 'state', 'stage')
@ -1051,7 +1168,6 @@ class JobViewSet(viewsets.GenericViewSet, mixins.ListModelMixin,
'project_id': 'segment__task__project_id', 'project_id': 'segment__task__project_id',
'task_name': 'segment__task__name', 'task_name': 'segment__task__name',
'project_name': 'segment__task__project__name', 'project_name': 'segment__task__project__name',
'updated_date': 'segment__task__updated_date',
'assignee': 'assignee__username' 'assignee': 'assignee__username'
} }
@ -1099,8 +1215,49 @@ class JobViewSet(viewsets.GenericViewSet, mixins.ListModelMixin,
status=status.HTTP_400_BAD_REQUEST) status=status.HTTP_400_BAD_REQUEST)
@extend_schema(methods=['GET'], summary='Method returns annotations for a specific job', @extend_schema(methods=['GET'], summary='Method returns annotations for a specific job',
parameters=[
OpenApiParameter('format', location=OpenApiParameter.QUERY,
description='Desired output format name\nYou can get the list of supported formats at:\n/server/annotation/formats',
type=OpenApiTypes.STR, required=False),
OpenApiParameter('filename', description='Desired output file name',
location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False),
OpenApiParameter('action', location=OpenApiParameter.QUERY,
description='Used to start downloading process after annotation file had been created',
type=OpenApiTypes.STR, required=False, enum=['download']),
OpenApiParameter('location', description='Where need to save downloaded annotation',
location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False,
enum=Location.list()),
OpenApiParameter('cloud_storage_id', description='Storage id',
location=OpenApiParameter.QUERY, type=OpenApiTypes.NUMBER, required=False),
OpenApiParameter('use_default_location', description='Use the location that was configured in the task to export annotation',
location=OpenApiParameter.QUERY, type=OpenApiTypes.BOOL, required=False,
default=True),
],
responses={
'200': LabeledDataSerializer,
'201': OpenApiResponse(description='Output file is ready for downloading'),
'202': OpenApiResponse(description='Exporting has been started'),
'405': OpenApiResponse(description='Format is not available'),
})
@extend_schema(methods=['POST'], summary='Method allows to upload job annotations',
parameters=[
OpenApiParameter('format', location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False,
description='Input format name\nYou can get the list of supported formats at:\n/server/annotation/formats'),
OpenApiParameter('location', description='where to import the annotation from',
location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False,
enum=Location.list()),
OpenApiParameter('cloud_storage_id', description='Storage id',
location=OpenApiParameter.QUERY, type=OpenApiTypes.NUMBER, required=False),
OpenApiParameter('use_default_location', description='Use the location that was configured in the task to import annotation',
location=OpenApiParameter.QUERY, type=OpenApiTypes.BOOL, required=False,
default=True),
OpenApiParameter('filename', description='Annotation file name',
location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False),
],
responses={ responses={
'200': LabeledDataSerializer(many=True), '201': OpenApiResponse(description='Uploading has finished'),
'202': OpenApiResponse(description='Uploading has been started'),
'405': OpenApiResponse(description='Format is not available'),
}) })
@extend_schema(methods=['PUT'], summary='Method performs an update of all annotations in a specific job', @extend_schema(methods=['PUT'], summary='Method performs an update of all annotations in a specific job',
request=AnnotationFileSerializer, responses={ request=AnnotationFileSerializer, responses={
@ -1126,10 +1283,24 @@ class JobViewSet(viewsets.GenericViewSet, mixins.ListModelMixin,
def annotations(self, request, pk): def annotations(self, request, pk):
self._object = self.get_object() # force to call check_object_permissions self._object = self.get_object() # force to call check_object_permissions
if request.method == 'GET': if request.method == 'GET':
data = dm.task.get_job_data(pk) return self.export_annotations(
return Response(data) request=request,
pk=pk,
db_obj=self._object.segment.task,
export_func=_export_annotations,
callback=dm.views.export_job_annotations,
get_data=dm.task.get_job_data,
)
elif request.method == 'POST' or request.method == 'OPTIONS': elif request.method == 'POST' or request.method == 'OPTIONS':
return self.upload_data(request) return self.import_annotations(
request=request,
pk=pk,
db_obj=self._object.segment.task,
import_func=_import_annotations,
rq_func=dm.task.import_job_annotations,
)
elif request.method == 'PUT': elif request.method == 'PUT':
format_name = request.query_params.get('format', '') format_name = request.query_params.get('format', '')
if format_name: if format_name:
@ -1169,6 +1340,44 @@ class JobViewSet(viewsets.GenericViewSet, mixins.ListModelMixin,
self._object = self.get_object() self._object = self.get_object()
return self.append_tus_chunk(request, file_id) return self.append_tus_chunk(request, file_id)
@extend_schema(summary='Export job as a dataset in a specific format',
parameters=[
OpenApiParameter('format', location=OpenApiParameter.QUERY,
description='Desired output format name\nYou can get the list of supported formats at:\n/server/annotation/formats',
type=OpenApiTypes.STR, required=True),
OpenApiParameter('filename', description='Desired output file name',
location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False),
OpenApiParameter('action', location=OpenApiParameter.QUERY,
description='Used to start downloading process after annotation file had been created',
type=OpenApiTypes.STR, required=False, enum=['download']),
OpenApiParameter('use_default_location', description='Use the location that was configured in the task to export dataset',
location=OpenApiParameter.QUERY, type=OpenApiTypes.BOOL, required=False,
default=True),
OpenApiParameter('location', description='Where need to save downloaded dataset',
location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False,
enum=Location.list()),
OpenApiParameter('cloud_storage_id', description='Storage id',
location=OpenApiParameter.QUERY, type=OpenApiTypes.NUMBER, required=False),
],
responses={
'200': OpenApiResponse(description='Download of file started'),
'201': OpenApiResponse(description='Output file is ready for downloading'),
'202': OpenApiResponse(description='Exporting has been started'),
'405': OpenApiResponse(description='Format is not available'),
})
@action(detail=True, methods=['GET'], serializer_class=None,
url_path='dataset')
def dataset_export(self, request, pk):
self._object = self.get_object() # force to call check_object_permissions
return self.export_annotations(
request=request,
pk=pk,
db_obj=self._object.segment.task,
export_func=_export_annotations,
callback=dm.views.export_job_as_dataset
)
@extend_schema( @extend_schema(
summary='Method returns list of issues for the job', summary='Method returns list of issues for the job',
responses={ responses={
@ -1589,12 +1798,12 @@ class CloudStorageViewSet(viewsets.ModelViewSet):
except IntegrityError: except IntegrityError:
response = HttpResponseBadRequest('Same storage already exists') response = HttpResponseBadRequest('Same storage already exists')
except ValidationError as exceptions: except ValidationError as exceptions:
msg_body = "" msg_body = ""
for ex in exceptions.args: for ex in exceptions.args:
for field, ex_msg in ex.items(): for field, ex_msg in ex.items():
msg_body += ': '.join([field, ex_msg if isinstance(ex_msg, str) else str(ex_msg[0])]) msg_body += ': '.join([field, ex_msg if isinstance(ex_msg, str) else str(ex_msg[0])])
msg_body += '\n' msg_body += '\n'
return HttpResponseBadRequest(msg_body) return HttpResponseBadRequest(msg_body)
except APIException as ex: except APIException as ex:
return Response(data=ex.get_full_details(), status=ex.status_code) return Response(data=ex.get_full_details(), status=ex.status_code)
except Exception as ex: except Exception as ex:
@ -1614,17 +1823,7 @@ class CloudStorageViewSet(viewsets.ModelViewSet):
storage = None storage = None
try: try:
db_storage = self.get_object() db_storage = self.get_object()
credentials = Credentials() storage = db_storage_to_storage_instance(db_storage)
credentials.convert_from_db({
'type': db_storage.credentials_type,
'value': db_storage.credentials,
})
details = {
'resource': db_storage.resource,
'credentials': credentials,
'specific_attributes': db_storage.get_specific_attributes()
}
storage = get_cloud_storage_instance(cloud_provider=db_storage.provider_type, **details)
if not db_storage.manifests.count(): if not db_storage.manifests.count():
raise Exception('There is no manifest file') raise Exception('There is no manifest file')
manifest_path = request.query_params.get('manifest_path', db_storage.manifests.first().filename) manifest_path = request.query_params.get('manifest_path', db_storage.manifests.first().filename)
@ -1675,17 +1874,7 @@ class CloudStorageViewSet(viewsets.ModelViewSet):
try: try:
db_storage = self.get_object() db_storage = self.get_object()
if not os.path.exists(db_storage.get_preview_path()): if not os.path.exists(db_storage.get_preview_path()):
credentials = Credentials() storage = db_storage_to_storage_instance(db_storage)
credentials.convert_from_db({
'type': db_storage.credentials_type,
'value': db_storage.credentials,
})
details = {
'resource': db_storage.resource,
'credentials': credentials,
'specific_attributes': db_storage.get_specific_attributes()
}
storage = get_cloud_storage_instance(cloud_provider=db_storage.provider_type, **details)
if not db_storage.manifests.count(): if not db_storage.manifests.count():
raise Exception('Cannot get the cloud storage preview. There is no manifest file') raise Exception('Cannot get the cloud storage preview. There is no manifest file')
preview_path = None preview_path = None
@ -1749,17 +1938,7 @@ class CloudStorageViewSet(viewsets.ModelViewSet):
def status(self, request, pk): def status(self, request, pk):
try: try:
db_storage = self.get_object() db_storage = self.get_object()
credentials = Credentials() storage = db_storage_to_storage_instance(db_storage)
credentials.convert_from_db({
'type': db_storage.credentials_type,
'value': db_storage.credentials,
})
details = {
'resource': db_storage.resource,
'credentials': credentials,
'specific_attributes': db_storage.get_specific_attributes()
}
storage = get_cloud_storage_instance(cloud_provider=db_storage.provider_type, **details)
storage_status = storage.get_status() storage_status = storage.get_status()
return HttpResponse(storage_status) return HttpResponse(storage_status)
except CloudStorageModel.DoesNotExist: except CloudStorageModel.DoesNotExist:
@ -1770,6 +1949,28 @@ class CloudStorageViewSet(viewsets.ModelViewSet):
msg = str(ex) msg = str(ex)
return HttpResponseBadRequest(msg) return HttpResponseBadRequest(msg)
@extend_schema(summary='Method returns allowed actions for the cloud storage',
responses={
'200': OpenApiResponse(response=OpenApiTypes.STR, description='Cloud Storage actions (GET | PUT | DELETE)'),
})
@action(detail=True, methods=['GET'], url_path='actions')
def actions(self, request, pk):
'''
Method return allowed actions for cloud storage. It's required for reading/writing
'''
try:
db_storage = self.get_object()
storage = db_storage_to_storage_instance(db_storage)
actions = storage.supported_actions
return Response(actions, content_type="text/plain")
except CloudStorageModel.DoesNotExist:
message = f"Storage {pk} does not exist"
slogger.glob.error(message)
return HttpResponseNotFound(message)
except Exception as ex:
msg = str(ex)
return HttpResponseBadRequest(msg)
def rq_handler(job, exc_type, exc_value, tb): def rq_handler(job, exc_type, exc_value, tb):
job.exc_info = "".join( job.exc_info = "".join(
traceback.format_exception_only(exc_type, exc_value)) traceback.format_exception_only(exc_type, exc_value))
@ -1779,7 +1980,16 @@ def rq_handler(job, exc_type, exc_value, tb):
return True return True
def _import_annotations(request, rq_id, rq_func, pk, format_name, filename=None): @validate_bucket_status
def _export_to_cloud_storage(storage, file_path, file_name):
storage.upload_file(file_path, file_name)
@validate_bucket_status
def _import_from_cloud_storage(storage, file_name):
return storage.download_fileobj(file_name)
def _import_annotations(request, rq_id, rq_func, pk, format_name,
filename=None, location_conf=None):
format_desc = {f.DISPLAY_NAME: f format_desc = {f.DISPLAY_NAME: f
for f in dm.views.get_import_formats()}.get(format_name) for f in dm.views.get_import_formats()}.get(format_name)
if format_desc is None: if format_desc is None:
@ -1794,15 +2004,36 @@ def _import_annotations(request, rq_id, rq_func, pk, format_name, filename=None)
if not rq_job: if not rq_job:
# If filename is specified we consider that file was uploaded via TUS, so it exists in filesystem # If filename is specified we consider that file was uploaded via TUS, so it exists in filesystem
# Then we dont need to create temporary file # Then we dont need to create temporary file
# Or filename specify key in cloud storage so we need to download file
fd = None fd = None
if not filename: location = location_conf.get('location') if location_conf else Location.LOCAL
serializer = AnnotationFileSerializer(data=request.data)
if serializer.is_valid(raise_exception=True): if not filename or location == Location.CLOUD_STORAGE:
anno_file = serializer.validated_data['annotation_file'] if location != Location.CLOUD_STORAGE:
fd, filename = mkstemp(prefix='cvat_{}'.format(pk)) serializer = AnnotationFileSerializer(data=request.data)
if serializer.is_valid(raise_exception=True):
anno_file = serializer.validated_data['annotation_file']
fd, filename = mkstemp(prefix='cvat_{}'.format(pk))
with open(filename, 'wb+') as f:
for chunk in anno_file.chunks():
f.write(chunk)
else:
# download annotation file from cloud storage
try:
storage_id = location_conf['storage_id']
except KeyError:
raise serializer.ValidationError(
'Cloud storage location was selected for destination'
' but cloud storage id was not specified')
db_storage = get_object_or_404(CloudStorageModel, pk=storage_id)
storage = db_storage_to_storage_instance(db_storage)
assert filename, 'filename was not spesified'
data = _import_from_cloud_storage(storage, filename)
fd, filename = mkstemp(prefix='cvat_')
with open(filename, 'wb+') as f: with open(filename, 'wb+') as f:
for chunk in anno_file.chunks(): f.write(data.getbuffer())
f.write(chunk)
av_scan_paths(filename) av_scan_paths(filename)
rq_job = queue.enqueue_call( rq_job = queue.enqueue_call(
@ -1838,7 +2069,8 @@ def _import_annotations(request, rq_id, rq_func, pk, format_name, filename=None)
return Response(status=status.HTTP_202_ACCEPTED) return Response(status=status.HTTP_202_ACCEPTED)
def _export_annotations(db_instance, rq_id, request, format_name, action, callback, filename): def _export_annotations(db_instance, rq_id, request, format_name, action, callback,
filename, location_conf):
if action not in {"", "download"}: if action not in {"", "download"}:
raise serializers.ValidationError( raise serializers.ValidationError(
"Unexpected action specified for the request") "Unexpected action specified for the request")
@ -1873,12 +2105,31 @@ def _export_annotations(db_instance, rq_id, request, format_name, action, callba
"%Y_%m_%d_%H_%M_%S") "%Y_%m_%d_%H_%M_%S")
filename = filename or \ filename = filename or \
"{}_{}-{}-{}{}".format( "{}_{}-{}-{}{}".format(
"project" if isinstance(db_instance, models.Project) else "task", db_instance.__class__.__name__.lower(),
db_instance.name, timestamp, db_instance.name if isinstance(db_instance, (Task, Project)) else db_instance.id,
format_name, osp.splitext(file_path)[1] timestamp, format_name, osp.splitext(file_path)[1]
) )
return sendfile(request, file_path, attachment=True,
attachment_filename=filename.lower()) # save annotation to specified location
location = location_conf.get('location')
if location == Location.LOCAL:
return sendfile(request, file_path, attachment=True,
attachment_filename=filename.lower())
elif location == Location.CLOUD_STORAGE:
try:
storage_id = location_conf['storage_id']
except KeyError:
return HttpResponseBadRequest(
'Cloud storage location was selected for destination'
' but cloud storage id was not specified')
db_storage = get_object_or_404(CloudStorageModel, pk=storage_id)
storage = db_storage_to_storage_instance(db_storage)
_export_to_cloud_storage(storage, file_path, filename)
return Response(status=status.HTTP_200_OK)
else:
raise NotImplementedError()
else: else:
if osp.exists(file_path): if osp.exists(file_path):
return Response(status=status.HTTP_201_CREATED) return Response(status=status.HTTP_201_CREATED)
@ -1897,14 +2148,19 @@ def _export_annotations(db_instance, rq_id, request, format_name, action, callba
except Exception: except Exception:
server_address = None server_address = None
ttl = (dm.views.PROJECT_CACHE_TTL if isinstance(db_instance, Project) else dm.views.TASK_CACHE_TTL).total_seconds() TTL_CONSTS = {
'project': dm.views.PROJECT_CACHE_TTL,
'task': dm.views.TASK_CACHE_TTL,
'job': dm.views.JOB_CACHE_TTL,
}
ttl = TTL_CONSTS[db_instance.__class__.__name__.lower()].total_seconds()
queue.enqueue_call(func=callback, queue.enqueue_call(func=callback,
args=(db_instance.id, format_name, server_address), job_id=rq_id, args=(db_instance.id, format_name, server_address), job_id=rq_id,
meta={ 'request_time': timezone.localtime() }, meta={ 'request_time': timezone.localtime() },
result_ttl=ttl, failure_ttl=ttl) result_ttl=ttl, failure_ttl=ttl)
return Response(status=status.HTTP_202_ACCEPTED) return Response(status=status.HTTP_202_ACCEPTED)
def _import_project_dataset(request, rq_id, rq_func, pk, format_name, filename=None): def _import_project_dataset(request, rq_id, rq_func, pk, format_name, filename=None, location_conf=None):
format_desc = {f.DISPLAY_NAME: f format_desc = {f.DISPLAY_NAME: f
for f in dm.views.get_import_formats()}.get(format_name) for f in dm.views.get_import_formats()}.get(format_name)
if format_desc is None: if format_desc is None:
@ -1918,7 +2174,8 @@ def _import_project_dataset(request, rq_id, rq_func, pk, format_name, filename=N
if not rq_job: if not rq_job:
fd = None fd = None
if not filename: location = location_conf.get('location') if location_conf else None
if not filename and location != Location.CLOUD_STORAGE:
serializer = DatasetFileSerializer(data=request.data) serializer = DatasetFileSerializer(data=request.data)
if serializer.is_valid(raise_exception=True): if serializer.is_valid(raise_exception=True):
dataset_file = serializer.validated_data['dataset_file'] dataset_file = serializer.validated_data['dataset_file']
@ -1926,6 +2183,24 @@ def _import_project_dataset(request, rq_id, rq_func, pk, format_name, filename=N
with open(filename, 'wb+') as f: with open(filename, 'wb+') as f:
for chunk in dataset_file.chunks(): for chunk in dataset_file.chunks():
f.write(chunk) f.write(chunk)
elif location == Location.CLOUD_STORAGE:
assert filename
# download project file from cloud storage
try:
storage_id = location_conf['storage_id']
except KeyError:
raise serializers.ValidationError(
'Cloud storage location was selected for destination'
' but cloud storage id was not specified')
db_storage = get_object_or_404(CloudStorageModel, pk=storage_id)
storage = db_storage_to_storage_instance(db_storage)
data = _import_from_cloud_storage(storage, filename)
fd, filename = mkstemp(prefix='cvat_')
with open(filename, 'wb+') as f:
f.write(data.getbuffer())
rq_job = queue.enqueue_call( rq_job = queue.enqueue_call(
func=rq_func, func=rq_func,

@ -63,7 +63,7 @@ class OpenPolicyAgentPermission(metaclass=ABCMeta):
'user': { 'user': {
'role': self.org_role, 'role': self.org_role,
}, },
} if self.org_id != None else None } if self.org_id is not None else None
} }
} }
} }
@ -210,7 +210,7 @@ class InvitationPermission(OpenPolicyAgentPermission):
'role': self.role, 'role': self.role,
'organization': { 'organization': {
'id': self.org_id 'id': self.org_id
} if self.org_id != None else None } if self.org_id is not None else None
} }
return data return data
@ -417,7 +417,8 @@ class CloudStoragePermission(OpenPolicyAgentPermission):
'destroy': 'delete', 'destroy': 'delete',
'content': 'list:content', 'content': 'list:content',
'preview': 'view', 'preview': 'view',
'status': 'view' 'status': 'view',
'actions': 'view',
}.get(view.action)] }.get(view.action)]
def get_resource(self): def get_resource(self):
@ -427,7 +428,7 @@ class CloudStoragePermission(OpenPolicyAgentPermission):
'owner': { 'id': self.user_id }, 'owner': { 'id': self.user_id },
'organization': { 'organization': {
'id': self.org_id 'id': self.org_id
} if self.org_id != None else None, } if self.org_id is not None else None,
'user': { 'user': {
'num_resources': Organization.objects.filter( 'num_resources': Organization.objects.filter(
owner=self.user_id).count() owner=self.user_id).count()
@ -620,9 +621,9 @@ class TaskPermission(OpenPolicyAgentPermission):
perm = TaskPermission.create_scope_create(request, org_id) perm = TaskPermission.create_scope_create(request, org_id)
# We don't create a project, just move it. Thus need to decrease # We don't create a project, just move it. Thus need to decrease
# the number of resources. # the number of resources.
if obj != None: if obj is not None:
perm.payload['input']['resource']['user']['num_resources'] -= 1 perm.payload['input']['resource']['user']['num_resources'] -= 1
if obj.project != None: if obj.project is not None:
ValidationError('Cannot change the organization for ' ValidationError('Cannot change the organization for '
'a task inside a project') 'a task inside a project')
permissions.append(perm) permissions.append(perm)

@ -352,6 +352,9 @@ os.makedirs(MEDIA_DATA_ROOT, exist_ok=True)
CACHE_ROOT = os.path.join(DATA_ROOT, 'cache') CACHE_ROOT = os.path.join(DATA_ROOT, 'cache')
os.makedirs(CACHE_ROOT, exist_ok=True) os.makedirs(CACHE_ROOT, exist_ok=True)
JOBS_ROOT = os.path.join(DATA_ROOT, 'jobs')
os.makedirs(JOBS_ROOT, exist_ok=True)
TASKS_ROOT = os.path.join(DATA_ROOT, 'tasks') TASKS_ROOT = os.path.join(DATA_ROOT, 'tasks')
os.makedirs(TASKS_ROOT, exist_ok=True) os.makedirs(TASKS_ROOT, exist_ok=True)

@ -22,6 +22,9 @@ os.makedirs(MEDIA_DATA_ROOT, exist_ok=True)
CACHE_ROOT = os.path.join(DATA_ROOT, 'cache') CACHE_ROOT = os.path.join(DATA_ROOT, 'cache')
os.makedirs(CACHE_ROOT, exist_ok=True) os.makedirs(CACHE_ROOT, exist_ok=True)
JOBS_ROOT = os.path.join(DATA_ROOT, 'jobs')
os.makedirs(JOBS_ROOT, exist_ok=True)
TASKS_ROOT = os.path.join(DATA_ROOT, 'tasks') TASKS_ROOT = os.path.join(DATA_ROOT, 'tasks')
os.makedirs(TASKS_ROOT, exist_ok=True) os.makedirs(TASKS_ROOT, exist_ok=True)

@ -1413,7 +1413,7 @@
"pk": 1, "pk": 1,
"fields": { "fields": {
"password": "pbkdf2_sha256$260000$DevmxlmLwciP1P6sZs2Qag$U9DFtjTWx96Sk95qY6UXVcvpdQEP2LcoFBftk5D2RKY=", "password": "pbkdf2_sha256$260000$DevmxlmLwciP1P6sZs2Qag$U9DFtjTWx96Sk95qY6UXVcvpdQEP2LcoFBftk5D2RKY=",
"last_login": "2022-06-08T08:32:30.152Z", "last_login": "2022-06-22T09:20:25.189Z",
"is_superuser": true, "is_superuser": true,
"username": "admin1", "username": "admin1",
"first_name": "Admin", "first_name": "Admin",
@ -2184,6 +2184,14 @@
"model": "gitdata" "model": "gitdata"
} }
}, },
{
"model": "contenttypes.contenttype",
"pk": 49,
"fields": {
"app_label": "engine",
"model": "storage"
}
},
{ {
"model": "sessions.session", "model": "sessions.session",
"pk": "5x9v6r58e4l9if78anupog0ittsq2w3j", "pk": "5x9v6r58e4l9if78anupog0ittsq2w3j",
@ -3688,7 +3696,9 @@
"created_date": "2021-12-14T19:46:37.969Z", "created_date": "2021-12-14T19:46:37.969Z",
"updated_date": "2022-03-05T09:47:49.679Z", "updated_date": "2022-03-05T09:47:49.679Z",
"status": "annotation", "status": "annotation",
"organization": null "organization": null,
"source_storage": null,
"target_storage": null
} }
}, },
{ {
@ -3702,7 +3712,9 @@
"created_date": "2021-12-14T19:52:37.278Z", "created_date": "2021-12-14T19:52:37.278Z",
"updated_date": "2022-03-28T13:04:54.669Z", "updated_date": "2022-03-28T13:04:54.669Z",
"status": "annotation", "status": "annotation",
"organization": 2 "organization": 2,
"source_storage": null,
"target_storage": null
} }
}, },
{ {
@ -3716,7 +3728,9 @@
"created_date": "2022-03-28T13:05:24.659Z", "created_date": "2022-03-28T13:05:24.659Z",
"updated_date": "2022-03-28T13:06:09.283Z", "updated_date": "2022-03-28T13:06:09.283Z",
"status": "annotation", "status": "annotation",
"organization": 2 "organization": 2,
"source_storage": null,
"target_storage": null
} }
}, },
{ {
@ -3730,7 +3744,9 @@
"created_date": "2022-06-08T08:32:45.521Z", "created_date": "2022-06-08T08:32:45.521Z",
"updated_date": "2022-06-08T08:33:20.759Z", "updated_date": "2022-06-08T08:33:20.759Z",
"status": "annotation", "status": "annotation",
"organization": 2 "organization": 2,
"source_storage": null,
"target_storage": null
} }
}, },
{ {
@ -3751,7 +3767,9 @@
"data": 2, "data": 2,
"dimension": "2d", "dimension": "2d",
"subset": "", "subset": "",
"organization": 1 "organization": 1,
"source_storage": null,
"target_storage": null
} }
}, },
{ {
@ -3772,7 +3790,9 @@
"data": 5, "data": 5,
"dimension": "2d", "dimension": "2d",
"subset": "", "subset": "",
"organization": null "organization": null,
"source_storage": null,
"target_storage": null
} }
}, },
{ {
@ -3793,7 +3813,9 @@
"data": 6, "data": 6,
"dimension": "3d", "dimension": "3d",
"subset": "", "subset": "",
"organization": null "organization": null,
"source_storage": null,
"target_storage": null
} }
}, },
{ {
@ -3814,7 +3836,9 @@
"data": 7, "data": 7,
"dimension": "2d", "dimension": "2d",
"subset": "", "subset": "",
"organization": 2 "organization": 2,
"source_storage": null,
"target_storage": null
} }
}, },
{ {
@ -3835,7 +3859,9 @@
"data": 8, "data": 8,
"dimension": "2d", "dimension": "2d",
"subset": "", "subset": "",
"organization": null "organization": null,
"source_storage": null,
"target_storage": null
} }
}, },
{ {
@ -3856,7 +3882,9 @@
"data": 9, "data": 9,
"dimension": "2d", "dimension": "2d",
"subset": "", "subset": "",
"organization": null "organization": null,
"source_storage": null,
"target_storage": null
} }
}, },
{ {
@ -3877,7 +3905,9 @@
"data": 11, "data": 11,
"dimension": "2d", "dimension": "2d",
"subset": "Train", "subset": "Train",
"organization": 2 "organization": 2,
"source_storage": null,
"target_storage": null
} }
}, },
{ {
@ -3898,7 +3928,9 @@
"data": null, "data": null,
"dimension": "2d", "dimension": "2d",
"subset": "", "subset": "",
"organization": null "organization": null,
"source_storage": null,
"target_storage": null
} }
}, },
{ {
@ -3919,7 +3951,9 @@
"data": 12, "data": 12,
"dimension": "2d", "dimension": "2d",
"subset": "", "subset": "",
"organization": 2 "organization": 2,
"source_storage": null,
"target_storage": null
} }
}, },
{ {
@ -4724,6 +4758,7 @@
"fields": { "fields": {
"segment": 2, "segment": 2,
"assignee": 6, "assignee": 6,
"updated_date": "2022-06-22T09:18:45.296Z",
"status": "annotation", "status": "annotation",
"stage": "annotation", "stage": "annotation",
"state": "new" "state": "new"
@ -4735,6 +4770,7 @@
"fields": { "fields": {
"segment": 7, "segment": 7,
"assignee": 9, "assignee": 9,
"updated_date": "2022-06-22T09:18:45.296Z",
"status": "annotation", "status": "annotation",
"stage": "annotation", "stage": "annotation",
"state": "in progress" "state": "in progress"
@ -4746,6 +4782,7 @@
"fields": { "fields": {
"segment": 8, "segment": 8,
"assignee": null, "assignee": null,
"updated_date": "2022-06-22T09:18:45.296Z",
"status": "annotation", "status": "annotation",
"stage": "annotation", "stage": "annotation",
"state": "new" "state": "new"
@ -4757,6 +4794,7 @@
"fields": { "fields": {
"segment": 9, "segment": 9,
"assignee": null, "assignee": null,
"updated_date": "2022-06-22T09:18:45.296Z",
"status": "annotation", "status": "annotation",
"stage": "annotation", "stage": "annotation",
"state": "in progress" "state": "in progress"
@ -4768,6 +4806,7 @@
"fields": { "fields": {
"segment": 10, "segment": 10,
"assignee": 1, "assignee": 1,
"updated_date": "2022-06-22T09:18:45.296Z",
"status": "annotation", "status": "annotation",
"stage": "annotation", "stage": "annotation",
"state": "in progress" "state": "in progress"
@ -4779,6 +4818,7 @@
"fields": { "fields": {
"segment": 11, "segment": 11,
"assignee": 9, "assignee": 9,
"updated_date": "2022-06-22T09:18:45.296Z",
"status": "annotation", "status": "annotation",
"stage": "annotation", "stage": "annotation",
"state": "in progress" "state": "in progress"
@ -4790,6 +4830,7 @@
"fields": { "fields": {
"segment": 12, "segment": 12,
"assignee": null, "assignee": null,
"updated_date": "2022-06-22T09:18:45.296Z",
"status": "validation", "status": "validation",
"stage": "validation", "stage": "validation",
"state": "new" "state": "new"
@ -4801,6 +4842,7 @@
"fields": { "fields": {
"segment": 13, "segment": 13,
"assignee": null, "assignee": null,
"updated_date": "2022-06-22T09:18:45.296Z",
"status": "validation", "status": "validation",
"stage": "acceptance", "stage": "acceptance",
"state": "new" "state": "new"
@ -4812,6 +4854,7 @@
"fields": { "fields": {
"segment": 14, "segment": 14,
"assignee": null, "assignee": null,
"updated_date": "2022-06-22T09:18:45.296Z",
"status": "annotation", "status": "annotation",
"stage": "annotation", "stage": "annotation",
"state": "new" "state": "new"
@ -4823,6 +4866,7 @@
"fields": { "fields": {
"segment": 16, "segment": 16,
"assignee": 7, "assignee": 7,
"updated_date": "2022-06-22T09:18:45.296Z",
"status": "annotation", "status": "annotation",
"stage": "annotation", "stage": "annotation",
"state": "in progress" "state": "in progress"
@ -4834,6 +4878,7 @@
"fields": { "fields": {
"segment": 17, "segment": 17,
"assignee": null, "assignee": null,
"updated_date": "2022-06-22T09:18:45.296Z",
"status": "annotation", "status": "annotation",
"stage": "annotation", "stage": "annotation",
"state": "in progress" "state": "in progress"

@ -32,6 +32,7 @@
"status": "annotation", "status": "annotation",
"stop_frame": 4, "stop_frame": 4,
"task_id": 13, "task_id": 13,
"updated_date": "2022-06-22T09:18:45.296000Z",
"url": "http://localhost:8080/api/jobs/17" "url": "http://localhost:8080/api/jobs/17"
}, },
{ {
@ -69,6 +70,7 @@
"status": "annotation", "status": "annotation",
"stop_frame": 10, "stop_frame": 10,
"task_id": 11, "task_id": 11,
"updated_date": "2022-06-22T09:18:45.296000Z",
"url": "http://localhost:8080/api/jobs/16" "url": "http://localhost:8080/api/jobs/16"
}, },
{ {
@ -113,6 +115,7 @@
"status": "annotation", "status": "annotation",
"stop_frame": 19, "stop_frame": 19,
"task_id": 9, "task_id": 9,
"updated_date": "2022-06-22T09:18:45.296000Z",
"url": "http://localhost:8080/api/jobs/14" "url": "http://localhost:8080/api/jobs/14"
}, },
{ {
@ -157,6 +160,7 @@
"status": "validation", "status": "validation",
"stop_frame": 14, "stop_frame": 14,
"task_id": 9, "task_id": 9,
"updated_date": "2022-06-22T09:18:45.296000Z",
"url": "http://localhost:8080/api/jobs/13" "url": "http://localhost:8080/api/jobs/13"
}, },
{ {
@ -201,6 +205,7 @@
"status": "validation", "status": "validation",
"stop_frame": 9, "stop_frame": 9,
"task_id": 9, "task_id": 9,
"updated_date": "2022-06-22T09:18:45.296000Z",
"url": "http://localhost:8080/api/jobs/12" "url": "http://localhost:8080/api/jobs/12"
}, },
{ {
@ -251,6 +256,7 @@
"status": "annotation", "status": "annotation",
"stop_frame": 4, "stop_frame": 4,
"task_id": 9, "task_id": 9,
"updated_date": "2022-06-22T09:18:45.296000Z",
"url": "http://localhost:8080/api/jobs/11" "url": "http://localhost:8080/api/jobs/11"
}, },
{ {
@ -288,6 +294,7 @@
"status": "annotation", "status": "annotation",
"stop_frame": 13, "stop_frame": 13,
"task_id": 8, "task_id": 8,
"updated_date": "2022-06-22T09:18:45.296000Z",
"url": "http://localhost:8080/api/jobs/10" "url": "http://localhost:8080/api/jobs/10"
}, },
{ {
@ -319,6 +326,7 @@
"status": "annotation", "status": "annotation",
"stop_frame": 10, "stop_frame": 10,
"task_id": 7, "task_id": 7,
"updated_date": "2022-06-22T09:18:45.296000Z",
"url": "http://localhost:8080/api/jobs/9" "url": "http://localhost:8080/api/jobs/9"
}, },
{ {
@ -344,6 +352,7 @@
"status": "annotation", "status": "annotation",
"stop_frame": 0, "stop_frame": 0,
"task_id": 6, "task_id": 6,
"updated_date": "2022-06-22T09:18:45.296000Z",
"url": "http://localhost:8080/api/jobs/8" "url": "http://localhost:8080/api/jobs/8"
}, },
{ {
@ -375,6 +384,7 @@
"status": "annotation", "status": "annotation",
"stop_frame": 24, "stop_frame": 24,
"task_id": 5, "task_id": 5,
"updated_date": "2022-06-22T09:18:45.296000Z",
"url": "http://localhost:8080/api/jobs/7" "url": "http://localhost:8080/api/jobs/7"
}, },
{ {
@ -412,6 +422,7 @@
"status": "annotation", "status": "annotation",
"stop_frame": 22, "stop_frame": 22,
"task_id": 2, "task_id": 2,
"updated_date": "2022-06-22T09:18:45.296000Z",
"url": "http://localhost:8080/api/jobs/2" "url": "http://localhost:8080/api/jobs/2"
} }
] ]

@ -32,7 +32,9 @@
"url": "http://localhost:8080/api/users/1", "url": "http://localhost:8080/api/users/1",
"username": "admin1" "username": "admin1"
}, },
"source_storage": null,
"status": "annotation", "status": "annotation",
"target_storage": null,
"task_subsets": [], "task_subsets": [],
"tasks": [ "tasks": [
13 13
@ -62,7 +64,9 @@
"url": "http://localhost:8080/api/users/3", "url": "http://localhost:8080/api/users/3",
"username": "user2" "username": "user2"
}, },
"source_storage": null,
"status": "annotation", "status": "annotation",
"target_storage": null,
"task_subsets": [], "task_subsets": [],
"tasks": [], "tasks": [],
"updated_date": "2022-03-28T13:06:09.283000Z", "updated_date": "2022-03-28T13:06:09.283000Z",
@ -103,7 +107,9 @@
"url": "http://localhost:8080/api/users/10", "url": "http://localhost:8080/api/users/10",
"username": "business1" "username": "business1"
}, },
"source_storage": null,
"status": "annotation", "status": "annotation",
"target_storage": null,
"task_subsets": [ "task_subsets": [
"Train" "Train"
], ],
@ -161,7 +167,9 @@
"url": "http://localhost:8080/api/users/10", "url": "http://localhost:8080/api/users/10",
"username": "business1" "username": "business1"
}, },
"source_storage": null,
"status": "annotation", "status": "annotation",
"target_storage": null,
"task_subsets": [], "task_subsets": [],
"tasks": [ "tasks": [
9 9

@ -58,8 +58,10 @@
} }
], ],
"size": 5, "size": 5,
"source_storage": null,
"status": "annotation", "status": "annotation",
"subset": "", "subset": "",
"target_storage": null,
"updated_date": "2022-06-08T08:33:20.808000Z", "updated_date": "2022-06-08T08:33:20.808000Z",
"url": "http://localhost:8080/api/tasks/13" "url": "http://localhost:8080/api/tasks/13"
}, },
@ -91,8 +93,10 @@
"project_id": null, "project_id": null,
"segment_size": 0, "segment_size": 0,
"segments": [], "segments": [],
"source_storage": null,
"status": "annotation", "status": "annotation",
"subset": "", "subset": "",
"target_storage": null,
"updated_date": "2022-03-14T13:24:05.861000Z", "updated_date": "2022-03-14T13:24:05.861000Z",
"url": "http://localhost:8080/api/tasks/12" "url": "http://localhost:8080/api/tasks/12"
}, },
@ -163,8 +167,10 @@
} }
], ],
"size": 11, "size": 11,
"source_storage": null,
"status": "annotation", "status": "annotation",
"subset": "Train", "subset": "Train",
"target_storage": null,
"updated_date": "2022-03-05T10:32:35.568000Z", "updated_date": "2022-03-05T10:32:35.568000Z",
"url": "http://localhost:8080/api/tasks/11" "url": "http://localhost:8080/api/tasks/11"
}, },
@ -290,8 +296,10 @@
} }
], ],
"size": 20, "size": 20,
"source_storage": null,
"status": "annotation", "status": "annotation",
"subset": "", "subset": "",
"target_storage": null,
"updated_date": "2022-03-05T09:47:49.667000Z", "updated_date": "2022-03-05T09:47:49.667000Z",
"url": "http://localhost:8080/api/tasks/9" "url": "http://localhost:8080/api/tasks/9"
}, },
@ -362,8 +370,10 @@
} }
], ],
"size": 14, "size": 14,
"source_storage": null,
"status": "annotation", "status": "annotation",
"subset": "", "subset": "",
"target_storage": null,
"updated_date": "2022-03-05T08:52:34.908000Z", "updated_date": "2022-03-05T08:52:34.908000Z",
"url": "http://localhost:8080/api/tasks/8" "url": "http://localhost:8080/api/tasks/8"
}, },
@ -428,8 +438,10 @@
} }
], ],
"size": 11, "size": 11,
"source_storage": null,
"status": "annotation", "status": "annotation",
"subset": "", "subset": "",
"target_storage": null,
"updated_date": "2022-02-21T10:41:38.540000Z", "updated_date": "2022-02-21T10:41:38.540000Z",
"url": "http://localhost:8080/api/tasks/7" "url": "http://localhost:8080/api/tasks/7"
}, },
@ -482,8 +494,10 @@
} }
], ],
"size": 1, "size": 1,
"source_storage": null,
"status": "annotation", "status": "annotation",
"subset": "", "subset": "",
"target_storage": null,
"updated_date": "2022-02-16T06:26:54.836000Z", "updated_date": "2022-02-16T06:26:54.836000Z",
"url": "http://localhost:8080/api/tasks/6" "url": "http://localhost:8080/api/tasks/6"
}, },
@ -548,8 +562,10 @@
} }
], ],
"size": 25, "size": 25,
"source_storage": null,
"status": "annotation", "status": "annotation",
"subset": "", "subset": "",
"target_storage": null,
"updated_date": "2022-02-21T10:40:21.257000Z", "updated_date": "2022-02-21T10:40:21.257000Z",
"url": "http://localhost:8080/api/tasks/5" "url": "http://localhost:8080/api/tasks/5"
}, },
@ -620,8 +636,10 @@
} }
], ],
"size": 23, "size": 23,
"source_storage": null,
"status": "annotation", "status": "annotation",
"subset": "", "subset": "",
"target_storage": null,
"updated_date": "2021-12-22T07:14:15.234000Z", "updated_date": "2021-12-22T07:14:15.234000Z",
"url": "http://localhost:8080/api/tasks/2" "url": "http://localhost:8080/api/tasks/2"
} }

@ -310,7 +310,7 @@
"is_active": true, "is_active": true,
"is_staff": true, "is_staff": true,
"is_superuser": true, "is_superuser": true,
"last_login": "2022-06-08T08:32:30.152708Z", "last_login": "2022-06-22T09:20:25.189000Z",
"last_name": "First", "last_name": "First",
"url": "http://localhost:8080/api/users/1", "url": "http://localhost:8080/api/users/1",
"username": "admin1" "username": "admin1"

@ -180,7 +180,7 @@ class TestPatchCloudStorage:
('maintainer', False, True), ('maintainer', False, True),
('supervisor', False, False), ('supervisor', False, False),
]) ])
def test_org_user_update_coud_storage(self, org_id, storage_id, role, is_owner, is_allow, find_users, cloud_storages): def test_org_user_update_cloud_storage(self, org_id, storage_id, role, is_owner, is_allow, find_users, cloud_storages):
cloud_storage = cloud_storages[storage_id] cloud_storage = cloud_storages[storage_id]
username = cloud_storage['owner']['username'] if is_owner else \ username = cloud_storage['owner']['username'] if is_owner else \
next((u for u in find_users(role=role, org=org_id) if u['id'] != cloud_storage['owner']['id']))['username'] next((u for u in find_users(role=role, org=org_id) if u['id'] != cloud_storage['owner']['id']))['username']

@ -44,7 +44,7 @@ class TestGetJobs:
response = get_method(user, f'jobs/{jid}', **kwargs) response = get_method(user, f'jobs/{jid}', **kwargs)
assert response.status_code == HTTPStatus.OK assert response.status_code == HTTPStatus.OK
assert DeepDiff(data, response.json()) == {} assert DeepDiff(data, response.json(), exclude_paths="root['updated_date']") == {}
def _test_get_job_403(self, user, jid, **kwargs): def _test_get_job_403(self, user, jid, **kwargs):
response = get_method(user, f'jobs/{jid}', **kwargs) response = get_method(user, f'jobs/{jid}', **kwargs)
@ -83,7 +83,7 @@ class TestListJobs:
response = get_method(user, 'jobs', **kwargs, page_size='all') response = get_method(user, 'jobs', **kwargs, page_size='all')
assert response.status_code == HTTPStatus.OK assert response.status_code == HTTPStatus.OK
assert DeepDiff(data, response.json()['results']) == {} assert DeepDiff(data, response.json()['results'], exclude_paths="root['updated_date']") == {}
def _test_list_jobs_403(self, user, **kwargs): def _test_list_jobs_403(self, user, **kwargs):
response = get_method(user, 'jobs', **kwargs) response = get_method(user, 'jobs', **kwargs)
@ -123,7 +123,7 @@ class TestGetAnnotations:
assert response.status_code == HTTPStatus.OK assert response.status_code == HTTPStatus.OK
assert DeepDiff(data, response_data, assert DeepDiff(data, response_data,
exclude_paths="root['version']") == {} exclude_regex_paths=r"root\['version|updated_date'\]") == {}
def _test_get_job_annotations_403(self, user, jid, **kwargs): def _test_get_job_annotations_403(self, user, jid, **kwargs):
response = get_method(user, f'jobs/{jid}/annotations', **kwargs) response = get_method(user, f'jobs/{jid}/annotations', **kwargs)
@ -193,7 +193,7 @@ class TestPatchJobAnnotations:
if is_allow: if is_allow:
assert response.status_code == HTTPStatus.OK assert response.status_code == HTTPStatus.OK
assert DeepDiff(data, response.json(), assert DeepDiff(data, response.json(),
exclude_paths="root['version']") == {} exclude_regex_paths=r"root\['version|updated_date'\]") == {}
else: else:
assert response.status_code == HTTPStatus.FORBIDDEN assert response.status_code == HTTPStatus.FORBIDDEN
@ -313,6 +313,7 @@ class TestPatchJob:
if is_allow: if is_allow:
assert response.status_code == HTTPStatus.OK assert response.status_code == HTTPStatus.OK
assert DeepDiff(expected_data(jid, assignee), response.json()) == {} assert DeepDiff(expected_data(jid, assignee), response.json(),
exclude_paths="root['updated_date']") == {}
else: else:
assert response.status_code == HTTPStatus.FORBIDDEN assert response.status_code == HTTPStatus.FORBIDDEN

Loading…
Cancel
Save