You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
742 lines
26 KiB
Python
742 lines
26 KiB
Python
|
|
# Copyright (C) 2019-2020 Intel Corporation
|
|
#
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
from collections import OrderedDict
|
|
from enum import Enum
|
|
|
|
from django.conf import settings
|
|
from django.db import transaction
|
|
from django.utils import timezone
|
|
|
|
from cvat.apps.engine import models, serializers
|
|
from cvat.apps.engine.plugins import plugin_decorator
|
|
from cvat.apps.profiler import silk_profile
|
|
|
|
from .annotation import AnnotationIR, AnnotationManager
|
|
from .bindings import TaskData
|
|
from .formats.registry import make_exporter, make_importer
|
|
|
|
|
|
class dotdict(OrderedDict):
|
|
"""dot.notation access to dictionary attributes"""
|
|
__getattr__ = OrderedDict.get
|
|
__setattr__ = OrderedDict.__setitem__
|
|
__delattr__ = OrderedDict.__delitem__
|
|
__eq__ = lambda self, other: self.id == other.id
|
|
__hash__ = lambda self: self.id
|
|
|
|
class PatchAction(str, Enum):
|
|
CREATE = "create"
|
|
UPDATE = "update"
|
|
DELETE = "delete"
|
|
|
|
@classmethod
|
|
def values(cls):
|
|
return [item.value for item in cls]
|
|
|
|
def __str__(self):
|
|
return self.value
|
|
|
|
def bulk_create(db_model, objects, flt_param):
|
|
if objects:
|
|
if flt_param:
|
|
if 'postgresql' in settings.DATABASES["default"]["ENGINE"]:
|
|
return db_model.objects.bulk_create(objects)
|
|
else:
|
|
ids = list(db_model.objects.filter(**flt_param).values_list('id', flat=True))
|
|
db_model.objects.bulk_create(objects)
|
|
|
|
return list(db_model.objects.exclude(id__in=ids).filter(**flt_param))
|
|
else:
|
|
return db_model.objects.bulk_create(objects)
|
|
|
|
return []
|
|
|
|
def _merge_table_rows(rows, keys_for_merge, field_id):
|
|
# It is necessary to keep a stable order of original rows
|
|
# (e.g. for tracked boxes). Otherwise prev_box.frame can be bigger
|
|
# than next_box.frame.
|
|
merged_rows = OrderedDict()
|
|
|
|
# Group all rows by field_id. In grouped rows replace fields in
|
|
# accordance with keys_for_merge structure.
|
|
for row in rows:
|
|
row_id = row[field_id]
|
|
if not row_id in merged_rows:
|
|
merged_rows[row_id] = dotdict(row)
|
|
for key in keys_for_merge:
|
|
merged_rows[row_id][key] = []
|
|
|
|
for key in keys_for_merge:
|
|
item = dotdict({v.split('__', 1)[-1]:row[v] for v in keys_for_merge[key]})
|
|
if item.id is not None:
|
|
merged_rows[row_id][key].append(item)
|
|
|
|
# Remove redundant keys from final objects
|
|
redundant_keys = [item for values in keys_for_merge.values() for item in values]
|
|
for i in merged_rows:
|
|
for j in redundant_keys:
|
|
del merged_rows[i][j]
|
|
|
|
return list(merged_rows.values())
|
|
|
|
class JobAnnotation:
|
|
def __init__(self, pk):
|
|
self.db_job = models.Job.objects.select_related('segment__task') \
|
|
.select_for_update().get(id=pk)
|
|
|
|
db_segment = self.db_job.segment
|
|
self.start_frame = db_segment.start_frame
|
|
self.stop_frame = db_segment.stop_frame
|
|
self.ir_data = AnnotationIR()
|
|
|
|
self.db_labels = {db_label.id:db_label
|
|
for db_label in db_segment.task.label_set.all()}
|
|
|
|
self.db_attributes = {}
|
|
for db_label in self.db_labels.values():
|
|
self.db_attributes[db_label.id] = {
|
|
"mutable": OrderedDict(),
|
|
"immutable": OrderedDict(),
|
|
"all": OrderedDict(),
|
|
}
|
|
for db_attr in db_label.attributespec_set.all():
|
|
default_value = dotdict([
|
|
('spec_id', db_attr.id),
|
|
('value', db_attr.default_value),
|
|
])
|
|
if db_attr.mutable:
|
|
self.db_attributes[db_label.id]["mutable"][db_attr.id] = default_value
|
|
else:
|
|
self.db_attributes[db_label.id]["immutable"][db_attr.id] = default_value
|
|
|
|
self.db_attributes[db_label.id]["all"][db_attr.id] = default_value
|
|
|
|
def reset(self):
|
|
self.ir_data.reset()
|
|
|
|
def _save_tracks_to_db(self, tracks):
|
|
db_tracks = []
|
|
db_track_attrvals = []
|
|
db_shapes = []
|
|
db_shape_attrvals = []
|
|
|
|
for track in tracks:
|
|
track_attributes = track.pop("attributes", [])
|
|
shapes = track.pop("shapes")
|
|
db_track = models.LabeledTrack(job=self.db_job, **track)
|
|
if db_track.label_id not in self.db_labels:
|
|
raise AttributeError("label_id `{}` is invalid".format(db_track.label_id))
|
|
|
|
for attr in track_attributes:
|
|
db_attrval = models.LabeledTrackAttributeVal(**attr)
|
|
if db_attrval.spec_id not in self.db_attributes[db_track.label_id]["immutable"]:
|
|
raise AttributeError("spec_id `{}` is invalid".format(db_attrval.spec_id))
|
|
db_attrval.track_id = len(db_tracks)
|
|
db_track_attrvals.append(db_attrval)
|
|
|
|
for shape in shapes:
|
|
shape_attributes = shape.pop("attributes", [])
|
|
# FIXME: need to clamp points (be sure that all of them inside the image)
|
|
# Should we check here or implement a validator?
|
|
db_shape = models.TrackedShape(**shape)
|
|
db_shape.track_id = len(db_tracks)
|
|
|
|
for attr in shape_attributes:
|
|
db_attrval = models.TrackedShapeAttributeVal(**attr)
|
|
if db_attrval.spec_id not in self.db_attributes[db_track.label_id]["mutable"]:
|
|
raise AttributeError("spec_id `{}` is invalid".format(db_attrval.spec_id))
|
|
db_attrval.shape_id = len(db_shapes)
|
|
db_shape_attrvals.append(db_attrval)
|
|
|
|
db_shapes.append(db_shape)
|
|
shape["attributes"] = shape_attributes
|
|
|
|
db_tracks.append(db_track)
|
|
track["attributes"] = track_attributes
|
|
track["shapes"] = shapes
|
|
|
|
db_tracks = bulk_create(
|
|
db_model=models.LabeledTrack,
|
|
objects=db_tracks,
|
|
flt_param={"job_id": self.db_job.id}
|
|
)
|
|
|
|
for db_attrval in db_track_attrvals:
|
|
db_attrval.track_id = db_tracks[db_attrval.track_id].id
|
|
bulk_create(
|
|
db_model=models.LabeledTrackAttributeVal,
|
|
objects=db_track_attrvals,
|
|
flt_param={}
|
|
)
|
|
|
|
for db_shape in db_shapes:
|
|
db_shape.track_id = db_tracks[db_shape.track_id].id
|
|
|
|
db_shapes = bulk_create(
|
|
db_model=models.TrackedShape,
|
|
objects=db_shapes,
|
|
flt_param={"track__job_id": self.db_job.id}
|
|
)
|
|
|
|
for db_attrval in db_shape_attrvals:
|
|
db_attrval.shape_id = db_shapes[db_attrval.shape_id].id
|
|
|
|
bulk_create(
|
|
db_model=models.TrackedShapeAttributeVal,
|
|
objects=db_shape_attrvals,
|
|
flt_param={}
|
|
)
|
|
|
|
shape_idx = 0
|
|
for track, db_track in zip(tracks, db_tracks):
|
|
track["id"] = db_track.id
|
|
for shape in track["shapes"]:
|
|
shape["id"] = db_shapes[shape_idx].id
|
|
shape_idx += 1
|
|
|
|
self.ir_data.tracks = tracks
|
|
|
|
def _save_shapes_to_db(self, shapes):
|
|
db_shapes = []
|
|
db_attrvals = []
|
|
|
|
for shape in shapes:
|
|
attributes = shape.pop("attributes", [])
|
|
# FIXME: need to clamp points (be sure that all of them inside the image)
|
|
# Should we check here or implement a validator?
|
|
db_shape = models.LabeledShape(job=self.db_job, **shape)
|
|
if db_shape.label_id not in self.db_labels:
|
|
raise AttributeError("label_id `{}` is invalid".format(db_shape.label_id))
|
|
|
|
for attr in attributes:
|
|
db_attrval = models.LabeledShapeAttributeVal(**attr)
|
|
if db_attrval.spec_id not in self.db_attributes[db_shape.label_id]["all"]:
|
|
raise AttributeError("spec_id `{}` is invalid".format(db_attrval.spec_id))
|
|
|
|
db_attrval.shape_id = len(db_shapes)
|
|
db_attrvals.append(db_attrval)
|
|
|
|
db_shapes.append(db_shape)
|
|
shape["attributes"] = attributes
|
|
|
|
db_shapes = bulk_create(
|
|
db_model=models.LabeledShape,
|
|
objects=db_shapes,
|
|
flt_param={"job_id": self.db_job.id}
|
|
)
|
|
|
|
for db_attrval in db_attrvals:
|
|
db_attrval.shape_id = db_shapes[db_attrval.shape_id].id
|
|
|
|
bulk_create(
|
|
db_model=models.LabeledShapeAttributeVal,
|
|
objects=db_attrvals,
|
|
flt_param={}
|
|
)
|
|
|
|
for shape, db_shape in zip(shapes, db_shapes):
|
|
shape["id"] = db_shape.id
|
|
|
|
self.ir_data.shapes = shapes
|
|
|
|
def _save_tags_to_db(self, tags):
|
|
db_tags = []
|
|
db_attrvals = []
|
|
|
|
for tag in tags:
|
|
attributes = tag.pop("attributes", [])
|
|
db_tag = models.LabeledImage(job=self.db_job, **tag)
|
|
if db_tag.label_id not in self.db_labels:
|
|
raise AttributeError("label_id `{}` is invalid".format(db_tag.label_id))
|
|
|
|
for attr in attributes:
|
|
db_attrval = models.LabeledImageAttributeVal(**attr)
|
|
if db_attrval.spec_id not in self.db_attributes[db_tag.label_id]["all"]:
|
|
raise AttributeError("spec_id `{}` is invalid".format(db_attrval.spec_id))
|
|
db_attrval.tag_id = len(db_tags)
|
|
db_attrvals.append(db_attrval)
|
|
|
|
db_tags.append(db_tag)
|
|
tag["attributes"] = attributes
|
|
|
|
db_tags = bulk_create(
|
|
db_model=models.LabeledImage,
|
|
objects=db_tags,
|
|
flt_param={"job_id": self.db_job.id}
|
|
)
|
|
|
|
for db_attrval in db_attrvals:
|
|
db_attrval.image_id = db_tags[db_attrval.tag_id].id
|
|
|
|
bulk_create(
|
|
db_model=models.LabeledImageAttributeVal,
|
|
objects=db_attrvals,
|
|
flt_param={}
|
|
)
|
|
|
|
for tag, db_tag in zip(tags, db_tags):
|
|
tag["id"] = db_tag.id
|
|
|
|
self.ir_data.tags = tags
|
|
|
|
def _commit(self):
|
|
db_prev_commit = self.db_job.commits.last()
|
|
db_curr_commit = models.JobCommit()
|
|
if db_prev_commit:
|
|
db_curr_commit.version = db_prev_commit.version + 1
|
|
else:
|
|
db_curr_commit.version = 1
|
|
db_curr_commit.job = self.db_job
|
|
db_curr_commit.message = "Changes: tags - {}; shapes - {}; tracks - {}".format(
|
|
len(self.ir_data.tags), len(self.ir_data.shapes), len(self.ir_data.tracks))
|
|
db_curr_commit.save()
|
|
self.ir_data.version = db_curr_commit.version
|
|
|
|
def _set_updated_date(self):
|
|
db_task = self.db_job.segment.task
|
|
db_task.updated_date = timezone.now()
|
|
db_task.save()
|
|
|
|
def _save_to_db(self, data):
|
|
self.reset()
|
|
self._save_tags_to_db(data["tags"])
|
|
self._save_shapes_to_db(data["shapes"])
|
|
self._save_tracks_to_db(data["tracks"])
|
|
|
|
return self.ir_data.tags or self.ir_data.shapes or self.ir_data.tracks
|
|
|
|
def _create(self, data):
|
|
if self._save_to_db(data):
|
|
self._set_updated_date()
|
|
self.db_job.save()
|
|
|
|
def create(self, data):
|
|
self._create(data)
|
|
self._commit()
|
|
|
|
def put(self, data):
|
|
self._delete()
|
|
self._create(data)
|
|
self._commit()
|
|
|
|
def update(self, data):
|
|
self._delete(data)
|
|
self._create(data)
|
|
self._commit()
|
|
|
|
def _delete(self, data=None):
|
|
deleted_shapes = 0
|
|
if data is None:
|
|
deleted_shapes += self.db_job.labeledimage_set.all().delete()[0]
|
|
deleted_shapes += self.db_job.labeledshape_set.all().delete()[0]
|
|
deleted_shapes += self.db_job.labeledtrack_set.all().delete()[0]
|
|
else:
|
|
labeledimage_ids = [image["id"] for image in data["tags"]]
|
|
labeledshape_ids = [shape["id"] for shape in data["shapes"]]
|
|
labeledtrack_ids = [track["id"] for track in data["tracks"]]
|
|
labeledimage_set = self.db_job.labeledimage_set
|
|
labeledimage_set = labeledimage_set.filter(pk__in=labeledimage_ids)
|
|
labeledshape_set = self.db_job.labeledshape_set
|
|
labeledshape_set = labeledshape_set.filter(pk__in=labeledshape_ids)
|
|
labeledtrack_set = self.db_job.labeledtrack_set
|
|
labeledtrack_set = labeledtrack_set.filter(pk__in=labeledtrack_ids)
|
|
|
|
# It is not important for us that data had some "invalid" objects
|
|
# which were skipped (not acutally deleted). The main idea is to
|
|
# say that all requested objects are absent in DB after the method.
|
|
self.ir_data.tags = data['tags']
|
|
self.ir_data.shapes = data['shapes']
|
|
self.ir_data.tracks = data['tracks']
|
|
|
|
deleted_shapes += labeledimage_set.delete()[0]
|
|
deleted_shapes += labeledshape_set.delete()[0]
|
|
deleted_shapes += labeledtrack_set.delete()[0]
|
|
|
|
if deleted_shapes:
|
|
self._set_updated_date()
|
|
|
|
def delete(self, data=None):
|
|
self._delete(data)
|
|
self._commit()
|
|
|
|
@staticmethod
|
|
def _extend_attributes(attributeval_set, default_attribute_values):
|
|
shape_attribute_specs_set = set(attr.spec_id for attr in attributeval_set)
|
|
for db_attr in default_attribute_values:
|
|
if db_attr.spec_id not in shape_attribute_specs_set:
|
|
attributeval_set.append(dotdict([
|
|
('spec_id', db_attr.spec_id),
|
|
('value', db_attr.value),
|
|
]))
|
|
|
|
def _init_tags_from_db(self):
|
|
db_tags = self.db_job.labeledimage_set.prefetch_related(
|
|
"label",
|
|
"labeledimageattributeval_set"
|
|
).values(
|
|
'id',
|
|
'frame',
|
|
'label_id',
|
|
'group',
|
|
'labeledimageattributeval__spec_id',
|
|
'labeledimageattributeval__value',
|
|
'labeledimageattributeval__id',
|
|
).order_by('frame')
|
|
|
|
db_tags = _merge_table_rows(
|
|
rows=db_tags,
|
|
keys_for_merge={
|
|
"labeledimageattributeval_set": [
|
|
'labeledimageattributeval__spec_id',
|
|
'labeledimageattributeval__value',
|
|
'labeledimageattributeval__id',
|
|
],
|
|
},
|
|
field_id='id',
|
|
)
|
|
|
|
for db_tag in db_tags:
|
|
self._extend_attributes(db_tag.labeledimageattributeval_set,
|
|
self.db_attributes[db_tag.label_id]["all"].values())
|
|
|
|
serializer = serializers.LabeledImageSerializer(db_tags, many=True)
|
|
self.ir_data.tags = serializer.data
|
|
|
|
def _init_shapes_from_db(self):
|
|
db_shapes = self.db_job.labeledshape_set.prefetch_related(
|
|
"label",
|
|
"labeledshapeattributeval_set"
|
|
).values(
|
|
'id',
|
|
'label_id',
|
|
'type',
|
|
'frame',
|
|
'group',
|
|
'occluded',
|
|
'z_order',
|
|
'points',
|
|
'labeledshapeattributeval__spec_id',
|
|
'labeledshapeattributeval__value',
|
|
'labeledshapeattributeval__id',
|
|
).order_by('frame')
|
|
|
|
db_shapes = _merge_table_rows(
|
|
rows=db_shapes,
|
|
keys_for_merge={
|
|
'labeledshapeattributeval_set': [
|
|
'labeledshapeattributeval__spec_id',
|
|
'labeledshapeattributeval__value',
|
|
'labeledshapeattributeval__id',
|
|
],
|
|
},
|
|
field_id='id',
|
|
)
|
|
for db_shape in db_shapes:
|
|
self._extend_attributes(db_shape.labeledshapeattributeval_set,
|
|
self.db_attributes[db_shape.label_id]["all"].values())
|
|
|
|
serializer = serializers.LabeledShapeSerializer(db_shapes, many=True)
|
|
self.ir_data.shapes = serializer.data
|
|
|
|
def _init_tracks_from_db(self):
|
|
db_tracks = self.db_job.labeledtrack_set.prefetch_related(
|
|
"label",
|
|
"labeledtrackattributeval_set",
|
|
"trackedshape_set__trackedshapeattributeval_set"
|
|
).values(
|
|
"id",
|
|
"frame",
|
|
"label_id",
|
|
"group",
|
|
"labeledtrackattributeval__spec_id",
|
|
"labeledtrackattributeval__value",
|
|
"labeledtrackattributeval__id",
|
|
"trackedshape__type",
|
|
"trackedshape__occluded",
|
|
"trackedshape__z_order",
|
|
"trackedshape__points",
|
|
"trackedshape__id",
|
|
"trackedshape__frame",
|
|
"trackedshape__outside",
|
|
"trackedshape__trackedshapeattributeval__spec_id",
|
|
"trackedshape__trackedshapeattributeval__value",
|
|
"trackedshape__trackedshapeattributeval__id",
|
|
).order_by('id', 'trackedshape__frame')
|
|
|
|
db_tracks = _merge_table_rows(
|
|
rows=db_tracks,
|
|
keys_for_merge={
|
|
"labeledtrackattributeval_set": [
|
|
"labeledtrackattributeval__spec_id",
|
|
"labeledtrackattributeval__value",
|
|
"labeledtrackattributeval__id",
|
|
],
|
|
"trackedshape_set":[
|
|
"trackedshape__type",
|
|
"trackedshape__occluded",
|
|
"trackedshape__z_order",
|
|
"trackedshape__points",
|
|
"trackedshape__id",
|
|
"trackedshape__frame",
|
|
"trackedshape__outside",
|
|
"trackedshape__trackedshapeattributeval__spec_id",
|
|
"trackedshape__trackedshapeattributeval__value",
|
|
"trackedshape__trackedshapeattributeval__id",
|
|
],
|
|
},
|
|
field_id="id",
|
|
)
|
|
|
|
for db_track in db_tracks:
|
|
db_track["trackedshape_set"] = _merge_table_rows(db_track["trackedshape_set"], {
|
|
'trackedshapeattributeval_set': [
|
|
'trackedshapeattributeval__value',
|
|
'trackedshapeattributeval__spec_id',
|
|
'trackedshapeattributeval__id',
|
|
]
|
|
}, 'id')
|
|
|
|
# A result table can consist many equal rows for track/shape attributes
|
|
# We need filter unique attributes manually
|
|
db_track["labeledtrackattributeval_set"] = list(set(db_track["labeledtrackattributeval_set"]))
|
|
self._extend_attributes(db_track.labeledtrackattributeval_set,
|
|
self.db_attributes[db_track.label_id]["immutable"].values())
|
|
|
|
default_attribute_values = self.db_attributes[db_track.label_id]["mutable"].values()
|
|
for db_shape in db_track["trackedshape_set"]:
|
|
db_shape["trackedshapeattributeval_set"] = list(
|
|
set(db_shape["trackedshapeattributeval_set"])
|
|
)
|
|
# in case of trackedshapes need to interpolate attriute values and extend it
|
|
# by previous shape attribute values (not default values)
|
|
self._extend_attributes(db_shape["trackedshapeattributeval_set"], default_attribute_values)
|
|
default_attribute_values = db_shape["trackedshapeattributeval_set"]
|
|
|
|
|
|
serializer = serializers.LabeledTrackSerializer(db_tracks, many=True)
|
|
self.ir_data.tracks = serializer.data
|
|
|
|
def _init_version_from_db(self):
|
|
db_commit = self.db_job.commits.last()
|
|
self.ir_data.version = db_commit.version if db_commit else 0
|
|
|
|
def init_from_db(self):
|
|
self._init_tags_from_db()
|
|
self._init_shapes_from_db()
|
|
self._init_tracks_from_db()
|
|
self._init_version_from_db()
|
|
|
|
@property
|
|
def data(self):
|
|
return self.ir_data.data
|
|
|
|
def import_annotations(self, src_file, importer):
|
|
task_data = TaskData(
|
|
annotation_ir=AnnotationIR(),
|
|
db_task=self.db_job.segment.task,
|
|
create_callback=self.create,
|
|
)
|
|
self.delete()
|
|
|
|
importer(src_file, task_data)
|
|
|
|
self.create(task_data.data.slice(self.start_frame, self.stop_frame).serialize())
|
|
|
|
class TaskAnnotation:
|
|
def __init__(self, pk):
|
|
self.db_task = models.Task.objects.prefetch_related("data__images").get(id=pk)
|
|
|
|
# Postgres doesn't guarantee an order by default without explicit order_by
|
|
self.db_jobs = models.Job.objects.select_related("segment").filter(segment__task_id=pk).order_by('id')
|
|
self.ir_data = AnnotationIR()
|
|
|
|
def reset(self):
|
|
self.ir_data.reset()
|
|
|
|
def _patch_data(self, data, action):
|
|
_data = data if isinstance(data, AnnotationIR) else AnnotationIR(data)
|
|
splitted_data = {}
|
|
jobs = {}
|
|
for db_job in self.db_jobs:
|
|
jid = db_job.id
|
|
start = db_job.segment.start_frame
|
|
stop = db_job.segment.stop_frame
|
|
jobs[jid] = { "start": start, "stop": stop }
|
|
splitted_data[jid] = _data.slice(start, stop)
|
|
|
|
for jid, job_data in splitted_data.items():
|
|
_data = AnnotationIR()
|
|
if action is None:
|
|
_data.data = put_job_data(jid, job_data)
|
|
else:
|
|
_data.data = patch_job_data(jid, job_data, action)
|
|
if _data.version > self.ir_data.version:
|
|
self.ir_data.version = _data.version
|
|
self._merge_data(_data, jobs[jid]["start"], self.db_task.overlap)
|
|
|
|
def _merge_data(self, data, start_frame, overlap):
|
|
annotation_manager = AnnotationManager(self.ir_data)
|
|
annotation_manager.merge(data, start_frame, overlap)
|
|
|
|
def put(self, data):
|
|
self._patch_data(data, None)
|
|
|
|
def create(self, data):
|
|
self._patch_data(data, PatchAction.CREATE)
|
|
|
|
def update(self, data):
|
|
self._patch_data(data, PatchAction.UPDATE)
|
|
|
|
def delete(self, data=None):
|
|
if data:
|
|
self._patch_data(data, PatchAction.DELETE)
|
|
else:
|
|
for db_job in self.db_jobs:
|
|
delete_job_data(db_job.id)
|
|
|
|
def init_from_db(self):
|
|
self.reset()
|
|
|
|
for db_job in self.db_jobs:
|
|
annotation = JobAnnotation(db_job.id)
|
|
annotation.init_from_db()
|
|
if annotation.ir_data.version > self.ir_data.version:
|
|
self.ir_data.version = annotation.ir_data.version
|
|
db_segment = db_job.segment
|
|
start_frame = db_segment.start_frame
|
|
overlap = self.db_task.overlap
|
|
self._merge_data(annotation.ir_data, start_frame, overlap)
|
|
|
|
def export(self, dst_file, exporter, host='', **options):
|
|
task_data = TaskData(
|
|
annotation_ir=self.ir_data,
|
|
db_task=self.db_task,
|
|
host=host,
|
|
)
|
|
exporter(dst_file, task_data, **options)
|
|
|
|
def import_annotations(self, src_file, importer, **options):
|
|
task_data = TaskData(
|
|
annotation_ir=AnnotationIR(),
|
|
db_task=self.db_task,
|
|
create_callback=self.create,
|
|
)
|
|
self.delete()
|
|
|
|
importer(src_file, task_data, **options)
|
|
|
|
self.create(task_data.data.serialize())
|
|
|
|
@property
|
|
def data(self):
|
|
return self.ir_data.data
|
|
|
|
|
|
@silk_profile(name="GET job data")
|
|
@transaction.atomic
|
|
def get_job_data(pk):
|
|
annotation = JobAnnotation(pk)
|
|
annotation.init_from_db()
|
|
|
|
return annotation.data
|
|
|
|
@silk_profile(name="POST job data")
|
|
@transaction.atomic
|
|
def put_job_data(pk, data):
|
|
annotation = JobAnnotation(pk)
|
|
annotation.put(data)
|
|
|
|
return annotation.data
|
|
|
|
@silk_profile(name="UPDATE job data")
|
|
@plugin_decorator
|
|
@transaction.atomic
|
|
def patch_job_data(pk, data, action):
|
|
annotation = JobAnnotation(pk)
|
|
if action == PatchAction.CREATE:
|
|
annotation.create(data)
|
|
elif action == PatchAction.UPDATE:
|
|
annotation.update(data)
|
|
elif action == PatchAction.DELETE:
|
|
annotation.delete(data)
|
|
|
|
return annotation.data
|
|
|
|
@silk_profile(name="DELETE job data")
|
|
@transaction.atomic
|
|
def delete_job_data(pk):
|
|
annotation = JobAnnotation(pk)
|
|
annotation.delete()
|
|
|
|
@silk_profile(name="GET task data")
|
|
@transaction.atomic
|
|
def get_task_data(pk):
|
|
annotation = TaskAnnotation(pk)
|
|
annotation.init_from_db()
|
|
|
|
return annotation.data
|
|
|
|
@silk_profile(name="POST task data")
|
|
@transaction.atomic
|
|
def put_task_data(pk, data):
|
|
annotation = TaskAnnotation(pk)
|
|
annotation.put(data)
|
|
|
|
return annotation.data
|
|
|
|
@silk_profile(name="UPDATE task data")
|
|
@transaction.atomic
|
|
def patch_task_data(pk, data, action):
|
|
annotation = TaskAnnotation(pk)
|
|
if action == PatchAction.CREATE:
|
|
annotation.create(data)
|
|
elif action == PatchAction.UPDATE:
|
|
annotation.update(data)
|
|
elif action == PatchAction.DELETE:
|
|
annotation.delete(data)
|
|
|
|
return annotation.data
|
|
|
|
@silk_profile(name="DELETE task data")
|
|
@transaction.atomic
|
|
def delete_task_data(pk):
|
|
annotation = TaskAnnotation(pk)
|
|
annotation.delete()
|
|
|
|
def export_task(task_id, dst_file, format_name,
|
|
server_url=None, save_images=False):
|
|
# For big tasks dump function may run for a long time and
|
|
# we dont need to acquire lock after the task has been initialized from DB.
|
|
# But there is the bug with corrupted dump file in case 2 or
|
|
# more dump request received at the same time:
|
|
# https://github.com/opencv/cvat/issues/217
|
|
with transaction.atomic():
|
|
task = TaskAnnotation(task_id)
|
|
task.init_from_db()
|
|
|
|
exporter = make_exporter(format_name)
|
|
with open(dst_file, 'wb') as f:
|
|
task.export(f, exporter, host=server_url,
|
|
save_images=save_images)
|
|
|
|
@transaction.atomic
|
|
def import_task_annotations(task_id, src_file, format_name):
|
|
task = TaskAnnotation(task_id)
|
|
task.init_from_db()
|
|
|
|
importer = make_importer(format_name)
|
|
with open(src_file, 'rb') as f:
|
|
task.import_annotations(f, importer)
|
|
|
|
@transaction.atomic
|
|
def import_job_annotations(job_id, src_file, format_name):
|
|
job = JobAnnotation(job_id)
|
|
job.init_from_db()
|
|
|
|
importer = make_importer(format_name)
|
|
with open(src_file, 'rb') as f:
|
|
job.import_annotations(f, importer)
|