diff --git a/cvat/apps/dataset_manager/task.py b/cvat/apps/dataset_manager/task.py index 80aa529d..d1d6af67 100644 --- a/cvat/apps/dataset_manager/task.py +++ b/cvat/apps/dataset_manager/task.py @@ -1,9 +1,11 @@ + # Copyright (C) 2019-2022 Intel Corporation # Copyright (C) 2022 CVAT.ai Corporation # # SPDX-License-Identifier: MIT from collections import OrderedDict +from enum import Enum from django.db import transaction from django.db.models.query import Prefetch @@ -12,7 +14,6 @@ from django.utils import timezone from cvat.apps.engine import models, serializers from cvat.apps.engine.plugins import plugin_decorator from cvat.apps.profiler import silk_profile -from cvat.apps.engine.utils import DjangoEnum, StrEnum from .annotation import AnnotationIR, AnnotationManager from .bindings import TaskData, JobData @@ -28,11 +29,18 @@ class dotdict(OrderedDict): __eq__ = lambda self, other: self.id == other.id __hash__ = lambda self: self.id -class PatchAction(DjangoEnum, StrEnum): +class PatchAction(str, Enum): CREATE = "create" UPDATE = "update" DELETE = "delete" + @classmethod + def values(cls): + return [item.value for item in cls] + + def __str__(self): + return self.value + def _merge_table_rows(rows, keys_for_merge, field_id): # It is necessary to keep a stable order of original rows # (e.g. for tracked boxes). Otherwise prev_box.frame can be bigger diff --git a/cvat/apps/dataset_repo/models.py b/cvat/apps/dataset_repo/models.py index 62fbb292..b57ed264 100644 --- a/cvat/apps/dataset_repo/models.py +++ b/cvat/apps/dataset_repo/models.py @@ -1,18 +1,21 @@ # Copyright (C) 2018-2022 Intel Corporation -# Copyright (C) 2022 CVAT.ai Corporation # # SPDX-License-Identifier: MIT from django.db import models from cvat.apps.engine.models import Task -from cvat.apps.engine.utils import DjangoEnum, StrEnum +from enum import Enum -class GitStatusChoice(DjangoEnum, StrEnum): +class GitStatusChoice(Enum): NON_SYNCED = '!sync' SYNCING = 'syncing' SYNCED = 'sync' MERGED = 'merged' + def __str__(self): + return self.value + + class GitData(models.Model): task = models.OneToOneField(Task, on_delete = models.CASCADE, primary_key = True) url = models.URLField(max_length = 2000) diff --git a/cvat/apps/engine/backup.py b/cvat/apps/engine/backup.py index 7a22a5e1..b7661d8c 100644 --- a/cvat/apps/engine/backup.py +++ b/cvat/apps/engine/backup.py @@ -3,9 +3,9 @@ # # SPDX-License-Identifier: MIT -from enum import Enum import io import os +from enum import Enum import re import shutil import tempfile diff --git a/cvat/apps/engine/cloud_provider.py b/cvat/apps/engine/cloud_provider.py index 180d2b2b..7135225c 100644 --- a/cvat/apps/engine/cloud_provider.py +++ b/cvat/apps/engine/cloud_provider.py @@ -1,5 +1,4 @@ # Copyright (C) 2021-2022 Intel Corporation -# Copyright (C) 2022 CVAT.ai Corporation # # SPDX-License-Identifier: MIT @@ -26,18 +25,29 @@ from google.cloud.exceptions import NotFound as GoogleCloudNotFound, Forbidden a from cvat.apps.engine.log import slogger from cvat.apps.engine.models import CredentialsTypeChoice, CloudProviderChoice -from cvat.apps.engine.utils import DjangoEnum, StrEnum -class Status(DjangoEnum, StrEnum): +class Status(str, Enum): AVAILABLE = 'AVAILABLE' NOT_FOUND = 'NOT_FOUND' FORBIDDEN = 'FORBIDDEN' -class Permissions(DjangoEnum, StrEnum): + @classmethod + def choices(cls): + return tuple((x.value, x.name) for x in cls) + + def __str__(self): + return self.value + +class Permissions(str, Enum): READ = 'read' WRITE = 'write' + @classmethod + def all(cls): + return {i.value for i in cls} + class _CloudStorage(ABC): + def __init__(self): self._files = [] @@ -303,7 +313,7 @@ class AWS_S3(_CloudStorage): bucket_policy = self._bucket.Policy().policy except ClientError as ex: if 'NoSuchBucketPolicy' in str(ex): - return Permissions.values() + return Permissions.all() else: raise Exception(str(ex)) bucket_policy = json.loads(bucket_policy) if isinstance(bucket_policy, str) else bucket_policy @@ -316,7 +326,7 @@ class AWS_S3(_CloudStorage): 's3:GetObject': Permissions.READ, 's3:PutObject': Permissions.WRITE, } - allowed_actions = set(Permissions.values()) & {access.get(i) for i in allowed_actions} + allowed_actions = Permissions.all() & {access.get(i) for i in allowed_actions} return allowed_actions diff --git a/cvat/apps/engine/location.py b/cvat/apps/engine/location.py index 44ed1e8d..fd3fadf0 100644 --- a/cvat/apps/engine/location.py +++ b/cvat/apps/engine/location.py @@ -1,19 +1,15 @@ -# Copyright (C) 2022 CVAT.ai Corporation -# -# SPDX-License-Identifier: MIT - -from typing import Optional +from enum import Enum from cvat.apps.engine.models import Location -from cvat.apps.engine.utils import StrEnum -class StorageType(StrEnum): +class StorageType(str, Enum): TARGET = 'target_storage' SOURCE = 'source_storage' -def get_location_configuration( - obj, *, field_name: Optional[StorageType], use_settings: bool = False -): + def __str__(self): + return self.value + +def get_location_configuration(obj, field_name, use_settings=False): location_conf = dict() if use_settings: storage = getattr(obj, field_name) diff --git a/cvat/apps/engine/models.py b/cvat/apps/engine/models.py index c1c0e88e..d12f5200 100644 --- a/cvat/apps/engine/models.py +++ b/cvat/apps/engine/models.py @@ -6,6 +6,7 @@ import os import re import shutil +from enum import Enum from typing import Optional from django.conf import settings @@ -16,7 +17,7 @@ from django.db.models.fields import FloatField from django.core.serializers.json import DjangoJSONEncoder from drf_spectacular.types import OpenApiTypes from drf_spectacular.utils import extend_schema_field -from cvat.apps.engine.utils import DjangoEnum, parse_specific_attributes, StrEnum +from cvat.apps.engine.utils import parse_specific_attributes from cvat.apps.organizations.models import Organization class SafeCharField(models.CharField): @@ -27,18 +28,36 @@ class SafeCharField(models.CharField): return value -class DimensionType(DjangoEnum, StrEnum): +class DimensionType(str, Enum): DIM_3D = '3d' DIM_2D = '2d' -class StatusChoice(DjangoEnum, StrEnum): + @classmethod + def choices(cls): + return tuple((x.value, x.name) for x in cls) + + def __str__(self): + return self.value + +class StatusChoice(str, Enum): """Deprecated. Use StageChoice and StateChoice instead""" ANNOTATION = 'annotation' VALIDATION = 'validation' COMPLETED = 'completed' -class LabelType(DjangoEnum, StrEnum): + @classmethod + def choices(cls): + return tuple((x.value, x.name) for x in cls) + + @classmethod + def list(cls): + return list(map(lambda x: x.value, cls)) + + def __str__(self): + return self.value + +class LabelType(str, Enum): BBOX = 'bbox' ELLIPSE = 'ellipse' POLYGON = 'polygon' @@ -50,37 +69,90 @@ class LabelType(DjangoEnum, StrEnum): TAG = 'tag' ANY = 'any' -class StageChoice(DjangoEnum, StrEnum): + @classmethod + def choices(cls): + return tuple((x.value, x.name) for x in cls) + + @classmethod + def list(cls): + return list(map(lambda x: x.value, cls)) + + def __str__(self): + return self.value + +class StageChoice(str, Enum): ANNOTATION = 'annotation' VALIDATION = 'validation' ACCEPTANCE = 'acceptance' -class StateChoice(DjangoEnum, StrEnum): + @classmethod + def choices(cls): + return tuple((x.value, x.name) for x in cls) + + def __str__(self): + return self.value + +class StateChoice(str, Enum): NEW = 'new' IN_PROGRESS = 'in progress' COMPLETED = 'completed' REJECTED = 'rejected' -class DataChoice(DjangoEnum, StrEnum): + @classmethod + def choices(cls): + return tuple((x.value, x.name) for x in cls) + + def __str__(self): + return self.value + +class DataChoice(str, Enum): VIDEO = 'video' IMAGESET = 'imageset' LIST = 'list' -class StorageMethodChoice(DjangoEnum, StrEnum): + @classmethod + def choices(cls): + return tuple((x.value, x.name) for x in cls) + + def __str__(self): + return self.value + +class StorageMethodChoice(str, Enum): CACHE = 'cache' FILE_SYSTEM = 'file_system' -class StorageChoice(DjangoEnum, StrEnum): + @classmethod + def choices(cls): + return tuple((x.value, x.name) for x in cls) + + def __str__(self): + return self.value + +class StorageChoice(str, Enum): CLOUD_STORAGE = 'cloud_storage' LOCAL = 'local' SHARE = 'share' -class SortingMethod(DjangoEnum, StrEnum): + @classmethod + def choices(cls): + return tuple((x.value, x.name) for x in cls) + + def __str__(self): + return self.value + +class SortingMethod(str, Enum): LEXICOGRAPHICAL = 'lexicographical' NATURAL = 'natural' PREDEFINED = 'predefined' RANDOM = 'random' + @classmethod + def choices(cls): + return tuple((x.value, x.name) for x in cls) + + def __str__(self): + return self.value + class AbstractArrayField(models.TextField): separator = "," converter = lambda x: x @@ -449,13 +521,20 @@ class Skeleton(models.Model): default_permissions = () unique_together = ('root',) -class AttributeType(DjangoEnum, StrEnum): +class AttributeType(str, Enum): CHECKBOX = 'checkbox' RADIO = 'radio' NUMBER = 'number' TEXT = 'text' SELECT = 'select' + @classmethod + def choices(cls): + return tuple((x.value, x.name) for x in cls) + + def __str__(self): + return self.value + class AttributeSpec(models.Model): label = models.ForeignKey(Label, on_delete=models.CASCADE) name = models.CharField(max_length=64) @@ -482,7 +561,7 @@ class AttributeVal(models.Model): abstract = True default_permissions = () -class ShapeType(DjangoEnum, StrEnum): +class ShapeType(str, Enum): RECTANGLE = 'rectangle' # (x0, y0, x1, y1) POLYGON = 'polygon' # (x0, y0, ..., xn, yn) POLYLINE = 'polyline' # (x0, y0, ..., xn, yn) @@ -492,10 +571,24 @@ class ShapeType(DjangoEnum, StrEnum): MASK = 'mask' # (rle mask, left, top, right, bottom) SKELETON = 'skeleton' -class SourceType(DjangoEnum, StrEnum): + @classmethod + def choices(cls): + return tuple((x.value, x.name) for x in cls) + + def __str__(self): + return self.value + +class SourceType(str, Enum): AUTO = 'auto' MANUAL = 'manual' + @classmethod + def choices(cls): + return tuple((x.value, x.name) for x in cls) + + def __str__(self): + return self.value + class Annotation(models.Model): id = models.BigAutoField(primary_key=True) job = models.ForeignKey(Job, on_delete=models.CASCADE) @@ -605,19 +698,41 @@ class Comment(models.Model): def get_organization_id(self): return self.issue.get_organization_id() -class CloudProviderChoice(DjangoEnum, StrEnum): +class CloudProviderChoice(str, Enum): AWS_S3 = 'AWS_S3_BUCKET' AZURE_CONTAINER = 'AZURE_CONTAINER' GOOGLE_DRIVE = 'GOOGLE_DRIVE' GOOGLE_CLOUD_STORAGE = 'GOOGLE_CLOUD_STORAGE' -class CredentialsTypeChoice(DjangoEnum, StrEnum): + @classmethod + def choices(cls): + return tuple((x.value, x.name) for x in cls) + + @classmethod + def list(cls): + return list(map(lambda x: x.value, cls)) + + def __str__(self): + return self.value + +class CredentialsTypeChoice(str, Enum): # ignore bandit issues because false positives KEY_SECRET_KEY_PAIR = 'KEY_SECRET_KEY_PAIR' # nosec ACCOUNT_NAME_TOKEN_PAIR = 'ACCOUNT_NAME_TOKEN_PAIR' # nosec KEY_FILE_PATH = 'KEY_FILE_PATH' ANONYMOUS_ACCESS = 'ANONYMOUS_ACCESS' + @classmethod + def choices(cls): + return tuple((x.value, x.name) for x in cls) + + @classmethod + def list(cls): + return list(map(lambda x: x.value, cls)) + + def __str__(self): + return self.value + class Manifest(models.Model): filename = models.CharField(max_length=1024, default='manifest.jsonl') cloud_storage = models.ForeignKey('CloudStorage', on_delete=models.CASCADE, null=True, related_name='manifests') @@ -625,10 +740,21 @@ class Manifest(models.Model): def __str__(self): return '{}'.format(self.filename) -class Location(DjangoEnum, StrEnum): +class Location(str, Enum): CLOUD_STORAGE = 'cloud_storage' LOCAL = 'local' + @classmethod + def choices(cls): + return tuple((x.value, x.name) for x in cls) + + def __str__(self): + return self.value + + @classmethod + def list(cls): + return [i.value for i in cls] + class CloudStorage(models.Model): # restrictions: # AWS bucket name, Azure container name - 63, Google bucket name - 63 without dots and 222 with dots diff --git a/cvat/apps/engine/tests/test_rest_api.py b/cvat/apps/engine/tests/test_rest_api.py index ac826cd4..a503c811 100644 --- a/cvat/apps/engine/tests/test_rest_api.py +++ b/cvat/apps/engine/tests/test_rest_api.py @@ -1,5 +1,4 @@ # Copyright (C) 2020-2022 Intel Corporation -# Copyright (C) 2022 CVAT.ai Corporation # # SPDX-License-Identifier: MIT @@ -13,6 +12,7 @@ import tempfile import xml.etree.ElementTree as ET import zipfile from collections import defaultdict +from enum import Enum from glob import glob from io import BytesIO from unittest import mock @@ -28,7 +28,6 @@ from PIL import Image from pycocotools import coco as coco_loader from rest_framework import status from rest_framework.test import APIClient, APITestCase -from cvat.apps.engine.utils import StrEnum from datumaro.util.test_utils import TestDir from cvat.apps.engine.models import (AttributeSpec, AttributeType, Data, Job, @@ -3098,10 +3097,13 @@ def generate_manifest_file(data_type, manifest_path, sources): class TaskDataAPITestCase(APITestCase): _image_sizes = {} - class ChunkType(StrEnum): + class ChunkType(str, Enum): IMAGESET = 'imageset' VIDEO = 'video' + def __str__(self): + return self.value + def setUp(self): self.client = APIClient() diff --git a/cvat/apps/engine/utils.py b/cvat/apps/engine/utils.py index c1749309..ead30d65 100644 --- a/cvat/apps/engine/utils.py +++ b/cvat/apps/engine/utils.py @@ -1,12 +1,10 @@ # Copyright (C) 2020-2022 Intel Corporation -# Copyright (C) 2022 CVAT.ai Corporation # # SPDX-License-Identifier: MIT import ast import cv2 as cv from collections import namedtuple -from enum import Enum, unique import hashlib import importlib import sys @@ -135,54 +133,3 @@ def configure_dependent_job(queue, rq_id, rq_func, db_storage, filename, key): job_id=rq_job_id_download_file ) return rq_job_download_file - -class StrEnum(str, Enum): - """ - An enum with string items. - """ - - def __str__(self): - """ - Implements item conversion to string so that there is only item value in the result. - """ - return self.value - -@unique -class DjangoEnum(Enum): - """ - A enum that implements operations required by DRF. The main application is in Django models, - serializers, and OpenAPI schema where choices are supposed. Items are expected to be unique - and have human-readable names. - - This class must be inherited this way: - class (DjangoEnum, [mixin_type, ...] [data_type,] enum_type): - ... - - Example: - class MyEnum(DjangoEnum, StrEnum): - foo = "foo" - bar = "bar" - - Django docs: https://docs.djangoproject.com/en/3.0/ref/models/fields/#choices - """ - # https://stackoverflow.com/a/58051918 - - @classmethod - def choices(cls): - """ - Returns a tuple of (value, label) pairs. - """ - return tuple((v, k) for k, v in cls.__members__.items()) - - @classmethod - def values(cls): - return tuple(cls.__members__.values()) - - @classmethod - def names(cls): - return tuple(cls.__members__.keys()) - - @classmethod - def labels(cls): - # This class supposes labels to be equal to names - return cls.names() diff --git a/cvat/apps/engine/views.py b/cvat/apps/engine/views.py index baf7654f..61be2a36 100644 --- a/cvat/apps/engine/views.py +++ b/cvat/apps/engine/views.py @@ -70,7 +70,7 @@ from cvat.apps.engine.serializers import ( ProjectFileSerializer, TaskFileSerializer) from utils.dataset_manifest import ImageManifestManager -from cvat.apps.engine.utils import DjangoEnum, StrEnum, av_scan_paths, process_failed_job, configure_dependent_job +from cvat.apps.engine.utils import av_scan_paths, process_failed_job, configure_dependent_job from cvat.apps.engine import backup from cvat.apps.engine.mixins import PartialUpdateModelMixin, UploadMixin, AnnotationMixin, SerializeMixin, DestroyModelMixin, CreateModelMixin from cvat.apps.engine.location import get_location_configuration, StorageType @@ -343,7 +343,7 @@ class ProjectViewSet(viewsets.GenericViewSet, mixins.ListModelMixin, location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False, enum=['download', 'import_status']), OpenApiParameter('location', description='Where need to save downloaded dataset', location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False, - enum=Location.values()), + enum=Location.list()), OpenApiParameter('cloud_storage_id', description='Storage id', location=OpenApiParameter.QUERY, type=OpenApiTypes.NUMBER, required=False), OpenApiParameter('use_default_location', description='Use the location that was configured in project to import dataset', @@ -363,7 +363,7 @@ class ProjectViewSet(viewsets.GenericViewSet, mixins.ListModelMixin, location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False), OpenApiParameter('location', description='Where to import the dataset from', location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False, - enum=Location.values()), + enum=Location.list()), OpenApiParameter('cloud_storage_id', description='Storage id', location=OpenApiParameter.QUERY, type=OpenApiTypes.NUMBER, required=False), OpenApiParameter('use_default_location', description='Use the location that was configured in the project to import annotations', @@ -495,7 +495,7 @@ class ProjectViewSet(viewsets.GenericViewSet, mixins.ListModelMixin, location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False, enum=['download']), OpenApiParameter('location', description='Where need to save downloaded dataset', location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False, - enum=Location.values()), + enum=Location.list()), OpenApiParameter('cloud_storage_id', description='Storage id', location=OpenApiParameter.QUERY, type=OpenApiTypes.NUMBER, required=False), OpenApiParameter('use_default_location', description='Use the location that was configured in project to export annotation', @@ -535,7 +535,7 @@ class ProjectViewSet(viewsets.GenericViewSet, mixins.ListModelMixin, location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False), OpenApiParameter('location', description='Where need to save downloaded backup', location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False, - enum=Location.values()), + enum=Location.list()), OpenApiParameter('cloud_storage_id', description='Storage id', location=OpenApiParameter.QUERY, type=OpenApiTypes.NUMBER, required=False), OpenApiParameter('use_default_location', description='Use the location that was configured in project to export backup', @@ -555,7 +555,7 @@ class ProjectViewSet(viewsets.GenericViewSet, mixins.ListModelMixin, parameters=[ OpenApiParameter('location', description='Where to import the backup file from', location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False, - enum=Location.values(), default=Location.LOCAL), + enum=Location.list(), default=Location.LOCAL), OpenApiParameter('cloud_storage_id', description='Storage id', location=OpenApiParameter.QUERY, type=OpenApiTypes.NUMBER, required=False), OpenApiParameter('filename', description='Backup file name', @@ -607,24 +607,18 @@ class ProjectViewSet(viewsets.GenericViewSet, mixins.ListModelMixin, return response -class DataChunkGetter: - class DataTypeChoices(DjangoEnum, StrEnum): - CHUNK = 'chunk' - FRAME = 'frame' - PREVIEW = 'preview' - CONTEXT_IMAGE = 'context_image' - - class ImageQualityChoices(DjangoEnum, StrEnum): - COMPRESSED = 'compressed' - ORIGINAL = 'original' +class DataChunkGetter: def __init__(self, data_type, data_num, data_quality, task_dim): - if data_type not in self.DataTypeChoices.values(): + possible_data_type_values = ('chunk', 'frame', 'preview', 'context_image') + possible_quality_values = ('compressed', 'original') + + if not data_type or data_type not in possible_data_type_values: raise ValidationError('Data type not specified or has wrong value') elif data_type == 'chunk' or data_type == 'frame': if not data_num: raise ValidationError('Number is not specified') - elif data_quality not in self.ImageQualityChoices.values(): + elif data_quality not in possible_quality_values: raise ValidationError('Wrong quality value') self.type = data_type @@ -752,7 +746,7 @@ class TaskViewSet(viewsets.GenericViewSet, mixins.ListModelMixin, parameters=[ OpenApiParameter('location', description='Where to import the backup file from', location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False, - enum=Location.values(), default=Location.LOCAL), + enum=Location.list(), default=Location.LOCAL), OpenApiParameter('cloud_storage_id', description='Storage id', location=OpenApiParameter.QUERY, type=OpenApiTypes.NUMBER, required=False), OpenApiParameter('filename', description='Backup file name', @@ -790,7 +784,7 @@ class TaskViewSet(viewsets.GenericViewSet, mixins.ListModelMixin, location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False), OpenApiParameter('location', description='Where need to save downloaded backup', location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False, - enum=Location.values()), + enum=Location.list()), OpenApiParameter('cloud_storage_id', description='Storage id', location=OpenApiParameter.QUERY, type=OpenApiTypes.NUMBER, required=False), OpenApiParameter('use_default_location', description='Use the location that was configured in the task to export backup', @@ -947,10 +941,10 @@ class TaskViewSet(viewsets.GenericViewSet, mixins.ListModelMixin, @extend_schema(methods=['GET'], summary='Method returns data for a specific task', parameters=[ OpenApiParameter('type', location=OpenApiParameter.QUERY, required=False, - type=OpenApiTypes.STR, enum=DataChunkGetter.DataTypeChoices.values(), + type=OpenApiTypes.STR, enum=['chunk', 'frame', 'preview', 'context_image'], description='Specifies the type of the requested data'), OpenApiParameter('quality', location=OpenApiParameter.QUERY, required=False, - type=OpenApiTypes.STR, enum=DataChunkGetter.ImageQualityChoices.values(), + type=OpenApiTypes.STR, enum=['compressed', 'original'], description="Specifies the quality level of the requested data, doesn't matter for 'preview' type"), OpenApiParameter('number', location=OpenApiParameter.QUERY, required=False, type=OpenApiTypes.INT, description="A unique number value identifying chunk or frame, doesn't matter for 'preview' type"), @@ -1010,7 +1004,7 @@ class TaskViewSet(viewsets.GenericViewSet, mixins.ListModelMixin, type=OpenApiTypes.STR, required=False, enum=['download']), OpenApiParameter('location', description='Where need to save downloaded dataset', location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False, - enum=Location.values()), + enum=Location.list()), OpenApiParameter('cloud_storage_id', description='Storage id', location=OpenApiParameter.QUERY, type=OpenApiTypes.NUMBER, required=False), OpenApiParameter('use_default_location', description='Use the location that was configured in the task to export annotation', @@ -1048,7 +1042,7 @@ class TaskViewSet(viewsets.GenericViewSet, mixins.ListModelMixin, description='Input format name\nYou can get the list of supported formats at:\n/server/annotation/formats'), OpenApiParameter('location', description='where to import the annotation from', location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False, - enum=Location.values()), + enum=Location.list()), OpenApiParameter('cloud_storage_id', description='Storage id', location=OpenApiParameter.QUERY, type=OpenApiTypes.NUMBER, required=False), OpenApiParameter('use_default_location', description='Use the location that was configured in task to import annotations', @@ -1246,7 +1240,7 @@ class TaskViewSet(viewsets.GenericViewSet, mixins.ListModelMixin, default=True), OpenApiParameter('location', description='Where need to save downloaded dataset', location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False, - enum=Location.values()), + enum=Location.list()), OpenApiParameter('cloud_storage_id', description='Storage id', location=OpenApiParameter.QUERY, type=OpenApiTypes.NUMBER, required=False), ], @@ -1365,7 +1359,7 @@ class JobViewSet(viewsets.GenericViewSet, mixins.ListModelMixin, type=OpenApiTypes.STR, required=False, enum=['download']), OpenApiParameter('location', description='Where need to save downloaded annotation', location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False, - enum=Location.values()), + enum=Location.list()), OpenApiParameter('cloud_storage_id', description='Storage id', location=OpenApiParameter.QUERY, type=OpenApiTypes.NUMBER, required=False), OpenApiParameter('use_default_location', description='Use the location that was configured in the task to export annotation', @@ -1388,7 +1382,7 @@ class JobViewSet(viewsets.GenericViewSet, mixins.ListModelMixin, description='Input format name\nYou can get the list of supported formats at:\n/server/annotation/formats'), OpenApiParameter('location', description='where to import the annotation from', location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False, - enum=Location.values()), + enum=Location.list()), OpenApiParameter('cloud_storage_id', description='Storage id', location=OpenApiParameter.QUERY, type=OpenApiTypes.NUMBER, required=False), OpenApiParameter('use_default_location', description='Use the location that was configured in the task to import annotation', @@ -1529,7 +1523,7 @@ class JobViewSet(viewsets.GenericViewSet, mixins.ListModelMixin, default=True), OpenApiParameter('location', description='Where need to save downloaded dataset', location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False, - enum=Location.values()), + enum=Location.list()), OpenApiParameter('cloud_storage_id', description='Storage id', location=OpenApiParameter.QUERY, type=OpenApiTypes.NUMBER, required=False), ], @@ -1964,7 +1958,7 @@ class CloudStorageViewSet(viewsets.GenericViewSet, mixins.ListModelMixin, provider_type = self.request.query_params.get('provider_type', None) if provider_type: - if provider_type in CloudProviderChoice.values(): + if provider_type in CloudProviderChoice.list(): return queryset.filter(provider_type=provider_type) raise ValidationError('Unsupported type of cloud provider') return queryset diff --git a/cvat/apps/lambda_manager/views.py b/cvat/apps/lambda_manager/views.py index 898c9f4f..9a920931 100644 --- a/cvat/apps/lambda_manager/views.py +++ b/cvat/apps/lambda_manager/views.py @@ -6,6 +6,7 @@ import base64 import json from functools import wraps +from enum import Enum from copy import deepcopy import django_rq @@ -27,15 +28,16 @@ from cvat.apps.engine.models import ShapeType, SourceType from drf_spectacular.utils import extend_schema, extend_schema_view, OpenApiResponse, OpenApiParameter from drf_spectacular.types import OpenApiTypes -from cvat.apps.engine.utils import StrEnum - -class LambdaType(StrEnum): +class LambdaType(Enum): DETECTOR = "detector" INTERACTOR = "interactor" REID = "reid" TRACKER = "tracker" UNKNOWN = "unknown" + def __str__(self): + return self.value + class LambdaGateway: NUCLIO_ROOT_URL = '/api/functions' diff --git a/cvat/apps/webhooks/models.py b/cvat/apps/webhooks/models.py index de0477c8..44885ffd 100644 --- a/cvat/apps/webhooks/models.py +++ b/cvat/apps/webhooks/models.py @@ -2,21 +2,38 @@ # # SPDX-License-Identifier: MIT +from enum import Enum + from django.contrib.auth.models import User from django.db import models from cvat.apps.engine.models import Project -from cvat.apps.engine.utils import DjangoEnum, StrEnum from cvat.apps.organizations.models import Organization -class WebhookTypeChoice(DjangoEnum, StrEnum): +class WebhookTypeChoice(str, Enum): ORGANIZATION = "organization" PROJECT = "project" -class WebhookContentTypeChoice(DjangoEnum, StrEnum): + @classmethod + def choices(cls): + return tuple((x.value, x.name) for x in cls) + + def __str__(self): + return self.value + + +class WebhookContentTypeChoice(str, Enum): JSON = "application/json" + @classmethod + def choices(cls): + return tuple((x.value, x.name) for x in cls) + + def __str__(self): + return self.value + + class Webhook(models.Model): target_url = models.URLField() description = models.CharField(max_length=128, default="", blank=True)