Az/fix no dump default attrs (#656)

* fill absent attributes by default values during annotation save
* fill absent attributes by default values during init from db
* fixed tests
* updated changelog, added some coments, minor fixes
main
Andrey Zhavoronkov 7 years ago committed by Nikita Manovich
parent 7fb7ba150e
commit fc2b9c94cc

@ -39,7 +39,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Installation of CVAT with OpenVINO on the Windows platform - Installation of CVAT with OpenVINO on the Windows platform
- Background color was always black in utils/mask/converter.py - Background color was always black in utils/mask/converter.py
- Exception in attribute annotation mode when a label are switched to a value without any attributes - Exception in attribute annotation mode when a label are switched to a value without any attributes
- Handling of wrong labelamp json file in auto annotation (https://github.com/opencv/cvat/issues/554) - Handling of wrong labelamp json file in auto annotation (<https://github.com/opencv/cvat/issues/554>)
- No default attributes in dumped annotation (<https://github.com/opencv/cvat/issues/601>)
### Security ### Security
- -

@ -4,6 +4,7 @@
import os import os
from enum import Enum from enum import Enum
from collections import OrderedDict
from django.utils import timezone from django.utils import timezone
from PIL import Image from PIL import Image
@ -192,9 +193,21 @@ class JobAnnotation:
self.logger = slogger.job[self.db_job.id] self.logger = slogger.job[self.db_job.id]
self.db_labels = {db_label.id:db_label self.db_labels = {db_label.id:db_label
for db_label in db_segment.task.label_set.all()} for db_label in db_segment.task.label_set.all()}
self.db_attributes = {db_attr.id:db_attr
for db_attr in models.AttributeSpec.objects.filter( self.db_attributes = {}
label__task__id=db_segment.task.id)} for db_label in self.db_labels.values():
self.db_attributes[db_label.id] = {
"mutable": OrderedDict(),
"immutable": OrderedDict(),
"all": OrderedDict(),
}
for db_attr in db_label.attributespec_set.all():
if db_attr.mutable:
self.db_attributes[db_label.id]["mutable"][db_attr.id] = db_attr
else:
self.db_attributes[db_label.id]["immutable"][db_attr.id] = db_attr
self.db_attributes[db_label.id]["all"][db_attr.id] = db_attr
def reset(self): def reset(self):
self.ir_data.reset() self.ir_data.reset()
@ -214,7 +227,7 @@ class JobAnnotation:
for attr in track_attributes: for attr in track_attributes:
db_attrval = models.LabeledTrackAttributeVal(**attr) db_attrval = models.LabeledTrackAttributeVal(**attr)
if db_attrval.spec_id not in self.db_attributes: if db_attrval.spec_id not in self.db_attributes[db_track.label_id]["immutable"]:
raise AttributeError("spec_id `{}` is invalid".format(db_attrval.spec_id)) raise AttributeError("spec_id `{}` is invalid".format(db_attrval.spec_id))
db_attrval.track_id = len(db_tracks) db_attrval.track_id = len(db_tracks)
db_track_attrvals.append(db_attrval) db_track_attrvals.append(db_attrval)
@ -228,7 +241,7 @@ class JobAnnotation:
for attr in shape_attributes: for attr in shape_attributes:
db_attrval = models.TrackedShapeAttributeVal(**attr) db_attrval = models.TrackedShapeAttributeVal(**attr)
if db_attrval.spec_id not in self.db_attributes: if db_attrval.spec_id not in self.db_attributes[db_track.label_id]["mutable"]:
raise AttributeError("spec_id `{}` is invalid".format(db_attrval.spec_id)) raise AttributeError("spec_id `{}` is invalid".format(db_attrval.spec_id))
db_attrval.shape_id = len(db_shapes) db_attrval.shape_id = len(db_shapes)
db_shape_attrvals.append(db_attrval) db_shape_attrvals.append(db_attrval)
@ -295,8 +308,9 @@ class JobAnnotation:
for attr in attributes: for attr in attributes:
db_attrval = models.LabeledShapeAttributeVal(**attr) db_attrval = models.LabeledShapeAttributeVal(**attr)
if db_attrval.spec_id not in self.db_attributes: if db_attrval.spec_id not in self.db_attributes[db_shape.label_id]["all"]:
raise AttributeError("spec_id `{}` is invalid".format(db_attrval.spec_id)) raise AttributeError("spec_id `{}` is invalid".format(db_attrval.spec_id))
db_attrval.shape_id = len(db_shapes) db_attrval.shape_id = len(db_shapes)
db_attrvals.append(db_attrval) db_attrvals.append(db_attrval)
@ -335,7 +349,7 @@ class JobAnnotation:
for attr in attributes: for attr in attributes:
db_attrval = models.LabeledImageAttributeVal(**attr) db_attrval = models.LabeledImageAttributeVal(**attr)
if db_attrval.spec_id not in self.db_attributes: if db_attrval.spec_id not in self.db_attributes[db_tag.label_id]["all"]:
raise AttributeError("spec_id `{}` is invalid".format(db_attrval.spec_id)) raise AttributeError("spec_id `{}` is invalid".format(db_attrval.spec_id))
db_attrval.tag_id = len(db_tags) db_attrval.tag_id = len(db_tags)
db_attrvals.append(db_attrval) db_attrvals.append(db_attrval)
@ -350,7 +364,7 @@ class JobAnnotation:
) )
for db_attrval in db_attrvals: for db_attrval in db_attrvals:
db_attrval.tag_id = db_tags[db_attrval.tag_id].id db_attrval.image_id = db_tags[db_attrval.tag_id].id
bulk_create( bulk_create(
db_model=models.LabeledImageAttributeVal, db_model=models.LabeledImageAttributeVal,
@ -436,6 +450,16 @@ class JobAnnotation:
self._delete(data) self._delete(data)
self._commit() self._commit()
@staticmethod
def _extend_attributes(attributeval_set, attribute_specs):
shape_attribute_specs_set = set(attr.spec_id for attr in attributeval_set)
for db_attr_spec in attribute_specs:
if db_attr_spec.id not in shape_attribute_specs_set:
attributeval_set.append(OrderedDict([
('spec_id', db_attr_spec.id),
('value', db_attr_spec.default_value),
]))
def _init_tags_from_db(self): def _init_tags_from_db(self):
db_tags = self.db_job.labeledimage_set.prefetch_related( db_tags = self.db_job.labeledimage_set.prefetch_related(
"label", "label",
@ -461,6 +485,11 @@ class JobAnnotation:
}, },
field_id='id', field_id='id',
) )
for db_tag in db_tags:
self._extend_attributes(db_tag.labeledimageattributeval_set,
self.db_attributes[db_tag.label_id]["all"].values())
serializer = serializers.LabeledImageSerializer(db_tags, many=True) serializer = serializers.LabeledImageSerializer(db_tags, many=True)
self.ir_data.tags = serializer.data self.ir_data.tags = serializer.data
@ -493,6 +522,9 @@ class JobAnnotation:
}, },
field_id='id', field_id='id',
) )
for db_shape in db_shapes:
self._extend_attributes(db_shape.labeledshapeattributeval_set,
self.db_attributes[db_shape.label_id]["all"].values())
serializer = serializers.LabeledShapeSerializer(db_shapes, many=True) serializer = serializers.LabeledShapeSerializer(db_shapes, many=True)
self.ir_data.shapes = serializer.data self.ir_data.shapes = serializer.data
@ -558,10 +590,15 @@ class JobAnnotation:
# A result table can consist many equal rows for track/shape attributes # A result table can consist many equal rows for track/shape attributes
# We need filter unique attributes manually # We need filter unique attributes manually
db_track["labeledtrackattributeval_set"] = list(set(db_track["labeledtrackattributeval_set"])) db_track["labeledtrackattributeval_set"] = list(set(db_track["labeledtrackattributeval_set"]))
self._extend_attributes(db_track.labeledtrackattributeval_set,
self.db_attributes[db_track.label_id]["immutable"].values())
for db_shape in db_track["trackedshape_set"]: for db_shape in db_track["trackedshape_set"]:
db_shape["trackedshapeattributeval_set"] = list( db_shape["trackedshapeattributeval_set"] = list(
set(db_shape["trackedshapeattributeval_set"]) set(db_shape["trackedshapeattributeval_set"])
) )
self._extend_attributes(db_shape["trackedshapeattributeval_set"],
self.db_attributes[db_track.label_id]["mutable"].values())
serializer = serializers.LabeledTrackSerializer(db_tracks, many=True) serializer = serializers.LabeledTrackSerializer(db_tracks, many=True)
self.ir_data.tracks = serializer.data self.ir_data.tracks = serializer.data

