diff --git a/cvat/apps/dataset_manager/project.py b/cvat/apps/dataset_manager/project.py index a649ba22..c3c8d217 100644 --- a/cvat/apps/dataset_manager/project.py +++ b/cvat/apps/dataset_manager/project.py @@ -8,7 +8,7 @@ from typing import Any, Callable, List, Mapping, Tuple from django.db import transaction from cvat.apps.engine import models -from cvat.apps.engine.serializers import DataSerializer, TaskSerializer +from cvat.apps.engine.serializers import DataSerializer, TaskWriteSerializer from cvat.apps.engine.task import _create_thread as create_task from cvat.apps.dataset_manager.task import TaskAnnotation @@ -80,7 +80,7 @@ class ProjectAnnotationAndData: }) data_serializer.is_valid(raise_exception=True) db_data = data_serializer.save() - db_task = TaskSerializer.create(None, { + db_task = TaskWriteSerializer.create(None, { **task_fields, 'data_id': db_data.id, 'project_id': self.db_project.id @@ -161,4 +161,4 @@ def import_dataset_as_project(project_id, dataset_file, format_name): importer = make_importer(format_name) with open(dataset_file, 'rb') as f: - project.import_dataset(f, importer) \ No newline at end of file + project.import_dataset(f, importer) diff --git a/cvat/apps/dataset_manager/views.py b/cvat/apps/dataset_manager/views.py index a0127f3b..ec92c808 100644 --- a/cvat/apps/dataset_manager/views.py +++ b/cvat/apps/dataset_manager/views.py @@ -15,7 +15,7 @@ from django.utils import timezone import cvat.apps.dataset_manager.task as task import cvat.apps.dataset_manager.project as project from cvat.apps.engine.log import slogger -from cvat.apps.engine.models import Project, Task +from cvat.apps.engine.models import Project, Task, Job from .formats.registry import EXPORT_FORMATS, IMPORT_FORMATS from .util import current_function_name @@ -30,29 +30,36 @@ def log_exception(logger=None, exc_info=True): def get_export_cache_dir(db_instance): - base_dir = osp.abspath(db_instance.get_project_dirname() if isinstance(db_instance, Project) else db_instance.get_task_dirname()) + base_dir = osp.abspath(db_instance.get_dirname()) + if osp.isdir(base_dir): return osp.join(base_dir, 'export_cache') else: - raise Exception('{} dir {} does not exist'.format("Project" if isinstance(db_instance, Project) else "Task", base_dir)) + raise FileNotFoundError('{} dir {} does not exist'.format(db_instance.__class__.__name__, base_dir)) DEFAULT_CACHE_TTL = timedelta(hours=10) TASK_CACHE_TTL = DEFAULT_CACHE_TTL PROJECT_CACHE_TTL = DEFAULT_CACHE_TTL / 3 +JOB_CACHE_TTL = DEFAULT_CACHE_TTL -def export(dst_format, task_id=None, project_id=None, server_url=None, save_images=False): +def export(dst_format, project_id=None, task_id=None, job_id=None, server_url=None, save_images=False): try: if task_id is not None: db_instance = Task.objects.get(pk=task_id) logger = slogger.task[task_id] cache_ttl = TASK_CACHE_TTL export_fn = task.export_task - else: + elif project_id is not None: db_instance = Project.objects.get(pk=project_id) logger = slogger.project[project_id] cache_ttl = PROJECT_CACHE_TTL export_fn = project.export_project + else: + db_instance = Job.objects.get(pk=job_id) + logger = slogger.job[job_id] + cache_ttl = JOB_CACHE_TTL + export_fn = task.export_job cache_dir = get_export_cache_dir(db_instance) @@ -86,8 +93,9 @@ def export(dst_format, task_id=None, project_id=None, server_url=None, save_imag "The {} '{}' is exported as '{}' at '{}' " "and available for downloading for the next {}. " "Export cache cleaning job is enqueued, id '{}'".format( - "project" if isinstance(db_instance, Project) else 'task', - db_instance.name, dst_format, output_path, cache_ttl, + db_instance.__class__.__name__.lower(), + db_instance.name if isinstance(db_instance, (Project, Task)) else db_instance.id, + dst_format, output_path, cache_ttl, cleaning_job.id )) @@ -96,6 +104,12 @@ def export(dst_format, task_id=None, project_id=None, server_url=None, save_imag log_exception(logger) raise +def export_job_annotations(job_id, dst_format=None, server_url=None): + return export(dst_format,job_id=job_id, server_url=server_url, save_images=False) + +def export_job_as_dataset(job_id, dst_format=None, server_url=None): + return export(dst_format, job_id=job_id, server_url=server_url, save_images=True) + def export_task_as_dataset(task_id, dst_format=None, server_url=None): return export(dst_format, task_id=task_id, server_url=server_url, save_images=True) diff --git a/cvat/apps/dataset_repo/dataset_repo.py b/cvat/apps/dataset_repo/dataset_repo.py index 2786532f..7b1d5b56 100644 --- a/cvat/apps/dataset_repo/dataset_repo.py +++ b/cvat/apps/dataset_repo/dataset_repo.py @@ -26,7 +26,7 @@ from cvat.apps.engine.plugins import add_plugin def _have_no_access_exception(ex): if 'Permission denied' in ex.stderr or 'Could not read from remote repository' in ex.stderr: - keys = subprocess.run(['ssh-add -L'], shell = True, + keys = subprocess.run(['ssh-add', '-L'], #nosec stdout = subprocess.PIPE).stdout.decode('utf-8').split('\n') keys = list(filter(len, list(map(lambda x: x.strip(), keys)))) raise Exception( @@ -268,7 +268,7 @@ class Git: # Dump an annotation timestamp = datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S") - dump_name = os.path.join(db_task.get_task_dirname(), + dump_name = os.path.join(db_task.get_dirname(), "git_annotation_{}_{}.zip".format(self._format, timestamp)) export_task( @@ -303,7 +303,7 @@ class Git: } old_diffs_dir = os.path.join(os.path.dirname(self._diffs_dir), 'repos_diffs') - if (os.path.isdir(old_diffs_dir)): + if os.path.isdir(old_diffs_dir): _read_old_diffs(old_diffs_dir, summary_diff) for diff_name in list(map(lambda x: os.path.join(self._diffs_dir, x), os.listdir(self._diffs_dir))): diff --git a/cvat/apps/engine/backup.py b/cvat/apps/engine/backup.py index 739d70f9..58a01184 100644 --- a/cvat/apps/engine/backup.py +++ b/cvat/apps/engine/backup.py @@ -17,27 +17,35 @@ import django_rq from django.conf import settings from django.db import transaction from django.utils import timezone +from django.shortcuts import get_object_or_404 from rest_framework import serializers, status from rest_framework.parsers import JSONParser from rest_framework.renderers import JSONRenderer from rest_framework.response import Response from django_sendfile import sendfile +from distutils.util import strtobool import cvat.apps.dataset_manager as dm from cvat.apps.engine import models from cvat.apps.engine.log import slogger from cvat.apps.engine.serializers import (AttributeSerializer, DataSerializer, - LabeledDataSerializer, SegmentSerializer, SimpleJobSerializer, TaskSerializer, - ProjectSerializer, ProjectFileSerializer, TaskFileSerializer) + LabeledDataSerializer, SegmentSerializer, SimpleJobSerializer, TaskReadSerializer, + ProjectReadSerializer, ProjectFileSerializer, TaskFileSerializer) from cvat.apps.engine.utils import av_scan_paths -from cvat.apps.engine.models import StorageChoice, StorageMethodChoice, DataChoice, Task, Project +from cvat.apps.engine.models import ( + StorageChoice, StorageMethodChoice, DataChoice, Task, Project, Location, + CloudStorage as CloudStorageModel) from cvat.apps.engine.task import _create_thread from cvat.apps.dataset_manager.views import TASK_CACHE_TTL, PROJECT_CACHE_TTL, get_export_cache_dir, clear_export_cache, log_exception from cvat.apps.dataset_manager.bindings import CvatImportError +from cvat.apps.engine.cloud_provider import ( + db_storage_to_storage_instance, validate_bucket_status +) +from cvat.apps.engine.location import StorageType, get_location_configuration class Version(Enum): - V1 = '1.0' + V1 = '1.0' def _get_label_mapping(db_labels): @@ -266,7 +274,7 @@ class TaskExporter(_ExporterBase, _TaskBackupBase): raise NotImplementedError() def _write_task(self, zip_object, target_dir=None): - task_dir = self._db_task.get_task_dirname() + task_dir = self._db_task.get_dirname() target_task_dir = os.path.join(target_dir, self.TASK_DIRNAME) if target_dir else self.TASK_DIRNAME self._write_directory( source_dir=task_dir, @@ -277,7 +285,7 @@ class TaskExporter(_ExporterBase, _TaskBackupBase): def _write_manifest(self, zip_object, target_dir=None): def serialize_task(): - task_serializer = TaskSerializer(self._db_task) + task_serializer = TaskReadSerializer(self._db_task) for field in ('url', 'owner', 'assignee', 'segments'): task_serializer.fields.pop(field) @@ -348,8 +356,8 @@ class TaskExporter(_ExporterBase, _TaskBackupBase): def export_to(self, file, target_dir=None): if self._db_task.data.storage_method == StorageMethodChoice.FILE_SYSTEM and \ - self._db_task.data.storage == StorageChoice.SHARE: - raise Exception('The task cannot be exported because it does not contain any raw data') + self._db_task.data.storage == StorageChoice.SHARE: + raise Exception('The task cannot be exported because it does not contain any raw data') if isinstance(file, str): with ZipFile(file, 'w') as zf: @@ -484,7 +492,7 @@ class TaskImporter(_ImporterBase, _TaskBackupBase): self._manifest['project_id'] = self._project_id self._db_task = models.Task.objects.create(**self._manifest, organization_id=self._org_id) - task_path = self._db_task.get_task_dirname() + task_path = self._db_task.get_dirname() if os.path.isdir(task_path): shutil.rmtree(task_path) @@ -569,7 +577,7 @@ class ProjectExporter(_ExporterBase, _ProjectBackupBase): def _write_manifest(self, zip_object): def serialize_project(): - project_serializer = ProjectSerializer(self._db_project) + project_serializer = ProjectReadSerializer(self._db_project) for field in ('assignee', 'owner', 'tasks', 'url'): project_serializer.fields.pop(field) @@ -591,7 +599,7 @@ class ProjectExporter(_ExporterBase, _ProjectBackupBase): self._write_manifest(output_file) class ProjectImporter(_ImporterBase, _ProjectBackupBase): - TASKNAME_RE = 'task_(\d+)/' + TASKNAME_RE = r'task_(\d+)/' def __init__(self, filename, user_id, org_id=None): super().__init__(logger=slogger.glob) @@ -616,7 +624,7 @@ class ProjectImporter(_ImporterBase, _ProjectBackupBase): self._manifest["owner_id"] = self._user_id self._db_project = models.Project.objects.create(**self._manifest, organization_id=self._org_id) - project_path = self._db_project.get_project_dirname() + project_path = self._db_project.get_dirname() if os.path.isdir(project_path): shutil.rmtree(project_path) os.makedirs(self._db_project.get_project_logs_dirname()) @@ -702,14 +710,23 @@ def export(db_instance, request): logger = slogger.task[db_instance.pk] Exporter = TaskExporter cache_ttl = TASK_CACHE_TTL + use_target_storage_conf = request.query_params.get('use_default_location', True) elif isinstance(db_instance, Project): filename_prefix = 'project' logger = slogger.project[db_instance.pk] Exporter = ProjectExporter cache_ttl = PROJECT_CACHE_TTL + use_target_storage_conf = request.query_params.get('use_default_location', True) else: raise Exception( "Unexpected type of db_isntance: {}".format(type(db_instance))) + use_settings = strtobool(str(use_target_storage_conf)) + obj = db_instance if use_settings else request.query_params + location_conf = get_location_configuration( + obj=obj, + use_settings=use_settings, + field_name=StorageType.TARGET + ) queue = django_rq.get_queue("default") rq_id = "/api/{}s/{}/backup".format(filename_prefix, db_instance.pk) @@ -731,8 +748,30 @@ def export(db_instance, request): filename = "{}_{}_backup_{}{}".format( filename_prefix, db_instance.name, timestamp, os.path.splitext(file_path)[1]) - return sendfile(request, file_path, attachment=True, - attachment_filename=filename.lower()) + + location = location_conf.get('location') + if location == Location.LOCAL: + return sendfile(request, file_path, attachment=True, + attachment_filename=filename.lower()) + elif location == Location.CLOUD_STORAGE: + + @validate_bucket_status + def _export_to_cloud_storage(storage, file_path, file_name): + storage.upload_file(file_path, file_name) + + try: + storage_id = location_conf['storage_id'] + except KeyError: + raise serializers.ValidationError( + 'Cloud storage location was selected for destination' + ' but cloud storage id was not specified') + db_storage = get_object_or_404(CloudStorageModel, pk=storage_id) + storage = db_storage_to_storage_instance(db_storage) + + _export_to_cloud_storage(storage, file_path, filename) + return Response(status=status.HTTP_200_OK) + else: + raise NotImplementedError() else: if os.path.exists(file_path): return Response(status=status.HTTP_201_CREATED) @@ -753,21 +792,47 @@ def export(db_instance, request): result_ttl=ttl, failure_ttl=ttl) return Response(status=status.HTTP_202_ACCEPTED) -def _import(importer, request, rq_id, Serializer, file_field_name, filename=None): +def _import(importer, request, rq_id, Serializer, file_field_name, location_conf, filename=None): queue = django_rq.get_queue("default") rq_job = queue.fetch_job(rq_id) if not rq_job: org_id = getattr(request.iam_context['organization'], 'id', None) fd = None - if not filename: - serializer = Serializer(data=request.data) - serializer.is_valid(raise_exception=True) - payload_file = serializer.validated_data[file_field_name] + + location = location_conf.get('location') + if location == Location.LOCAL: + if not filename: + serializer = Serializer(data=request.data) + serializer.is_valid(raise_exception=True) + payload_file = serializer.validated_data[file_field_name] + fd, filename = mkstemp(prefix='cvat_') + with open(filename, 'wb+') as f: + for chunk in payload_file.chunks(): + f.write(chunk) + else: + @validate_bucket_status + def _import_from_cloud_storage(storage, file_name): + return storage.download_fileobj(file_name) + + file_name = request.query_params.get('filename') + assert file_name + + # download file from cloud storage + try: + storage_id = location_conf['storage_id'] + except KeyError: + raise serializers.ValidationError( + 'Cloud storage location was selected for destination' + ' but cloud storage id was not specified') + db_storage = get_object_or_404(CloudStorageModel, pk=storage_id) + storage = db_storage_to_storage_instance(db_storage) + + data = _import_from_cloud_storage(storage, file_name) + fd, filename = mkstemp(prefix='cvat_') with open(filename, 'wb+') as f: - for chunk in payload_file.chunks(): - f.write(chunk) + f.write(data.getbuffer()) rq_job = queue.enqueue_call( func=importer, args=(filename, request.user.id, org_id), @@ -814,12 +879,18 @@ def import_project(request, filename=None): Serializer = ProjectFileSerializer file_field_name = 'project_file' + location_conf = get_location_configuration( + obj=request.query_params, + field_name=StorageType.SOURCE, + ) + return _import( importer=_import_project, request=request, rq_id=rq_id, Serializer=Serializer, file_field_name=file_field_name, + location_conf=location_conf, filename=filename ) @@ -831,11 +902,17 @@ def import_task(request, filename=None): Serializer = TaskFileSerializer file_field_name = 'task_file' + location_conf = get_location_configuration( + obj=request.query_params, + field_name=StorageType.SOURCE + ) + return _import( importer=_import_task, request=request, rq_id=rq_id, Serializer=Serializer, file_field_name=file_field_name, + location_conf=location_conf, filename=filename ) diff --git a/cvat/apps/engine/cloud_provider.py b/cvat/apps/engine/cloud_provider.py index 4e2e1b26..0861ce92 100644 --- a/cvat/apps/engine/cloud_provider.py +++ b/cvat/apps/engine/cloud_provider.py @@ -4,10 +4,13 @@ import os import boto3 +import functools +import json from abc import ABC, abstractmethod, abstractproperty from enum import Enum from io import BytesIO +from rest_framework import serializers from boto3.s3.transfer import TransferConfig from botocore.exceptions import ClientError @@ -35,6 +38,14 @@ class Status(str, Enum): def __str__(self): return self.value +class Permissions(str, Enum): + READ = 'read' + WRITE = 'write' + + @classmethod + def all(cls): + return {i.value for i in cls} + class _CloudStorage(ABC): def __init__(self): @@ -86,7 +97,11 @@ class _CloudStorage(ABC): raise NotImplementedError("Unsupported type {} was found".format(type(file_obj))) @abstractmethod - def upload_file(self, file_obj, file_name): + def upload_fileobj(self, file_obj, file_name): + pass + + @abstractmethod + def upload_file(self, file_path, file_name=None): pass def __contains__(self, file_name): @@ -99,6 +114,18 @@ class _CloudStorage(ABC): def content(self): return list(map(lambda x: x['name'] , self._files)) + @abstractproperty + def supported_actions(self): + pass + + @property + def read_access(self): + return Permissions.READ in self.access + + @property + def write_access(self): + return Permissions.WRITE in self.access + def get_cloud_storage_instance(cloud_provider, resource, credentials, specific_attributes=None, endpoint=None): instance = None if cloud_provider == CloudProviderChoice.AWS_S3: @@ -133,6 +160,12 @@ class AWS_S3(_CloudStorage): transfer_config = { 'max_io_queue': 10, } + + class Effect(str, Enum): + ALLOW = 'Allow' + DENY = 'Deny' + + def __init__(self, bucket, region, @@ -209,13 +242,27 @@ class AWS_S3(_CloudStorage): def get_file_last_modified(self, key): return self._head_file(key).get('LastModified') - def upload_file(self, file_obj, file_name): + def upload_fileobj(self, file_obj, file_name): self._bucket.upload_fileobj( Fileobj=file_obj, Key=file_name, Config=TransferConfig(max_io_queue=self.transfer_config['max_io_queue']) ) + def upload_file(self, file_path, file_name=None): + if not file_name: + file_name = os.path.basename(file_path) + try: + self._bucket.upload_file( + file_path, + file_name, + Config=TransferConfig(max_io_queue=self.transfer_config['max_io_queue']) + ) + except ClientError as ex: + msg = str(ex) + slogger.glob.error(msg) + raise Exception(msg) + def initialize_content(self): files = self._bucket.objects.all() self._files = [{ @@ -251,8 +298,45 @@ class AWS_S3(_CloudStorage): slogger.glob.info(msg) raise Exception(msg) + def delete_file(self, file_name: str): + try: + self._client_s3.delete_object(Bucket=self.name, Key=file_name) + except Exception as ex: + msg = str(ex) + slogger.glob.info(msg) + raise + + @property + def supported_actions(self): + allowed_actions = set() + try: + bucket_policy = self._bucket.Policy().policy + except ClientError as ex: + if 'NoSuchBucketPolicy' in str(ex): + return Permissions.all() + else: + raise Exception(str(ex)) + bucket_policy = json.loads(bucket_policy) if isinstance(bucket_policy, str) else bucket_policy + for statement in bucket_policy['Statement']: + effect = statement.get('Effect') # Allow | Deny + actions = statement.get('Action', set()) + if effect == self.Effect.ALLOW: + allowed_actions.update(actions) + access = { + 's3:GetObject': Permissions.READ, + 's3:PutObject': Permissions.WRITE, + } + allowed_actions = Permissions.all() & {access.get(i) for i in allowed_actions} + + return allowed_actions + class AzureBlobContainer(_CloudStorage): MAX_CONCURRENCY = 3 + + + class Effect: + pass + def __init__(self, container, account_name, sas_token=None): super().__init__() self._account_name = account_name @@ -317,9 +401,18 @@ class AzureBlobContainer(_CloudStorage): else: return Status.NOT_FOUND - def upload_file(self, file_obj, file_name): + def upload_fileobj(self, file_obj, file_name): self._container_client.upload_blob(name=file_name, data=file_obj) + def upload_file(self, file_path, file_name=None): + if not file_name: + file_name = os.path.basename(file_path) + try: + with open(file_path, 'r') as f: + self.upload_fileobj(f, file_name) + except Exception as ex: + slogger.glob.error(str(ex)) + raise # TODO: # def multipart_upload(self, file_obj): @@ -342,6 +435,10 @@ class AzureBlobContainer(_CloudStorage): buf.seek(0) return buf + @property + def supported_actions(self): + pass + class GOOGLE_DRIVE(_CloudStorage): pass @@ -361,6 +458,9 @@ def _define_gcs_status(func): class GoogleCloudStorage(_CloudStorage): + class Effect: + pass + def __init__(self, bucket_name, prefix=None, service_account_json=None, anonymous_access=False, project=None, location=None): super().__init__() if service_account_json: @@ -416,9 +516,18 @@ class GoogleCloudStorage(_CloudStorage): buf.seek(0) return buf - def upload_file(self, file_obj, file_name): + def upload_fileobj(self, file_obj, file_name): self.bucket.blob(file_name).upload_from_file(file_obj) + def upload_file(self, file_path, file_name=None): + if not file_name: + file_name = os.path.basename(file_path) + try: + self.bucket.blob(file_name).upload_from_filename(file_path) + except Exception as ex: + slogger.glob.info(str(ex)) + raise + def create(self): try: self._bucket = self._storage_client.create_bucket( @@ -441,6 +550,10 @@ class GoogleCloudStorage(_CloudStorage): blob.reload() return blob.updated + @property + def supported_actions(self): + pass + class Credentials: __slots__ = ('key', 'secret_key', 'session_token', 'account_name', 'key_file_path', 'credentials_type') @@ -502,3 +615,36 @@ class Credentials: def values(self): return [self.key, self.secret_key, self.session_token, self.account_name, self.key_file_path] + + +def validate_bucket_status(func): + @functools.wraps(func) + def wrapper(storage, *args, **kwargs): + try: + res = func(storage, *args, **kwargs) + except Exception as ex: + # check that cloud storage exists + storage_status = storage.get_status() if storage is not None else None + if storage_status == Status.FORBIDDEN: + msg = 'The resource {} is no longer available. Access forbidden.'.format(storage.name) + elif storage_status == Status.NOT_FOUND: + msg = 'The resource {} not found. It may have been deleted.'.format(storage.name) + else: + msg = str(ex) + raise serializers.ValidationError(msg) + return res + return wrapper + + +def db_storage_to_storage_instance(db_storage): + credentials = Credentials() + credentials.convert_from_db({ + 'type': db_storage.credentials_type, + 'value': db_storage.credentials, + }) + details = { + 'resource': db_storage.resource, + 'credentials': credentials, + 'specific_attributes': db_storage.get_specific_attributes() + } + return get_cloud_storage_instance(cloud_provider=db_storage.provider_type, **details) diff --git a/cvat/apps/engine/location.py b/cvat/apps/engine/location.py new file mode 100644 index 00000000..fd3fadf0 --- /dev/null +++ b/cvat/apps/engine/location.py @@ -0,0 +1,36 @@ +from enum import Enum + +from cvat.apps.engine.models import Location + +class StorageType(str, Enum): + TARGET = 'target_storage' + SOURCE = 'source_storage' + + def __str__(self): + return self.value + +def get_location_configuration(obj, field_name, use_settings=False): + location_conf = dict() + if use_settings: + storage = getattr(obj, field_name) + if storage is None: + location_conf['location'] = Location.LOCAL + else: + location_conf['location'] = storage.location + sid = storage.cloud_storage_id + if sid: + location_conf['storage_id'] = sid + else: + # obj is query_params + # FIXME when ui part will be done + location_conf['location'] = obj.get('location', Location.LOCAL) + # try: + # location_conf['location'] = obj['location'] + # except KeyError: + # raise ValidationError("Custom settings were selected but no location was specified") + + sid = obj.get('cloud_storage_id') + if sid: + location_conf['storage_id'] = int(sid) + + return location_conf diff --git a/cvat/apps/engine/migrations/0054_auto_20220610_1829.py b/cvat/apps/engine/migrations/0054_auto_20220610_1829.py new file mode 100644 index 00000000..1c7ae1a8 --- /dev/null +++ b/cvat/apps/engine/migrations/0054_auto_20220610_1829.py @@ -0,0 +1,51 @@ +# Generated by Django 3.2.12 on 2022-06-10 18:29 + +import cvat.apps.engine.models +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + ('engine', '0053_data_deleted_frames'), + ] + + operations = [ + migrations.CreateModel( + name='Storage', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('location', models.CharField(choices=[('cloud_storage', 'CLOUD_STORAGE'), ('local', 'LOCAL')], default=cvat.apps.engine.models.Location['LOCAL'], max_length=16)), + ('cloud_storage_id', models.IntegerField(blank=True, default=None, null=True)), + ], + options={ + 'default_permissions': (), + }, + ), + migrations.AddField( + model_name='job', + name='updated_date', + field=models.DateTimeField(auto_now=True), + ), + migrations.AddField( + model_name='project', + name='source_storage', + field=models.ForeignKey(blank=True, default=None, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='+', to='engine.storage'), + ), + migrations.AddField( + model_name='project', + name='target_storage', + field=models.ForeignKey(blank=True, default=None, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='+', to='engine.storage'), + ), + migrations.AddField( + model_name='task', + name='source_storage', + field=models.ForeignKey(blank=True, default=None, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='+', to='engine.storage'), + ), + migrations.AddField( + model_name='task', + name='target_storage', + field=models.ForeignKey(blank=True, default=None, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='+', to='engine.storage'), + ), + ] diff --git a/cvat/apps/engine/migrations/0055_jobs_directories.py b/cvat/apps/engine/migrations/0055_jobs_directories.py new file mode 100644 index 00000000..ec97f2c8 --- /dev/null +++ b/cvat/apps/engine/migrations/0055_jobs_directories.py @@ -0,0 +1,53 @@ +# Generated by Django 3.2.12 on 2022-06-10 18:29 + +import os +import shutil + +from django.db import migrations +from django.conf import settings +from cvat.apps.engine.log import get_logger + +MIGRATION_NAME = os.path.splitext(os.path.basename(__file__))[0] +MIGRATION_LOG = os.path.join(settings.MIGRATIONS_LOGS_ROOT, f"{MIGRATION_NAME}.log") + +def _get_query_set(apps): + Job = apps.get_model("engine", "Job") + query_set = Job.objects.all() + return query_set + +def _get_job_dir_path(jid): + return os.path.join(settings.JOBS_ROOT, str(jid)) + +def create_directories(apps, schema_editor): + logger = get_logger(MIGRATION_NAME, MIGRATION_LOG) + query_set = _get_query_set(apps) + logger.info(f'Migration has been started. Need to create {query_set.count()} directories.') + + for db_job in query_set: + jid = db_job.id + os.makedirs(_get_job_dir_path(jid), exist_ok=True) + logger.info(f'Migration has been finished successfully.') + +def delete_directories(apps, schema_editor): + logger = get_logger(MIGRATION_NAME, MIGRATION_LOG) + query_set = _get_query_set(apps) + logger.info(f'Reverse migration has been started. Need to delete {query_set.count()} directories.') + for db_job in query_set: + jid = db_job.id + job_dir = _get_job_dir_path(jid) + if os.path.isdir(job_dir): + shutil.rmtree(job_dir) + logger.info(f'Migration has been reversed successfully.') + +class Migration(migrations.Migration): + + dependencies = [ + ('engine', '0054_auto_20220610_1829'), + ] + + operations = [ + migrations.RunPython( + code=create_directories, + reverse_code=delete_directories + ) + ] diff --git a/cvat/apps/engine/mixins.py b/cvat/apps/engine/mixins.py index 295e4ece..42f164ff 100644 --- a/cvat/apps/engine/mixins.py +++ b/cvat/apps/engine/mixins.py @@ -8,10 +8,13 @@ import uuid from django.conf import settings from django.core.cache import cache +from distutils.util import strtobool from rest_framework import status from rest_framework.response import Response -from cvat.apps.engine.serializers import DataSerializer +from cvat.apps.engine.models import Location +from cvat.apps.engine.location import StorageType, get_location_configuration +from cvat.apps.engine.serializers import DataSerializer, LabeledDataSerializer class TusFile: _tus_cache_timeout = 3600 @@ -90,7 +93,7 @@ class TusChunk: # This upload mixin is implemented using tus # tus is open protocol for file uploads (see more https://tus.io/) -class UploadMixin(object): +class UploadMixin: _tus_api_version = '1.0.0' _tus_api_version_supported = ['1.0.0'] _tus_api_extensions = [] @@ -238,3 +241,80 @@ class UploadMixin(object): # override this to do stuff after upload def upload_finished(self, request): raise NotImplementedError('You need to implement upload_finished in UploadMixin') + +class AnnotationMixin: + def export_annotations(self, request, pk, db_obj, export_func, callback, get_data=None): + format_name = request.query_params.get("format") + action = request.query_params.get("action", "").lower() + filename = request.query_params.get("filename", "") + + use_default_location = request.query_params.get("use_default_location", True) + use_settings = strtobool(str(use_default_location)) + obj = db_obj if use_settings else request.query_params + location_conf = get_location_configuration( + obj=obj, + use_settings=use_settings, + field_name=StorageType.TARGET, + ) + + rq_id = "/api/{}/{}/annotations/{}".format(self._object.__class__.__name__.lower(), pk, format_name) + + if format_name: + return export_func(db_instance=self._object, + rq_id=rq_id, + request=request, + action=action, + callback=callback, + format_name=format_name, + filename=filename, + location_conf=location_conf, + ) + + if not get_data: + return Response("Format is not specified",status=status.HTTP_400_BAD_REQUEST) + + data = get_data(pk) + serializer = LabeledDataSerializer(data=data) + if serializer.is_valid(raise_exception=True): + return Response(serializer.data) + + def import_annotations(self, request, pk, db_obj, import_func, rq_func): + use_default_location = request.query_params.get('use_default_location', True) + use_settings = strtobool(str(use_default_location)) + obj = db_obj if use_settings else request.query_params + location_conf = get_location_configuration( + obj=obj, + use_settings=use_settings, + field_name=StorageType.SOURCE, + ) + + if location_conf['location'] == Location.CLOUD_STORAGE: + format_name = request.query_params.get('format') + file_name = request.query_params.get('filename') + rq_id = "{}@/api/{}/{}/annotations/upload".format( + self._object.__class__.__name__.lower(), request.user, pk + ) + + return import_func( + request=request, + rq_id=rq_id, + rq_func=rq_func, + pk=pk, + format_name=format_name, + location_conf=location_conf, + filename=file_name, + ) + + return self.upload_data(request) + +class SerializeMixin: + def serialize(self, request, export_func): + db_object = self.get_object() # force to call check_object_permissions + return export_func(db_object, request) + + def deserialize(self, request, import_func): + location = request.query_params.get("location", Location.LOCAL) + if location == Location.CLOUD_STORAGE: + file_name = request.query_params.get("filename", "") + return import_func(request, filename=file_name) + return self.upload_data(request) diff --git a/cvat/apps/engine/models.py b/cvat/apps/engine/models.py index 5af453f1..af84195c 100644 --- a/cvat/apps/engine/models.py +++ b/cvat/apps/engine/models.py @@ -181,7 +181,7 @@ class Data(models.Model): default_permissions = () def get_frame_step(self): - match = re.search("step\s*=\s*([1-9]\d*)", self.frame_filter) + match = re.search(r"step\s*=\s*([1-9]\d*)", self.frame_filter) return int(match.group(1)) if match else 1 def get_data_dirname(self): @@ -265,7 +265,6 @@ class Image(models.Model): default_permissions = () class Project(models.Model): - name = SafeCharField(max_length=256) owner = models.ForeignKey(User, null=True, blank=True, on_delete=models.SET_NULL, related_name="+") @@ -278,15 +277,19 @@ class Project(models.Model): default=StatusChoice.ANNOTATION) organization = models.ForeignKey(Organization, null=True, default=None, blank=True, on_delete=models.SET_NULL, related_name="projects") + source_storage = models.ForeignKey('Storage', null=True, default=None, + blank=True, on_delete=models.SET_NULL, related_name='+') + target_storage = models.ForeignKey('Storage', null=True, default=None, + blank=True, on_delete=models.SET_NULL, related_name='+') - def get_project_dirname(self): + def get_dirname(self): return os.path.join(settings.PROJECTS_ROOT, str(self.id)) def get_project_logs_dirname(self): - return os.path.join(self.get_project_dirname(), 'logs') + return os.path.join(self.get_dirname(), 'logs') def get_tmp_dirname(self): - return os.path.join(self.get_project_dirname(), "tmp") + return os.path.join(self.get_dirname(), "tmp") def get_client_log_path(self): return os.path.join(self.get_project_logs_dirname(), "client.log") @@ -324,17 +327,20 @@ class Task(models.Model): subset = models.CharField(max_length=64, blank=True, default="") organization = models.ForeignKey(Organization, null=True, default=None, blank=True, on_delete=models.SET_NULL, related_name="tasks") - + source_storage = models.ForeignKey('Storage', null=True, default=None, + blank=True, on_delete=models.SET_NULL, related_name='+') + target_storage = models.ForeignKey('Storage', null=True, default=None, + blank=True, on_delete=models.SET_NULL, related_name='+') # Extend default permission model class Meta: default_permissions = () - def get_task_dirname(self): + def get_dirname(self): return os.path.join(settings.TASKS_ROOT, str(self.id)) def get_task_logs_dirname(self): - return os.path.join(self.get_task_dirname(), 'logs') + return os.path.join(self.get_dirname(), 'logs') def get_client_log_path(self): return os.path.join(self.get_task_logs_dirname(), "client.log") @@ -343,10 +349,10 @@ class Task(models.Model): return os.path.join(self.get_task_logs_dirname(), "task.log") def get_task_artifacts_dirname(self): - return os.path.join(self.get_task_dirname(), 'artifacts') + return os.path.join(self.get_dirname(), 'artifacts') def get_tmp_dirname(self): - return os.path.join(self.get_task_dirname(), "tmp") + return os.path.join(self.get_dirname(), "tmp") def __str__(self): return self.name @@ -414,6 +420,7 @@ class Segment(models.Model): class Job(models.Model): segment = models.ForeignKey(Segment, on_delete=models.CASCADE) assignee = models.ForeignKey(User, null=True, blank=True, on_delete=models.SET_NULL) + updated_date = models.DateTimeField(auto_now=True) # TODO: it has to be deleted in Job, Task, Project and replaced by (stage, state) # The stage field cannot be changed by an assignee, but state field can be. For # now status is read only and it will be updated by (stage, state). Thus we don't @@ -425,6 +432,9 @@ class Job(models.Model): state = models.CharField(max_length=32, choices=StateChoice.choices(), default=StateChoice.NEW) + def get_dirname(self): + return os.path.join(settings.JOBS_ROOT, str(self.id)) + def get_project_id(self): project = self.segment.task.project return project.id if project else None @@ -524,8 +534,8 @@ class SourceType(str, Enum): MANUAL = 'manual' @classmethod - def choices(self): - return tuple((x.value, x.name) for x in self) + def choices(cls): + return tuple((x.value, x.name) for x in cls) def __str__(self): return self.value @@ -669,6 +679,21 @@ class Manifest(models.Model): def __str__(self): return '{}'.format(self.filename) +class Location(str, Enum): + CLOUD_STORAGE = 'cloud_storage' + LOCAL = 'local' + + @classmethod + def choices(cls): + return tuple((x.value, x.name) for x in cls) + + def __str__(self): + return self.value + + @classmethod + def list(cls): + return [i.value for i in cls] + class CloudStorage(models.Model): # restrictions: # AWS bucket name, Azure container name - 63, Google bucket name - 63 without dots and 222 with dots @@ -696,7 +721,6 @@ class CloudStorage(models.Model): organization = models.ForeignKey(Organization, null=True, default=None, blank=True, on_delete=models.SET_NULL, related_name="cloudstorages") - class Meta: default_permissions = () unique_together = ('provider_type', 'resource', 'credentials') @@ -721,3 +745,10 @@ class CloudStorage(models.Model): def get_key_file_path(self): return os.path.join(self.get_storage_dirname(), 'key.json') + +class Storage(models.Model): + location = models.CharField(max_length=16, choices=Location.choices(), default=Location.LOCAL) + cloud_storage_id = models.IntegerField(null=True, blank=True, default=None) + + class Meta: + default_permissions = () diff --git a/cvat/apps/engine/serializers.py b/cvat/apps/engine/serializers.py index ed01e5a7..2e95c357 100644 --- a/cvat/apps/engine/serializers.py +++ b/cvat/apps/engine/serializers.py @@ -20,7 +20,7 @@ from cvat.apps.engine.utils import parse_specific_attributes from drf_spectacular.utils import OpenApiExample, extend_schema_serializer class BasicUserSerializer(serializers.ModelSerializer): - def validate(self, data): + def validate(self, attrs): if hasattr(self, 'initial_data'): unknown_keys = set(self.initial_data.keys()) - set(self.fields.keys()) if unknown_keys: @@ -30,7 +30,7 @@ class BasicUserSerializer(serializers.ModelSerializer): else: message = 'Got unknown fields: {}'.format(unknown_keys) raise serializers.ValidationError(message) - return data + return attrs class Meta: model = User @@ -81,7 +81,7 @@ class LabelSerializer(serializers.ModelSerializer): fields = ('id', 'name', 'color', 'attributes', 'deleted') def validate(self, attrs): - if attrs.get('deleted') == True and attrs.get('id') is None: + if attrs.get('deleted') and attrs.get('id') is None: raise serializers.ValidationError('Deleted label must have an ID') return attrs @@ -107,7 +107,7 @@ class LabelSerializer(serializers.ModelSerializer): else: db_label = models.Label.objects.create(name=validated_data.get('name'), **instance) logger.info("New {} label was created".format(db_label.name)) - if validated_data.get('deleted') == True: + if validated_data.get('deleted'): db_label.delete() return if not validated_data.get('color', None): @@ -159,7 +159,8 @@ class JobReadSerializer(serializers.ModelSerializer): model = models.Job fields = ('url', 'id', 'task_id', 'project_id', 'assignee', 'dimension', 'labels', 'bug_tracker', 'status', 'stage', 'state', 'mode', - 'start_frame', 'stop_frame', 'data_chunk_size', 'data_compressed_chunk_type') + 'start_frame', 'stop_frame', 'data_chunk_size', 'data_compressed_chunk_type', + 'updated_date',) read_only_fields = fields class JobWriteSerializer(serializers.ModelSerializer): @@ -336,7 +337,7 @@ class DataSerializer(serializers.ModelSerializer): # pylint: disable=no-self-use def validate_frame_filter(self, value): - match = re.search("step\s*=\s*([1-9]\d*)", value) + match = re.search(r"step\s*=\s*([1-9]\d*)", value) if not match: raise serializers.ValidationError("Invalid frame filter expression") return value @@ -348,11 +349,11 @@ class DataSerializer(serializers.ModelSerializer): return value # pylint: disable=no-self-use - def validate(self, data): - if 'start_frame' in data and 'stop_frame' in data \ - and data['start_frame'] > data['stop_frame']: + def validate(self, attrs): + if 'start_frame' in attrs and 'stop_frame' in attrs \ + and attrs['start_frame'] > attrs['stop_frame']: raise serializers.ValidationError('Stop frame must be more or equal start frame') - return data + return attrs def create(self, validated_data): files = self._pop_data(validated_data) @@ -404,8 +405,12 @@ class DataSerializer(serializers.ModelSerializer): remote_file = models.RemoteFile(data=instance, **f) remote_file.save() +class StorageSerializer(serializers.ModelSerializer): + class Meta: + model = models.Storage + fields = ('id', 'location', 'cloud_storage_id') -class TaskSerializer(WriteOnceMixin, serializers.ModelSerializer): +class TaskReadSerializer(serializers.ModelSerializer): labels = LabelSerializer(many=True, source='label_set', partial=True, required=False) segments = SegmentSerializer(many=True, source='segment_set', read_only=True) data_chunk_size = serializers.ReadOnlyField(source='data.chunk_size') @@ -415,24 +420,47 @@ class TaskSerializer(WriteOnceMixin, serializers.ModelSerializer): image_quality = serializers.ReadOnlyField(source='data.image_quality') data = serializers.ReadOnlyField(source='data.id') owner = BasicUserSerializer(required=False) - owner_id = serializers.IntegerField(write_only=True, allow_null=True, required=False) assignee = BasicUserSerializer(allow_null=True, required=False) - assignee_id = serializers.IntegerField(write_only=True, allow_null=True, required=False) project_id = serializers.IntegerField(required=False, allow_null=True) dimension = serializers.CharField(allow_blank=True, required=False) + target_storage = StorageSerializer(required=False) + source_storage = StorageSerializer(required=False) class Meta: model = models.Task fields = ('url', 'id', 'name', 'project_id', 'mode', 'owner', 'assignee', - 'owner_id', 'assignee_id', 'bug_tracker', 'created_date', 'updated_date', - 'overlap', 'segment_size', 'status', 'labels', 'segments', - 'data_chunk_size', 'data_compressed_chunk_type', 'data_original_chunk_type', - 'size', 'image_quality', 'data', 'dimension', 'subset', 'organization') - read_only_fields = ('mode', 'created_date', 'updated_date', 'status', - 'data_chunk_size', 'owner', 'assignee', 'data_compressed_chunk_type', - 'data_original_chunk_type', 'size', 'image_quality', 'data', - 'organization') - write_once_fields = ('overlap', 'segment_size', 'project_id') + 'bug_tracker', 'created_date', 'updated_date', 'overlap', 'segment_size', + 'status', 'labels', 'segments', 'data_chunk_size', 'data_compressed_chunk_type', + 'data_original_chunk_type', 'size', 'image_quality', 'data', 'dimension', + 'subset', 'organization', 'target_storage', 'source_storage', + ) + read_only_fields = fields + + def to_representation(self, instance): + response = super().to_representation(instance) + if instance.project_id: + response["labels"] = LabelSerializer(many=True).to_representation(instance.project.label_set) + return response + +class TaskWriteSerializer(WriteOnceMixin, serializers.ModelSerializer): + labels = LabelSerializer(many=True, source='label_set', partial=True, required=False) + owner_id = serializers.IntegerField(write_only=True, allow_null=True, required=False) + assignee_id = serializers.IntegerField(write_only=True, allow_null=True, required=False) + project_id = serializers.IntegerField(required=False, allow_null=True) + target_storage = StorageSerializer(required=False) + source_storage = StorageSerializer(required=False) + + class Meta: + model = models.Task + fields = ('url', 'id', 'name', 'project_id', 'owner_id', 'assignee_id', + 'bug_tracker', 'overlap', 'segment_size', 'labels', 'subset', + 'target_storage', 'source_storage', + ) + write_once_fields = ('overlap', 'segment_size', 'project_id', 'owner_id', 'labels') + + def to_representation(self, instance): + serializer = TaskReadSerializer(instance, context=self.context) + return serializer.data # pylint: disable=no-self-use def create(self, validated_data): @@ -453,7 +481,17 @@ class TaskSerializer(WriteOnceMixin, serializers.ModelSerializer): raise serializers.ValidationError(f'The task and its project should be in the same organization.') labels = validated_data.pop('label_set', []) - db_task = models.Task.objects.create(**validated_data) + + # configure source/target storages for import/export + storages = _configure_related_storages({ + 'source_storage': validated_data.pop('source_storage', None), + 'target_storage': validated_data.pop('target_storage', None), + }) + + db_task = models.Task.objects.create( + **storages, + **validated_data) + label_colors = list() for label in labels: attributes = label.pop('attributespec_set') @@ -468,7 +506,7 @@ class TaskSerializer(WriteOnceMixin, serializers.ModelSerializer): del attr['id'] models.AttributeSpec.objects.create(label=db_label, **attr) - task_path = db_task.get_task_dirname() + task_path = db_task.get_dirname() if os.path.isdir(task_path): shutil.rmtree(task_path) @@ -478,12 +516,6 @@ class TaskSerializer(WriteOnceMixin, serializers.ModelSerializer): db_task.save() return db_task - def to_representation(self, instance): - response = super().to_representation(instance) - if instance.project_id: - response["labels"] = LabelSerializer(many=True).to_representation(instance.project.label_set) - return response - # pylint: disable=no-self-use def update(self, instance, validated_data): instance.name = validated_data.get('name', instance.name) @@ -500,7 +532,7 @@ class TaskSerializer(WriteOnceMixin, serializers.ModelSerializer): if validated_project_id is not None and validated_project_id != instance.project_id: project = models.Project.objects.get(id=validated_project_id) if project.tasks.count() and project.tasks.first().dimension != instance.dimension: - raise serializers.ValidationError(f'Dimension ({instance.dimension}) of the task must be the same as other tasks in project ({project.tasks.first().dimension})') + raise serializers.ValidationError(f'Dimension ({instance.dimension}) of the task must be the same as other tasks in project ({project.tasks.first().dimension})') if instance.project_id is None: for old_label in instance.label_set.all(): try: @@ -536,6 +568,9 @@ class TaskSerializer(WriteOnceMixin, serializers.ModelSerializer): ) instance.project = project + # update source and target storages + _update_related_storages(instance, validated_data) + instance.save() return instance @@ -547,6 +582,7 @@ class TaskSerializer(WriteOnceMixin, serializers.ModelSerializer): project = models.Project.objects.filter(id=project_id).first() if project is None: raise serializers.ValidationError(f'Cannot find project with ID {project_id}') + # Check that all labels can be mapped new_label_names = set() old_labels = self.instance.project.label_set.all() if self.instance.project_id else self.instance.label_set.all() @@ -577,22 +613,26 @@ class ProjectSearchSerializer(serializers.ModelSerializer): fields = ('id', 'name') read_only_fields = ('name',) -class ProjectSerializer(serializers.ModelSerializer): +class ProjectReadSerializer(serializers.ModelSerializer): labels = LabelSerializer(many=True, source='label_set', partial=True, default=[]) owner = BasicUserSerializer(required=False, read_only=True) - owner_id = serializers.IntegerField(write_only=True, allow_null=True, required=False) - assignee = BasicUserSerializer(allow_null=True, required=False) - assignee_id = serializers.IntegerField(write_only=True, allow_null=True, required=False) + assignee = BasicUserSerializer(allow_null=True, required=False, read_only=True) task_subsets = serializers.ListField(child=serializers.CharField(), required=False) dimension = serializers.CharField(max_length=16, required=False, read_only=True) + target_storage = StorageSerializer(required=False) + source_storage = StorageSerializer(required=False) class Meta: model = models.Project fields = ('url', 'id', 'name', 'labels', 'tasks', 'owner', 'assignee', - 'owner_id', 'assignee_id', 'bug_tracker', 'task_subsets', - 'created_date', 'updated_date', 'status', 'dimension', 'organization') + 'bug_tracker', 'task_subsets', # 'owner_id', 'assignee_id', + 'created_date', 'updated_date', 'status', 'dimension', 'organization', + 'target_storage', 'source_storage', + ) read_only_fields = ('created_date', 'updated_date', 'status', 'owner', - 'assignee', 'task_subsets', 'dimension', 'organization', 'tasks') + 'assignee', 'task_subsets', 'dimension', 'organization', 'tasks', + 'target_storage', 'source_storage', + ) def to_representation(self, instance): response = super().to_representation(instance) @@ -602,10 +642,38 @@ class ProjectSerializer(serializers.ModelSerializer): response['dimension'] = instance.tasks.first().dimension if instance.tasks.count() else None return response +class ProjectWriteSerializer(serializers.ModelSerializer): + labels = LabelSerializer(many=True, source='label_set', partial=True, default=[]) + owner_id = serializers.IntegerField(write_only=True, allow_null=True, required=False) + assignee_id = serializers.IntegerField(write_only=True, allow_null=True, required=False) + task_subsets = serializers.ListField(child=serializers.CharField(), required=False) + + target_storage = StorageSerializer(required=False) + source_storage = StorageSerializer(required=False) + + class Meta: + model = models.Project + fields = ('name', 'labels', 'owner_id', 'assignee_id', 'bug_tracker', + 'target_storage', 'source_storage', 'task_subsets', + ) + + def to_representation(self, instance): + serializer = ProjectReadSerializer(instance, context=self.context) + return serializer.data + # pylint: disable=no-self-use def create(self, validated_data): labels = validated_data.pop('label_set') - db_project = models.Project.objects.create(**validated_data) + + # configure source/target storages for import/export + storages = _configure_related_storages({ + 'source_storage': validated_data.pop('source_storage', None), + 'target_storage': validated_data.pop('target_storage', None), + }) + + db_project = models.Project.objects.create( + **storages, + **validated_data) label_colors = list() for label in labels: if label.get('id', None): @@ -620,7 +688,7 @@ class ProjectSerializer(serializers.ModelSerializer): del attr['id'] models.AttributeSpec.objects.create(label=db_label, **attr) - project_path = db_project.get_project_dirname() + project_path = db_project.get_dirname() if os.path.isdir(project_path): shutil.rmtree(project_path) os.makedirs(db_project.get_project_logs_dirname()) @@ -637,6 +705,9 @@ class ProjectSerializer(serializers.ModelSerializer): for label in labels: LabelSerializer.update_instance(label, instance) + # update source and target storages + _update_related_storages(instance, validated_data) + instance.save() return instance @@ -976,19 +1047,19 @@ class CloudStorageWriteSerializer(serializers.ModelSerializer): @staticmethod def _manifests_validation(storage, manifests): - # check manifest files availability - for manifest in manifests: - file_status = storage.get_file_status(manifest) - if file_status == Status.NOT_FOUND: - raise serializers.ValidationError({ - 'manifests': "The '{}' file does not exist on '{}' cloud storage" \ - .format(manifest, storage.name) - }) - elif file_status == Status.FORBIDDEN: - raise serializers.ValidationError({ - 'manifests': "The '{}' file does not available on '{}' cloud storage. Access denied" \ - .format(manifest, storage.name) - }) + # check manifest files availability + for manifest in manifests: + file_status = storage.get_file_status(manifest) + if file_status == Status.NOT_FOUND: + raise serializers.ValidationError({ + 'manifests': "The '{}' file does not exist on '{}' cloud storage" \ + .format(manifest, storage.name) + }) + elif file_status == Status.FORBIDDEN: + raise serializers.ValidationError({ + 'manifests': "The '{}' file does not available on '{}' cloud storage. Access denied" \ + .format(manifest, storage.name) + }) def create(self, validated_data): provider_type = validated_data.get('provider_type') @@ -1140,3 +1211,61 @@ class RelatedFileSerializer(serializers.ModelSerializer): model = models.RelatedFile fields = '__all__' read_only_fields = ('path',) + + +def _update_related_storages(instance, validated_data): + for storage in ('source_storage', 'target_storage'): + new_conf = validated_data.pop(storage, None) + + if not new_conf: + continue + + cloud_storage_id = new_conf.get('cloud_storage_id') + if cloud_storage_id: + _validate_existence_of_cloud_storage(cloud_storage_id) + + # storage_instance maybe None + storage_instance = getattr(instance, storage) + if not storage_instance: + storage_instance = models.Storage(**new_conf) + storage_instance.save() + setattr(instance, storage, storage_instance) + continue + + new_location = new_conf.get('location') + storage_instance.location = new_location or storage_instance.location + storage_instance.cloud_storage_id = new_conf.get('cloud_storage_id', \ + storage_instance.cloud_storage_id if not new_location else None) + + cloud_storage_id = storage_instance.cloud_storage_id + if cloud_storage_id: + try: + _ = models.CloudStorage.objects.get(id=cloud_storage_id) + except models.CloudStorage.DoesNotExist: + raise serializers.ValidationError(f'The specified cloud storage {cloud_storage_id} does not exist.') + + storage_instance.save() + +def _configure_related_storages(validated_data): + + storages = { + 'source_storage': None, + 'target_storage': None, + } + + for i in storages: + storage_conf = validated_data.get(i) + if storage_conf: + cloud_storage_id = storage_conf.get('cloud_storage_id') + if cloud_storage_id: + _validate_existence_of_cloud_storage(cloud_storage_id) + storage_instance = models.Storage(**storage_conf) + storage_instance.save() + storages[i] = storage_instance + return storages + +def _validate_existence_of_cloud_storage(cloud_storage_id): + try: + _ = models.CloudStorage.objects.get(id=cloud_storage_id) + except models.CloudStorage.DoesNotExist: + raise serializers.ValidationError(f'The specified cloud storage {cloud_storage_id} does not exist.') diff --git a/cvat/apps/engine/signals.py b/cvat/apps/engine/signals.py index 5ef6e5f3..c86c58e5 100644 --- a/cvat/apps/engine/signals.py +++ b/cvat/apps/engine/signals.py @@ -39,7 +39,7 @@ def create_profile(instance, **kwargs): @receiver(post_delete, sender=Task, dispatch_uid="delete_task_files_on_delete_task") def delete_task_files_on_delete_task(instance, **kwargs): - shutil.rmtree(instance.get_task_dirname(), ignore_errors=True) + shutil.rmtree(instance.get_dirname(), ignore_errors=True) @receiver(post_delete, sender=Data, dispatch_uid="delete_data_files_on_delete_data") diff --git a/cvat/apps/engine/task.py b/cvat/apps/engine/task.py index 33464493..5267010c 100644 --- a/cvat/apps/engine/task.py +++ b/cvat/apps/engine/task.py @@ -112,6 +112,12 @@ def _save_task_to_db(db_task): db_job = models.Job(segment=db_segment) db_job.save() + # create job directory + job_path = db_job.get_dirname() + if os.path.isdir(job_path): + shutil.rmtree(job_path) + os.makedirs(job_path) + db_task.data.save() db_task.save() @@ -489,7 +495,7 @@ def _create_thread(db_task, data, isBackupRestore=False, isDatasetImport=False): # calculate chunk size if it isn't specified if db_data.chunk_size is None: if isinstance(compressed_chunk_writer, ZipCompressedChunkWriter): - if not (db_data.storage == models.StorageChoice.CLOUD_STORAGE): + if not db_data.storage == models.StorageChoice.CLOUD_STORAGE: w, h = extractor.get_image_size(0) else: img_properties = manifest[0] @@ -507,7 +513,7 @@ def _create_thread(db_task, data, isBackupRestore=False, isDatasetImport=False): job.save_meta() if settings.USE_CACHE and db_data.storage_method == models.StorageMethodChoice.CACHE: - for media_type, media_files in media.items(): + for media_type, media_files in media.items(): if not media_files: continue diff --git a/cvat/apps/engine/tests/test_rest_api.py b/cvat/apps/engine/tests/test_rest_api.py index 1059dc1f..bdc0787f 100644 --- a/cvat/apps/engine/tests/test_rest_api.py +++ b/cvat/apps/engine/tests/test_rest_api.py @@ -80,8 +80,8 @@ def create_db_task(data): labels = data.pop('labels', None) db_task = Task.objects.create(**data) - shutil.rmtree(db_task.get_task_dirname(), ignore_errors=True) - os.makedirs(db_task.get_task_dirname()) + shutil.rmtree(db_task.get_dirname(), ignore_errors=True) + os.makedirs(db_task.get_dirname()) os.makedirs(db_task.get_task_logs_dirname()) os.makedirs(db_task.get_task_artifacts_dirname()) db_task.data = db_data @@ -117,8 +117,8 @@ def create_db_task(data): def create_db_project(data): labels = data.pop('labels', None) db_project = Project.objects.create(**data) - shutil.rmtree(db_project.get_project_dirname(), ignore_errors=True) - os.makedirs(db_project.get_project_dirname()) + shutil.rmtree(db_project.get_dirname(), ignore_errors=True) + os.makedirs(db_project.get_dirname()) os.makedirs(db_project.get_project_logs_dirname()) if not labels is None: @@ -1979,11 +1979,11 @@ class TaskDeleteAPITestCase(APITestCase): def test_api_v2_tasks_delete_task_data_after_delete_task(self): for task in self.tasks: - task_dir = task.get_task_dirname() + task_dir = task.get_dirname() self.assertTrue(os.path.exists(task_dir)) self._check_api_v2_tasks_id(self.admin) for task in self.tasks: - task_dir = task.get_task_dirname() + task_dir = task.get_dirname() self.assertFalse(os.path.exists(task_dir)) class TaskUpdateAPITestCase(APITestCase): @@ -2418,7 +2418,7 @@ class TaskMoveAPITestCase(APITestCase): def _check_api_v2_tasks(self, tid, data, expected_status=status.HTTP_200_OK): response = self._run_api_v2_tasks_id(tid, data) self.assertEqual(response.status_code, expected_status) - if (expected_status == status.HTTP_200_OK): + if expected_status == status.HTTP_200_OK: self._check_response(response, data) def test_move_task_bad_request(self): @@ -2936,6 +2936,8 @@ class TaskImportExportAPITestCase(APITestCase): "created_date", "updated_date", "data", + "source_storage", + "target_storage", ), ) diff --git a/cvat/apps/engine/views.py b/cvat/apps/engine/views.py index 2999f718..d48f4136 100644 --- a/cvat/apps/engine/views.py +++ b/cvat/apps/engine/views.py @@ -15,6 +15,7 @@ from tempfile import mkstemp, NamedTemporaryFile import cv2 from django.db.models.query import Prefetch +from django.shortcuts import get_object_or_404 import django_rq from django.apps import apps from django.conf import settings @@ -40,7 +41,8 @@ from django_sendfile import sendfile import cvat.apps.dataset_manager as dm import cvat.apps.dataset_manager.views # pylint: disable=unused-import -from cvat.apps.engine.cloud_provider import get_cloud_storage_instance, Credentials, Status as CloudStorageStatus +from cvat.apps.engine.cloud_provider import ( + db_storage_to_storage_instance, validate_bucket_status, Status as CloudStorageStatus) from cvat.apps.dataset_manager.bindings import CvatImportError from cvat.apps.dataset_manager.serializers import DatasetFormatsSerializer from cvat.apps.engine.frame_provider import FrameProvider @@ -49,22 +51,23 @@ from cvat.apps.engine.mime_types import mimetypes from cvat.apps.engine.models import ( Job, Task, Project, Issue, Data, Comment, StorageMethodChoice, StorageChoice, Image, - CloudProviderChoice + CloudProviderChoice, Location ) from cvat.apps.engine.models import CloudStorage as CloudStorageModel from cvat.apps.engine.serializers import ( AboutSerializer, AnnotationFileSerializer, BasicUserSerializer, DataMetaReadSerializer, DataMetaWriteSerializer, DataSerializer, ExceptionSerializer, FileInfoSerializer, JobReadSerializer, JobWriteSerializer, LabeledDataSerializer, - LogEventSerializer, ProjectSerializer, ProjectSearchSerializer, - RqStatusSerializer, TaskSerializer, UserSerializer, PluginsSerializer, IssueReadSerializer, + LogEventSerializer, ProjectReadSerializer, ProjectWriteSerializer, ProjectSearchSerializer, + RqStatusSerializer, TaskReadSerializer, TaskWriteSerializer, UserSerializer, PluginsSerializer, IssueReadSerializer, IssueWriteSerializer, CommentReadSerializer, CommentWriteSerializer, CloudStorageWriteSerializer, - CloudStorageReadSerializer, DatasetFileSerializer, JobCommitSerializer) + CloudStorageReadSerializer, DatasetFileSerializer, JobCommitSerializer, + ProjectFileSerializer, TaskFileSerializer) from utils.dataset_manifest import ImageManifestManager from cvat.apps.engine.utils import av_scan_paths from cvat.apps.engine import backup -from cvat.apps.engine.mixins import UploadMixin +from cvat.apps.engine.mixins import UploadMixin, AnnotationMixin, SerializeMixin from . import models, task from .log import clogger, slogger @@ -72,6 +75,7 @@ from cvat.apps.iam.permissions import (CloudStoragePermission, CommentPermission, IssuePermission, JobPermission, ProjectPermission, TaskPermission, UserPermission) + @extend_schema(tags=['server']) class ServerViewSet(viewsets.ViewSet): serializer_class = None @@ -226,18 +230,18 @@ class ServerViewSet(viewsets.ViewSet): responses={ '200': PolymorphicProxySerializer(component_name='PolymorphicProject', serializers=[ - ProjectSerializer, ProjectSearchSerializer, + ProjectReadSerializer, ProjectSearchSerializer, ], resource_type_field_name='name', many=True), }), create=extend_schema( summary='Method creates a new project', responses={ - '201': ProjectSerializer, + '201': ProjectWriteSerializer, }), retrieve=extend_schema( summary='Method returns details of a specific project', responses={ - '200': ProjectSerializer, + '200': ProjectReadSerializer, }), destroy=extend_schema( summary='Method deletes a specific project', @@ -247,10 +251,10 @@ class ServerViewSet(viewsets.ViewSet): partial_update=extend_schema( summary='Methods does a partial update of chosen fields in a project', responses={ - '200': ProjectSerializer, + '200': ProjectWriteSerializer, }) ) -class ProjectViewSet(viewsets.ModelViewSet, UploadMixin): +class ProjectViewSet(viewsets.ModelViewSet, UploadMixin, AnnotationMixin, SerializeMixin): queryset = models.Project.objects.prefetch_related(Prefetch('label_set', queryset=models.Label.objects.order_by('id') )) @@ -267,9 +271,12 @@ class ProjectViewSet(viewsets.ModelViewSet, UploadMixin): def get_serializer_class(self): if self.request.path.endswith('tasks'): - return TaskSerializer + return TaskReadSerializer else: - return ProjectSerializer + if self.request.method in SAFE_METHODS: + return ProjectReadSerializer + else: + return ProjectWriteSerializer def get_queryset(self): queryset = super().get_queryset() @@ -285,9 +292,9 @@ class ProjectViewSet(viewsets.ModelViewSet, UploadMixin): @extend_schema( summary='Method returns information of the tasks of the project with the selected id', responses={ - '200': TaskSerializer(many=True), + '200': TaskReadSerializer(many=True), }) - @action(detail=True, methods=['GET'], serializer_class=TaskSerializer) + @action(detail=True, methods=['GET'], serializer_class=TaskReadSerializer) def tasks(self, request, pk): self.get_object() # force to call check_object_permissions queryset = Task.objects.filter(project_id=pk).order_by('-id') @@ -311,7 +318,15 @@ class ProjectViewSet(viewsets.ModelViewSet, UploadMixin): OpenApiParameter('filename', description='Desired output file name', location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False), OpenApiParameter('action', description='Used to start downloading process after annotation file had been created', - location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False, enum=['download', 'import_status']) + location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False, enum=['download', 'import_status']), + OpenApiParameter('location', description='Where need to save downloaded dataset', + location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False, + enum=Location.list()), + OpenApiParameter('cloud_storage_id', description='Storage id', + location=OpenApiParameter.QUERY, type=OpenApiTypes.NUMBER, required=False), + OpenApiParameter('use_default_location', description='Use the location that was configured in project to import dataset', + location=OpenApiParameter.QUERY, type=OpenApiTypes.BOOL, required=False, + default=True), ], responses={ '200': OpenApiResponse(description='Download of file started'), @@ -323,7 +338,17 @@ class ProjectViewSet(viewsets.ModelViewSet, UploadMixin): parameters=[ OpenApiParameter('format', description='Desired dataset format name\n' 'You can get the list of supported formats at:\n/server/annotation/formats', - location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=True) + location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=True), + OpenApiParameter('location', description='Where to import the dataset from', + location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False, + enum=Location.list()), + OpenApiParameter('cloud_storage_id', description='Storage id', + location=OpenApiParameter.QUERY, type=OpenApiTypes.NUMBER, required=False), + OpenApiParameter('use_default_location', description='Use the location that was configured in the project to import annotations', + location=OpenApiParameter.QUERY, type=OpenApiTypes.BOOL, required=False, + default=True), + OpenApiParameter('filename', description='Dataset file name', + location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False), ], responses={ '202': OpenApiResponse(description='Exporting has been started'), @@ -335,8 +360,15 @@ class ProjectViewSet(viewsets.ModelViewSet, UploadMixin): def dataset(self, request, pk): self._object = self.get_object() # force to call check_object_permissions - if request.method == 'POST' or request.method == 'OPTIONS': - return self.upload_data(request) + if request.method in {'POST', 'OPTIONS'}: + + return self.import_annotations( + request=request, + pk=pk, + db_obj=self._object, + import_func=_import_project_dataset, + rq_func=dm.project.import_dataset_as_project + ) else: action = request.query_params.get("action", "").lower() if action in ("import_status",): @@ -363,15 +395,12 @@ class ProjectViewSet(viewsets.ModelViewSet, UploadMixin): status=status.HTTP_202_ACCEPTED ) else: - format_name = request.query_params.get("format", "") - return _export_annotations( - db_instance=self._object, - rq_id="/api/project/{}/dataset/{}".format(pk, format_name), + return self.export_annotations( request=request, - action=action, - callback=dm.views.export_project_as_dataset, - format_name=format_name, - filename=request.query_params.get("filename", "").lower(), + pk=pk, + db_obj=self._object, + export_func=_export_annotations, + callback=dm.views.export_project_as_dataset ) @action(detail=True, methods=['HEAD', 'PATCH'], url_path='dataset/'+UploadMixin.file_id_regex) @@ -423,7 +452,15 @@ class ProjectViewSet(viewsets.ModelViewSet, UploadMixin): OpenApiParameter('filename', description='Desired output file name', location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False), OpenApiParameter('action', description='Used to start downloading process after annotation file had been created', - location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False, enum=['download']) + location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False, enum=['download']), + OpenApiParameter('location', description='Where need to save downloaded dataset', + location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False, + enum=Location.list()), + OpenApiParameter('cloud_storage_id', description='Storage id', + location=OpenApiParameter.QUERY, type=OpenApiTypes.NUMBER, required=False), + OpenApiParameter('use_default_location', description='Use the location that was configured in project to export annotation', + location=OpenApiParameter.QUERY, type=OpenApiTypes.BOOL, required=False, + default=True), ], responses={ '200': OpenApiResponse(description='Download of file started'), @@ -435,21 +472,30 @@ class ProjectViewSet(viewsets.ModelViewSet, UploadMixin): @action(detail=True, methods=['GET'], serializer_class=LabeledDataSerializer) def annotations(self, request, pk): - db_project = self.get_object() # force to call check_object_permissions - format_name = request.query_params.get('format') - if format_name: - return _export_annotations(db_instance=db_project, - rq_id="/api/projects/{}/annotations/{}".format(pk, format_name), - request=request, - action=request.query_params.get("action", "").lower(), - callback=dm.views.export_project_annotations, - format_name=format_name, - filename=request.query_params.get("filename", "").lower(), - ) - else: - return Response("Format is not specified",status=status.HTTP_400_BAD_REQUEST) + self._object = self.get_object() # force to call check_object_permissions + return self.export_annotations( + request=request, + pk=pk, + db_obj=self._object, + export_func=_export_annotations, + callback=dm.views.export_project_annotations, + get_data=dm.task.get_job_data, + ) @extend_schema(summary='Methods creates a backup copy of a project', + parameters=[ + OpenApiParameter('action', location=OpenApiParameter.QUERY, + description='Used to start downloading process after backup file had been created', + type=OpenApiTypes.STR, required=False, enum=['download']), + OpenApiParameter('location', description='Where need to save downloaded backup', + location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False, + enum=Location.list()), + OpenApiParameter('cloud_storage_id', description='Storage id', + location=OpenApiParameter.QUERY, type=OpenApiTypes.NUMBER, required=False), + OpenApiParameter('use_default_location', description='Use the location that was configured in project to export backup', + location=OpenApiParameter.QUERY, type=OpenApiTypes.BOOL, required=False, + default=True), + ], responses={ '200': OpenApiResponse(description='Download of file started'), '201': OpenApiResponse(description='Output backup file is ready for downloading'), @@ -457,17 +503,26 @@ class ProjectViewSet(viewsets.ModelViewSet, UploadMixin): }) @action(methods=['GET'], detail=True, url_path='backup') def export_backup(self, request, pk=None): - db_project = self.get_object() # force to call check_object_permissions - return backup.export(db_project, request) + return self.serialize(request, backup.export) @extend_schema(summary='Methods create a project from a backup', + parameters=[ + OpenApiParameter('location', description='Where to import the backup file from', + location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False, + enum=Location.list(), default=Location.LOCAL), + OpenApiParameter('cloud_storage_id', description='Storage id', + location=OpenApiParameter.QUERY, type=OpenApiTypes.NUMBER, required=False), + OpenApiParameter('filename', description='Backup file name', + location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False), + ], + request=ProjectFileSerializer(required=False), responses={ '201': OpenApiResponse(description='The project has been imported'), # or better specify {id: project_id} '202': OpenApiResponse(description='Importing a backup file has been started'), }) - @action(detail=False, methods=['OPTIONS', 'POST'], url_path=r'backup/?$') + @action(detail=False, methods=['OPTIONS', 'POST'], url_path=r'backup/?$', serializer_class=ProjectFileSerializer(required=False)) def import_backup(self, request, pk=None): - return self.upload_data(request) + return self.deserialize(request, backup.import_project) @action(detail=False, methods=['HEAD', 'PATCH'], url_path='backup/'+UploadMixin.file_id_regex) def append_backup_chunk(self, request, file_id): @@ -522,6 +577,7 @@ class DataChunkGetter: if self.type == 'chunk': start_chunk = frame_provider.get_chunk_number(start) stop_chunk = frame_provider.get_chunk_number(stop) + # pylint: disable=superfluous-parens if not (start_chunk <= self.number <= stop_chunk): raise ValidationError('The chunk number should be in ' + f'[{start_chunk}, {stop_chunk}] range') @@ -571,20 +627,20 @@ class DataChunkGetter: list=extend_schema( summary='Returns a paginated list of tasks according to query parameters (10 tasks per page)', responses={ - '200': TaskSerializer(many=True), + '200': TaskReadSerializer(many=True), }), create=extend_schema( summary='Method creates a new task in a database without any attached images and videos', responses={ - '201': TaskSerializer, + '201': TaskWriteSerializer, }), retrieve=extend_schema( summary='Method returns details of a specific task', - responses=TaskSerializer), + responses=TaskReadSerializer), update=extend_schema( summary='Method updates a task by id', responses={ - '200': TaskSerializer, + '200': TaskWriteSerializer, }), destroy=extend_schema( summary='Method deletes a specific task, all attached jobs, annotations, and data', @@ -594,15 +650,14 @@ class DataChunkGetter: partial_update=extend_schema( summary='Methods does a partial update of chosen fields in a task', responses={ - '200': TaskSerializer, + '200': TaskWriteSerializer, }) ) -class TaskViewSet(UploadMixin, viewsets.ModelViewSet): +class TaskViewSet(UploadMixin, AnnotationMixin, viewsets.ModelViewSet, SerializeMixin): queryset = Task.objects.prefetch_related( Prefetch('label_set', queryset=models.Label.objects.order_by('id')), "label_set__attributespec_set", "segment_set__job_set") - serializer_class = TaskSerializer lookup_fields = {'project_name': 'project__name', 'owner': 'owner__username', 'assignee': 'assignee__username'} search_fields = ('project_name', 'name', 'owner', 'status', 'assignee', 'subset', 'mode', 'dimension') filter_fields = list(search_fields) + ['id', 'project_id', 'updated_date'] @@ -610,6 +665,12 @@ class TaskViewSet(UploadMixin, viewsets.ModelViewSet): ordering = "-id" iam_organization_field = 'organization' + def get_serializer_class(self): + if self.request.method in SAFE_METHODS: + return TaskReadSerializer + else: + return TaskWriteSerializer + def get_queryset(self): queryset = super().get_queryset() if self.action == 'list': @@ -619,19 +680,42 @@ class TaskViewSet(UploadMixin, viewsets.ModelViewSet): return queryset @extend_schema(summary='Method recreates a task from an attached task backup file', + parameters=[ + OpenApiParameter('location', description='Where to import the backup file from', + location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False, + enum=Location.list(), default=Location.LOCAL), + OpenApiParameter('cloud_storage_id', description='Storage id', + location=OpenApiParameter.QUERY, type=OpenApiTypes.NUMBER, required=False), + OpenApiParameter('filename', description='Backup file name', + location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False), + ], + request=TaskFileSerializer(required=False), responses={ '201': OpenApiResponse(description='The task has been imported'), # or better specify {id: task_id} '202': OpenApiResponse(description='Importing a backup file has been started'), }) - @action(detail=False, methods=['OPTIONS', 'POST'], url_path=r'backup/?$') + @action(detail=False, methods=['OPTIONS', 'POST'], url_path=r'backup/?$', serializer_class=TaskFileSerializer(required=False)) def import_backup(self, request, pk=None): - return self.upload_data(request) + return self.deserialize(request, backup.import_task) @action(detail=False, methods=['HEAD', 'PATCH'], url_path='backup/'+UploadMixin.file_id_regex) def append_backup_chunk(self, request, file_id): return self.append_tus_chunk(request, file_id) @extend_schema(summary='Method backup a specified task', + parameters=[ + OpenApiParameter('action', location=OpenApiParameter.QUERY, + description='Used to start downloading process after backup file had been created', + type=OpenApiTypes.STR, required=False, enum=['download']), + OpenApiParameter('location', description='Where need to save downloaded backup', + location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False, + enum=Location.list()), + OpenApiParameter('cloud_storage_id', description='Storage id', + location=OpenApiParameter.QUERY, type=OpenApiTypes.NUMBER, required=False), + OpenApiParameter('use_default_location', description='Use the location that was configured in the task to export backup', + location=OpenApiParameter.QUERY, type=OpenApiTypes.BOOL, required=False, + default=True), + ], responses={ '200': OpenApiResponse(description='Download of file started'), '201': OpenApiResponse(description='Output backup file is ready for downloading'), @@ -639,8 +723,7 @@ class TaskViewSet(UploadMixin, viewsets.ModelViewSet): }) @action(methods=['GET'], detail=True, url_path='backup') def export_backup(self, request, pk=None): - db_task = self.get_object() # force to call check_object_permissions - return backup.export(db_task, request) + return self.serialize(request, backup.export) def perform_update(self, serializer): instance = serializer.instance @@ -659,7 +742,7 @@ class TaskViewSet(UploadMixin, viewsets.ModelViewSet): assert instance.organization == db_project.organization def perform_destroy(self, instance): - task_dirname = instance.get_task_dirname() + task_dirname = instance.get_dirname() super().perform_destroy(instance) shutil.rmtree(task_dirname, ignore_errors=True) if instance.data and not instance.data.tasks.all(): @@ -695,6 +778,7 @@ class TaskViewSet(UploadMixin, viewsets.ModelViewSet): # UploadMixin method def upload_finished(self, request): if self.action == 'annotations': + # db_task = self.get_object() format_name = request.query_params.get("format", "") filename = request.query_params.get("filename", "") tmp_dir = self._object.get_tmp_dirname() @@ -823,7 +907,15 @@ class TaskViewSet(UploadMixin, viewsets.ModelViewSet): location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False), OpenApiParameter('action', location=OpenApiParameter.QUERY, description='Used to start downloading process after annotation file had been created', - type=OpenApiTypes.STR, required=False, enum=['download']) + type=OpenApiTypes.STR, required=False, enum=['download']), + OpenApiParameter('location', description='Where need to save downloaded dataset', + location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False, + enum=Location.list()), + OpenApiParameter('cloud_storage_id', description='Storage id', + location=OpenApiParameter.QUERY, type=OpenApiTypes.NUMBER, required=False), + OpenApiParameter('use_default_location', description='Use the location that was configured in the task to export annotation', + location=OpenApiParameter.QUERY, type=OpenApiTypes.BOOL, required=False, + default=True), ], responses={ '200': OpenApiResponse(description='Download of file started'), @@ -841,6 +933,26 @@ class TaskViewSet(UploadMixin, viewsets.ModelViewSet): '202': OpenApiResponse(description='Uploading has been started'), '405': OpenApiResponse(description='Format is not available'), }) + @extend_schema(methods=['POST'], summary='Method allows to upload task annotations from storage', + parameters=[ + OpenApiParameter('format', location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False, + description='Input format name\nYou can get the list of supported formats at:\n/server/annotation/formats'), + OpenApiParameter('location', description='where to import the annotation from', + location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False, + enum=Location.list()), + OpenApiParameter('cloud_storage_id', description='Storage id', + location=OpenApiParameter.QUERY, type=OpenApiTypes.NUMBER, required=False), + OpenApiParameter('use_default_location', description='Use the location that was configured in task to import annotations', + location=OpenApiParameter.QUERY, type=OpenApiTypes.BOOL, required=False, + default=True), + OpenApiParameter('filename', description='Annotation file name', + location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False), + ], + responses={ + '201': OpenApiResponse(description='Uploading has finished'), + '202': OpenApiResponse(description='Uploading has been started'), + '405': OpenApiResponse(description='Format is not available'), + }) @extend_schema(methods=['PATCH'], summary='Method performs a partial update of annotations in a specific task', parameters=[ OpenApiParameter('action', location=OpenApiParameter.QUERY, required=True, @@ -851,27 +963,26 @@ class TaskViewSet(UploadMixin, viewsets.ModelViewSet): '204': OpenApiResponse(description='The annotation has been deleted'), }) @action(detail=True, methods=['GET', 'DELETE', 'PUT', 'PATCH', 'POST', 'OPTIONS'], url_path=r'annotations/?$', - serializer_class=LabeledDataSerializer) + serializer_class=LabeledDataSerializer(required=False)) def annotations(self, request, pk): self._object = self.get_object() # force to call check_object_permissions if request.method == 'GET': - format_name = request.query_params.get('format') - if format_name: - return _export_annotations(db_instance=self._object, - rq_id="/api/tasks/{}/annotations/{}".format(pk, format_name), - request=request, - action=request.query_params.get("action", "").lower(), - callback=dm.views.export_task_annotations, - format_name=format_name, - filename=request.query_params.get("filename", "").lower(), - ) - else: - data = dm.task.get_task_data(pk) - serializer = LabeledDataSerializer(data=data) - if serializer.is_valid(raise_exception=True): - return Response(serializer.data) + return self.export_annotations( + request=request, + pk=pk, + db_obj=self._object, + export_func=_export_annotations, + callback=dm.views.export_task_annotations, + get_data=dm.task.get_task_data, + ) elif request.method == 'POST' or request.method == 'OPTIONS': - return self.upload_data(request) + return self.import_annotations( + request=request, + pk=pk, + db_obj=self._object, + import_func=_import_annotations, + rq_func=dm.task.import_task_annotations, + ) elif request.method == 'PUT': format_name = request.query_params.get('format') if format_name: @@ -991,7 +1102,15 @@ class TaskViewSet(UploadMixin, viewsets.ModelViewSet): location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False), OpenApiParameter('action', location=OpenApiParameter.QUERY, description='Used to start downloading process after annotation file had been created', - type=OpenApiTypes.STR, required=False, enum=['download']) + type=OpenApiTypes.STR, required=False, enum=['download']), + OpenApiParameter('use_default_location', description='Use the location that was configured in task to export annotations', + location=OpenApiParameter.QUERY, type=OpenApiTypes.BOOL, required=False, + default=True), + OpenApiParameter('location', description='Where need to save downloaded dataset', + location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False, + enum=Location.list()), + OpenApiParameter('cloud_storage_id', description='Storage id', + location=OpenApiParameter.QUERY, type=OpenApiTypes.NUMBER, required=False), ], responses={ '200': OpenApiResponse(description='Download of file started'), @@ -1002,16 +1121,14 @@ class TaskViewSet(UploadMixin, viewsets.ModelViewSet): @action(detail=True, methods=['GET'], serializer_class=None, url_path='dataset') def dataset_export(self, request, pk): - db_task = self.get_object() # force to call check_object_permissions + self._object = self.get_object() # force to call check_object_permissions - format_name = request.query_params.get("format", "") - return _export_annotations(db_instance=db_task, - rq_id="/api/tasks/{}/dataset/{}".format(pk, format_name), + return self.export_annotations( request=request, - action=request.query_params.get("action", "").lower(), - callback=dm.views.export_task_as_dataset, - format_name=format_name, - filename=request.query_params.get("filename", "").lower(), + pk=pk, + db_obj=self._object, + export_func=_export_annotations, + callback=dm.views.export_task_as_dataset ) @extend_schema(tags=['jobs']) @@ -1038,7 +1155,7 @@ class TaskViewSet(UploadMixin, viewsets.ModelViewSet): }) ) class JobViewSet(viewsets.GenericViewSet, mixins.ListModelMixin, - mixins.RetrieveModelMixin, mixins.UpdateModelMixin, UploadMixin): + mixins.RetrieveModelMixin, mixins.UpdateModelMixin, UploadMixin, AnnotationMixin): queryset = Job.objects.all() iam_organization_field = 'segment__task__organization' search_fields = ('task_name', 'project_name', 'assignee', 'state', 'stage') @@ -1051,7 +1168,6 @@ class JobViewSet(viewsets.GenericViewSet, mixins.ListModelMixin, 'project_id': 'segment__task__project_id', 'task_name': 'segment__task__name', 'project_name': 'segment__task__project__name', - 'updated_date': 'segment__task__updated_date', 'assignee': 'assignee__username' } @@ -1099,8 +1215,49 @@ class JobViewSet(viewsets.GenericViewSet, mixins.ListModelMixin, status=status.HTTP_400_BAD_REQUEST) @extend_schema(methods=['GET'], summary='Method returns annotations for a specific job', + parameters=[ + OpenApiParameter('format', location=OpenApiParameter.QUERY, + description='Desired output format name\nYou can get the list of supported formats at:\n/server/annotation/formats', + type=OpenApiTypes.STR, required=False), + OpenApiParameter('filename', description='Desired output file name', + location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False), + OpenApiParameter('action', location=OpenApiParameter.QUERY, + description='Used to start downloading process after annotation file had been created', + type=OpenApiTypes.STR, required=False, enum=['download']), + OpenApiParameter('location', description='Where need to save downloaded annotation', + location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False, + enum=Location.list()), + OpenApiParameter('cloud_storage_id', description='Storage id', + location=OpenApiParameter.QUERY, type=OpenApiTypes.NUMBER, required=False), + OpenApiParameter('use_default_location', description='Use the location that was configured in the task to export annotation', + location=OpenApiParameter.QUERY, type=OpenApiTypes.BOOL, required=False, + default=True), + ], + responses={ + '200': LabeledDataSerializer, + '201': OpenApiResponse(description='Output file is ready for downloading'), + '202': OpenApiResponse(description='Exporting has been started'), + '405': OpenApiResponse(description='Format is not available'), + }) + @extend_schema(methods=['POST'], summary='Method allows to upload job annotations', + parameters=[ + OpenApiParameter('format', location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False, + description='Input format name\nYou can get the list of supported formats at:\n/server/annotation/formats'), + OpenApiParameter('location', description='where to import the annotation from', + location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False, + enum=Location.list()), + OpenApiParameter('cloud_storage_id', description='Storage id', + location=OpenApiParameter.QUERY, type=OpenApiTypes.NUMBER, required=False), + OpenApiParameter('use_default_location', description='Use the location that was configured in the task to import annotation', + location=OpenApiParameter.QUERY, type=OpenApiTypes.BOOL, required=False, + default=True), + OpenApiParameter('filename', description='Annotation file name', + location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False), + ], responses={ - '200': LabeledDataSerializer(many=True), + '201': OpenApiResponse(description='Uploading has finished'), + '202': OpenApiResponse(description='Uploading has been started'), + '405': OpenApiResponse(description='Format is not available'), }) @extend_schema(methods=['PUT'], summary='Method performs an update of all annotations in a specific job', request=AnnotationFileSerializer, responses={ @@ -1126,10 +1283,24 @@ class JobViewSet(viewsets.GenericViewSet, mixins.ListModelMixin, def annotations(self, request, pk): self._object = self.get_object() # force to call check_object_permissions if request.method == 'GET': - data = dm.task.get_job_data(pk) - return Response(data) + return self.export_annotations( + request=request, + pk=pk, + db_obj=self._object.segment.task, + export_func=_export_annotations, + callback=dm.views.export_job_annotations, + get_data=dm.task.get_job_data, + ) + elif request.method == 'POST' or request.method == 'OPTIONS': - return self.upload_data(request) + return self.import_annotations( + request=request, + pk=pk, + db_obj=self._object.segment.task, + import_func=_import_annotations, + rq_func=dm.task.import_job_annotations, + ) + elif request.method == 'PUT': format_name = request.query_params.get('format', '') if format_name: @@ -1169,6 +1340,44 @@ class JobViewSet(viewsets.GenericViewSet, mixins.ListModelMixin, self._object = self.get_object() return self.append_tus_chunk(request, file_id) + @extend_schema(summary='Export job as a dataset in a specific format', + parameters=[ + OpenApiParameter('format', location=OpenApiParameter.QUERY, + description='Desired output format name\nYou can get the list of supported formats at:\n/server/annotation/formats', + type=OpenApiTypes.STR, required=True), + OpenApiParameter('filename', description='Desired output file name', + location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False), + OpenApiParameter('action', location=OpenApiParameter.QUERY, + description='Used to start downloading process after annotation file had been created', + type=OpenApiTypes.STR, required=False, enum=['download']), + OpenApiParameter('use_default_location', description='Use the location that was configured in the task to export dataset', + location=OpenApiParameter.QUERY, type=OpenApiTypes.BOOL, required=False, + default=True), + OpenApiParameter('location', description='Where need to save downloaded dataset', + location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False, + enum=Location.list()), + OpenApiParameter('cloud_storage_id', description='Storage id', + location=OpenApiParameter.QUERY, type=OpenApiTypes.NUMBER, required=False), + ], + responses={ + '200': OpenApiResponse(description='Download of file started'), + '201': OpenApiResponse(description='Output file is ready for downloading'), + '202': OpenApiResponse(description='Exporting has been started'), + '405': OpenApiResponse(description='Format is not available'), + }) + @action(detail=True, methods=['GET'], serializer_class=None, + url_path='dataset') + def dataset_export(self, request, pk): + self._object = self.get_object() # force to call check_object_permissions + + return self.export_annotations( + request=request, + pk=pk, + db_obj=self._object.segment.task, + export_func=_export_annotations, + callback=dm.views.export_job_as_dataset + ) + @extend_schema( summary='Method returns list of issues for the job', responses={ @@ -1589,12 +1798,12 @@ class CloudStorageViewSet(viewsets.ModelViewSet): except IntegrityError: response = HttpResponseBadRequest('Same storage already exists') except ValidationError as exceptions: - msg_body = "" - for ex in exceptions.args: - for field, ex_msg in ex.items(): - msg_body += ': '.join([field, ex_msg if isinstance(ex_msg, str) else str(ex_msg[0])]) - msg_body += '\n' - return HttpResponseBadRequest(msg_body) + msg_body = "" + for ex in exceptions.args: + for field, ex_msg in ex.items(): + msg_body += ': '.join([field, ex_msg if isinstance(ex_msg, str) else str(ex_msg[0])]) + msg_body += '\n' + return HttpResponseBadRequest(msg_body) except APIException as ex: return Response(data=ex.get_full_details(), status=ex.status_code) except Exception as ex: @@ -1614,17 +1823,7 @@ class CloudStorageViewSet(viewsets.ModelViewSet): storage = None try: db_storage = self.get_object() - credentials = Credentials() - credentials.convert_from_db({ - 'type': db_storage.credentials_type, - 'value': db_storage.credentials, - }) - details = { - 'resource': db_storage.resource, - 'credentials': credentials, - 'specific_attributes': db_storage.get_specific_attributes() - } - storage = get_cloud_storage_instance(cloud_provider=db_storage.provider_type, **details) + storage = db_storage_to_storage_instance(db_storage) if not db_storage.manifests.count(): raise Exception('There is no manifest file') manifest_path = request.query_params.get('manifest_path', db_storage.manifests.first().filename) @@ -1675,17 +1874,7 @@ class CloudStorageViewSet(viewsets.ModelViewSet): try: db_storage = self.get_object() if not os.path.exists(db_storage.get_preview_path()): - credentials = Credentials() - credentials.convert_from_db({ - 'type': db_storage.credentials_type, - 'value': db_storage.credentials, - }) - details = { - 'resource': db_storage.resource, - 'credentials': credentials, - 'specific_attributes': db_storage.get_specific_attributes() - } - storage = get_cloud_storage_instance(cloud_provider=db_storage.provider_type, **details) + storage = db_storage_to_storage_instance(db_storage) if not db_storage.manifests.count(): raise Exception('Cannot get the cloud storage preview. There is no manifest file') preview_path = None @@ -1749,17 +1938,7 @@ class CloudStorageViewSet(viewsets.ModelViewSet): def status(self, request, pk): try: db_storage = self.get_object() - credentials = Credentials() - credentials.convert_from_db({ - 'type': db_storage.credentials_type, - 'value': db_storage.credentials, - }) - details = { - 'resource': db_storage.resource, - 'credentials': credentials, - 'specific_attributes': db_storage.get_specific_attributes() - } - storage = get_cloud_storage_instance(cloud_provider=db_storage.provider_type, **details) + storage = db_storage_to_storage_instance(db_storage) storage_status = storage.get_status() return HttpResponse(storage_status) except CloudStorageModel.DoesNotExist: @@ -1770,6 +1949,28 @@ class CloudStorageViewSet(viewsets.ModelViewSet): msg = str(ex) return HttpResponseBadRequest(msg) + @extend_schema(summary='Method returns allowed actions for the cloud storage', + responses={ + '200': OpenApiResponse(response=OpenApiTypes.STR, description='Cloud Storage actions (GET | PUT | DELETE)'), + }) + @action(detail=True, methods=['GET'], url_path='actions') + def actions(self, request, pk): + ''' + Method return allowed actions for cloud storage. It's required for reading/writing + ''' + try: + db_storage = self.get_object() + storage = db_storage_to_storage_instance(db_storage) + actions = storage.supported_actions + return Response(actions, content_type="text/plain") + except CloudStorageModel.DoesNotExist: + message = f"Storage {pk} does not exist" + slogger.glob.error(message) + return HttpResponseNotFound(message) + except Exception as ex: + msg = str(ex) + return HttpResponseBadRequest(msg) + def rq_handler(job, exc_type, exc_value, tb): job.exc_info = "".join( traceback.format_exception_only(exc_type, exc_value)) @@ -1779,7 +1980,16 @@ def rq_handler(job, exc_type, exc_value, tb): return True -def _import_annotations(request, rq_id, rq_func, pk, format_name, filename=None): +@validate_bucket_status +def _export_to_cloud_storage(storage, file_path, file_name): + storage.upload_file(file_path, file_name) + +@validate_bucket_status +def _import_from_cloud_storage(storage, file_name): + return storage.download_fileobj(file_name) + +def _import_annotations(request, rq_id, rq_func, pk, format_name, + filename=None, location_conf=None): format_desc = {f.DISPLAY_NAME: f for f in dm.views.get_import_formats()}.get(format_name) if format_desc is None: @@ -1794,15 +2004,36 @@ def _import_annotations(request, rq_id, rq_func, pk, format_name, filename=None) if not rq_job: # If filename is specified we consider that file was uploaded via TUS, so it exists in filesystem # Then we dont need to create temporary file + # Or filename specify key in cloud storage so we need to download file fd = None - if not filename: - serializer = AnnotationFileSerializer(data=request.data) - if serializer.is_valid(raise_exception=True): - anno_file = serializer.validated_data['annotation_file'] - fd, filename = mkstemp(prefix='cvat_{}'.format(pk)) + location = location_conf.get('location') if location_conf else Location.LOCAL + + if not filename or location == Location.CLOUD_STORAGE: + if location != Location.CLOUD_STORAGE: + serializer = AnnotationFileSerializer(data=request.data) + if serializer.is_valid(raise_exception=True): + anno_file = serializer.validated_data['annotation_file'] + fd, filename = mkstemp(prefix='cvat_{}'.format(pk)) + with open(filename, 'wb+') as f: + for chunk in anno_file.chunks(): + f.write(chunk) + else: + # download annotation file from cloud storage + try: + storage_id = location_conf['storage_id'] + except KeyError: + raise serializer.ValidationError( + 'Cloud storage location was selected for destination' + ' but cloud storage id was not specified') + db_storage = get_object_or_404(CloudStorageModel, pk=storage_id) + storage = db_storage_to_storage_instance(db_storage) + assert filename, 'filename was not spesified' + + data = _import_from_cloud_storage(storage, filename) + + fd, filename = mkstemp(prefix='cvat_') with open(filename, 'wb+') as f: - for chunk in anno_file.chunks(): - f.write(chunk) + f.write(data.getbuffer()) av_scan_paths(filename) rq_job = queue.enqueue_call( @@ -1838,7 +2069,8 @@ def _import_annotations(request, rq_id, rq_func, pk, format_name, filename=None) return Response(status=status.HTTP_202_ACCEPTED) -def _export_annotations(db_instance, rq_id, request, format_name, action, callback, filename): +def _export_annotations(db_instance, rq_id, request, format_name, action, callback, + filename, location_conf): if action not in {"", "download"}: raise serializers.ValidationError( "Unexpected action specified for the request") @@ -1873,12 +2105,31 @@ def _export_annotations(db_instance, rq_id, request, format_name, action, callba "%Y_%m_%d_%H_%M_%S") filename = filename or \ "{}_{}-{}-{}{}".format( - "project" if isinstance(db_instance, models.Project) else "task", - db_instance.name, timestamp, - format_name, osp.splitext(file_path)[1] + db_instance.__class__.__name__.lower(), + db_instance.name if isinstance(db_instance, (Task, Project)) else db_instance.id, + timestamp, format_name, osp.splitext(file_path)[1] ) - return sendfile(request, file_path, attachment=True, - attachment_filename=filename.lower()) + + # save annotation to specified location + location = location_conf.get('location') + if location == Location.LOCAL: + return sendfile(request, file_path, attachment=True, + attachment_filename=filename.lower()) + elif location == Location.CLOUD_STORAGE: + try: + storage_id = location_conf['storage_id'] + except KeyError: + return HttpResponseBadRequest( + 'Cloud storage location was selected for destination' + ' but cloud storage id was not specified') + + db_storage = get_object_or_404(CloudStorageModel, pk=storage_id) + storage = db_storage_to_storage_instance(db_storage) + + _export_to_cloud_storage(storage, file_path, filename) + return Response(status=status.HTTP_200_OK) + else: + raise NotImplementedError() else: if osp.exists(file_path): return Response(status=status.HTTP_201_CREATED) @@ -1897,14 +2148,19 @@ def _export_annotations(db_instance, rq_id, request, format_name, action, callba except Exception: server_address = None - ttl = (dm.views.PROJECT_CACHE_TTL if isinstance(db_instance, Project) else dm.views.TASK_CACHE_TTL).total_seconds() + TTL_CONSTS = { + 'project': dm.views.PROJECT_CACHE_TTL, + 'task': dm.views.TASK_CACHE_TTL, + 'job': dm.views.JOB_CACHE_TTL, + } + ttl = TTL_CONSTS[db_instance.__class__.__name__.lower()].total_seconds() queue.enqueue_call(func=callback, args=(db_instance.id, format_name, server_address), job_id=rq_id, meta={ 'request_time': timezone.localtime() }, result_ttl=ttl, failure_ttl=ttl) return Response(status=status.HTTP_202_ACCEPTED) -def _import_project_dataset(request, rq_id, rq_func, pk, format_name, filename=None): +def _import_project_dataset(request, rq_id, rq_func, pk, format_name, filename=None, location_conf=None): format_desc = {f.DISPLAY_NAME: f for f in dm.views.get_import_formats()}.get(format_name) if format_desc is None: @@ -1918,7 +2174,8 @@ def _import_project_dataset(request, rq_id, rq_func, pk, format_name, filename=N if not rq_job: fd = None - if not filename: + location = location_conf.get('location') if location_conf else None + if not filename and location != Location.CLOUD_STORAGE: serializer = DatasetFileSerializer(data=request.data) if serializer.is_valid(raise_exception=True): dataset_file = serializer.validated_data['dataset_file'] @@ -1926,6 +2183,24 @@ def _import_project_dataset(request, rq_id, rq_func, pk, format_name, filename=N with open(filename, 'wb+') as f: for chunk in dataset_file.chunks(): f.write(chunk) + elif location == Location.CLOUD_STORAGE: + assert filename + + # download project file from cloud storage + try: + storage_id = location_conf['storage_id'] + except KeyError: + raise serializers.ValidationError( + 'Cloud storage location was selected for destination' + ' but cloud storage id was not specified') + db_storage = get_object_or_404(CloudStorageModel, pk=storage_id) + storage = db_storage_to_storage_instance(db_storage) + + data = _import_from_cloud_storage(storage, filename) + + fd, filename = mkstemp(prefix='cvat_') + with open(filename, 'wb+') as f: + f.write(data.getbuffer()) rq_job = queue.enqueue_call( func=rq_func, diff --git a/cvat/apps/iam/permissions.py b/cvat/apps/iam/permissions.py index 04d14abe..fe5a6d15 100644 --- a/cvat/apps/iam/permissions.py +++ b/cvat/apps/iam/permissions.py @@ -63,7 +63,7 @@ class OpenPolicyAgentPermission(metaclass=ABCMeta): 'user': { 'role': self.org_role, }, - } if self.org_id != None else None + } if self.org_id is not None else None } } } @@ -210,7 +210,7 @@ class InvitationPermission(OpenPolicyAgentPermission): 'role': self.role, 'organization': { 'id': self.org_id - } if self.org_id != None else None + } if self.org_id is not None else None } return data @@ -417,7 +417,8 @@ class CloudStoragePermission(OpenPolicyAgentPermission): 'destroy': 'delete', 'content': 'list:content', 'preview': 'view', - 'status': 'view' + 'status': 'view', + 'actions': 'view', }.get(view.action)] def get_resource(self): @@ -427,7 +428,7 @@ class CloudStoragePermission(OpenPolicyAgentPermission): 'owner': { 'id': self.user_id }, 'organization': { 'id': self.org_id - } if self.org_id != None else None, + } if self.org_id is not None else None, 'user': { 'num_resources': Organization.objects.filter( owner=self.user_id).count() @@ -620,9 +621,9 @@ class TaskPermission(OpenPolicyAgentPermission): perm = TaskPermission.create_scope_create(request, org_id) # We don't create a project, just move it. Thus need to decrease # the number of resources. - if obj != None: + if obj is not None: perm.payload['input']['resource']['user']['num_resources'] -= 1 - if obj.project != None: + if obj.project is not None: ValidationError('Cannot change the organization for ' 'a task inside a project') permissions.append(perm) diff --git a/cvat/settings/base.py b/cvat/settings/base.py index ba4587dd..173a0c15 100644 --- a/cvat/settings/base.py +++ b/cvat/settings/base.py @@ -352,6 +352,9 @@ os.makedirs(MEDIA_DATA_ROOT, exist_ok=True) CACHE_ROOT = os.path.join(DATA_ROOT, 'cache') os.makedirs(CACHE_ROOT, exist_ok=True) +JOBS_ROOT = os.path.join(DATA_ROOT, 'jobs') +os.makedirs(JOBS_ROOT, exist_ok=True) + TASKS_ROOT = os.path.join(DATA_ROOT, 'tasks') os.makedirs(TASKS_ROOT, exist_ok=True) diff --git a/cvat/settings/testing.py b/cvat/settings/testing.py index b8659aa4..71d9060c 100644 --- a/cvat/settings/testing.py +++ b/cvat/settings/testing.py @@ -22,6 +22,9 @@ os.makedirs(MEDIA_DATA_ROOT, exist_ok=True) CACHE_ROOT = os.path.join(DATA_ROOT, 'cache') os.makedirs(CACHE_ROOT, exist_ok=True) +JOBS_ROOT = os.path.join(DATA_ROOT, 'jobs') +os.makedirs(JOBS_ROOT, exist_ok=True) + TASKS_ROOT = os.path.join(DATA_ROOT, 'tasks') os.makedirs(TASKS_ROOT, exist_ok=True) diff --git a/tests/rest_api/assets/cvat_db/cvat_data.tar.bz2 b/tests/rest_api/assets/cvat_db/cvat_data.tar.bz2 index 38ae1ac1..2dd5de83 100644 Binary files a/tests/rest_api/assets/cvat_db/cvat_data.tar.bz2 and b/tests/rest_api/assets/cvat_db/cvat_data.tar.bz2 differ diff --git a/tests/rest_api/assets/cvat_db/data.json b/tests/rest_api/assets/cvat_db/data.json index f09c6d77..d6a70063 100644 --- a/tests/rest_api/assets/cvat_db/data.json +++ b/tests/rest_api/assets/cvat_db/data.json @@ -1413,7 +1413,7 @@ "pk": 1, "fields": { "password": "pbkdf2_sha256$260000$DevmxlmLwciP1P6sZs2Qag$U9DFtjTWx96Sk95qY6UXVcvpdQEP2LcoFBftk5D2RKY=", - "last_login": "2022-06-08T08:32:30.152Z", + "last_login": "2022-06-22T09:20:25.189Z", "is_superuser": true, "username": "admin1", "first_name": "Admin", @@ -2184,6 +2184,14 @@ "model": "gitdata" } }, +{ + "model": "contenttypes.contenttype", + "pk": 49, + "fields": { + "app_label": "engine", + "model": "storage" + } +}, { "model": "sessions.session", "pk": "5x9v6r58e4l9if78anupog0ittsq2w3j", @@ -3688,7 +3696,9 @@ "created_date": "2021-12-14T19:46:37.969Z", "updated_date": "2022-03-05T09:47:49.679Z", "status": "annotation", - "organization": null + "organization": null, + "source_storage": null, + "target_storage": null } }, { @@ -3702,7 +3712,9 @@ "created_date": "2021-12-14T19:52:37.278Z", "updated_date": "2022-03-28T13:04:54.669Z", "status": "annotation", - "organization": 2 + "organization": 2, + "source_storage": null, + "target_storage": null } }, { @@ -3716,7 +3728,9 @@ "created_date": "2022-03-28T13:05:24.659Z", "updated_date": "2022-03-28T13:06:09.283Z", "status": "annotation", - "organization": 2 + "organization": 2, + "source_storage": null, + "target_storage": null } }, { @@ -3730,7 +3744,9 @@ "created_date": "2022-06-08T08:32:45.521Z", "updated_date": "2022-06-08T08:33:20.759Z", "status": "annotation", - "organization": 2 + "organization": 2, + "source_storage": null, + "target_storage": null } }, { @@ -3751,7 +3767,9 @@ "data": 2, "dimension": "2d", "subset": "", - "organization": 1 + "organization": 1, + "source_storage": null, + "target_storage": null } }, { @@ -3772,7 +3790,9 @@ "data": 5, "dimension": "2d", "subset": "", - "organization": null + "organization": null, + "source_storage": null, + "target_storage": null } }, { @@ -3793,7 +3813,9 @@ "data": 6, "dimension": "3d", "subset": "", - "organization": null + "organization": null, + "source_storage": null, + "target_storage": null } }, { @@ -3814,7 +3836,9 @@ "data": 7, "dimension": "2d", "subset": "", - "organization": 2 + "organization": 2, + "source_storage": null, + "target_storage": null } }, { @@ -3835,7 +3859,9 @@ "data": 8, "dimension": "2d", "subset": "", - "organization": null + "organization": null, + "source_storage": null, + "target_storage": null } }, { @@ -3856,7 +3882,9 @@ "data": 9, "dimension": "2d", "subset": "", - "organization": null + "organization": null, + "source_storage": null, + "target_storage": null } }, { @@ -3877,7 +3905,9 @@ "data": 11, "dimension": "2d", "subset": "Train", - "organization": 2 + "organization": 2, + "source_storage": null, + "target_storage": null } }, { @@ -3898,7 +3928,9 @@ "data": null, "dimension": "2d", "subset": "", - "organization": null + "organization": null, + "source_storage": null, + "target_storage": null } }, { @@ -3919,7 +3951,9 @@ "data": 12, "dimension": "2d", "subset": "", - "organization": 2 + "organization": 2, + "source_storage": null, + "target_storage": null } }, { @@ -4724,6 +4758,7 @@ "fields": { "segment": 2, "assignee": 6, + "updated_date": "2022-06-22T09:18:45.296Z", "status": "annotation", "stage": "annotation", "state": "new" @@ -4735,6 +4770,7 @@ "fields": { "segment": 7, "assignee": 9, + "updated_date": "2022-06-22T09:18:45.296Z", "status": "annotation", "stage": "annotation", "state": "in progress" @@ -4746,6 +4782,7 @@ "fields": { "segment": 8, "assignee": null, + "updated_date": "2022-06-22T09:18:45.296Z", "status": "annotation", "stage": "annotation", "state": "new" @@ -4757,6 +4794,7 @@ "fields": { "segment": 9, "assignee": null, + "updated_date": "2022-06-22T09:18:45.296Z", "status": "annotation", "stage": "annotation", "state": "in progress" @@ -4768,6 +4806,7 @@ "fields": { "segment": 10, "assignee": 1, + "updated_date": "2022-06-22T09:18:45.296Z", "status": "annotation", "stage": "annotation", "state": "in progress" @@ -4779,6 +4818,7 @@ "fields": { "segment": 11, "assignee": 9, + "updated_date": "2022-06-22T09:18:45.296Z", "status": "annotation", "stage": "annotation", "state": "in progress" @@ -4790,6 +4830,7 @@ "fields": { "segment": 12, "assignee": null, + "updated_date": "2022-06-22T09:18:45.296Z", "status": "validation", "stage": "validation", "state": "new" @@ -4801,6 +4842,7 @@ "fields": { "segment": 13, "assignee": null, + "updated_date": "2022-06-22T09:18:45.296Z", "status": "validation", "stage": "acceptance", "state": "new" @@ -4812,6 +4854,7 @@ "fields": { "segment": 14, "assignee": null, + "updated_date": "2022-06-22T09:18:45.296Z", "status": "annotation", "stage": "annotation", "state": "new" @@ -4823,6 +4866,7 @@ "fields": { "segment": 16, "assignee": 7, + "updated_date": "2022-06-22T09:18:45.296Z", "status": "annotation", "stage": "annotation", "state": "in progress" @@ -4834,6 +4878,7 @@ "fields": { "segment": 17, "assignee": null, + "updated_date": "2022-06-22T09:18:45.296Z", "status": "annotation", "stage": "annotation", "state": "in progress" diff --git a/tests/rest_api/assets/jobs.json b/tests/rest_api/assets/jobs.json index 1282226c..c578c1c2 100644 --- a/tests/rest_api/assets/jobs.json +++ b/tests/rest_api/assets/jobs.json @@ -32,6 +32,7 @@ "status": "annotation", "stop_frame": 4, "task_id": 13, + "updated_date": "2022-06-22T09:18:45.296000Z", "url": "http://localhost:8080/api/jobs/17" }, { @@ -69,6 +70,7 @@ "status": "annotation", "stop_frame": 10, "task_id": 11, + "updated_date": "2022-06-22T09:18:45.296000Z", "url": "http://localhost:8080/api/jobs/16" }, { @@ -113,6 +115,7 @@ "status": "annotation", "stop_frame": 19, "task_id": 9, + "updated_date": "2022-06-22T09:18:45.296000Z", "url": "http://localhost:8080/api/jobs/14" }, { @@ -157,6 +160,7 @@ "status": "validation", "stop_frame": 14, "task_id": 9, + "updated_date": "2022-06-22T09:18:45.296000Z", "url": "http://localhost:8080/api/jobs/13" }, { @@ -201,6 +205,7 @@ "status": "validation", "stop_frame": 9, "task_id": 9, + "updated_date": "2022-06-22T09:18:45.296000Z", "url": "http://localhost:8080/api/jobs/12" }, { @@ -251,6 +256,7 @@ "status": "annotation", "stop_frame": 4, "task_id": 9, + "updated_date": "2022-06-22T09:18:45.296000Z", "url": "http://localhost:8080/api/jobs/11" }, { @@ -288,6 +294,7 @@ "status": "annotation", "stop_frame": 13, "task_id": 8, + "updated_date": "2022-06-22T09:18:45.296000Z", "url": "http://localhost:8080/api/jobs/10" }, { @@ -319,6 +326,7 @@ "status": "annotation", "stop_frame": 10, "task_id": 7, + "updated_date": "2022-06-22T09:18:45.296000Z", "url": "http://localhost:8080/api/jobs/9" }, { @@ -344,6 +352,7 @@ "status": "annotation", "stop_frame": 0, "task_id": 6, + "updated_date": "2022-06-22T09:18:45.296000Z", "url": "http://localhost:8080/api/jobs/8" }, { @@ -375,6 +384,7 @@ "status": "annotation", "stop_frame": 24, "task_id": 5, + "updated_date": "2022-06-22T09:18:45.296000Z", "url": "http://localhost:8080/api/jobs/7" }, { @@ -412,6 +422,7 @@ "status": "annotation", "stop_frame": 22, "task_id": 2, + "updated_date": "2022-06-22T09:18:45.296000Z", "url": "http://localhost:8080/api/jobs/2" } ] diff --git a/tests/rest_api/assets/projects.json b/tests/rest_api/assets/projects.json index 3ad5e4b6..6d5a2cbc 100644 --- a/tests/rest_api/assets/projects.json +++ b/tests/rest_api/assets/projects.json @@ -32,7 +32,9 @@ "url": "http://localhost:8080/api/users/1", "username": "admin1" }, + "source_storage": null, "status": "annotation", + "target_storage": null, "task_subsets": [], "tasks": [ 13 @@ -62,7 +64,9 @@ "url": "http://localhost:8080/api/users/3", "username": "user2" }, + "source_storage": null, "status": "annotation", + "target_storage": null, "task_subsets": [], "tasks": [], "updated_date": "2022-03-28T13:06:09.283000Z", @@ -103,7 +107,9 @@ "url": "http://localhost:8080/api/users/10", "username": "business1" }, + "source_storage": null, "status": "annotation", + "target_storage": null, "task_subsets": [ "Train" ], @@ -161,7 +167,9 @@ "url": "http://localhost:8080/api/users/10", "username": "business1" }, + "source_storage": null, "status": "annotation", + "target_storage": null, "task_subsets": [], "tasks": [ 9 diff --git a/tests/rest_api/assets/tasks.json b/tests/rest_api/assets/tasks.json index dd4bd249..a3d21bea 100644 --- a/tests/rest_api/assets/tasks.json +++ b/tests/rest_api/assets/tasks.json @@ -58,8 +58,10 @@ } ], "size": 5, + "source_storage": null, "status": "annotation", "subset": "", + "target_storage": null, "updated_date": "2022-06-08T08:33:20.808000Z", "url": "http://localhost:8080/api/tasks/13" }, @@ -91,8 +93,10 @@ "project_id": null, "segment_size": 0, "segments": [], + "source_storage": null, "status": "annotation", "subset": "", + "target_storage": null, "updated_date": "2022-03-14T13:24:05.861000Z", "url": "http://localhost:8080/api/tasks/12" }, @@ -163,8 +167,10 @@ } ], "size": 11, + "source_storage": null, "status": "annotation", "subset": "Train", + "target_storage": null, "updated_date": "2022-03-05T10:32:35.568000Z", "url": "http://localhost:8080/api/tasks/11" }, @@ -290,8 +296,10 @@ } ], "size": 20, + "source_storage": null, "status": "annotation", "subset": "", + "target_storage": null, "updated_date": "2022-03-05T09:47:49.667000Z", "url": "http://localhost:8080/api/tasks/9" }, @@ -362,8 +370,10 @@ } ], "size": 14, + "source_storage": null, "status": "annotation", "subset": "", + "target_storage": null, "updated_date": "2022-03-05T08:52:34.908000Z", "url": "http://localhost:8080/api/tasks/8" }, @@ -428,8 +438,10 @@ } ], "size": 11, + "source_storage": null, "status": "annotation", "subset": "", + "target_storage": null, "updated_date": "2022-02-21T10:41:38.540000Z", "url": "http://localhost:8080/api/tasks/7" }, @@ -482,8 +494,10 @@ } ], "size": 1, + "source_storage": null, "status": "annotation", "subset": "", + "target_storage": null, "updated_date": "2022-02-16T06:26:54.836000Z", "url": "http://localhost:8080/api/tasks/6" }, @@ -548,8 +562,10 @@ } ], "size": 25, + "source_storage": null, "status": "annotation", "subset": "", + "target_storage": null, "updated_date": "2022-02-21T10:40:21.257000Z", "url": "http://localhost:8080/api/tasks/5" }, @@ -620,8 +636,10 @@ } ], "size": 23, + "source_storage": null, "status": "annotation", "subset": "", + "target_storage": null, "updated_date": "2021-12-22T07:14:15.234000Z", "url": "http://localhost:8080/api/tasks/2" } diff --git a/tests/rest_api/assets/users.json b/tests/rest_api/assets/users.json index 9fda64e0..d14cf1b0 100644 --- a/tests/rest_api/assets/users.json +++ b/tests/rest_api/assets/users.json @@ -310,7 +310,7 @@ "is_active": true, "is_staff": true, "is_superuser": true, - "last_login": "2022-06-08T08:32:30.152708Z", + "last_login": "2022-06-22T09:20:25.189000Z", "last_name": "First", "url": "http://localhost:8080/api/users/1", "username": "admin1" diff --git a/tests/rest_api/test_cloud_storages.py b/tests/rest_api/test_cloud_storages.py index 1069c8e6..147959ab 100644 --- a/tests/rest_api/test_cloud_storages.py +++ b/tests/rest_api/test_cloud_storages.py @@ -180,7 +180,7 @@ class TestPatchCloudStorage: ('maintainer', False, True), ('supervisor', False, False), ]) - def test_org_user_update_coud_storage(self, org_id, storage_id, role, is_owner, is_allow, find_users, cloud_storages): + def test_org_user_update_cloud_storage(self, org_id, storage_id, role, is_owner, is_allow, find_users, cloud_storages): cloud_storage = cloud_storages[storage_id] username = cloud_storage['owner']['username'] if is_owner else \ next((u for u in find_users(role=role, org=org_id) if u['id'] != cloud_storage['owner']['id']))['username'] diff --git a/tests/rest_api/test_jobs.py b/tests/rest_api/test_jobs.py index c336f3e1..e807aef2 100644 --- a/tests/rest_api/test_jobs.py +++ b/tests/rest_api/test_jobs.py @@ -44,7 +44,7 @@ class TestGetJobs: response = get_method(user, f'jobs/{jid}', **kwargs) assert response.status_code == HTTPStatus.OK - assert DeepDiff(data, response.json()) == {} + assert DeepDiff(data, response.json(), exclude_paths="root['updated_date']") == {} def _test_get_job_403(self, user, jid, **kwargs): response = get_method(user, f'jobs/{jid}', **kwargs) @@ -83,7 +83,7 @@ class TestListJobs: response = get_method(user, 'jobs', **kwargs, page_size='all') assert response.status_code == HTTPStatus.OK - assert DeepDiff(data, response.json()['results']) == {} + assert DeepDiff(data, response.json()['results'], exclude_paths="root['updated_date']") == {} def _test_list_jobs_403(self, user, **kwargs): response = get_method(user, 'jobs', **kwargs) @@ -123,7 +123,7 @@ class TestGetAnnotations: assert response.status_code == HTTPStatus.OK assert DeepDiff(data, response_data, - exclude_paths="root['version']") == {} + exclude_regex_paths=r"root\['version|updated_date'\]") == {} def _test_get_job_annotations_403(self, user, jid, **kwargs): response = get_method(user, f'jobs/{jid}/annotations', **kwargs) @@ -193,7 +193,7 @@ class TestPatchJobAnnotations: if is_allow: assert response.status_code == HTTPStatus.OK assert DeepDiff(data, response.json(), - exclude_paths="root['version']") == {} + exclude_regex_paths=r"root\['version|updated_date'\]") == {} else: assert response.status_code == HTTPStatus.FORBIDDEN @@ -313,6 +313,7 @@ class TestPatchJob: if is_allow: assert response.status_code == HTTPStatus.OK - assert DeepDiff(expected_data(jid, assignee), response.json()) == {} + assert DeepDiff(expected_data(jid, assignee), response.json(), + exclude_paths="root['updated_date']") == {} else: assert response.status_code == HTTPStatus.FORBIDDEN