From 8b3e247d526bba4f8e99af86476795ce7031f181 Mon Sep 17 00:00:00 2001 From: Anastasia Yasakova Date: Thu, 1 Sep 2022 11:50:53 +0300 Subject: [PATCH] Extend COCO format to skeletons (#4821) --- CHANGELOG.md | 2 +- cvat/apps/dataset_manager/bindings.py | 136 +++++++++++++----- cvat/apps/dataset_manager/formats/coco.py | 35 ++++- .../tests/assets/annotations.json | 133 +++++++++++++++++ .../dataset_manager/tests/assets/tasks.json | 41 ++++++ .../dataset_manager/tests/test_formats.py | 3 + .../tests/test_rest_api_formats.py | 44 +++++- cvat/apps/engine/serializers.py | 40 +++++- cvat/apps/engine/tests/test_rest_api.py | 5 +- 9 files changed, 387 insertions(+), 52 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 901dd4bc..45c77c1f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,7 +26,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 Skeleton (), () - Added helm chart support for serverless functions and analytics () - Added confirmation when remove a track () - +- [COCO Keypoints](https://cocodataset.org/#keypoints-2020) format support () ### Changed - Bumped nuclio version to 1.8.14 diff --git a/cvat/apps/dataset_manager/bindings.py b/cvat/apps/dataset_manager/bindings.py index fcefc12f..b416bcae 100644 --- a/cvat/apps/dataset_manager/bindings.py +++ b/cvat/apps/dataset_manager/bindings.py @@ -3,28 +3,29 @@ # # SPDX-License-Identifier: MIT -import sys -import rq import os.path as osp -from attr import attrib, attrs +import re +import sys from collections import namedtuple -from types import SimpleNamespace from pathlib import Path +from types import SimpleNamespace from typing import (Any, Callable, DefaultDict, Dict, List, Literal, Mapping, - NamedTuple, OrderedDict, Tuple, Union, Set) + NamedTuple, OrderedDict, Set, Tuple, Union) import datumaro.components.annotation as datum_annotation import datumaro.components.extractor as datum_extractor +import rq +from attr import attrib, attrs from datumaro.components.dataset import Dataset from datumaro.util import cast from datumaro.util.image import ByteImage, Image from django.utils import timezone +from cvat.apps.dataset_manager.formats.utils import get_label_color from cvat.apps.engine.frame_provider import FrameProvider -from cvat.apps.engine.models import AttributeType, DimensionType, AttributeSpec, LabelType +from cvat.apps.engine.models import AttributeSpec, AttributeType, DimensionType from cvat.apps.engine.models import Image as Img -from cvat.apps.engine.models import Label, Project, ShapeType, Task -from cvat.apps.dataset_manager.formats.utils import get_label_color +from cvat.apps.engine.models import Label, LabelType, Project, ShapeType, Task from .annotation import AnnotationIR, AnnotationManager, TrackManager from .formats.transformations import EllipsesToMasks @@ -64,9 +65,9 @@ class InstanceLabelData: **attr_mapping['immutable'], } - def _get_label_id(self, label_name): + def _get_label_id(self, label_name, parent_id=None): for db_label in self._label_mapping.values(): - if label_name == db_label.name: + if label_name == db_label.name and parent_id == db_label.parent_id: return db_label.id raise ValueError("Label {!r} is not registered for this task".format(label_name)) @@ -288,6 +289,8 @@ class TaskData(InstanceLabelData): if db_label.type == str(LabelType.SKELETON): label["svg"] = db_label.skeleton.svg + for db_sublabel in list(db_label.sublabels.all()): + label["svg"] = label["svg"].replace(f'data-label-id="{db_sublabel.id}"', f'data-label-name="{db_sublabel.name}"') labels.append(('label', label)) @@ -488,9 +491,9 @@ class TaskData(InstanceLabelData): ] return _tag - def _import_shape(self, shape): + def _import_shape(self, shape, parent_label_id=None): _shape = shape._asdict() - label_id = self._get_label_id(_shape.pop('label')) + label_id = self._get_label_id(_shape.pop('label'), parent_label_id) _shape['frame'] = self.rel_frame_id(int(_shape['frame'])) _shape['label_id'] = label_id _shape['attributes'] = [self._import_attribute(label_id, attrib) @@ -500,19 +503,19 @@ class TaskData(InstanceLabelData): ) ] _shape['points'] = list(map(float, _shape['points'])) - _shape['elements'] = [self._import_shape(element) for element in _shape.get('elements', [])] + _shape['elements'] = [self._import_shape(element, label_id) for element in _shape.get('elements', [])] return _shape - def _import_track(self, track): + def _import_track(self, track, parent_label_id=None): _track = track._asdict() - label_id = self._get_label_id(_track.pop('label')) + label_id = self._get_label_id(_track.pop('label'), parent_label_id) _track['frame'] = self.rel_frame_id( min(int(shape.frame) for shape in _track['shapes'])) _track['label_id'] = label_id _track['attributes'] = [] _track['shapes'] = [shape._asdict() for shape in _track['shapes']] - _track['elements'] = [self._import_track(element) for element in _track.get('elements', [])] + _track['elements'] = [self._import_track(element, label_id) for element in _track.get('elements', [])] for shape in _track['shapes']: shape['frame'] = self.rel_frame_id(int(shape['frame'])) _track['attributes'] = [self._import_attribute(label_id, attrib) @@ -778,22 +781,6 @@ class ProjectData(InstanceLabelData): ) for db_task in self._db_tasks.values() ]), - ("labels", [ - ("label", OrderedDict([ - ("name", db_label.name), - ("color", db_label.color), - ("type", db_label.type), - ("attributes", [ - ("attribute", OrderedDict([ - ("name", db_attr.name), - ("mutable", str(db_attr.mutable)), - ("input_type", db_attr.input_type), - ("default_value", db_attr.default_value), - ("values", db_attr.values)])) - for db_attr in db_label.attributespec_set.all()]) - ])) for db_label in self._label_mapping.values() - ]), - ("subsets", '\n'.join([s if s else datum_extractor.DEFAULT_SUBSET_NAME for s in self._subsets])), ("owner", OrderedDict([ @@ -809,6 +796,35 @@ class ProjectData(InstanceLabelData): ("dumped", str(timezone.localtime(timezone.now()))) ]) + if self._label_mapping is not None: + labels = [] + for db_label in self._label_mapping.values(): + label = OrderedDict([ + ("name", db_label.name), + ("color", db_label.color), + ("type", db_label.type), + ("attributes", [ + ("attribute", OrderedDict([ + ("name", db_attr.name), + ("mutable", str(db_attr.mutable)), + ("input_type", db_attr.input_type), + ("default_value", db_attr.default_value), + ("values", db_attr.values)])) + for db_attr in db_label.attributespec_set.all()]) + ]) + + if db_label.parent: + label["parent"] = db_label.parent.name + + if db_label.type == str(LabelType.SKELETON): + label["svg"] = db_label.skeleton.svg + for db_sublabel in list(db_label.sublabels.all()): + label["svg"] = label["svg"].replace(f'data-label-id="{db_sublabel.id}"', f'data-label-name="{db_sublabel.name}"') + + labels.append(('label', label)) + + self._meta['project']['labels'] = labels + def _export_tracked_shape(self, shape: dict, task_id: int): return ProjectData.TrackedShape( type=shape["type"], @@ -1056,14 +1072,24 @@ class CVATDataExtractorMixin: datum_annotation.Categories] = {} label_categories = datum_annotation.LabelCategories(attributes=['occluded']) + point_categories = datum_annotation.PointsCategories() for _, label in labels: - label_categories.add(label['name']) - for _, attr in label['attributes']: - label_categories.attributes.add(attr['name']) + if label.get('parent') is None: + label_id = label_categories.add(label['name']) + for _, attr in label['attributes']: + label_categories.attributes.add(attr['name']) + if label['type'] == str(LabelType.SKELETON): + labels_from = list(map(int, re.findall(r'data-node-from="(\d+)"', label['svg']))) + labels_to = list(map(int, re.findall(r'data-node-to="(\d+)"', label['svg']))) + sublabels = re.findall(r'data-label-name="(\w+)"', label['svg']) + joints = zip(labels_from, labels_to) + + point_categories.add(label_id, sublabels, joints) categories[datum_annotation.AnnotationType.label] = label_categories + categories[datum_annotation.AnnotationType.points] = point_categories return categories @@ -1423,6 +1449,21 @@ def convert_cvat_anno_to_dm(cvat_frame_anno, label_attrs, map_label, format_name ) else: continue + elif shape_obj.type == ShapeType.SKELETON: + points = [] + vis = [] + for element in shape_obj.elements: + points.extend(element.points) + element_vis = datum_annotation.Points.Visibility.visible + if element.outside: + element_vis = datum_annotation.Points.Visibility.absent + elif element.occluded: + element_vis = datum_annotation.Points.Visibility.hidden + vis.append(element_vis) + + anno = datum_annotation.Points(points, vis, + label=anno_label, attributes=anno_attr, group=anno_group, + z_order=shape_obj.z_order) else: raise Exception("Unknown shape type '%s'" % shape_obj.type) @@ -1485,6 +1526,7 @@ def import_dm_annotations(dm_dataset: Dataset, instance_data: Union[TaskData, Pr } label_cat = dm_dataset.categories()[datum_annotation.AnnotationType.label] + point_cat = dm_dataset.categories().get(datum_annotation.AnnotationType.points) root_hint = find_dataset_root(dm_dataset, instance_data) @@ -1540,9 +1582,30 @@ def import_dm_annotations(dm_dataset: Dataset, instance_data: Union[TaskData, Pr source = ann.attributes.pop('source').lower() \ if ann.attributes.get('source', '').lower() in {'auto', 'manual'} else 'manual' + shape_type = shapes[ann.type] + elements = [] + if point_cat and shape_type == ShapeType.POINTS: + labels = point_cat.items[ann.label].labels + shape_type = ShapeType.SKELETON + for i in range(len(ann.points) // 2): + label = None + if i < len(labels): + label = labels[i] + elements.append(instance_data.LabeledShape( + type=ShapeType.POINTS, + frame=frame_number, + points=ann.points[2 * i : 2 * i + 2], + label=label, + occluded=ann.visibility[i] == datum_annotation.Points.Visibility.hidden, + source=source, + attributes=[], + outside=ann.visibility[i] == datum_annotation.Points.Visibility.absent, + )) + + if track_id is None or dm_dataset.format != 'cvat' : instance_data.add_shape(instance_data.LabeledShape( - type=shapes[ann.type], + type=shape_type, frame=frame_number, points=ann.points, label=label_cat.items[ann.label].name, @@ -1551,6 +1614,7 @@ def import_dm_annotations(dm_dataset: Dataset, instance_data: Union[TaskData, Pr group=group_map.get(ann.group, 0), source=source, attributes=attributes, + elements=elements, )) continue diff --git a/cvat/apps/dataset_manager/formats/coco.py b/cvat/apps/dataset_manager/formats/coco.py index a184b241..fba7c3c4 100644 --- a/cvat/apps/dataset_manager/formats/coco.py +++ b/cvat/apps/dataset_manager/formats/coco.py @@ -6,6 +6,7 @@ import zipfile from tempfile import TemporaryDirectory from datumaro.components.dataset import Dataset +from datumaro.components.annotation import AnnotationType from cvat.apps.dataset_manager.bindings import GetCVATDataExtractor, \ import_dm_annotations @@ -28,13 +29,43 @@ def _import(src_file, instance_data, load_data_callback=None): if zipfile.is_zipfile(src_file): with TemporaryDirectory() as tmp_dir: zipfile.ZipFile(src_file).extractall(tmp_dir) + dataset = Dataset.import_from(tmp_dir, 'coco_instances', env=dm_env) + if load_data_callback is not None: + load_data_callback(dataset, instance_data) + import_dm_annotations(dataset, instance_data) + else: + dataset = Dataset.import_from(src_file.name, + 'coco_instances', env=dm_env) + import_dm_annotations(dataset, instance_data) + +@exporter(name='COCO Keypoints', ext='ZIP', version='1.0') +def _export(dst_file, instance_data, save_images=False): + dataset = Dataset.from_extractors(GetCVATDataExtractor( + instance_data, include_images=save_images), env=dm_env) + with TemporaryDirectory() as temp_dir: + dataset.export(temp_dir, 'coco_person_keypoints', save_images=save_images, + merge_images=True) + + make_zip_archive(temp_dir, dst_file) +@importer(name='COCO Keypoints', ext='JSON, ZIP', version='1.0') +def _import(src_file, instance_data, load_data_callback=None): + def remove_extra_annotations(dataset): + for item in dataset: + annotations = [ann for ann in item.annotations + if ann.type != AnnotationType.bbox] + item.annotations = annotations - dataset = Dataset.import_from(tmp_dir, 'coco', env=dm_env) + if zipfile.is_zipfile(src_file): + with TemporaryDirectory() as tmp_dir: + zipfile.ZipFile(src_file).extractall(tmp_dir) + dataset = Dataset.import_from(tmp_dir, 'coco_person_keypoints', env=dm_env) + remove_extra_annotations(dataset) if load_data_callback is not None: load_data_callback(dataset, instance_data) import_dm_annotations(dataset, instance_data) else: dataset = Dataset.import_from(src_file.name, - 'coco_instances', env=dm_env) + 'coco_person_keypoints', env=dm_env) + remove_extra_annotations(dataset) import_dm_annotations(dataset, instance_data) diff --git a/cvat/apps/dataset_manager/tests/assets/annotations.json b/cvat/apps/dataset_manager/tests/assets/annotations.json index 1a447963..4ec55a94 100644 --- a/cvat/apps/dataset_manager/tests/assets/annotations.json +++ b/cvat/apps/dataset_manager/tests/assets/annotations.json @@ -283,6 +283,139 @@ ], "tracks": [] }, + "COCO Keypoints 1.0": { + "version": 0, + "tags": [], + "shapes": [ + { + "type": "skeleton", + "occluded": false, + "outside": false, + "z_order": 0, + "rotation": 0, + "points": [], + "frame": 0, + "label_id": null, + "group": 0, + "source": "manual", + "attributes": [], + "elements": [ + { + "type": "points", + "occluded": false, + "outside": true, + "z_order": 0, + "rotation": 0, + "points": [ + 223.02, + 72.83 + ], + "frame": 0, + "label_id": null, + "group": 0, + "source": "manual", + "attributes": [] + }, + { + "type": "points", + "occluded": false, + "outside": false, + "z_order": 0, + "rotation": 0, + "points": [ + 232.98, + 124.6 + ], + "frame": 0, + "label_id": null, + "group": 0, + "source": "manual", + "attributes": [] + }, + { + "type": "points", + "occluded": false, + "outside": false, + "z_order": 0, + "rotation": 0, + "points": [ + 281.22, + 36.63 + ], + "frame": 0, + "label_id": null, + "group": 0, + "source": "manual", + "attributes": [] + } + ] + }, + { + "type": "skeleton", + "occluded": false, + "outside": false, + "z_order": 0, + "rotation": 0, + "points": [], + "frame": 1, + "label_id": null, + "group": 0, + "source": "manual", + "attributes": [], + "elements": [ + { + "type": "points", + "occluded": false, + "outside": false, + "z_order": 0, + "rotation": 0, + "points": [ + 248.61, + 72.83 + ], + "frame": 1, + "label_id": null, + "group": 0, + "source": "manual", + "attributes": [] + }, + { + "type": "points", + "occluded": false, + "outside": false, + "z_order": 0, + "rotation": 0, + "points": [ + 267.06, + 122.43 + ], + "frame": 1, + "label_id": null, + "group": 0, + "source": "manual", + "attributes": [] + }, + { + "type": "points", + "occluded": true, + "outside": false, + "z_order": 0, + "rotation": 0, + "points": [ + 356.43, + 38.15 + ], + "frame": 1, + "label_id": null, + "group": 0, + "source": "manual", + "attributes": [] + } + ] + } + ], + "tracks": [] + }, "ICDAR Localization 1.0": { "version": 0, "tags": [], diff --git a/cvat/apps/dataset_manager/tests/assets/tasks.json b/cvat/apps/dataset_manager/tests/assets/tasks.json index fcff87a5..e301c54a 100644 --- a/cvat/apps/dataset_manager/tests/assets/tasks.json +++ b/cvat/apps/dataset_manager/tests/assets/tasks.json @@ -326,5 +326,46 @@ "project_id": 1, "overlap": 0, "segment_size": 100 + }, + "COCO Keypoints 1.0": { + "name": "coco keupoints task", + "overlap": 0, + "segment_size": 100, + "labels": [ + { + "name": "skeleton", + "color": "#2080c0", + "type": "skeleton", + "attributes": [], + "sublabels": [ + { + "name": "1", + "color": "#d12345", + "attributes": [], + "type": "points" + }, + { + "name": "2", + "color": "#350dea", + "attributes": [], + "type": "points" + }, + { + "name": "3", + "color": "#479ffe", + "attributes": [ + { + "name": "attr", + "mutable": false, + "input_type": "select", + "values": ["0", "1", "2"] + } + ], + "type": "points" + } + ], + "svg": "" + } + ] } } diff --git a/cvat/apps/dataset_manager/tests/test_formats.py b/cvat/apps/dataset_manager/tests/test_formats.py index 4a771dcb..326a104b 100644 --- a/cvat/apps/dataset_manager/tests/test_formats.py +++ b/cvat/apps/dataset_manager/tests/test_formats.py @@ -275,6 +275,7 @@ class TaskExportTest(_DbTestBase): self.assertEqual({f.DISPLAY_NAME for f in formats}, { 'COCO 1.0', + 'COCO Keypoints 1.0', 'CVAT for images 1.1', 'CVAT for video 1.1', 'Datumaro 1.0', @@ -308,6 +309,7 @@ class TaskExportTest(_DbTestBase): self.assertEqual({f.DISPLAY_NAME for f in formats}, { 'COCO 1.0', + 'COCO Keypoints 1.0', 'CVAT 1.1', 'LabelMe 3.0', 'MOT 1.1', @@ -360,6 +362,7 @@ class TaskExportTest(_DbTestBase): for format_name, importer_name in [ ('COCO 1.0', 'coco'), + ('COCO Keypoints 1.0', 'coco_person_keypoints'), ('CVAT for images 1.1', 'cvat'), # ('CVAT for video 1.1', 'cvat'), # does not support ('Datumaro 1.0', 'datumaro'), diff --git a/cvat/apps/dataset_manager/tests/test_rest_api_formats.py b/cvat/apps/dataset_manager/tests/test_rest_api_formats.py index 47a1556a..140d1717 100644 --- a/cvat/apps/dataset_manager/tests/test_rest_api_formats.py +++ b/cvat/apps/dataset_manager/tests/test_rest_api_formats.py @@ -237,6 +237,36 @@ class _DbTestBase(APITestCase): "spec_id": spec_id, "value": value, }) + elements = tmp_annotations[item][index_elem].get("elements", []) + labels = task["labels"][0].get("sublabels", []) + for element, label in zip(elements, labels): + element["label_id"] = label["id"] + + for index_attribute, attribute in enumerate(label["attributes"]): + spec_id = label["attributes"][index_attribute]["id"] + + if key_get_values == "random": + if attribute["input_type"] == "number": + start = int(attribute["values"][0]) + stop = int(attribute["values"][1]) + 1 + step = int(attribute["values"][2]) + value = str(random.randrange(start, stop, step)) + else: + value = random.choice(label["attributes"][index_attribute]["values"]) + elif key_get_values == "default": + value = attribute["default_value"] + + if item == "tracks" and attribute["mutable"]: + for index_shape, _ in enumerate(element["shapes"]): + element["shapes"][index_shape]["attributes"].append({ + "spec_id": spec_id, + "value": value, + }) + else: + element["attributes"].append({ + "spec_id": spec_id, + "value": value, + }) response = self._put_api_v2_task_id_annotations(task["id"], tmp_annotations) self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -358,7 +388,7 @@ class TaskDumpUploadTest(_DbTestBase): if dump_format_name in [ "Market-1501 1.0", "Cityscapes 1.0", \ "ICDAR Localization 1.0", "ICDAR Recognition 1.0", \ - "ICDAR Segmentation 1.0" + "ICDAR Segmentation 1.0", "COCO Keypoints 1.0", ]: task = self._create_task(tasks[dump_format_name], images) else: @@ -421,7 +451,7 @@ class TaskDumpUploadTest(_DbTestBase): if upload_format_name in [ "Market-1501 1.0", "Cityscapes 1.0", \ "ICDAR Localization 1.0", "ICDAR Recognition 1.0", \ - "ICDAR Segmentation 1.0" + "ICDAR Segmentation 1.0", "COCO Keypoints 1.0", ]: task = self._create_task(tasks[upload_format_name], images) else: @@ -464,7 +494,7 @@ class TaskDumpUploadTest(_DbTestBase): if dump_format_name in [ "Market-1501 1.0", "Cityscapes 1.0", \ "ICDAR Localization 1.0", "ICDAR Recognition 1.0", \ - "ICDAR Segmentation 1.0" + "ICDAR Segmentation 1.0", "COCO Keypoints 1.0", ]: task = self._create_task(tasks[dump_format_name], video) else: @@ -525,7 +555,7 @@ class TaskDumpUploadTest(_DbTestBase): if upload_format_name in [ "Market-1501 1.0", "Cityscapes 1.0", \ "ICDAR Localization 1.0", "ICDAR Recognition 1.0", \ - "ICDAR Segmentation 1.0" + "ICDAR Segmentation 1.0", "COCO Keypoints 1.0", ]: task = self._create_task(tasks[upload_format_name], video) else: @@ -806,7 +836,7 @@ class TaskDumpUploadTest(_DbTestBase): if dump_format_name in [ "Market-1501 1.0", "Cityscapes 1.0", \ "ICDAR Localization 1.0", "ICDAR Recognition 1.0", \ - "ICDAR Segmentation 1.0" + "ICDAR Segmentation 1.0","COCO Keypoints 1.0", ]: task = self._create_task(tasks[dump_format_name], images) else: @@ -910,7 +940,7 @@ class TaskDumpUploadTest(_DbTestBase): if dump_format_name in [ "Market-1501 1.0", "Cityscapes 1.0", \ "ICDAR Localization 1.0", "ICDAR Recognition 1.0", \ - "ICDAR Segmentation 1.0" + "ICDAR Segmentation 1.0", "COCO Keypoints 1.0", ]: task = self._create_task(tasks[dump_format_name], images) else: @@ -1024,7 +1054,7 @@ class TaskDumpUploadTest(_DbTestBase): if dump_format_name in [ "Market-1501 1.0", "Cityscapes 1.0", \ "ICDAR Localization 1.0", "ICDAR Recognition 1.0", \ - "ICDAR Segmentation 1.0" + "ICDAR Segmentation 1.0", "COCO Keypoints 1.0" ]: task = self._create_task(tasks[dump_format_name], images) else: diff --git a/cvat/apps/engine/serializers.py b/cvat/apps/engine/serializers.py index 433e316e..81a06ee8 100644 --- a/cvat/apps/engine/serializers.py +++ b/cvat/apps/engine/serializers.py @@ -601,7 +601,10 @@ class TaskWriteSerializer(WriteOnceMixin, serializers.ModelSerializer): if instance.project_id is None: for old_label in instance.label_set.all(): try: - new_label = project.label_set.filter(name=old_label.name).first() + if old_label.parent: + new_label = project.label_set.filter(name=old_label.name, parent__name=old_label.parent.name).first() + else: + new_label = project.label_set.filter(name=old_label.name).first() except ValueError: raise serializers.ValidationError(f'Target project does not have label with name "{old_label.name}"') old_label.attributespec_set.all().delete() @@ -616,7 +619,10 @@ class TaskWriteSerializer(WriteOnceMixin, serializers.ModelSerializer): if len(new_label_for_name): old_label.name = new_label_for_name[0].get('name', old_label.name) try: - new_label = project.label_set.filter(name=old_label.name).first() + if old_label.parent: + new_label = project.label_set.filter(name=old_label.name, parent__name=old_label.parent.name).first() + else: + new_label = project.label_set.filter(name=old_label.name).first() except ValueError: raise serializers.ValidationError(f'Target project does not have label with name "{old_label.name}"') for (model, attr, attr_name) in ( @@ -651,18 +657,42 @@ class TaskWriteSerializer(WriteOnceMixin, serializers.ModelSerializer): # Check that all labels can be mapped new_label_names = set() old_labels = self.instance.project.label_set.all() if self.instance.project_id else self.instance.label_set.all() + new_sublabel_names = {} for old_label in old_labels: new_labels = tuple(filter(lambda x: x.get('id') == old_label.id, attrs.get('label_set', []))) if len(new_labels): - new_label_names.add(new_labels[0].get('name', old_label.name)) + parent = new_labels[0].get('parent', old_label.parent) + if parent: + if parent.name not in new_sublabel_names: + new_sublabel_names[parent.name] = set() + new_sublabel_names[parent.name].add(new_labels[0].get('name', old_label.name)) + else: + new_label_names.add(new_labels[0].get('name', old_label.name)) else: - new_label_names.add(old_label.name) + parent = old_label.parent + if parent: + if parent.name not in new_sublabel_names: + new_sublabel_names[parent.name] = set() + new_sublabel_names[parent.name].add(old_label.name) + else: + new_label_names.add(old_label.name) target_project = models.Project.objects.get(id=project_id) target_project_label_names = set() + target_project_sublabel_names = {} for label in target_project.label_set.all(): - target_project_label_names.add(label.name) + parent = label.parent + if parent: + if parent.name not in target_project_sublabel_names: + target_project_sublabel_names[parent.name] = set() + target_project_sublabel_names[parent.name].add(label.name) + else: + target_project_label_names.add(label.name) if not new_label_names.issubset(target_project_label_names): raise serializers.ValidationError('All task or project label names must be mapped to the target project') + + for label, sublabels in new_sublabel_names.items(): + if sublabels != target_project_sublabel_names.get(label): + raise serializers.ValidationError('All task or project label names must be mapped to the target project') else: if 'label_set' in attrs.keys(): label_names = [label['name'] for label in attrs.get('label_set')] diff --git a/cvat/apps/engine/tests/test_rest_api.py b/cvat/apps/engine/tests/test_rest_api.py index bb599227..43b82ab4 100644 --- a/cvat/apps/engine/tests/test_rest_api.py +++ b/cvat/apps/engine/tests/test_rest_api.py @@ -5416,6 +5416,9 @@ class TaskAnnotationAPITestCase(JobAnnotationAPITestCase): elif annotation_format == "COCO 1.0": annotations["shapes"] = polygon_shapes_wo_attrs + elif annotation_format == "COCO Keypoints 1.0": + annotations["shapes"] = points_wo_attrs + elif annotation_format == "Segmentation mask 1.1": annotations["shapes"] = rectangle_shapes_wo_attrs \ + polygon_shapes_wo_attrs @@ -5771,7 +5774,7 @@ class TaskAnnotationAPITestCase(JobAnnotationAPITestCase): self.assertTrue(zipfile.is_zipfile(content)) elif format_name in ['Kitti Raw Format 1.0','Sly Point Cloud Format 1.0']: self.assertTrue(zipfile.is_zipfile(content)) - elif format_name == "COCO 1.0": + elif format_name in ["COCO 1.0", "COCO Keypoints 1.0"]: with tempfile.TemporaryDirectory() as tmp_dir: zipfile.ZipFile(content).extractall(tmp_dir) jsons = glob(osp.join(tmp_dir, '**', '*.json'), recursive=True)