You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

742 lines
26 KiB
Python

# Copyright (C) 2019-2020 Intel Corporation
#
# SPDX-License-Identifier: MIT
from collections import OrderedDict
from enum import Enum
from django.conf import settings
from django.db import transaction
from django.utils import timezone
from cvat.apps.engine import models, serializers
from cvat.apps.engine.plugins import plugin_decorator
from cvat.apps.profiler import silk_profile
from .annotation import AnnotationIR, AnnotationManager
from .bindings import TaskData
from .formats.registry import make_exporter, make_importer
class dotdict(OrderedDict):
"""dot.notation access to dictionary attributes"""
__getattr__ = OrderedDict.get
__setattr__ = OrderedDict.__setitem__
__delattr__ = OrderedDict.__delitem__
__eq__ = lambda self, other: self.id == other.id
__hash__ = lambda self: self.id
class PatchAction(str, Enum):
CREATE = "create"
UPDATE = "update"
DELETE = "delete"
@classmethod
def values(cls):
return [item.value for item in cls]
def __str__(self):
return self.value
def bulk_create(db_model, objects, flt_param):
if objects:
if flt_param:
if 'postgresql' in settings.DATABASES["default"]["ENGINE"]:
return db_model.objects.bulk_create(objects)
else:
ids = list(db_model.objects.filter(**flt_param).values_list('id', flat=True))
db_model.objects.bulk_create(objects)
return list(db_model.objects.exclude(id__in=ids).filter(**flt_param))
else:
return db_model.objects.bulk_create(objects)
return []
def _merge_table_rows(rows, keys_for_merge, field_id):
# It is necessary to keep a stable order of original rows
# (e.g. for tracked boxes). Otherwise prev_box.frame can be bigger
# than next_box.frame.
merged_rows = OrderedDict()
# Group all rows by field_id. In grouped rows replace fields in
# accordance with keys_for_merge structure.
for row in rows:
row_id = row[field_id]
if not row_id in merged_rows:
merged_rows[row_id] = dotdict(row)
for key in keys_for_merge:
merged_rows[row_id][key] = []
for key in keys_for_merge:
item = dotdict({v.split('__', 1)[-1]:row[v] for v in keys_for_merge[key]})
if item.id is not None:
merged_rows[row_id][key].append(item)
# Remove redundant keys from final objects
redundant_keys = [item for values in keys_for_merge.values() for item in values]
for i in merged_rows:
for j in redundant_keys:
del merged_rows[i][j]
return list(merged_rows.values())
class JobAnnotation:
def __init__(self, pk):
self.db_job = models.Job.objects.select_related('segment__task') \
.select_for_update().get(id=pk)
db_segment = self.db_job.segment
self.start_frame = db_segment.start_frame
self.stop_frame = db_segment.stop_frame
self.ir_data = AnnotationIR()
self.db_labels = {db_label.id:db_label
for db_label in db_segment.task.label_set.all()}
self.db_attributes = {}
for db_label in self.db_labels.values():
self.db_attributes[db_label.id] = {
"mutable": OrderedDict(),
"immutable": OrderedDict(),
"all": OrderedDict(),
}
for db_attr in db_label.attributespec_set.all():
default_value = dotdict([
('spec_id', db_attr.id),
('value', db_attr.default_value),
])
if db_attr.mutable:
self.db_attributes[db_label.id]["mutable"][db_attr.id] = default_value
else:
self.db_attributes[db_label.id]["immutable"][db_attr.id] = default_value
self.db_attributes[db_label.id]["all"][db_attr.id] = default_value
def reset(self):
self.ir_data.reset()
def _save_tracks_to_db(self, tracks):
db_tracks = []
db_track_attrvals = []
db_shapes = []
db_shape_attrvals = []
for track in tracks:
track_attributes = track.pop("attributes", [])
shapes = track.pop("shapes")
db_track = models.LabeledTrack(job=self.db_job, **track)
if db_track.label_id not in self.db_labels:
raise AttributeError("label_id `{}` is invalid".format(db_track.label_id))
for attr in track_attributes:
db_attrval = models.LabeledTrackAttributeVal(**attr)
if db_attrval.spec_id not in self.db_attributes[db_track.label_id]["immutable"]:
raise AttributeError("spec_id `{}` is invalid".format(db_attrval.spec_id))
db_attrval.track_id = len(db_tracks)
db_track_attrvals.append(db_attrval)
for shape in shapes:
shape_attributes = shape.pop("attributes", [])
# FIXME: need to clamp points (be sure that all of them inside the image)
# Should we check here or implement a validator?
db_shape = models.TrackedShape(**shape)
db_shape.track_id = len(db_tracks)
for attr in shape_attributes:
db_attrval = models.TrackedShapeAttributeVal(**attr)
if db_attrval.spec_id not in self.db_attributes[db_track.label_id]["mutable"]:
raise AttributeError("spec_id `{}` is invalid".format(db_attrval.spec_id))
db_attrval.shape_id = len(db_shapes)
db_shape_attrvals.append(db_attrval)
db_shapes.append(db_shape)
shape["attributes"] = shape_attributes
db_tracks.append(db_track)
track["attributes"] = track_attributes
track["shapes"] = shapes
db_tracks = bulk_create(
db_model=models.LabeledTrack,
objects=db_tracks,
flt_param={"job_id": self.db_job.id}
)
for db_attrval in db_track_attrvals:
db_attrval.track_id = db_tracks[db_attrval.track_id].id
bulk_create(
db_model=models.LabeledTrackAttributeVal,
objects=db_track_attrvals,
flt_param={}
)
for db_shape in db_shapes:
db_shape.track_id = db_tracks[db_shape.track_id].id
db_shapes = bulk_create(
db_model=models.TrackedShape,
objects=db_shapes,
flt_param={"track__job_id": self.db_job.id}
)
for db_attrval in db_shape_attrvals:
db_attrval.shape_id = db_shapes[db_attrval.shape_id].id
bulk_create(
db_model=models.TrackedShapeAttributeVal,
objects=db_shape_attrvals,
flt_param={}
)
shape_idx = 0
for track, db_track in zip(tracks, db_tracks):
track["id"] = db_track.id
for shape in track["shapes"]:
shape["id"] = db_shapes[shape_idx].id
shape_idx += 1
self.ir_data.tracks = tracks
def _save_shapes_to_db(self, shapes):
db_shapes = []
db_attrvals = []
for shape in shapes:
attributes = shape.pop("attributes", [])
# FIXME: need to clamp points (be sure that all of them inside the image)
# Should we check here or implement a validator?
db_shape = models.LabeledShape(job=self.db_job, **shape)
if db_shape.label_id not in self.db_labels:
raise AttributeError("label_id `{}` is invalid".format(db_shape.label_id))
for attr in attributes:
db_attrval = models.LabeledShapeAttributeVal(**attr)
if db_attrval.spec_id not in self.db_attributes[db_shape.label_id]["all"]:
raise AttributeError("spec_id `{}` is invalid".format(db_attrval.spec_id))
db_attrval.shape_id = len(db_shapes)
db_attrvals.append(db_attrval)
db_shapes.append(db_shape)
shape["attributes"] = attributes
db_shapes = bulk_create(
db_model=models.LabeledShape,
objects=db_shapes,
flt_param={"job_id": self.db_job.id}
)
for db_attrval in db_attrvals:
db_attrval.shape_id = db_shapes[db_attrval.shape_id].id
bulk_create(
db_model=models.LabeledShapeAttributeVal,
objects=db_attrvals,
flt_param={}
)
for shape, db_shape in zip(shapes, db_shapes):
shape["id"] = db_shape.id
self.ir_data.shapes = shapes
def _save_tags_to_db(self, tags):
db_tags = []
db_attrvals = []
for tag in tags:
attributes = tag.pop("attributes", [])
db_tag = models.LabeledImage(job=self.db_job, **tag)
if db_tag.label_id not in self.db_labels:
raise AttributeError("label_id `{}` is invalid".format(db_tag.label_id))
for attr in attributes:
db_attrval = models.LabeledImageAttributeVal(**attr)
if db_attrval.spec_id not in self.db_attributes[db_tag.label_id]["all"]:
raise AttributeError("spec_id `{}` is invalid".format(db_attrval.spec_id))
db_attrval.tag_id = len(db_tags)
db_attrvals.append(db_attrval)
db_tags.append(db_tag)
tag["attributes"] = attributes
db_tags = bulk_create(
db_model=models.LabeledImage,
objects=db_tags,
flt_param={"job_id": self.db_job.id}
)
for db_attrval in db_attrvals:
db_attrval.image_id = db_tags[db_attrval.tag_id].id
bulk_create(
db_model=models.LabeledImageAttributeVal,
objects=db_attrvals,
flt_param={}
)
for tag, db_tag in zip(tags, db_tags):
tag["id"] = db_tag.id
self.ir_data.tags = tags
def _commit(self):
db_prev_commit = self.db_job.commits.last()
db_curr_commit = models.JobCommit()
if db_prev_commit:
db_curr_commit.version = db_prev_commit.version + 1
else:
db_curr_commit.version = 1
db_curr_commit.job = self.db_job
db_curr_commit.message = "Changes: tags - {}; shapes - {}; tracks - {}".format(
len(self.ir_data.tags), len(self.ir_data.shapes), len(self.ir_data.tracks))
db_curr_commit.save()
self.ir_data.version = db_curr_commit.version
def _set_updated_date(self):
db_task = self.db_job.segment.task
db_task.updated_date = timezone.now()
db_task.save()
def _save_to_db(self, data):
self.reset()
self._save_tags_to_db(data["tags"])
self._save_shapes_to_db(data["shapes"])
self._save_tracks_to_db(data["tracks"])
return self.ir_data.tags or self.ir_data.shapes or self.ir_data.tracks
def _create(self, data):
if self._save_to_db(data):
self._set_updated_date()
self.db_job.save()
def create(self, data):
self._create(data)
self._commit()
def put(self, data):
self._delete()
self._create(data)
self._commit()
def update(self, data):
self._delete(data)
self._create(data)
self._commit()
def _delete(self, data=None):
deleted_shapes = 0
if data is None:
deleted_shapes += self.db_job.labeledimage_set.all().delete()[0]
deleted_shapes += self.db_job.labeledshape_set.all().delete()[0]
deleted_shapes += self.db_job.labeledtrack_set.all().delete()[0]
else:
labeledimage_ids = [image["id"] for image in data["tags"]]
labeledshape_ids = [shape["id"] for shape in data["shapes"]]
labeledtrack_ids = [track["id"] for track in data["tracks"]]
labeledimage_set = self.db_job.labeledimage_set
labeledimage_set = labeledimage_set.filter(pk__in=labeledimage_ids)
labeledshape_set = self.db_job.labeledshape_set
labeledshape_set = labeledshape_set.filter(pk__in=labeledshape_ids)
labeledtrack_set = self.db_job.labeledtrack_set
labeledtrack_set = labeledtrack_set.filter(pk__in=labeledtrack_ids)
# It is not important for us that data had some "invalid" objects
# which were skipped (not acutally deleted). The main idea is to
# say that all requested objects are absent in DB after the method.
self.ir_data.tags = data['tags']
self.ir_data.shapes = data['shapes']
self.ir_data.tracks = data['tracks']
deleted_shapes += labeledimage_set.delete()[0]
deleted_shapes += labeledshape_set.delete()[0]
deleted_shapes += labeledtrack_set.delete()[0]
if deleted_shapes:
self._set_updated_date()
def delete(self, data=None):
self._delete(data)
self._commit()
@staticmethod
def _extend_attributes(attributeval_set, default_attribute_values):
shape_attribute_specs_set = set(attr.spec_id for attr in attributeval_set)
for db_attr in default_attribute_values:
if db_attr.spec_id not in shape_attribute_specs_set:
attributeval_set.append(dotdict([
('spec_id', db_attr.spec_id),
('value', db_attr.value),
]))
def _init_tags_from_db(self):
db_tags = self.db_job.labeledimage_set.prefetch_related(
"label",
"labeledimageattributeval_set"
).values(
'id',
'frame',
'label_id',
'group',
'labeledimageattributeval__spec_id',
'labeledimageattributeval__value',
'labeledimageattributeval__id',
).order_by('frame')
db_tags = _merge_table_rows(
rows=db_tags,
keys_for_merge={
"labeledimageattributeval_set": [
'labeledimageattributeval__spec_id',
'labeledimageattributeval__value',
'labeledimageattributeval__id',
],
},
field_id='id',
)
for db_tag in db_tags:
self._extend_attributes(db_tag.labeledimageattributeval_set,
self.db_attributes[db_tag.label_id]["all"].values())
serializer = serializers.LabeledImageSerializer(db_tags, many=True)
self.ir_data.tags = serializer.data
def _init_shapes_from_db(self):
db_shapes = self.db_job.labeledshape_set.prefetch_related(
"label",
"labeledshapeattributeval_set"
).values(
'id',
'label_id',
'type',
'frame',
'group',
'occluded',
'z_order',
'points',
'labeledshapeattributeval__spec_id',
'labeledshapeattributeval__value',
'labeledshapeattributeval__id',
).order_by('frame')
db_shapes = _merge_table_rows(
rows=db_shapes,
keys_for_merge={
'labeledshapeattributeval_set': [
'labeledshapeattributeval__spec_id',
'labeledshapeattributeval__value',
'labeledshapeattributeval__id',
],
},
field_id='id',
)
for db_shape in db_shapes:
self._extend_attributes(db_shape.labeledshapeattributeval_set,
self.db_attributes[db_shape.label_id]["all"].values())
serializer = serializers.LabeledShapeSerializer(db_shapes, many=True)
self.ir_data.shapes = serializer.data
def _init_tracks_from_db(self):
db_tracks = self.db_job.labeledtrack_set.prefetch_related(
"label",
"labeledtrackattributeval_set",
"trackedshape_set__trackedshapeattributeval_set"
).values(
"id",
"frame",
"label_id",
"group",
"labeledtrackattributeval__spec_id",
"labeledtrackattributeval__value",
"labeledtrackattributeval__id",
"trackedshape__type",
"trackedshape__occluded",
"trackedshape__z_order",
"trackedshape__points",
"trackedshape__id",
"trackedshape__frame",
"trackedshape__outside",
"trackedshape__trackedshapeattributeval__spec_id",
"trackedshape__trackedshapeattributeval__value",
"trackedshape__trackedshapeattributeval__id",
).order_by('id', 'trackedshape__frame')
db_tracks = _merge_table_rows(
rows=db_tracks,
keys_for_merge={
"labeledtrackattributeval_set": [
"labeledtrackattributeval__spec_id",
"labeledtrackattributeval__value",
"labeledtrackattributeval__id",
],
"trackedshape_set":[
"trackedshape__type",
"trackedshape__occluded",
"trackedshape__z_order",
"trackedshape__points",
"trackedshape__id",
"trackedshape__frame",
"trackedshape__outside",
"trackedshape__trackedshapeattributeval__spec_id",
"trackedshape__trackedshapeattributeval__value",
"trackedshape__trackedshapeattributeval__id",
],
},
field_id="id",
)
for db_track in db_tracks:
db_track["trackedshape_set"] = _merge_table_rows(db_track["trackedshape_set"], {
'trackedshapeattributeval_set': [
'trackedshapeattributeval__value',
'trackedshapeattributeval__spec_id',
'trackedshapeattributeval__id',
]
}, 'id')
# A result table can consist many equal rows for track/shape attributes
# We need filter unique attributes manually
db_track["labeledtrackattributeval_set"] = list(set(db_track["labeledtrackattributeval_set"]))
self._extend_attributes(db_track.labeledtrackattributeval_set,
self.db_attributes[db_track.label_id]["immutable"].values())
default_attribute_values = self.db_attributes[db_track.label_id]["mutable"].values()
for db_shape in db_track["trackedshape_set"]:
db_shape["trackedshapeattributeval_set"] = list(
set(db_shape["trackedshapeattributeval_set"])
)
# in case of trackedshapes need to interpolate attriute values and extend it
# by previous shape attribute values (not default values)
self._extend_attributes(db_shape["trackedshapeattributeval_set"], default_attribute_values)
default_attribute_values = db_shape["trackedshapeattributeval_set"]
serializer = serializers.LabeledTrackSerializer(db_tracks, many=True)
self.ir_data.tracks = serializer.data
def _init_version_from_db(self):
db_commit = self.db_job.commits.last()
self.ir_data.version = db_commit.version if db_commit else 0
def init_from_db(self):
self._init_tags_from_db()
self._init_shapes_from_db()
self._init_tracks_from_db()
self._init_version_from_db()
@property
def data(self):
return self.ir_data.data
def import_annotations(self, src_file, importer):
task_data = TaskData(
annotation_ir=AnnotationIR(),
db_task=self.db_job.segment.task,
create_callback=self.create,
)
self.delete()
importer(src_file, task_data)
self.create(task_data.data.slice(self.start_frame, self.stop_frame).serialize())
class TaskAnnotation:
def __init__(self, pk):
self.db_task = models.Task.objects.prefetch_related("data__images").get(id=pk)
# Postgres doesn't guarantee an order by default without explicit order_by
self.db_jobs = models.Job.objects.select_related("segment").filter(segment__task_id=pk).order_by('id')
self.ir_data = AnnotationIR()
def reset(self):
self.ir_data.reset()
def _patch_data(self, data, action):
_data = data if isinstance(data, AnnotationIR) else AnnotationIR(data)
splitted_data = {}
jobs = {}
for db_job in self.db_jobs:
jid = db_job.id
start = db_job.segment.start_frame
stop = db_job.segment.stop_frame
jobs[jid] = { "start": start, "stop": stop }
splitted_data[jid] = _data.slice(start, stop)
for jid, job_data in splitted_data.items():
_data = AnnotationIR()
if action is None:
_data.data = put_job_data(jid, job_data)
else:
_data.data = patch_job_data(jid, job_data, action)
if _data.version > self.ir_data.version:
self.ir_data.version = _data.version
self._merge_data(_data, jobs[jid]["start"], self.db_task.overlap)
def _merge_data(self, data, start_frame, overlap):
annotation_manager = AnnotationManager(self.ir_data)
annotation_manager.merge(data, start_frame, overlap)
def put(self, data):
self._patch_data(data, None)
def create(self, data):
self._patch_data(data, PatchAction.CREATE)
def update(self, data):
self._patch_data(data, PatchAction.UPDATE)
def delete(self, data=None):
if data:
self._patch_data(data, PatchAction.DELETE)
else:
for db_job in self.db_jobs:
delete_job_data(db_job.id)
def init_from_db(self):
self.reset()
for db_job in self.db_jobs:
annotation = JobAnnotation(db_job.id)
annotation.init_from_db()
if annotation.ir_data.version > self.ir_data.version:
self.ir_data.version = annotation.ir_data.version
db_segment = db_job.segment
start_frame = db_segment.start_frame
overlap = self.db_task.overlap
self._merge_data(annotation.ir_data, start_frame, overlap)
def export(self, dst_file, exporter, host='', **options):
task_data = TaskData(
annotation_ir=self.ir_data,
db_task=self.db_task,
host=host,
)
exporter(dst_file, task_data, **options)
def import_annotations(self, src_file, importer, **options):
task_data = TaskData(
annotation_ir=AnnotationIR(),
db_task=self.db_task,
create_callback=self.create,
)
self.delete()
importer(src_file, task_data, **options)
self.create(task_data.data.serialize())
@property
def data(self):
return self.ir_data.data
@silk_profile(name="GET job data")
@transaction.atomic
def get_job_data(pk):
annotation = JobAnnotation(pk)
annotation.init_from_db()
return annotation.data
@silk_profile(name="POST job data")
@transaction.atomic
def put_job_data(pk, data):
annotation = JobAnnotation(pk)
annotation.put(data)
return annotation.data
@silk_profile(name="UPDATE job data")
@plugin_decorator
@transaction.atomic
def patch_job_data(pk, data, action):
annotation = JobAnnotation(pk)
if action == PatchAction.CREATE:
annotation.create(data)
elif action == PatchAction.UPDATE:
annotation.update(data)
elif action == PatchAction.DELETE:
annotation.delete(data)
return annotation.data
@silk_profile(name="DELETE job data")
@transaction.atomic
def delete_job_data(pk):
annotation = JobAnnotation(pk)
annotation.delete()
@silk_profile(name="GET task data")
@transaction.atomic
def get_task_data(pk):
annotation = TaskAnnotation(pk)
annotation.init_from_db()
return annotation.data
@silk_profile(name="POST task data")
@transaction.atomic
def put_task_data(pk, data):
annotation = TaskAnnotation(pk)
annotation.put(data)
return annotation.data
@silk_profile(name="UPDATE task data")
@transaction.atomic
def patch_task_data(pk, data, action):
annotation = TaskAnnotation(pk)
if action == PatchAction.CREATE:
annotation.create(data)
elif action == PatchAction.UPDATE:
annotation.update(data)
elif action == PatchAction.DELETE:
annotation.delete(data)
return annotation.data
@silk_profile(name="DELETE task data")
@transaction.atomic
def delete_task_data(pk):
annotation = TaskAnnotation(pk)
annotation.delete()
def export_task(task_id, dst_file, format_name,
server_url=None, save_images=False):
# For big tasks dump function may run for a long time and
# we dont need to acquire lock after the task has been initialized from DB.
# But there is the bug with corrupted dump file in case 2 or
# more dump request received at the same time:
# https://github.com/opencv/cvat/issues/217
with transaction.atomic():
task = TaskAnnotation(task_id)
task.init_from_db()
exporter = make_exporter(format_name)
with open(dst_file, 'wb') as f:
task.export(f, exporter, host=server_url,
save_images=save_images)
@transaction.atomic
def import_task_annotations(task_id, src_file, format_name):
task = TaskAnnotation(task_id)
task.init_from_db()
importer = make_importer(format_name)
with open(src_file, 'rb') as f:
task.import_annotations(f, importer)
@transaction.atomic
def import_job_annotations(job_id, src_file, format_name):
job = JobAnnotation(job_id)
job.init_from_db()
importer = make_importer(format_name)
with open(src_file, 'rb') as f:
job.import_annotations(f, importer)