You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
411 lines
15 KiB
Python
411 lines
15 KiB
Python
# Copyright (C) 2018 Intel Corporation
|
|
#
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
import os
|
|
import copy
|
|
from collections import OrderedDict, namedtuple
|
|
|
|
from django.utils import timezone
|
|
|
|
from cvat.apps.engine.data_manager import DataManager, TrackManager
|
|
from cvat.apps.engine.serializers import LabeledDataSerializer
|
|
|
|
class AnnotationIR:
|
|
def __init__(self, data=None):
|
|
self.reset()
|
|
if data:
|
|
self._tags = getattr(data, 'tags', []) or data['tags']
|
|
self._shapes = getattr(data, 'shapes', []) or data['shapes']
|
|
self._tracks = getattr(data, 'tracks', []) or data['tracks']
|
|
|
|
def add_tag(self, tag):
|
|
self._tags.append(tag)
|
|
|
|
def add_shape(self, shape):
|
|
self._shapes.append(shape)
|
|
|
|
def add_track(self, track):
|
|
self._tracks.append(track)
|
|
|
|
@property
|
|
def tags(self):
|
|
return self._tags
|
|
|
|
@property
|
|
def shapes(self):
|
|
return self._shapes
|
|
|
|
@property
|
|
def tracks(self):
|
|
return self._tracks
|
|
|
|
@property
|
|
def version(self):
|
|
return self._version
|
|
|
|
@tags.setter
|
|
def tags(self, tags):
|
|
self._tags = tags
|
|
|
|
@shapes.setter
|
|
def shapes(self, shapes):
|
|
self._shapes = shapes
|
|
|
|
@tracks.setter
|
|
def tracks(self, tracks):
|
|
self._tracks = tracks
|
|
|
|
@version.setter
|
|
def version(self, version):
|
|
self._version = version
|
|
|
|
def __getitem__(self, key):
|
|
return getattr(self, key)
|
|
|
|
@property
|
|
def data(self):
|
|
return {
|
|
'version': self.version,
|
|
'tags': self.tags,
|
|
'shapes': self.shapes,
|
|
'tracks': self.tracks,
|
|
}
|
|
|
|
def serialize(self):
|
|
serializer = LabeledDataSerializer(data=self.data)
|
|
if serializer.is_valid(raise_exception=True):
|
|
return serializer.data
|
|
|
|
#makes a data copy from specified frame interval
|
|
def slice(self, start, stop):
|
|
is_frame_inside = lambda x: (start <= int(x['frame']) <= stop)
|
|
splitted_data = AnnotationIR()
|
|
splitted_data.tags = copy.deepcopy(list(filter(is_frame_inside, self.tags)))
|
|
splitted_data.shapes = copy.deepcopy(list(filter(is_frame_inside, self.shapes)))
|
|
splitted_data.tracks = copy.deepcopy(list(filter(lambda y: len(list(filter(is_frame_inside, y['shapes']))), self.tracks)))
|
|
|
|
return splitted_data
|
|
|
|
@data.setter
|
|
def data(self, data):
|
|
self.version = data['version']
|
|
self.tags = data['tags']
|
|
self.shapes = data['shapes']
|
|
self.tracks = data['tracks']
|
|
|
|
def reset(self):
|
|
self._version = 0
|
|
self._tags = []
|
|
self._shapes = []
|
|
self._tracks = []
|
|
|
|
class Annotation:
|
|
Attribute = namedtuple('Attribute', 'name, value')
|
|
LabeledShape = namedtuple('LabeledShape', 'type, frame, label, points, occluded, attributes, group, z_order')
|
|
LabeledShape.__new__.__defaults__ = (0, 0)
|
|
TrackedShape = namedtuple('TrackedShape', 'type, points, occluded, frame, attributes, outside, keyframe, z_order')
|
|
TrackedShape.__new__.__defaults__ = (0, )
|
|
Track = namedtuple('Track', 'label, group, shapes')
|
|
Tag = namedtuple('Tag', 'frame, label, attributes, group')
|
|
Tag.__new__.__defaults__ = (0, )
|
|
Frame = namedtuple('Frame', 'frame, name, width, height, labeled_shapes, tags')
|
|
|
|
def __init__(self, annotation_ir, db_task, scheme='', host='', create_callback=None):
|
|
self._annotation_ir = annotation_ir
|
|
self._db_task = db_task
|
|
self._scheme = scheme
|
|
self._host = host
|
|
self._create_callback=create_callback
|
|
self._MAX_ANNO_SIZE=30000
|
|
|
|
db_labels = self._db_task.label_set.all().prefetch_related('attributespec_set')
|
|
|
|
self._label_mapping = {db_label.id: db_label for db_label in db_labels}
|
|
|
|
self._attribute_mapping = {
|
|
'mutable': {},
|
|
'immutable': {},
|
|
}
|
|
for db_label in db_labels:
|
|
for db_attribute in db_label.attributespec_set.all():
|
|
if db_attribute.mutable:
|
|
self._attribute_mapping['mutable'][db_attribute.id] = db_attribute.name
|
|
else:
|
|
self._attribute_mapping['immutable'][db_attribute.id] = db_attribute.name
|
|
|
|
self._attribute_mapping_merged = {
|
|
**self._attribute_mapping['mutable'],
|
|
**self._attribute_mapping['immutable'],
|
|
}
|
|
|
|
self._init_frame_info()
|
|
self._init_meta()
|
|
|
|
def _get_label_id(self, label_name):
|
|
for db_label in self._label_mapping.values():
|
|
if label_name == db_label.name:
|
|
return db_label.id
|
|
return None
|
|
|
|
def _get_label_name(self, label_id):
|
|
return self._label_mapping[label_id].name
|
|
|
|
def _get_attribute_name(self, attribute_id):
|
|
return self._attribute_mapping_merged[attribute_id]
|
|
|
|
def _get_attribute_id(self, attribute_name, attribute_type=None):
|
|
if attribute_type:
|
|
container = self._attribute_mapping[attribute_type]
|
|
else:
|
|
container = self._attribute_mapping_merged
|
|
|
|
for attr_id, attr_name in container.items():
|
|
if attribute_name == attr_name:
|
|
return attr_id
|
|
return None
|
|
|
|
def _get_mutable_attribute_id(self, attribute_name):
|
|
return self._get_attribute_id(attribute_name, 'mutable')
|
|
|
|
def _get_immutable_attribute_id(self, attribute_name):
|
|
return self._get_attribute_id(attribute_name, 'immutable')
|
|
|
|
def _init_frame_info(self):
|
|
if self._db_task.mode == "interpolation":
|
|
self._frame_info = {
|
|
frame: {
|
|
"path": "frame_{:06d}".format(frame),
|
|
"width": self._db_task.video.width,
|
|
"height": self._db_task.video.height,
|
|
} for frame in range(self._db_task.size)
|
|
}
|
|
else:
|
|
self._frame_info = {db_image.frame: {
|
|
"path": db_image.path,
|
|
"width": db_image.width,
|
|
"height": db_image.height,
|
|
} for db_image in self._db_task.image_set.all()}
|
|
|
|
def _init_meta(self):
|
|
db_segments = self._db_task.segment_set.all().prefetch_related('job_set')
|
|
self._meta = OrderedDict([
|
|
("task", OrderedDict([
|
|
("id", str(self._db_task.id)),
|
|
("name", self._db_task.name),
|
|
("size", str(self._db_task.size)),
|
|
("mode", self._db_task.mode),
|
|
("overlap", str(self._db_task.overlap)),
|
|
("bugtracker", self._db_task.bug_tracker),
|
|
("created", str(timezone.localtime(self._db_task.created_date))),
|
|
("updated", str(timezone.localtime(self._db_task.updated_date))),
|
|
("start_frame", str(self._db_task.start_frame)),
|
|
("stop_frame", str(self._db_task.stop_frame)),
|
|
("frame_filter", self._db_task.frame_filter),
|
|
("z_order", str(self._db_task.z_order)),
|
|
|
|
("labels", [
|
|
("label", OrderedDict([
|
|
("name", db_label.name),
|
|
("attributes", [
|
|
("attribute", OrderedDict([
|
|
("name", db_attr.name),
|
|
("mutable", str(db_attr.mutable)),
|
|
("input_type", db_attr.input_type),
|
|
("default_value", db_attr.default_value),
|
|
("values", db_attr.values)]))
|
|
for db_attr in db_label.attributespec_set.all()])
|
|
])) for db_label in self._label_mapping.values()
|
|
]),
|
|
|
|
("segments", [
|
|
("segment", OrderedDict([
|
|
("id", str(db_segment.id)),
|
|
("start", str(db_segment.start_frame)),
|
|
("stop", str(db_segment.stop_frame)),
|
|
("url", "{0}://{1}/?id={2}".format(
|
|
self._scheme, self._host, db_segment.job_set.all()[0].id))]
|
|
)) for db_segment in db_segments
|
|
]),
|
|
|
|
("owner", OrderedDict([
|
|
("username", self._db_task.owner.username),
|
|
("email", self._db_task.owner.email)
|
|
]) if self._db_task.owner else ""),
|
|
|
|
("assignee", OrderedDict([
|
|
("username", self._db_task.assignee.username),
|
|
("email", self._db_task.assignee.email)
|
|
]) if self._db_task.assignee else ""),
|
|
])),
|
|
("dumped", str(timezone.localtime(timezone.now())))
|
|
])
|
|
|
|
if self._db_task.mode == "interpolation":
|
|
self._meta["task"]["original_size"] = OrderedDict([
|
|
("width", str(self._db_task.video.width)),
|
|
("height", str(self._db_task.video.height))
|
|
])
|
|
# Add source to dumped file
|
|
self._meta["source"] = str(os.path.basename(self._db_task.video.path))
|
|
|
|
def _export_attributes(self, attributes):
|
|
exported_attributes = []
|
|
for attr in attributes:
|
|
db_attribute = self._attribute_mapping_merged[attr["spec_id"]]
|
|
exported_attributes.append(Annotation.Attribute(
|
|
name=db_attribute,
|
|
value=attr["value"],
|
|
))
|
|
return exported_attributes
|
|
|
|
def _export_tracked_shape(self, shape):
|
|
return Annotation.TrackedShape(
|
|
type=shape["type"],
|
|
frame=self._db_task.start_frame + shape["frame"] * self._db_task.get_frame_step(),
|
|
points=shape["points"],
|
|
occluded=shape["occluded"],
|
|
outside=shape.get("outside", False),
|
|
keyframe=shape.get("keyframe", True),
|
|
z_order=shape["z_order"],
|
|
attributes=self._export_attributes(shape["attributes"]),
|
|
)
|
|
|
|
def _export_labeled_shape(self, shape):
|
|
return Annotation.LabeledShape(
|
|
type=shape["type"],
|
|
label=self._get_label_name(shape["label_id"]),
|
|
frame=self._db_task.start_frame + shape["frame"] * self._db_task.get_frame_step(),
|
|
points=shape["points"],
|
|
occluded=shape["occluded"],
|
|
z_order=shape.get("z_order", 0),
|
|
group=shape.get("group", 0),
|
|
attributes=self._export_attributes(shape["attributes"]),
|
|
)
|
|
|
|
def _export_tag(self, tag):
|
|
return Annotation.Tag(
|
|
frame=self._db_task.start_frame + tag["frame"] * self._db_task.get_frame_step(),
|
|
label=self._get_label_name(tag["label_id"]),
|
|
group=tag.get("group", 0),
|
|
attributes=self._export_attributes(tag["attributes"]),
|
|
)
|
|
|
|
def group_by_frame(self):
|
|
def _get_frame(annotations, shape):
|
|
db_image = self._frame_info[shape["frame"]]
|
|
frame = self._db_task.start_frame + shape["frame"] * self._db_task.get_frame_step()
|
|
rpath = db_image['path'].split(os.path.sep)
|
|
if len(rpath) != 1:
|
|
rpath = os.path.sep.join(rpath[rpath.index(".upload")+1:])
|
|
else:
|
|
rpath = rpath[0]
|
|
if frame not in annotations:
|
|
annotations[frame] = Annotation.Frame(
|
|
frame=frame,
|
|
name=rpath,
|
|
height=db_image["height"],
|
|
width=db_image["width"],
|
|
labeled_shapes=[],
|
|
tags=[],
|
|
)
|
|
return annotations[frame]
|
|
|
|
annotations = {}
|
|
data_manager = DataManager(self._annotation_ir)
|
|
for shape in data_manager.to_shapes(self._db_task.size):
|
|
_get_frame(annotations, shape).labeled_shapes.append(self._export_labeled_shape(shape))
|
|
|
|
for tag in self._annotation_ir.tags:
|
|
_get_frame(annotations, tag).tags.append(self._export_tag(tag))
|
|
|
|
return iter(annotations.values())
|
|
|
|
@property
|
|
def shapes(self):
|
|
for shape in self._annotation_ir.shapes:
|
|
yield self._export_labeled_shape(shape)
|
|
|
|
@property
|
|
def tracks(self):
|
|
for track in self._annotation_ir.tracks:
|
|
tracked_shapes = TrackManager.get_interpolated_shapes(track, 0, self._db_task.size)
|
|
yield Annotation.Track(
|
|
label=self._get_label_name(track["label_id"]),
|
|
group=track['group'],
|
|
shapes=[self._export_tracked_shape(shape) for shape in tracked_shapes],
|
|
)
|
|
|
|
@property
|
|
def tags(self):
|
|
for tag in self._annotation_ir.tags:
|
|
yield self._export_tag(tag)
|
|
|
|
@property
|
|
def meta(self):
|
|
return self._meta
|
|
|
|
def _import_tag(self, tag):
|
|
_tag = tag._asdict()
|
|
_tag['label_id'] = self._get_label_id(_tag.pop('label'))
|
|
_tag['attributes'] = [self._import_attribute(attrib) for attrib in _tag['attributes'] if self._get_attribute_id(attrib.name)]
|
|
return _tag
|
|
|
|
def _import_attribute(self, attribute):
|
|
return {
|
|
'spec_id': self._get_attribute_id(attribute.name),
|
|
'value': attribute.value,
|
|
}
|
|
|
|
def _import_shape(self, shape):
|
|
_shape = shape._asdict()
|
|
_shape['label_id'] = self._get_label_id(_shape.pop('label'))
|
|
_shape['attributes'] = [self._import_attribute(attrib) for attrib in _shape['attributes'] if self._get_attribute_id(attrib.name)]
|
|
return _shape
|
|
|
|
def _import_track(self, track):
|
|
_track = track._asdict()
|
|
_track['frame'] = min(shape.frame for shape in _track['shapes'])
|
|
_track['label_id'] = self._get_label_id(_track.pop('label'))
|
|
_track['attributes'] = []
|
|
_track['shapes'] = [shape._asdict() for shape in _track['shapes']]
|
|
for shape in _track['shapes']:
|
|
_track['attributes'] = [self._import_attribute(attrib) for attrib in shape['attributes'] if self._get_immutable_attribute_id(attrib.name)]
|
|
shape['attributes'] = [self._import_attribute(attrib) for attrib in shape['attributes'] if self._get_mutable_attribute_id(attrib.name)]
|
|
|
|
return _track
|
|
|
|
def _call_callback(self):
|
|
if self._len() > self._MAX_ANNO_SIZE:
|
|
self._create_callback(self._annotation_ir.serialize())
|
|
self._annotation_ir.reset()
|
|
|
|
def add_tag(self, tag):
|
|
imported_tag = self._import_tag(tag)
|
|
if imported_tag['label_id']:
|
|
self._annotation_ir.add_tag(imported_tag)
|
|
self._call_callback()
|
|
|
|
def add_shape(self, shape):
|
|
imported_shape = self._import_shape(shape)
|
|
if imported_shape['label_id']:
|
|
self._annotation_ir.add_shape(imported_shape)
|
|
self._call_callback()
|
|
|
|
def add_track(self, track):
|
|
imported_track = self._import_track(track)
|
|
if imported_track['label_id']:
|
|
self._annotation_ir.add_track(imported_track)
|
|
self._call_callback()
|
|
|
|
@property
|
|
def data(self):
|
|
return self._annotation_ir
|
|
|
|
def _len(self):
|
|
track_len = 0
|
|
for track in self._annotation_ir.tracks:
|
|
track_len += len(track['shapes'])
|
|
|
|
return len(self._annotation_ir.tags) + len(self._annotation_ir.shapes) + track_len
|