Extend COCO format to skeletons (#4821)

main
Anastasia Yasakova 4 years ago committed by GitHub
parent 9da3cbd6d8
commit 8b3e247d52
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -26,7 +26,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
Skeleton (<https://github.com/cvat-ai/cvat/pull/1>), (<https://github.com/opencv/cvat/pull/4829>) Skeleton (<https://github.com/cvat-ai/cvat/pull/1>), (<https://github.com/opencv/cvat/pull/4829>)
- Added helm chart support for serverless functions and analytics (<https://github.com/cvat-ai/cvat/pull/110>) - Added helm chart support for serverless functions and analytics (<https://github.com/cvat-ai/cvat/pull/110>)
- Added confirmation when remove a track (<https://github.com/opencv/cvat/pull/4846>) - Added confirmation when remove a track (<https://github.com/opencv/cvat/pull/4846>)
- [COCO Keypoints](https://cocodataset.org/#keypoints-2020) format support (<https://github.com/opencv/cvat/pull/4821>)
### Changed ### Changed
- Bumped nuclio version to 1.8.14 - Bumped nuclio version to 1.8.14

@ -3,28 +3,29 @@
# #
# SPDX-License-Identifier: MIT # SPDX-License-Identifier: MIT
import sys
import rq
import os.path as osp import os.path as osp
from attr import attrib, attrs import re
import sys
from collections import namedtuple from collections import namedtuple
from types import SimpleNamespace
from pathlib import Path from pathlib import Path
from types import SimpleNamespace
from typing import (Any, Callable, DefaultDict, Dict, List, Literal, Mapping, from typing import (Any, Callable, DefaultDict, Dict, List, Literal, Mapping,
NamedTuple, OrderedDict, Tuple, Union, Set) NamedTuple, OrderedDict, Set, Tuple, Union)
import datumaro.components.annotation as datum_annotation import datumaro.components.annotation as datum_annotation
import datumaro.components.extractor as datum_extractor import datumaro.components.extractor as datum_extractor
import rq
from attr import attrib, attrs
from datumaro.components.dataset import Dataset from datumaro.components.dataset import Dataset
from datumaro.util import cast from datumaro.util import cast
from datumaro.util.image import ByteImage, Image from datumaro.util.image import ByteImage, Image
from django.utils import timezone from django.utils import timezone
from cvat.apps.dataset_manager.formats.utils import get_label_color
from cvat.apps.engine.frame_provider import FrameProvider from cvat.apps.engine.frame_provider import FrameProvider
from cvat.apps.engine.models import AttributeType, DimensionType, AttributeSpec, LabelType from cvat.apps.engine.models import AttributeSpec, AttributeType, DimensionType
from cvat.apps.engine.models import Image as Img from cvat.apps.engine.models import Image as Img
from cvat.apps.engine.models import Label, Project, ShapeType, Task from cvat.apps.engine.models import Label, LabelType, Project, ShapeType, Task
from cvat.apps.dataset_manager.formats.utils import get_label_color
from .annotation import AnnotationIR, AnnotationManager, TrackManager from .annotation import AnnotationIR, AnnotationManager, TrackManager
from .formats.transformations import EllipsesToMasks from .formats.transformations import EllipsesToMasks
@ -64,9 +65,9 @@ class InstanceLabelData:
**attr_mapping['immutable'], **attr_mapping['immutable'],
} }
def _get_label_id(self, label_name): def _get_label_id(self, label_name, parent_id=None):
for db_label in self._label_mapping.values(): for db_label in self._label_mapping.values():
if label_name == db_label.name: if label_name == db_label.name and parent_id == db_label.parent_id:
return db_label.id return db_label.id
raise ValueError("Label {!r} is not registered for this task".format(label_name)) raise ValueError("Label {!r} is not registered for this task".format(label_name))
@ -288,6 +289,8 @@ class TaskData(InstanceLabelData):
if db_label.type == str(LabelType.SKELETON): if db_label.type == str(LabelType.SKELETON):
label["svg"] = db_label.skeleton.svg label["svg"] = db_label.skeleton.svg
for db_sublabel in list(db_label.sublabels.all()):
label["svg"] = label["svg"].replace(f'data-label-id="{db_sublabel.id}"', f'data-label-name="{db_sublabel.name}"')
labels.append(('label', label)) labels.append(('label', label))
@ -488,9 +491,9 @@ class TaskData(InstanceLabelData):
] ]
return _tag return _tag
def _import_shape(self, shape): def _import_shape(self, shape, parent_label_id=None):
_shape = shape._asdict() _shape = shape._asdict()
label_id = self._get_label_id(_shape.pop('label')) label_id = self._get_label_id(_shape.pop('label'), parent_label_id)
_shape['frame'] = self.rel_frame_id(int(_shape['frame'])) _shape['frame'] = self.rel_frame_id(int(_shape['frame']))
_shape['label_id'] = label_id _shape['label_id'] = label_id
_shape['attributes'] = [self._import_attribute(label_id, attrib) _shape['attributes'] = [self._import_attribute(label_id, attrib)
@ -500,19 +503,19 @@ class TaskData(InstanceLabelData):
) )
] ]
_shape['points'] = list(map(float, _shape['points'])) _shape['points'] = list(map(float, _shape['points']))
_shape['elements'] = [self._import_shape(element) for element in _shape.get('elements', [])] _shape['elements'] = [self._import_shape(element, label_id) for element in _shape.get('elements', [])]
return _shape return _shape
def _import_track(self, track): def _import_track(self, track, parent_label_id=None):
_track = track._asdict() _track = track._asdict()
label_id = self._get_label_id(_track.pop('label')) label_id = self._get_label_id(_track.pop('label'), parent_label_id)
_track['frame'] = self.rel_frame_id( _track['frame'] = self.rel_frame_id(
min(int(shape.frame) for shape in _track['shapes'])) min(int(shape.frame) for shape in _track['shapes']))
_track['label_id'] = label_id _track['label_id'] = label_id
_track['attributes'] = [] _track['attributes'] = []
_track['shapes'] = [shape._asdict() for shape in _track['shapes']] _track['shapes'] = [shape._asdict() for shape in _track['shapes']]
_track['elements'] = [self._import_track(element) for element in _track.get('elements', [])] _track['elements'] = [self._import_track(element, label_id) for element in _track.get('elements', [])]
for shape in _track['shapes']: for shape in _track['shapes']:
shape['frame'] = self.rel_frame_id(int(shape['frame'])) shape['frame'] = self.rel_frame_id(int(shape['frame']))
_track['attributes'] = [self._import_attribute(label_id, attrib) _track['attributes'] = [self._import_attribute(label_id, attrib)
@ -778,22 +781,6 @@ class ProjectData(InstanceLabelData):
) for db_task in self._db_tasks.values() ) for db_task in self._db_tasks.values()
]), ]),
("labels", [
("label", OrderedDict([
("name", db_label.name),
("color", db_label.color),
("type", db_label.type),
("attributes", [
("attribute", OrderedDict([
("name", db_attr.name),
("mutable", str(db_attr.mutable)),
("input_type", db_attr.input_type),
("default_value", db_attr.default_value),
("values", db_attr.values)]))
for db_attr in db_label.attributespec_set.all()])
])) for db_label in self._label_mapping.values()
]),
("subsets", '\n'.join([s if s else datum_extractor.DEFAULT_SUBSET_NAME for s in self._subsets])), ("subsets", '\n'.join([s if s else datum_extractor.DEFAULT_SUBSET_NAME for s in self._subsets])),
("owner", OrderedDict([ ("owner", OrderedDict([
@ -809,6 +796,35 @@ class ProjectData(InstanceLabelData):
("dumped", str(timezone.localtime(timezone.now()))) ("dumped", str(timezone.localtime(timezone.now())))
]) ])
if self._label_mapping is not None:
labels = []
for db_label in self._label_mapping.values():
label = OrderedDict([
("name", db_label.name),
("color", db_label.color),
("type", db_label.type),
("attributes", [
("attribute", OrderedDict([
("name", db_attr.name),
("mutable", str(db_attr.mutable)),
("input_type", db_attr.input_type),
("default_value", db_attr.default_value),
("values", db_attr.values)]))
for db_attr in db_label.attributespec_set.all()])
])
if db_label.parent:
label["parent"] = db_label.parent.name
if db_label.type == str(LabelType.SKELETON):
label["svg"] = db_label.skeleton.svg
for db_sublabel in list(db_label.sublabels.all()):
label["svg"] = label["svg"].replace(f'data-label-id="{db_sublabel.id}"', f'data-label-name="{db_sublabel.name}"')
labels.append(('label', label))
self._meta['project']['labels'] = labels
def _export_tracked_shape(self, shape: dict, task_id: int): def _export_tracked_shape(self, shape: dict, task_id: int):
return ProjectData.TrackedShape( return ProjectData.TrackedShape(
type=shape["type"], type=shape["type"],
@ -1056,14 +1072,24 @@ class CVATDataExtractorMixin:
datum_annotation.Categories] = {} datum_annotation.Categories] = {}
label_categories = datum_annotation.LabelCategories(attributes=['occluded']) label_categories = datum_annotation.LabelCategories(attributes=['occluded'])
point_categories = datum_annotation.PointsCategories()
for _, label in labels: for _, label in labels:
label_categories.add(label['name']) if label.get('parent') is None:
label_id = label_categories.add(label['name'])
for _, attr in label['attributes']: for _, attr in label['attributes']:
label_categories.attributes.add(attr['name']) label_categories.attributes.add(attr['name'])
if label['type'] == str(LabelType.SKELETON):
labels_from = list(map(int, re.findall(r'data-node-from="(\d+)"', label['svg'])))
labels_to = list(map(int, re.findall(r'data-node-to="(\d+)"', label['svg'])))
sublabels = re.findall(r'data-label-name="(\w+)"', label['svg'])
joints = zip(labels_from, labels_to)
point_categories.add(label_id, sublabels, joints)
categories[datum_annotation.AnnotationType.label] = label_categories categories[datum_annotation.AnnotationType.label] = label_categories
categories[datum_annotation.AnnotationType.points] = point_categories
return categories return categories
@ -1423,6 +1449,21 @@ def convert_cvat_anno_to_dm(cvat_frame_anno, label_attrs, map_label, format_name
) )
else: else:
continue continue
elif shape_obj.type == ShapeType.SKELETON:
points = []
vis = []
for element in shape_obj.elements:
points.extend(element.points)
element_vis = datum_annotation.Points.Visibility.visible
if element.outside:
element_vis = datum_annotation.Points.Visibility.absent
elif element.occluded:
element_vis = datum_annotation.Points.Visibility.hidden
vis.append(element_vis)
anno = datum_annotation.Points(points, vis,
label=anno_label, attributes=anno_attr, group=anno_group,
z_order=shape_obj.z_order)
else: else:
raise Exception("Unknown shape type '%s'" % shape_obj.type) raise Exception("Unknown shape type '%s'" % shape_obj.type)
@ -1485,6 +1526,7 @@ def import_dm_annotations(dm_dataset: Dataset, instance_data: Union[TaskData, Pr
} }
label_cat = dm_dataset.categories()[datum_annotation.AnnotationType.label] label_cat = dm_dataset.categories()[datum_annotation.AnnotationType.label]
point_cat = dm_dataset.categories().get(datum_annotation.AnnotationType.points)
root_hint = find_dataset_root(dm_dataset, instance_data) root_hint = find_dataset_root(dm_dataset, instance_data)
@ -1540,9 +1582,30 @@ def import_dm_annotations(dm_dataset: Dataset, instance_data: Union[TaskData, Pr
source = ann.attributes.pop('source').lower() \ source = ann.attributes.pop('source').lower() \
if ann.attributes.get('source', '').lower() in {'auto', 'manual'} else 'manual' if ann.attributes.get('source', '').lower() in {'auto', 'manual'} else 'manual'
shape_type = shapes[ann.type]
elements = []
if point_cat and shape_type == ShapeType.POINTS:
labels = point_cat.items[ann.label].labels
shape_type = ShapeType.SKELETON
for i in range(len(ann.points) // 2):
label = None
if i < len(labels):
label = labels[i]
elements.append(instance_data.LabeledShape(
type=ShapeType.POINTS,
frame=frame_number,
points=ann.points[2 * i : 2 * i + 2],
label=label,
occluded=ann.visibility[i] == datum_annotation.Points.Visibility.hidden,
source=source,
attributes=[],
outside=ann.visibility[i] == datum_annotation.Points.Visibility.absent,
))
if track_id is None or dm_dataset.format != 'cvat' : if track_id is None or dm_dataset.format != 'cvat' :
instance_data.add_shape(instance_data.LabeledShape( instance_data.add_shape(instance_data.LabeledShape(
type=shapes[ann.type], type=shape_type,
frame=frame_number, frame=frame_number,
points=ann.points, points=ann.points,
label=label_cat.items[ann.label].name, label=label_cat.items[ann.label].name,
@ -1551,6 +1614,7 @@ def import_dm_annotations(dm_dataset: Dataset, instance_data: Union[TaskData, Pr
group=group_map.get(ann.group, 0), group=group_map.get(ann.group, 0),
source=source, source=source,
attributes=attributes, attributes=attributes,
elements=elements,
)) ))
continue continue

@ -6,6 +6,7 @@ import zipfile
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from datumaro.components.dataset import Dataset from datumaro.components.dataset import Dataset
from datumaro.components.annotation import AnnotationType
from cvat.apps.dataset_manager.bindings import GetCVATDataExtractor, \ from cvat.apps.dataset_manager.bindings import GetCVATDataExtractor, \
import_dm_annotations import_dm_annotations
@ -28,13 +29,43 @@ def _import(src_file, instance_data, load_data_callback=None):
if zipfile.is_zipfile(src_file): if zipfile.is_zipfile(src_file):
with TemporaryDirectory() as tmp_dir: with TemporaryDirectory() as tmp_dir:
zipfile.ZipFile(src_file).extractall(tmp_dir) zipfile.ZipFile(src_file).extractall(tmp_dir)
dataset = Dataset.import_from(tmp_dir, 'coco_instances', env=dm_env)
if load_data_callback is not None:
load_data_callback(dataset, instance_data)
import_dm_annotations(dataset, instance_data)
else:
dataset = Dataset.import_from(src_file.name,
'coco_instances', env=dm_env)
import_dm_annotations(dataset, instance_data)
@exporter(name='COCO Keypoints', ext='ZIP', version='1.0')
def _export(dst_file, instance_data, save_images=False):
dataset = Dataset.from_extractors(GetCVATDataExtractor(
instance_data, include_images=save_images), env=dm_env)
with TemporaryDirectory() as temp_dir:
dataset.export(temp_dir, 'coco_person_keypoints', save_images=save_images,
merge_images=True)
make_zip_archive(temp_dir, dst_file)
@importer(name='COCO Keypoints', ext='JSON, ZIP', version='1.0')
def _import(src_file, instance_data, load_data_callback=None):
def remove_extra_annotations(dataset):
for item in dataset:
annotations = [ann for ann in item.annotations
if ann.type != AnnotationType.bbox]
item.annotations = annotations
dataset = Dataset.import_from(tmp_dir, 'coco', env=dm_env) if zipfile.is_zipfile(src_file):
with TemporaryDirectory() as tmp_dir:
zipfile.ZipFile(src_file).extractall(tmp_dir)
dataset = Dataset.import_from(tmp_dir, 'coco_person_keypoints', env=dm_env)
remove_extra_annotations(dataset)
if load_data_callback is not None: if load_data_callback is not None:
load_data_callback(dataset, instance_data) load_data_callback(dataset, instance_data)
import_dm_annotations(dataset, instance_data) import_dm_annotations(dataset, instance_data)
else: else:
dataset = Dataset.import_from(src_file.name, dataset = Dataset.import_from(src_file.name,
'coco_instances', env=dm_env) 'coco_person_keypoints', env=dm_env)
remove_extra_annotations(dataset)
import_dm_annotations(dataset, instance_data) import_dm_annotations(dataset, instance_data)

@ -283,6 +283,139 @@
], ],
"tracks": [] "tracks": []
}, },
"COCO Keypoints 1.0": {
"version": 0,
"tags": [],
"shapes": [
{
"type": "skeleton",
"occluded": false,
"outside": false,
"z_order": 0,
"rotation": 0,
"points": [],
"frame": 0,
"label_id": null,
"group": 0,
"source": "manual",
"attributes": [],
"elements": [
{
"type": "points",
"occluded": false,
"outside": true,
"z_order": 0,
"rotation": 0,
"points": [
223.02,
72.83
],
"frame": 0,
"label_id": null,
"group": 0,
"source": "manual",
"attributes": []
},
{
"type": "points",
"occluded": false,
"outside": false,
"z_order": 0,
"rotation": 0,
"points": [
232.98,
124.6
],
"frame": 0,
"label_id": null,
"group": 0,
"source": "manual",
"attributes": []
},
{
"type": "points",
"occluded": false,
"outside": false,
"z_order": 0,
"rotation": 0,
"points": [
281.22,
36.63
],
"frame": 0,
"label_id": null,
"group": 0,
"source": "manual",
"attributes": []
}
]
},
{
"type": "skeleton",
"occluded": false,
"outside": false,
"z_order": 0,
"rotation": 0,
"points": [],
"frame": 1,
"label_id": null,
"group": 0,
"source": "manual",
"attributes": [],
"elements": [
{
"type": "points",
"occluded": false,
"outside": false,
"z_order": 0,
"rotation": 0,
"points": [
248.61,
72.83
],
"frame": 1,
"label_id": null,
"group": 0,
"source": "manual",
"attributes": []
},
{
"type": "points",
"occluded": false,
"outside": false,
"z_order": 0,
"rotation": 0,
"points": [
267.06,
122.43
],
"frame": 1,
"label_id": null,
"group": 0,
"source": "manual",
"attributes": []
},
{
"type": "points",
"occluded": true,
"outside": false,
"z_order": 0,
"rotation": 0,
"points": [
356.43,
38.15
],
"frame": 1,
"label_id": null,
"group": 0,
"source": "manual",
"attributes": []
}
]
}
],
"tracks": []
},
"ICDAR Localization 1.0": { "ICDAR Localization 1.0": {
"version": 0, "version": 0,
"tags": [], "tags": [],

@ -326,5 +326,46 @@
"project_id": 1, "project_id": 1,
"overlap": 0, "overlap": 0,
"segment_size": 100 "segment_size": 100
},
"COCO Keypoints 1.0": {
"name": "coco keupoints task",
"overlap": 0,
"segment_size": 100,
"labels": [
{
"name": "skeleton",
"color": "#2080c0",
"type": "skeleton",
"attributes": [],
"sublabels": [
{
"name": "1",
"color": "#d12345",
"attributes": [],
"type": "points"
},
{
"name": "2",
"color": "#350dea",
"attributes": [],
"type": "points"
},
{
"name": "3",
"color": "#479ffe",
"attributes": [
{
"name": "attr",
"mutable": false,
"input_type": "select",
"values": ["0", "1", "2"]
}
],
"type": "points"
}
],
"svg": "<line x1=\"38.92810821533203\" y1=\"53.31378173828125\" x2=\"80.23341369628906\" y2=\"18.36313819885254\" stroke=\"black\" data-type=\"edge\" data-node-from=\"2\" stroke-width=\"0.5\" data-node-to=\"3\"></line><line x1=\"30.399484634399414\" y1=\"32.74474334716797\" x2=\"38.92810821533203\" y2=\"53.31378173828125\" stroke=\"black\" data-type=\"edge\" data-node-from=\"1\" stroke-width=\"0.5\" data-node-to=\"2\"></line><circle r=\"1.5\" stroke=\"black\" fill=\"#b3b3b3\" cx=\"30.399484634399414\" cy=\"32.74474334716797\" stroke-width=\"0.1\" data-type=\"element node\" data-element-id=\"1\" data-node-id=\"1\" data-label-name=\"1\"></circle><circle r=\"1.5\" stroke=\"black\" fill=\"#b3b3b3\" cx=\"38.92810821533203\" cy=\"53.31378173828125\" stroke-width=\"0.1\" data-type=\"element node\" data-element-id=\"2\" data-node-id=\"2\" data-label-name=\"2\"></circle><circle r=\"1.5\" stroke=\"black\" fill=\"#b3b3b3\" cx=\"80.23341369628906\" cy=\"18.36313819885254\" stroke-width=\"0.1\" data-type=\"element node\" data-element-id=\"3\" data-node-id=\"3\" data-label-name=\"3\"></circle>"
}
]
} }
} }

