diff --git a/cvat/apps/dataset_manager/task.py b/cvat/apps/dataset_manager/task.py index d1d6af67..80aa529d 100644 --- a/cvat/apps/dataset_manager/task.py +++ b/cvat/apps/dataset_manager/task.py @@ -1,11 +1,9 @@ - # Copyright (C) 2019-2022 Intel Corporation # Copyright (C) 2022 CVAT.ai Corporation # # SPDX-License-Identifier: MIT from collections import OrderedDict -from enum import Enum from django.db import transaction from django.db.models.query import Prefetch @@ -14,6 +12,7 @@ from django.utils import timezone from cvat.apps.engine import models, serializers from cvat.apps.engine.plugins import plugin_decorator from cvat.apps.profiler import silk_profile +from cvat.apps.engine.utils import DjangoEnum, StrEnum from .annotation import AnnotationIR, AnnotationManager from .bindings import TaskData, JobData @@ -29,18 +28,11 @@ class dotdict(OrderedDict): __eq__ = lambda self, other: self.id == other.id __hash__ = lambda self: self.id -class PatchAction(str, Enum): +class PatchAction(DjangoEnum, StrEnum): CREATE = "create" UPDATE = "update" DELETE = "delete" - @classmethod - def values(cls): - return [item.value for item in cls] - - def __str__(self): - return self.value - def _merge_table_rows(rows, keys_for_merge, field_id): # It is necessary to keep a stable order of original rows # (e.g. for tracked boxes). Otherwise prev_box.frame can be bigger diff --git a/cvat/apps/dataset_repo/models.py b/cvat/apps/dataset_repo/models.py index b57ed264..62fbb292 100644 --- a/cvat/apps/dataset_repo/models.py +++ b/cvat/apps/dataset_repo/models.py @@ -1,21 +1,18 @@ # Copyright (C) 2018-2022 Intel Corporation +# Copyright (C) 2022 CVAT.ai Corporation # # SPDX-License-Identifier: MIT from django.db import models from cvat.apps.engine.models import Task -from enum import Enum +from cvat.apps.engine.utils import DjangoEnum, StrEnum -class GitStatusChoice(Enum): +class GitStatusChoice(DjangoEnum, StrEnum): NON_SYNCED = '!sync' SYNCING = 'syncing' SYNCED = 'sync' MERGED = 'merged' - def __str__(self): - return self.value - - class GitData(models.Model): task = models.OneToOneField(Task, on_delete = models.CASCADE, primary_key = True) url = models.URLField(max_length = 2000) diff --git a/cvat/apps/engine/backup.py b/cvat/apps/engine/backup.py index b7661d8c..7a22a5e1 100644 --- a/cvat/apps/engine/backup.py +++ b/cvat/apps/engine/backup.py @@ -3,9 +3,9 @@ # # SPDX-License-Identifier: MIT +from enum import Enum import io import os -from enum import Enum import re import shutil import tempfile diff --git a/cvat/apps/engine/cloud_provider.py b/cvat/apps/engine/cloud_provider.py index 7135225c..180d2b2b 100644 --- a/cvat/apps/engine/cloud_provider.py +++ b/cvat/apps/engine/cloud_provider.py @@ -1,4 +1,5 @@ # Copyright (C) 2021-2022 Intel Corporation +# Copyright (C) 2022 CVAT.ai Corporation # # SPDX-License-Identifier: MIT @@ -25,29 +26,18 @@ from google.cloud.exceptions import NotFound as GoogleCloudNotFound, Forbidden a from cvat.apps.engine.log import slogger from cvat.apps.engine.models import CredentialsTypeChoice, CloudProviderChoice +from cvat.apps.engine.utils import DjangoEnum, StrEnum -class Status(str, Enum): +class Status(DjangoEnum, StrEnum): AVAILABLE = 'AVAILABLE' NOT_FOUND = 'NOT_FOUND' FORBIDDEN = 'FORBIDDEN' - @classmethod - def choices(cls): - return tuple((x.value, x.name) for x in cls) - - def __str__(self): - return self.value - -class Permissions(str, Enum): +class Permissions(DjangoEnum, StrEnum): READ = 'read' WRITE = 'write' - @classmethod - def all(cls): - return {i.value for i in cls} - class _CloudStorage(ABC): - def __init__(self): self._files = [] @@ -313,7 +303,7 @@ class AWS_S3(_CloudStorage): bucket_policy = self._bucket.Policy().policy except ClientError as ex: if 'NoSuchBucketPolicy' in str(ex): - return Permissions.all() + return Permissions.values() else: raise Exception(str(ex)) bucket_policy = json.loads(bucket_policy) if isinstance(bucket_policy, str) else bucket_policy @@ -326,7 +316,7 @@ class AWS_S3(_CloudStorage): 's3:GetObject': Permissions.READ, 's3:PutObject': Permissions.WRITE, } - allowed_actions = Permissions.all() & {access.get(i) for i in allowed_actions} + allowed_actions = set(Permissions.values()) & {access.get(i) for i in allowed_actions} return allowed_actions diff --git a/cvat/apps/engine/location.py b/cvat/apps/engine/location.py index fd3fadf0..44ed1e8d 100644 --- a/cvat/apps/engine/location.py +++ b/cvat/apps/engine/location.py @@ -1,15 +1,19 @@ -from enum import Enum +# Copyright (C) 2022 CVAT.ai Corporation +# +# SPDX-License-Identifier: MIT + +from typing import Optional from cvat.apps.engine.models import Location +from cvat.apps.engine.utils import StrEnum -class StorageType(str, Enum): +class StorageType(StrEnum): TARGET = 'target_storage' SOURCE = 'source_storage' - def __str__(self): - return self.value - -def get_location_configuration(obj, field_name, use_settings=False): +def get_location_configuration( + obj, *, field_name: Optional[StorageType], use_settings: bool = False +): location_conf = dict() if use_settings: storage = getattr(obj, field_name) diff --git a/cvat/apps/engine/models.py b/cvat/apps/engine/models.py index d12f5200..c1c0e88e 100644 --- a/cvat/apps/engine/models.py +++ b/cvat/apps/engine/models.py @@ -6,7 +6,6 @@ import os import re import shutil -from enum import Enum from typing import Optional from django.conf import settings @@ -17,7 +16,7 @@ from django.db.models.fields import FloatField from django.core.serializers.json import DjangoJSONEncoder from drf_spectacular.types import OpenApiTypes from drf_spectacular.utils import extend_schema_field -from cvat.apps.engine.utils import parse_specific_attributes +from cvat.apps.engine.utils import DjangoEnum, parse_specific_attributes, StrEnum from cvat.apps.organizations.models import Organization class SafeCharField(models.CharField): @@ -28,36 +27,18 @@ class SafeCharField(models.CharField): return value -class DimensionType(str, Enum): +class DimensionType(DjangoEnum, StrEnum): DIM_3D = '3d' DIM_2D = '2d' - @classmethod - def choices(cls): - return tuple((x.value, x.name) for x in cls) - - def __str__(self): - return self.value - -class StatusChoice(str, Enum): +class StatusChoice(DjangoEnum, StrEnum): """Deprecated. Use StageChoice and StateChoice instead""" ANNOTATION = 'annotation' VALIDATION = 'validation' COMPLETED = 'completed' - @classmethod - def choices(cls): - return tuple((x.value, x.name) for x in cls) - - @classmethod - def list(cls): - return list(map(lambda x: x.value, cls)) - - def __str__(self): - return self.value - -class LabelType(str, Enum): +class LabelType(DjangoEnum, StrEnum): BBOX = 'bbox' ELLIPSE = 'ellipse' POLYGON = 'polygon' @@ -69,90 +50,37 @@ class LabelType(str, Enum): TAG = 'tag' ANY = 'any' - @classmethod - def choices(cls): - return tuple((x.value, x.name) for x in cls) - - @classmethod - def list(cls): - return list(map(lambda x: x.value, cls)) - - def __str__(self): - return self.value - -class StageChoice(str, Enum): +class StageChoice(DjangoEnum, StrEnum): ANNOTATION = 'annotation' VALIDATION = 'validation' ACCEPTANCE = 'acceptance' - @classmethod - def choices(cls): - return tuple((x.value, x.name) for x in cls) - - def __str__(self): - return self.value - -class StateChoice(str, Enum): +class StateChoice(DjangoEnum, StrEnum): NEW = 'new' IN_PROGRESS = 'in progress' COMPLETED = 'completed' REJECTED = 'rejected' - @classmethod - def choices(cls): - return tuple((x.value, x.name) for x in cls) - - def __str__(self): - return self.value - -class DataChoice(str, Enum): +class DataChoice(DjangoEnum, StrEnum): VIDEO = 'video' IMAGESET = 'imageset' LIST = 'list' - @classmethod - def choices(cls): - return tuple((x.value, x.name) for x in cls) - - def __str__(self): - return self.value - -class StorageMethodChoice(str, Enum): +class StorageMethodChoice(DjangoEnum, StrEnum): CACHE = 'cache' FILE_SYSTEM = 'file_system' - @classmethod - def choices(cls): - return tuple((x.value, x.name) for x in cls) - - def __str__(self): - return self.value - -class StorageChoice(str, Enum): +class StorageChoice(DjangoEnum, StrEnum): CLOUD_STORAGE = 'cloud_storage' LOCAL = 'local' SHARE = 'share' - @classmethod - def choices(cls): - return tuple((x.value, x.name) for x in cls) - - def __str__(self): - return self.value - -class SortingMethod(str, Enum): +class SortingMethod(DjangoEnum, StrEnum): LEXICOGRAPHICAL = 'lexicographical' NATURAL = 'natural' PREDEFINED = 'predefined' RANDOM = 'random' - @classmethod - def choices(cls): - return tuple((x.value, x.name) for x in cls) - - def __str__(self): - return self.value - class AbstractArrayField(models.TextField): separator = "," converter = lambda x: x @@ -521,20 +449,13 @@ class Skeleton(models.Model): default_permissions = () unique_together = ('root',) -class AttributeType(str, Enum): +class AttributeType(DjangoEnum, StrEnum): CHECKBOX = 'checkbox' RADIO = 'radio' NUMBER = 'number' TEXT = 'text' SELECT = 'select' - @classmethod - def choices(cls): - return tuple((x.value, x.name) for x in cls) - - def __str__(self): - return self.value - class AttributeSpec(models.Model): label = models.ForeignKey(Label, on_delete=models.CASCADE) name = models.CharField(max_length=64) @@ -561,7 +482,7 @@ class AttributeVal(models.Model): abstract = True default_permissions = () -class ShapeType(str, Enum): +class ShapeType(DjangoEnum, StrEnum): RECTANGLE = 'rectangle' # (x0, y0, x1, y1) POLYGON = 'polygon' # (x0, y0, ..., xn, yn) POLYLINE = 'polyline' # (x0, y0, ..., xn, yn) @@ -571,24 +492,10 @@ class ShapeType(str, Enum): MASK = 'mask' # (rle mask, left, top, right, bottom) SKELETON = 'skeleton' - @classmethod - def choices(cls): - return tuple((x.value, x.name) for x in cls) - - def __str__(self): - return self.value - -class SourceType(str, Enum): +class SourceType(DjangoEnum, StrEnum): AUTO = 'auto' MANUAL = 'manual' - @classmethod - def choices(cls): - return tuple((x.value, x.name) for x in cls) - - def __str__(self): - return self.value - class Annotation(models.Model): id = models.BigAutoField(primary_key=True) job = models.ForeignKey(Job, on_delete=models.CASCADE) @@ -698,41 +605,19 @@ class Comment(models.Model): def get_organization_id(self): return self.issue.get_organization_id() -class CloudProviderChoice(str, Enum): +class CloudProviderChoice(DjangoEnum, StrEnum): AWS_S3 = 'AWS_S3_BUCKET' AZURE_CONTAINER = 'AZURE_CONTAINER' GOOGLE_DRIVE = 'GOOGLE_DRIVE' GOOGLE_CLOUD_STORAGE = 'GOOGLE_CLOUD_STORAGE' - @classmethod - def choices(cls): - return tuple((x.value, x.name) for x in cls) - - @classmethod - def list(cls): - return list(map(lambda x: x.value, cls)) - - def __str__(self): - return self.value - -class CredentialsTypeChoice(str, Enum): +class CredentialsTypeChoice(DjangoEnum, StrEnum): # ignore bandit issues because false positives KEY_SECRET_KEY_PAIR = 'KEY_SECRET_KEY_PAIR' # nosec ACCOUNT_NAME_TOKEN_PAIR = 'ACCOUNT_NAME_TOKEN_PAIR' # nosec KEY_FILE_PATH = 'KEY_FILE_PATH' ANONYMOUS_ACCESS = 'ANONYMOUS_ACCESS' - @classmethod - def choices(cls): - return tuple((x.value, x.name) for x in cls) - - @classmethod - def list(cls): - return list(map(lambda x: x.value, cls)) - - def __str__(self): - return self.value - class Manifest(models.Model): filename = models.CharField(max_length=1024, default='manifest.jsonl') cloud_storage = models.ForeignKey('CloudStorage', on_delete=models.CASCADE, null=True, related_name='manifests') @@ -740,21 +625,10 @@ class Manifest(models.Model): def __str__(self): return '{}'.format(self.filename) -class Location(str, Enum): +class Location(DjangoEnum, StrEnum): CLOUD_STORAGE = 'cloud_storage' LOCAL = 'local' - @classmethod - def choices(cls): - return tuple((x.value, x.name) for x in cls) - - def __str__(self): - return self.value - - @classmethod - def list(cls): - return [i.value for i in cls] - class CloudStorage(models.Model): # restrictions: # AWS bucket name, Azure container name - 63, Google bucket name - 63 without dots and 222 with dots diff --git a/cvat/apps/engine/tests/test_rest_api.py b/cvat/apps/engine/tests/test_rest_api.py index a503c811..ac826cd4 100644 --- a/cvat/apps/engine/tests/test_rest_api.py +++ b/cvat/apps/engine/tests/test_rest_api.py @@ -1,4 +1,5 @@ # Copyright (C) 2020-2022 Intel Corporation +# Copyright (C) 2022 CVAT.ai Corporation # # SPDX-License-Identifier: MIT @@ -12,7 +13,6 @@ import tempfile import xml.etree.ElementTree as ET import zipfile from collections import defaultdict -from enum import Enum from glob import glob from io import BytesIO from unittest import mock @@ -28,6 +28,7 @@ from PIL import Image from pycocotools import coco as coco_loader from rest_framework import status from rest_framework.test import APIClient, APITestCase +from cvat.apps.engine.utils import StrEnum from datumaro.util.test_utils import TestDir from cvat.apps.engine.models import (AttributeSpec, AttributeType, Data, Job, @@ -3097,13 +3098,10 @@ def generate_manifest_file(data_type, manifest_path, sources): class TaskDataAPITestCase(APITestCase): _image_sizes = {} - class ChunkType(str, Enum): + class ChunkType(StrEnum): IMAGESET = 'imageset' VIDEO = 'video' - def __str__(self): - return self.value - def setUp(self): self.client = APIClient() diff --git a/cvat/apps/engine/utils.py b/cvat/apps/engine/utils.py index ead30d65..c1749309 100644 --- a/cvat/apps/engine/utils.py +++ b/cvat/apps/engine/utils.py @@ -1,10 +1,12 @@ # Copyright (C) 2020-2022 Intel Corporation +# Copyright (C) 2022 CVAT.ai Corporation # # SPDX-License-Identifier: MIT import ast import cv2 as cv from collections import namedtuple +from enum import Enum, unique import hashlib import importlib import sys @@ -133,3 +135,54 @@ def configure_dependent_job(queue, rq_id, rq_func, db_storage, filename, key): job_id=rq_job_id_download_file ) return rq_job_download_file + +class StrEnum(str, Enum): + """ + An enum with string items. + """ + + def __str__(self): + """ + Implements item conversion to string so that there is only item value in the result. + """ + return self.value + +@unique +class DjangoEnum(Enum): + """ + A enum that implements operations required by DRF. The main application is in Django models, + serializers, and OpenAPI schema where choices are supposed. Items are expected to be unique + and have human-readable names. + + This class must be inherited this way: + class (DjangoEnum, [mixin_type, ...] [data_type,] enum_type): + ... + + Example: + class MyEnum(DjangoEnum, StrEnum): + foo = "foo" + bar = "bar" + + Django docs: https://docs.djangoproject.com/en/3.0/ref/models/fields/#choices + """ + # https://stackoverflow.com/a/58051918 + + @classmethod + def choices(cls): + """ + Returns a tuple of (value, label) pairs. + """ + return tuple((v, k) for k, v in cls.__members__.items()) + + @classmethod + def values(cls): + return tuple(cls.__members__.values()) + + @classmethod + def names(cls): + return tuple(cls.__members__.keys()) + + @classmethod + def labels(cls): + # This class supposes labels to be equal to names + return cls.names() diff --git a/cvat/apps/engine/views.py b/cvat/apps/engine/views.py index 61be2a36..baf7654f 100644 --- a/cvat/apps/engine/views.py +++ b/cvat/apps/engine/views.py @@ -70,7 +70,7 @@ from cvat.apps.engine.serializers import ( ProjectFileSerializer, TaskFileSerializer) from utils.dataset_manifest import ImageManifestManager -from cvat.apps.engine.utils import av_scan_paths, process_failed_job, configure_dependent_job +from cvat.apps.engine.utils import DjangoEnum, StrEnum, av_scan_paths, process_failed_job, configure_dependent_job from cvat.apps.engine import backup from cvat.apps.engine.mixins import PartialUpdateModelMixin, UploadMixin, AnnotationMixin, SerializeMixin, DestroyModelMixin, CreateModelMixin from cvat.apps.engine.location import get_location_configuration, StorageType @@ -343,7 +343,7 @@ class ProjectViewSet(viewsets.GenericViewSet, mixins.ListModelMixin, location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False, enum=['download', 'import_status']), OpenApiParameter('location', description='Where need to save downloaded dataset', location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False, - enum=Location.list()), + enum=Location.values()), OpenApiParameter('cloud_storage_id', description='Storage id', location=OpenApiParameter.QUERY, type=OpenApiTypes.NUMBER, required=False), OpenApiParameter('use_default_location', description='Use the location that was configured in project to import dataset', @@ -363,7 +363,7 @@ class ProjectViewSet(viewsets.GenericViewSet, mixins.ListModelMixin, location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False), OpenApiParameter('location', description='Where to import the dataset from', location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False, - enum=Location.list()), + enum=Location.values()), OpenApiParameter('cloud_storage_id', description='Storage id', location=OpenApiParameter.QUERY, type=OpenApiTypes.NUMBER, required=False), OpenApiParameter('use_default_location', description='Use the location that was configured in the project to import annotations', @@ -495,7 +495,7 @@ class ProjectViewSet(viewsets.GenericViewSet, mixins.ListModelMixin, location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False, enum=['download']), OpenApiParameter('location', description='Where need to save downloaded dataset', location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False, - enum=Location.list()), + enum=Location.values()), OpenApiParameter('cloud_storage_id', description='Storage id', location=OpenApiParameter.QUERY, type=OpenApiTypes.NUMBER, required=False), OpenApiParameter('use_default_location', description='Use the location that was configured in project to export annotation', @@ -535,7 +535,7 @@ class ProjectViewSet(viewsets.GenericViewSet, mixins.ListModelMixin, location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False), OpenApiParameter('location', description='Where need to save downloaded backup', location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False, - enum=Location.list()), + enum=Location.values()), OpenApiParameter('cloud_storage_id', description='Storage id', location=OpenApiParameter.QUERY, type=OpenApiTypes.NUMBER, required=False), OpenApiParameter('use_default_location', description='Use the location that was configured in project to export backup', @@ -555,7 +555,7 @@ class ProjectViewSet(viewsets.GenericViewSet, mixins.ListModelMixin, parameters=[ OpenApiParameter('location', description='Where to import the backup file from', location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False, - enum=Location.list(), default=Location.LOCAL), + enum=Location.values(), default=Location.LOCAL), OpenApiParameter('cloud_storage_id', description='Storage id', location=OpenApiParameter.QUERY, type=OpenApiTypes.NUMBER, required=False), OpenApiParameter('filename', description='Backup file name', @@ -607,18 +607,24 @@ class ProjectViewSet(viewsets.GenericViewSet, mixins.ListModelMixin, return response - class DataChunkGetter: - def __init__(self, data_type, data_num, data_quality, task_dim): - possible_data_type_values = ('chunk', 'frame', 'preview', 'context_image') - possible_quality_values = ('compressed', 'original') + class DataTypeChoices(DjangoEnum, StrEnum): + CHUNK = 'chunk' + FRAME = 'frame' + PREVIEW = 'preview' + CONTEXT_IMAGE = 'context_image' + + class ImageQualityChoices(DjangoEnum, StrEnum): + COMPRESSED = 'compressed' + ORIGINAL = 'original' - if not data_type or data_type not in possible_data_type_values: + def __init__(self, data_type, data_num, data_quality, task_dim): + if data_type not in self.DataTypeChoices.values(): raise ValidationError('Data type not specified or has wrong value') elif data_type == 'chunk' or data_type == 'frame': if not data_num: raise ValidationError('Number is not specified') - elif data_quality not in possible_quality_values: + elif data_quality not in self.ImageQualityChoices.values(): raise ValidationError('Wrong quality value') self.type = data_type @@ -746,7 +752,7 @@ class TaskViewSet(viewsets.GenericViewSet, mixins.ListModelMixin, parameters=[ OpenApiParameter('location', description='Where to import the backup file from', location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False, - enum=Location.list(), default=Location.LOCAL), + enum=Location.values(), default=Location.LOCAL), OpenApiParameter('cloud_storage_id', description='Storage id', location=OpenApiParameter.QUERY, type=OpenApiTypes.NUMBER, required=False), OpenApiParameter('filename', description='Backup file name', @@ -784,7 +790,7 @@ class TaskViewSet(viewsets.GenericViewSet, mixins.ListModelMixin, location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False), OpenApiParameter('location', description='Where need to save downloaded backup', location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False, - enum=Location.list()), + enum=Location.values()), OpenApiParameter('cloud_storage_id', description='Storage id', location=OpenApiParameter.QUERY, type=OpenApiTypes.NUMBER, required=False), OpenApiParameter('use_default_location', description='Use the location that was configured in the task to export backup', @@ -941,10 +947,10 @@ class TaskViewSet(viewsets.GenericViewSet, mixins.ListModelMixin, @extend_schema(methods=['GET'], summary='Method returns data for a specific task', parameters=[ OpenApiParameter('type', location=OpenApiParameter.QUERY, required=False, - type=OpenApiTypes.STR, enum=['chunk', 'frame', 'preview', 'context_image'], + type=OpenApiTypes.STR, enum=DataChunkGetter.DataTypeChoices.values(), description='Specifies the type of the requested data'), OpenApiParameter('quality', location=OpenApiParameter.QUERY, required=False, - type=OpenApiTypes.STR, enum=['compressed', 'original'], + type=OpenApiTypes.STR, enum=DataChunkGetter.ImageQualityChoices.values(), description="Specifies the quality level of the requested data, doesn't matter for 'preview' type"), OpenApiParameter('number', location=OpenApiParameter.QUERY, required=False, type=OpenApiTypes.INT, description="A unique number value identifying chunk or frame, doesn't matter for 'preview' type"), @@ -1004,7 +1010,7 @@ class TaskViewSet(viewsets.GenericViewSet, mixins.ListModelMixin, type=OpenApiTypes.STR, required=False, enum=['download']), OpenApiParameter('location', description='Where need to save downloaded dataset', location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False, - enum=Location.list()), + enum=Location.values()), OpenApiParameter('cloud_storage_id', description='Storage id', location=OpenApiParameter.QUERY, type=OpenApiTypes.NUMBER, required=False), OpenApiParameter('use_default_location', description='Use the location that was configured in the task to export annotation', @@ -1042,7 +1048,7 @@ class TaskViewSet(viewsets.GenericViewSet, mixins.ListModelMixin, description='Input format name\nYou can get the list of supported formats at:\n/server/annotation/formats'), OpenApiParameter('location', description='where to import the annotation from', location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False, - enum=Location.list()), + enum=Location.values()), OpenApiParameter('cloud_storage_id', description='Storage id', location=OpenApiParameter.QUERY, type=OpenApiTypes.NUMBER, required=False), OpenApiParameter('use_default_location', description='Use the location that was configured in task to import annotations', @@ -1240,7 +1246,7 @@ class TaskViewSet(viewsets.GenericViewSet, mixins.ListModelMixin, default=True), OpenApiParameter('location', description='Where need to save downloaded dataset', location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False, - enum=Location.list()), + enum=Location.values()), OpenApiParameter('cloud_storage_id', description='Storage id', location=OpenApiParameter.QUERY, type=OpenApiTypes.NUMBER, required=False), ], @@ -1359,7 +1365,7 @@ class JobViewSet(viewsets.GenericViewSet, mixins.ListModelMixin, type=OpenApiTypes.STR, required=False, enum=['download']), OpenApiParameter('location', description='Where need to save downloaded annotation', location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False, - enum=Location.list()), + enum=Location.values()), OpenApiParameter('cloud_storage_id', description='Storage id', location=OpenApiParameter.QUERY, type=OpenApiTypes.NUMBER, required=False), OpenApiParameter('use_default_location', description='Use the location that was configured in the task to export annotation', @@ -1382,7 +1388,7 @@ class JobViewSet(viewsets.GenericViewSet, mixins.ListModelMixin, description='Input format name\nYou can get the list of supported formats at:\n/server/annotation/formats'), OpenApiParameter('location', description='where to import the annotation from', location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False, - enum=Location.list()), + enum=Location.values()), OpenApiParameter('cloud_storage_id', description='Storage id', location=OpenApiParameter.QUERY, type=OpenApiTypes.NUMBER, required=False), OpenApiParameter('use_default_location', description='Use the location that was configured in the task to import annotation', @@ -1523,7 +1529,7 @@ class JobViewSet(viewsets.GenericViewSet, mixins.ListModelMixin, default=True), OpenApiParameter('location', description='Where need to save downloaded dataset', location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False, - enum=Location.list()), + enum=Location.values()), OpenApiParameter('cloud_storage_id', description='Storage id', location=OpenApiParameter.QUERY, type=OpenApiTypes.NUMBER, required=False), ], @@ -1958,7 +1964,7 @@ class CloudStorageViewSet(viewsets.GenericViewSet, mixins.ListModelMixin, provider_type = self.request.query_params.get('provider_type', None) if provider_type: - if provider_type in CloudProviderChoice.list(): + if provider_type in CloudProviderChoice.values(): return queryset.filter(provider_type=provider_type) raise ValidationError('Unsupported type of cloud provider') return queryset diff --git a/cvat/apps/lambda_manager/views.py b/cvat/apps/lambda_manager/views.py index 9a920931..898c9f4f 100644 --- a/cvat/apps/lambda_manager/views.py +++ b/cvat/apps/lambda_manager/views.py @@ -6,7 +6,6 @@ import base64 import json from functools import wraps -from enum import Enum from copy import deepcopy import django_rq @@ -28,16 +27,15 @@ from cvat.apps.engine.models import ShapeType, SourceType from drf_spectacular.utils import extend_schema, extend_schema_view, OpenApiResponse, OpenApiParameter from drf_spectacular.types import OpenApiTypes -class LambdaType(Enum): +from cvat.apps.engine.utils import StrEnum + +class LambdaType(StrEnum): DETECTOR = "detector" INTERACTOR = "interactor" REID = "reid" TRACKER = "tracker" UNKNOWN = "unknown" - def __str__(self): - return self.value - class LambdaGateway: NUCLIO_ROOT_URL = '/api/functions' diff --git a/cvat/apps/webhooks/models.py b/cvat/apps/webhooks/models.py index 44885ffd..de0477c8 100644 --- a/cvat/apps/webhooks/models.py +++ b/cvat/apps/webhooks/models.py @@ -2,38 +2,21 @@ # # SPDX-License-Identifier: MIT -from enum import Enum - from django.contrib.auth.models import User from django.db import models from cvat.apps.engine.models import Project +from cvat.apps.engine.utils import DjangoEnum, StrEnum from cvat.apps.organizations.models import Organization -class WebhookTypeChoice(str, Enum): +class WebhookTypeChoice(DjangoEnum, StrEnum): ORGANIZATION = "organization" PROJECT = "project" - @classmethod - def choices(cls): - return tuple((x.value, x.name) for x in cls) - - def __str__(self): - return self.value - - -class WebhookContentTypeChoice(str, Enum): +class WebhookContentTypeChoice(DjangoEnum, StrEnum): JSON = "application/json" - @classmethod - def choices(cls): - return tuple((x.value, x.name) for x in cls) - - def __str__(self): - return self.value - - class Webhook(models.Model): target_url = models.URLField() description = models.CharField(max_length=128, default="", blank=True)