@ -1177,7 +1177,7 @@ class JobAnnotationAPITestCase(APITestCase):
"mutable": False, "mutable": False,
"input_type": "select", "input_type": "select",
"default_value": "mazda", "default_value": "mazda",
"values": ["bmw", "mazda", "reno"] "values": ["bmw", "mazda", "renault"]
}, },
{ {
"name": "parked", "name": "parked",
@ -1212,6 +1212,27 @@ class JobAnnotationAPITestCase(APITestCase):
return (task, jobs) return (task, jobs)
@staticmethod
def _get_default_attr_values(task):
default_attr_values = {}
for label in task["labels"]:
default_attr_values[label["id"]] = {
"mutable": [],
"immutable": [],
"all": [],
}
for attr in label["attributes"]:
default_value = {
"spec_id": attr["id"],
"value": attr["default_value"],
}
if attr["mutable"]:
default_attr_values[label["id"]]["mutable"].append(default_value)
else:
default_attr_values[label["id"]]["immutable"].append(default_value)
default_attr_values[label["id"]]["all"].append(default_value)
return default_attr_values
def _put_api_v1_jobs_id_data(self, jid, user, data): def _put_api_v1_jobs_id_data(self, jid, user, data):
with ForceLogin(user, self.client): with ForceLogin(user, self.client):
response = self.client.put("/api/v1/jobs/{}/annotations".format(jid), response = self.client.put("/api/v1/jobs/{}/annotations".format(jid),
@ -1288,7 +1309,7 @@ class JobAnnotationAPITestCase(APITestCase):
}, },
{ {
"spec_id": task["labels"][0]["attributes"][1]["id"], "spec_id": task["labels"][0]["attributes"][1]["id"],
"value": task["labels"][0]["attributes"][0]["default_value"] "value": task["labels"][0]["attributes"][1]["default_value"]
} }
], ],
"points": [1.0, 2.1, 100, 300.222], "points": [1.0, 2.1, 100, 300.222],
@ -1310,7 +1331,12 @@ class JobAnnotationAPITestCase(APITestCase):
"frame": 0, "frame": 0,
"label_id": task["labels"][0]["id"], "label_id": task["labels"][0]["id"],
"group": None, "group": None,
"attributes": [], "attributes": [
{
"spec_id": task["labels"][0]["attributes"][0]["id"],
"value": task["labels"][0]["attributes"][0]["values"][0]
},
],
"shapes": [ "shapes": [
{ {
"frame": 0, "frame": 0,
@ -1319,14 +1345,10 @@ class JobAnnotationAPITestCase(APITestCase):
"occluded": False, "occluded": False,
"outside": False, "outside": False,
"attributes": [ "attributes": [
{
"spec_id": task["labels"][0]["attributes"][0]["id"],
"value": task["labels"][0]["attributes"][0]["values"][0]
},
{ {
"spec_id": task["labels"][0]["attributes"][1]["id"], "spec_id": task["labels"][0]["attributes"][1]["id"],
"value": task["labels"][0]["attributes"][0]["default_value"] "value": task["labels"][0]["attributes"][1]["default_value"]
} },
] ]
}, },
{ {
@ -1357,6 +1379,8 @@ class JobAnnotationAPITestCase(APITestCase):
}, },
] ]
} }
default_attr_values = self._get_default_attr_values(task)
response = self._put_api_v1_jobs_id_data(job["id"], annotator, data) response = self._put_api_v1_jobs_id_data(job["id"], annotator, data)
data["version"] += 1 # need to update the version data["version"] += 1 # need to update the version
self.assertEqual(response.status_code, HTTP_200_OK) self.assertEqual(response.status_code, HTTP_200_OK)
@ -1364,6 +1388,9 @@ class JobAnnotationAPITestCase(APITestCase):
response = self._get_api_v1_jobs_id_data(job["id"], annotator) response = self._get_api_v1_jobs_id_data(job["id"], annotator)
self.assertEqual(response.status_code, HTTP_200_OK) self.assertEqual(response.status_code, HTTP_200_OK)
# server should add default attribute values if puted data doesn't contain it
data["tags"][0]["attributes"] = default_attr_values[data["tags"][0]["label_id"]]["all"]
data["tracks"][0]["shapes"][1]["attributes"] = default_attr_values[data["tracks"][0]["label_id"]]["mutable"]
self._check_response(response, data) self._check_response(response, data)
response = self._delete_api_v1_jobs_id_data(job["id"], annotator) response = self._delete_api_v1_jobs_id_data(job["id"], annotator)
@ -1402,7 +1429,7 @@ class JobAnnotationAPITestCase(APITestCase):
}, },
{ {
"spec_id": task["labels"][0]["attributes"][1]["id"], "spec_id": task["labels"][0]["attributes"][1]["id"],
"value": task["labels"][0]["attributes"][0]["default_value"] "value": task["labels"][0]["attributes"][1]["default_value"]
} }
], ],
"points": [1.0, 2.1, 100, 300.222], "points": [1.0, 2.1, 100, 300.222],
@ -1424,7 +1451,12 @@ class JobAnnotationAPITestCase(APITestCase):
"frame": 0, "frame": 0,
"label_id": task["labels"][0]["id"], "label_id": task["labels"][0]["id"],
"group": None, "group": None,
"attributes": [], "attributes": [
{
"spec_id": task["labels"][0]["attributes"][0]["id"],
"value": task["labels"][0]["attributes"][0]["values"][0]
},
],
"shapes": [ "shapes": [
{ {
"frame": 0, "frame": 0,
@ -1433,14 +1465,10 @@ class JobAnnotationAPITestCase(APITestCase):
"occluded": False, "occluded": False,
"outside": False, "outside": False,
"attributes": [ "attributes": [
{
"spec_id": task["labels"][0]["attributes"][0]["id"],
"value": task["labels"][0]["attributes"][0]["values"][0]
},
{ {
"spec_id": task["labels"][0]["attributes"][1]["id"], "spec_id": task["labels"][0]["attributes"][1]["id"],
"value": task["labels"][0]["attributes"][0]["default_value"] "value": task["labels"][0]["attributes"][1]["default_value"]
} },
] ]
}, },
{ {
@ -1479,6 +1507,9 @@ class JobAnnotationAPITestCase(APITestCase):
response = self._get_api_v1_jobs_id_data(job["id"], annotator) response = self._get_api_v1_jobs_id_data(job["id"], annotator)
self.assertEqual(response.status_code, HTTP_200_OK) self.assertEqual(response.status_code, HTTP_200_OK)
# server should add default attribute values if puted data doesn't contain it
data["tags"][0]["attributes"] = default_attr_values[data["tags"][0]["label_id"]]["all"]
data["tracks"][0]["shapes"][1]["attributes"] = default_attr_values[data["tracks"][0]["label_id"]]["mutable"]
self._check_response(response, data) self._check_response(response, data)
data = response.data data = response.data
@ -1576,7 +1607,7 @@ class JobAnnotationAPITestCase(APITestCase):
}, },
{ {
"spec_id": task["labels"][0]["attributes"][1]["id"], "spec_id": task["labels"][0]["attributes"][1]["id"],
"value": task["labels"][0]["attributes"][0]["default_value"] "value": task["labels"][0]["attributes"][1]["default_value"]
} }
] ]
}, },
@ -1733,7 +1764,12 @@ class TaskAnnotationAPITestCase(JobAnnotationAPITestCase):
"frame": 0, "frame": 0,
"label_id": task["labels"][0]["id"], "label_id": task["labels"][0]["id"],
"group": None, "group": None,
"attributes": [], "attributes": [
{
"spec_id": task["labels"][0]["attributes"][0]["id"],
"value": task["labels"][0]["attributes"][0]["values"][0]
},
],
"shapes": [ "shapes": [
{ {
"frame": 0, "frame": 0,
@ -1742,13 +1778,9 @@ class TaskAnnotationAPITestCase(JobAnnotationAPITestCase):
"occluded": False, "occluded": False,
"outside": False, "outside": False,
"attributes": [ "attributes": [
{
"spec_id": task["labels"][0]["attributes"][0]["id"],
"value": task["labels"][0]["attributes"][0]["values"][0]
},
{ {
"spec_id": task["labels"][0]["attributes"][1]["id"], "spec_id": task["labels"][0]["attributes"][1]["id"],
"value": task["labels"][0]["attributes"][0]["default_value"] "value": task["labels"][0]["attributes"][1]["default_value"]
} }
] ]
}, },
@ -1782,10 +1814,15 @@ class TaskAnnotationAPITestCase(JobAnnotationAPITestCase):
} }
response = self._put_api_v1_tasks_id_annotations(task["id"], annotator, data) response = self._put_api_v1_tasks_id_annotations(task["id"], annotator, data)
data["version"] += 1 data["version"] += 1
self.assertEqual(response.status_code, HTTP_200_OK) self.assertEqual(response.status_code, HTTP_200_OK)
self._check_response(response, data) self._check_response(response, data)
default_attr_values = self._get_default_attr_values(task)
response = self._get_api_v1_tasks_id_annotations(task["id"], annotator) response = self._get_api_v1_tasks_id_annotations(task["id"], annotator)
# server should add default attribute values if puted data doesn't contain it
data["tags"][0]["attributes"] = default_attr_values[data["tags"][0]["label_id"]]["all"]
data["tracks"][0]["shapes"][1]["attributes"] = default_attr_values[data["tracks"][0]["label_id"]]["mutable"]
self.assertEqual(response.status_code, HTTP_200_OK) self.assertEqual(response.status_code, HTTP_200_OK)
self._check_response(response, data) self._check_response(response, data)
@ -1847,7 +1884,12 @@ class TaskAnnotationAPITestCase(JobAnnotationAPITestCase):
"frame": 0, "frame": 0,
"label_id": task["labels"][0]["id"], "label_id": task["labels"][0]["id"],
"group": None, "group": None,
"attributes": [], "attributes": [
{
"spec_id": task["labels"][0]["attributes"][0]["id"],
"value": task["labels"][0]["attributes"][0]["values"][0]
},
],
"shapes": [ "shapes": [
{ {
"frame": 0, "frame": 0,
@ -1856,13 +1898,9 @@ class TaskAnnotationAPITestCase(JobAnnotationAPITestCase):
"occluded": False, "occluded": False,
"outside": False, "outside": False,
"attributes": [ "attributes": [
{
"spec_id": task["labels"][0]["attributes"][0]["id"],
"value": task["labels"][0]["attributes"][0]["values"][0]
},
{ {
"spec_id": task["labels"][0]["attributes"][1]["id"], "spec_id": task["labels"][0]["attributes"][1]["id"],
"value": task["labels"][0]["attributes"][0]["default_value"] "value": task["labels"][0]["attributes"][1]["default_value"]
} }
] ]
}, },
@ -1901,6 +1939,9 @@ class TaskAnnotationAPITestCase(JobAnnotationAPITestCase):
self._check_response(response, data) self._check_response(response, data)
response = self._get_api_v1_tasks_id_annotations(task["id"], annotator) response = self._get_api_v1_tasks_id_annotations(task["id"], annotator)
# server should add default attribute values if puted data doesn't contain it
data["tags"][0]["attributes"] = default_attr_values[data["tags"][0]["label_id"]]["all"]
data["tracks"][0]["shapes"][1]["attributes"] = default_attr_values[data["tracks"][0]["label_id"]]["mutable"]
self.assertEqual(response.status_code, HTTP_200_OK) self.assertEqual(response.status_code, HTTP_200_OK)
self._check_response(response, data) self._check_response(response, data)

Loading…
Cancel
Save