@ -275,6 +275,7 @@ class TaskExportTest(_DbTestBase):
self.assertEqual({f.DISPLAY_NAME for f in formats}, self.assertEqual({f.DISPLAY_NAME for f in formats},
{ {
'COCO 1.0', 'COCO 1.0',
'COCO Keypoints 1.0',
'CVAT for images 1.1', 'CVAT for images 1.1',
'CVAT for video 1.1', 'CVAT for video 1.1',
'Datumaro 1.0', 'Datumaro 1.0',
@ -308,6 +309,7 @@ class TaskExportTest(_DbTestBase):
self.assertEqual({f.DISPLAY_NAME for f in formats}, self.assertEqual({f.DISPLAY_NAME for f in formats},
{ {
'COCO 1.0', 'COCO 1.0',
'COCO Keypoints 1.0',
'CVAT 1.1', 'CVAT 1.1',
'LabelMe 3.0', 'LabelMe 3.0',
'MOT 1.1', 'MOT 1.1',
@ -360,6 +362,7 @@ class TaskExportTest(_DbTestBase):
for format_name, importer_name in [ for format_name, importer_name in [
('COCO 1.0', 'coco'), ('COCO 1.0', 'coco'),
('COCO Keypoints 1.0', 'coco_person_keypoints'),
('CVAT for images 1.1', 'cvat'), ('CVAT for images 1.1', 'cvat'),
# ('CVAT for video 1.1', 'cvat'), # does not support # ('CVAT for video 1.1', 'cvat'), # does not support
('Datumaro 1.0', 'datumaro'), ('Datumaro 1.0', 'datumaro'),

@ -237,6 +237,36 @@ class _DbTestBase(APITestCase):
"spec_id": spec_id, "spec_id": spec_id,
"value": value, "value": value,
}) })
elements = tmp_annotations[item][index_elem].get("elements", [])
labels = task["labels"][0].get("sublabels", [])
for element, label in zip(elements, labels):
element["label_id"] = label["id"]
for index_attribute, attribute in enumerate(label["attributes"]):
spec_id = label["attributes"][index_attribute]["id"]
if key_get_values == "random":
if attribute["input_type"] == "number":
start = int(attribute["values"][0])
stop = int(attribute["values"][1]) + 1
step = int(attribute["values"][2])
value = str(random.randrange(start, stop, step))
else:
value = random.choice(label["attributes"][index_attribute]["values"])
elif key_get_values == "default":
value = attribute["default_value"]
if item == "tracks" and attribute["mutable"]:
for index_shape, _ in enumerate(element["shapes"]):
element["shapes"][index_shape]["attributes"].append({
"spec_id": spec_id,
"value": value,
})
else:
element["attributes"].append({
"spec_id": spec_id,
"value": value,
})
response = self._put_api_v2_task_id_annotations(task["id"], tmp_annotations) response = self._put_api_v2_task_id_annotations(task["id"], tmp_annotations)
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
@ -358,7 +388,7 @@ class TaskDumpUploadTest(_DbTestBase):
if dump_format_name in [ if dump_format_name in [
"Market-1501 1.0", "Cityscapes 1.0", \ "Market-1501 1.0", "Cityscapes 1.0", \
"ICDAR Localization 1.0", "ICDAR Recognition 1.0", \ "ICDAR Localization 1.0", "ICDAR Recognition 1.0", \
"ICDAR Segmentation 1.0" "ICDAR Segmentation 1.0", "COCO Keypoints 1.0",
]: ]:
task = self._create_task(tasks[dump_format_name], images) task = self._create_task(tasks[dump_format_name], images)
else: else:
@ -421,7 +451,7 @@ class TaskDumpUploadTest(_DbTestBase):
if upload_format_name in [ if upload_format_name in [
"Market-1501 1.0", "Cityscapes 1.0", \ "Market-1501 1.0", "Cityscapes 1.0", \
"ICDAR Localization 1.0", "ICDAR Recognition 1.0", \ "ICDAR Localization 1.0", "ICDAR Recognition 1.0", \
"ICDAR Segmentation 1.0" "ICDAR Segmentation 1.0", "COCO Keypoints 1.0",
]: ]:
task = self._create_task(tasks[upload_format_name], images) task = self._create_task(tasks[upload_format_name], images)
else: else:
@ -464,7 +494,7 @@ class TaskDumpUploadTest(_DbTestBase):
if dump_format_name in [ if dump_format_name in [
"Market-1501 1.0", "Cityscapes 1.0", \ "Market-1501 1.0", "Cityscapes 1.0", \
"ICDAR Localization 1.0", "ICDAR Recognition 1.0", \ "ICDAR Localization 1.0", "ICDAR Recognition 1.0", \
"ICDAR Segmentation 1.0" "ICDAR Segmentation 1.0", "COCO Keypoints 1.0",
]: ]:
task = self._create_task(tasks[dump_format_name], video) task = self._create_task(tasks[dump_format_name], video)
else: else:
@ -525,7 +555,7 @@ class TaskDumpUploadTest(_DbTestBase):
if upload_format_name in [ if upload_format_name in [
"Market-1501 1.0", "Cityscapes 1.0", \ "Market-1501 1.0", "Cityscapes 1.0", \
"ICDAR Localization 1.0", "ICDAR Recognition 1.0", \ "ICDAR Localization 1.0", "ICDAR Recognition 1.0", \
"ICDAR Segmentation 1.0" "ICDAR Segmentation 1.0", "COCO Keypoints 1.0",
]: ]:
task = self._create_task(tasks[upload_format_name], video) task = self._create_task(tasks[upload_format_name], video)
else: else:
@ -806,7 +836,7 @@ class TaskDumpUploadTest(_DbTestBase):
if dump_format_name in [ if dump_format_name in [
"Market-1501 1.0", "Cityscapes 1.0", \ "Market-1501 1.0", "Cityscapes 1.0", \
"ICDAR Localization 1.0", "ICDAR Recognition 1.0", \ "ICDAR Localization 1.0", "ICDAR Recognition 1.0", \
"ICDAR Segmentation 1.0" "ICDAR Segmentation 1.0","COCO Keypoints 1.0",
]: ]:
task = self._create_task(tasks[dump_format_name], images) task = self._create_task(tasks[dump_format_name], images)
else: else:
@ -910,7 +940,7 @@ class TaskDumpUploadTest(_DbTestBase):
if dump_format_name in [ if dump_format_name in [
"Market-1501 1.0", "Cityscapes 1.0", \ "Market-1501 1.0", "Cityscapes 1.0", \
"ICDAR Localization 1.0", "ICDAR Recognition 1.0", \ "ICDAR Localization 1.0", "ICDAR Recognition 1.0", \
"ICDAR Segmentation 1.0" "ICDAR Segmentation 1.0", "COCO Keypoints 1.0",
]: ]:
task = self._create_task(tasks[dump_format_name], images) task = self._create_task(tasks[dump_format_name], images)
else: else:
@ -1024,7 +1054,7 @@ class TaskDumpUploadTest(_DbTestBase):
if dump_format_name in [ if dump_format_name in [
"Market-1501 1.0", "Cityscapes 1.0", \ "Market-1501 1.0", "Cityscapes 1.0", \
"ICDAR Localization 1.0", "ICDAR Recognition 1.0", \ "ICDAR Localization 1.0", "ICDAR Recognition 1.0", \
"ICDAR Segmentation 1.0" "ICDAR Segmentation 1.0", "COCO Keypoints 1.0"
]: ]:
task = self._create_task(tasks[dump_format_name], images) task = self._create_task(tasks[dump_format_name], images)
else: else:

@ -601,6 +601,9 @@ class TaskWriteSerializer(WriteOnceMixin, serializers.ModelSerializer):
if instance.project_id is None: if instance.project_id is None:
for old_label in instance.label_set.all(): for old_label in instance.label_set.all():
try: try:
if old_label.parent:
new_label = project.label_set.filter(name=old_label.name, parent__name=old_label.parent.name).first()
else:
new_label = project.label_set.filter(name=old_label.name).first() new_label = project.label_set.filter(name=old_label.name).first()
except ValueError: except ValueError:
raise serializers.ValidationError(f'Target project does not have label with name "{old_label.name}"') raise serializers.ValidationError(f'Target project does not have label with name "{old_label.name}"')
@ -616,6 +619,9 @@ class TaskWriteSerializer(WriteOnceMixin, serializers.ModelSerializer):
if len(new_label_for_name): if len(new_label_for_name):
old_label.name = new_label_for_name[0].get('name', old_label.name) old_label.name = new_label_for_name[0].get('name', old_label.name)
try: try:
if old_label.parent:
new_label = project.label_set.filter(name=old_label.name, parent__name=old_label.parent.name).first()
else:
new_label = project.label_set.filter(name=old_label.name).first() new_label = project.label_set.filter(name=old_label.name).first()
except ValueError: except ValueError:
raise serializers.ValidationError(f'Target project does not have label with name "{old_label.name}"') raise serializers.ValidationError(f'Target project does not have label with name "{old_label.name}"')
@ -651,18 +657,42 @@ class TaskWriteSerializer(WriteOnceMixin, serializers.ModelSerializer):
# Check that all labels can be mapped # Check that all labels can be mapped
new_label_names = set() new_label_names = set()
old_labels = self.instance.project.label_set.all() if self.instance.project_id else self.instance.label_set.all() old_labels = self.instance.project.label_set.all() if self.instance.project_id else self.instance.label_set.all()
new_sublabel_names = {}
for old_label in old_labels: for old_label in old_labels:
new_labels = tuple(filter(lambda x: x.get('id') == old_label.id, attrs.get('label_set', []))) new_labels = tuple(filter(lambda x: x.get('id') == old_label.id, attrs.get('label_set', [])))
if len(new_labels): if len(new_labels):
parent = new_labels[0].get('parent', old_label.parent)
if parent:
if parent.name not in new_sublabel_names:
new_sublabel_names[parent.name] = set()
new_sublabel_names[parent.name].add(new_labels[0].get('name', old_label.name))
else:
new_label_names.add(new_labels[0].get('name', old_label.name)) new_label_names.add(new_labels[0].get('name', old_label.name))
else:
parent = old_label.parent
if parent:
if parent.name not in new_sublabel_names:
new_sublabel_names[parent.name] = set()
new_sublabel_names[parent.name].add(old_label.name)
else: else:
new_label_names.add(old_label.name) new_label_names.add(old_label.name)
target_project = models.Project.objects.get(id=project_id) target_project = models.Project.objects.get(id=project_id)
target_project_label_names = set() target_project_label_names = set()
target_project_sublabel_names = {}
for label in target_project.label_set.all(): for label in target_project.label_set.all():
parent = label.parent
if parent:
if parent.name not in target_project_sublabel_names:
target_project_sublabel_names[parent.name] = set()
target_project_sublabel_names[parent.name].add(label.name)
else:
target_project_label_names.add(label.name) target_project_label_names.add(label.name)
if not new_label_names.issubset(target_project_label_names): if not new_label_names.issubset(target_project_label_names):
raise serializers.ValidationError('All task or project label names must be mapped to the target project') raise serializers.ValidationError('All task or project label names must be mapped to the target project')
for label, sublabels in new_sublabel_names.items():
if sublabels != target_project_sublabel_names.get(label):
raise serializers.ValidationError('All task or project label names must be mapped to the target project')
else: else:
if 'label_set' in attrs.keys(): if 'label_set' in attrs.keys():
label_names = [label['name'] for label in attrs.get('label_set')] label_names = [label['name'] for label in attrs.get('label_set')]

@ -5416,6 +5416,9 @@ class TaskAnnotationAPITestCase(JobAnnotationAPITestCase):
elif annotation_format == "COCO 1.0": elif annotation_format == "COCO 1.0":
annotations["shapes"] = polygon_shapes_wo_attrs annotations["shapes"] = polygon_shapes_wo_attrs
elif annotation_format == "COCO Keypoints 1.0":
annotations["shapes"] = points_wo_attrs
elif annotation_format == "Segmentation mask 1.1": elif annotation_format == "Segmentation mask 1.1":
annotations["shapes"] = rectangle_shapes_wo_attrs \ annotations["shapes"] = rectangle_shapes_wo_attrs \
+ polygon_shapes_wo_attrs + polygon_shapes_wo_attrs
@ -5771,7 +5774,7 @@ class TaskAnnotationAPITestCase(JobAnnotationAPITestCase):
self.assertTrue(zipfile.is_zipfile(content)) self.assertTrue(zipfile.is_zipfile(content))
elif format_name in ['Kitti Raw Format 1.0','Sly Point Cloud Format 1.0']: elif format_name in ['Kitti Raw Format 1.0','Sly Point Cloud Format 1.0']:
self.assertTrue(zipfile.is_zipfile(content)) self.assertTrue(zipfile.is_zipfile(content))
elif format_name == "COCO 1.0": elif format_name in ["COCO 1.0", "COCO Keypoints 1.0"]:
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
zipfile.ZipFile(content).extractall(tmp_dir) zipfile.ZipFile(content).extractall(tmp_dir)
jsons = glob(osp.join(tmp_dir, '**', '*.json'), recursive=True) jsons = glob(osp.join(tmp_dir, '**', '*.json'), recursive=True)

Loading…
Cancel
Save