diff --git a/CHANGELOG.md b/CHANGELOG.md index e715b7c7..26e9014a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,7 +25,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 non-ascii paths while adding files from "Connected file share" (issue #4428) - Removed unnecessary volumes defined in docker-compose.serverless.yml () -- Project import with skeletons () +- Project import with skeletons (, + ) ### Security - TDB diff --git a/cvat/apps/dataset_manager/task.py b/cvat/apps/dataset_manager/task.py index f76274fc..ea486b68 100644 --- a/cvat/apps/dataset_manager/task.py +++ b/cvat/apps/dataset_manager/task.py @@ -1,5 +1,6 @@ # Copyright (C) 2019-2022 Intel Corporation +# Copyright (C) 2022 CVAT.ai Corporation # # SPDX-License-Identifier: MIT @@ -426,15 +427,21 @@ class JobAnnotation: ) shapes = {} + elements = {} for db_shape in db_shapes: self._extend_attributes(db_shape.labeledshapeattributeval_set, self.db_attributes[db_shape.label_id]["all"].values()) - db_shape.elements = [] + if db_shape.parent is None: shapes[db_shape.id] = db_shape else: - shapes[db_shape.parent].elements.append(db_shape) + if db_shape.parent not in elements: + elements[db_shape.parent] = [] + elements[db_shape.parent].append(db_shape) + + for shape_id, shape_elements in elements.items(): + shapes[shape_id].elements = shape_elements serializer = serializers.LabeledShapeSerializer(list(shapes.values()), many=True) self.ir_data.shapes = serializer.data @@ -493,6 +500,7 @@ class JobAnnotation: ) tracks = {} + elements = {} for db_track in db_tracks: db_track["trackedshape_set"] = _merge_table_rows(db_track["trackedshape_set"], { 'trackedshapeattributeval_set': [ @@ -518,11 +526,15 @@ class JobAnnotation: self._extend_attributes(db_shape["trackedshapeattributeval_set"], default_attribute_values) default_attribute_values = db_shape["trackedshapeattributeval_set"] - db_track.elements = [] if db_track.parent is None: tracks[db_track.id] = db_track else: - tracks[db_track.parent].elements.append(db_track) + if db_track.parent not in elements: + elements[db_track.parent] = [] + elements[db_track.parent].append(db_track) + + for track_id, track_elements in elements.items(): + tracks[track_id].elements = track_elements serializer = serializers.LabeledTrackSerializer(list(tracks.values()), many=True) self.ir_data.tracks = serializer.data diff --git a/tests/python/rest_api/test_tasks.py b/tests/python/rest_api/test_tasks.py index d4da5fca..9ce80a14 100644 --- a/tests/python/rest_api/test_tasks.py +++ b/tests/python/rest_api/test_tasks.py @@ -13,7 +13,7 @@ from cvat_sdk.core.helpers import get_paginated_collection import pytest from deepdiff import DeepDiff -from shared.utils.config import make_api_client +from shared.utils.config import get_method, make_api_client, patch_method from shared.utils.helpers import generate_image_files from .utils import export_dataset @@ -309,6 +309,174 @@ class TestPostTaskData: (task, _) = api_client.tasks_api.retrieve(task_id) assert task.size == 4 + def test_can_get_annotations_from_new_task_with_skeletons(self): + spec = { + "name": f'test admin1 to create a task with skeleton', + "labels": [ + { + "name": "s1", + "color": "#5c5eba", + "attributes": [], + "type": "skeleton", + "sublabels": [ + { + "name": "1", + "color": "#d12345", + "attributes": [], + "type": "points" + }, + { + "name": "2", + "color": "#350dea", + "attributes": [], + "type": "points" + } + ], + "svg": "" \ + "" \ + "" + } + ] + } + + task_data = { + 'image_quality': 75, + 'client_files': generate_image_files(3), + } + + task_id = self._test_create_task(self._USERNAME, spec, task_data, + content_type="multipart/form-data") + + response = get_method(self._USERNAME, f"tasks/{task_id}") + label_ids = {} + for label in response.json()["labels"]: + label_ids.setdefault(label["type"], []).append(label["id"]) + + job_id = response.json()["segments"][0]["jobs"][0]["id"] + patch_data = { + "shapes": [{ + "type": "skeleton", + "occluded": False, + "outside": False, + "z_order": 0, + "rotation": 0, + "points": [], + "frame": 0, + "label_id": label_ids["skeleton"][0], + "group": 0, + "source": "manual", + "attributes": [], + "elements": [ + { + "type": "points", + "occluded": False, + "outside": False, + "z_order": 0, + "rotation": 0, + "points": [ + 131.63947368421032, + 165.0868421052637 + ], + "frame": 0, + "label_id": label_ids["points"][0], + "group": 0, + "source": "manual", + "attributes": [] + }, + { + "type": "points", + "occluded": False, + "outside": False, + "z_order": 0, + "rotation": 0, + "points": [ + 354.98157894736823, + 304.2710526315795 + ], + "frame": 0, + "label_id": label_ids["points"][1], + "group": 0, + "source": "manual", + "attributes": [] + } + ] + }], + "tracks": [{ + "frame": 0, + "label_id": label_ids["skeleton"][0], + "group": 0, + "source": "manual", + "shapes": [ + { + "type": "skeleton", + "occluded": False, + "outside": False, + "z_order": 0, + "rotation": 0, + "points": [], + "frame": 0, + "attributes": [] + } + ], + "attributes": [], + "elements": [ + { + "frame": 0, + "label_id": label_ids["points"][0], + "group": 0, + "source": "manual", + "shapes": [ + { + "type": "points", + "occluded": False, + "outside": False, + "z_order": 0, + "rotation": 0, + "points": [ + 295.6394736842103, + 472.5868421052637 + ], + "frame": 0, + "attributes": [] + } + ], + "attributes": [] + }, + { + "frame": 0, + "label_id": label_ids["points"][1], + "group": 0, + "source": "manual", + "shapes": [ + { + "type": "points", + "occluded": False, + "outside": False, + "z_order": 0, + "rotation": 0, + "points": [ + 619.3236842105262, + 846.9815789473689 + ], + "frame": 0, + "attributes": [] + } + ], + "attributes": [] + } + ] + }], + "tags": [], + "version": 0 + } + + response = patch_method(self._USERNAME, f"jobs/{job_id}/annotations", patch_data, action="create") + response = get_method(self._USERNAME, f"jobs/{job_id}/annotations") + assert response.status_code == HTTPStatus.OK + @pytest.mark.parametrize('cloud_storage_id, manifest, use_bucket_content, org', [ (1, 'manifest.jsonl', False, ''), # public bucket (2, 'sub/manifest.jsonl', True, 'org2'), # private bucket