Fix export job dataset (#5052)

main
Maria Khrustaleva 3 years ago committed by GitHub
parent c37a5c0036
commit 15c3a63bb3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -47,6 +47,7 @@ non-ascii paths while adding files from "Connected file share" (issue #4428)
(<https://github.com/opencv/cvat/pull/5057>) (<https://github.com/opencv/cvat/pull/5057>)
- Double modal export/backup a task/project (<https://github.com/opencv/cvat/pull/5075>) - Double modal export/backup a task/project (<https://github.com/opencv/cvat/pull/5075>)
- Fixed bug of computing Job's unsolved/resolved issues numbers (<https://github.com/opencv/cvat/pull/5101>) - Fixed bug of computing Job's unsolved/resolved issues numbers (<https://github.com/opencv/cvat/pull/5101>)
- Dataset export for job (<https://github.com/opencv/cvat/pull/5052>)
### Security ### Security
- TDB - TDB

@ -161,8 +161,7 @@ class InstanceLabelData:
)) ))
return exported_attributes return exported_attributes
class CommonData(InstanceLabelData):
class TaskData(InstanceLabelData):
Shape = namedtuple("Shape", 'id, label_id') # 3d Shape = namedtuple("Shape", 'id, label_id') # 3d
LabeledShape = namedtuple( LabeledShape = namedtuple(
'LabeledShape', 'type, frame, label, points, occluded, attributes, source, rotation, group, z_order, elements, outside') 'LabeledShape', 'type, frame, label, points, occluded, attributes, source, rotation, group, z_order, elements, outside')
@ -178,48 +177,69 @@ class TaskData(InstanceLabelData):
'Frame', 'idx, id, frame, name, width, height, labeled_shapes, tags, shapes, labels') 'Frame', 'idx, id, frame, name, width, height, labeled_shapes, tags, shapes, labels')
Labels = namedtuple('Label', 'id, name, color, type') Labels = namedtuple('Label', 'id, name, color, type')
def __init__(self, annotation_ir, db_task, host='', create_callback=None): def __init__(self, annotation_ir, db_task, host='', create_callback=None) -> None:
self._annotation_ir = annotation_ir self._annotation_ir = annotation_ir
self._db_task = db_task
self._host = host self._host = host
self._create_callback = create_callback self._create_callback = create_callback
self._MAX_ANNO_SIZE = 30000 self._MAX_ANNO_SIZE = 30000
self._frame_info = {} self._frame_info = {}
self._frame_mapping = {} self._frame_mapping = {}
self._frame_step = db_task.data.get_frame_step() self._frame_step = db_task.data.get_frame_step()
self._db_data = db_task.data
InstanceLabelData.__init__(self, db_task) super().__init__(db_task)
self._init_frame_info() self._init_frame_info()
self._init_meta() self._init_meta()
@property
def rel_range(self):
raise NotImplementedError()
@property
def start(self) -> int:
return 0
@property
def stop(self) -> int:
return len(self)
def _get_queryset(self):
raise NotImplementedError()
def abs_frame_id(self, relative_id): def abs_frame_id(self, relative_id):
if relative_id not in range(0, self._db_task.data.size): # relative_id is frame index in segment for job, so it can start with more than just zero
if relative_id not in self.rel_range:
raise ValueError("Unknown internal frame id %s" % relative_id) raise ValueError("Unknown internal frame id %s" % relative_id)
return relative_id * self._frame_step + self._db_task.data.start_frame return relative_id * self._frame_step + self._db_data.start_frame
def rel_frame_id(self, absolute_id): def rel_frame_id(self, absolute_id):
d, m = divmod( d, m = divmod(
absolute_id - self._db_task.data.start_frame, self._frame_step) absolute_id - self._db_data.start_frame, self._frame_step)
if m or d not in range(0, self._db_task.data.size): if m or d not in self.rel_range:
raise ValueError("Unknown frame %s" % absolute_id) raise ValueError("Unknown frame %s" % absolute_id)
return d return d
def _init_frame_info(self): def _init_frame_info(self):
self._deleted_frames = { k: True for k in self._db_task.data.deleted_frames } self._deleted_frames = { k: True for k in self._db_data.deleted_frames }
if hasattr(self._db_task.data, 'video'): if hasattr(self._db_data, 'video'):
self._frame_info = {frame: { self._frame_info = {
"path": "frame_{:06d}".format(self.abs_frame_id(frame)), frame: {
"width": self._db_task.data.video.width, "path": "frame_{:06d}".format(self.abs_frame_id(frame)),
"height": self._db_task.data.video.height, "width": self._db_data.video.width,
} for frame in range(self._db_task.data.size)} "height": self._db_data.video.height,
} for frame in self.rel_range
}
else: else:
self._frame_info = {self.rel_frame_id(db_image.frame): { queryset = self._get_queryset()
"id": db_image.id, self._frame_info = {
"path": db_image.path, self.rel_frame_id(db_image.frame): {
"width": db_image.width, "id": db_image.id,
"height": db_image.height, "path": db_image.path,
} for db_image in self._db_task.data.images.all()} "width": db_image.width,
"height": db_image.height,
} for db_image in queryset
}
self._frame_mapping = { self._frame_mapping = {
self._get_filename(info["path"]): frame_number self._get_filename(info["path"]): frame_number
@ -227,94 +247,39 @@ class TaskData(InstanceLabelData):
} }
@staticmethod @staticmethod
def meta_for_task(db_task, host, label_mapping=None): def _convert_db_labels(db_labels):
db_segments = db_task.segment_set.all().prefetch_related('job_set') labels = []
for db_label in db_labels:
meta = OrderedDict([ label = OrderedDict([
("id", str(db_task.id)), ("name", db_label.name),
("name", db_task.name), ("color", db_label.color),
("size", str(db_task.data.size)), ("type", db_label.type),
("mode", db_task.mode), ("attributes", [
("overlap", str(db_task.overlap)), ("attribute", OrderedDict([
("bugtracker", db_task.bug_tracker), ("name", db_attr.name),
("created", str(timezone.localtime(db_task.created_date))), ("mutable", str(db_attr.mutable)),
("updated", str(timezone.localtime(db_task.updated_date))), ("input_type", db_attr.input_type),
("subset", db_task.subset or dm.DEFAULT_SUBSET_NAME), ("default_value", db_attr.default_value),
("start_frame", str(db_task.data.start_frame)), ("values", db_attr.values)]))
("stop_frame", str(db_task.data.stop_frame)), for db_attr in db_label.attributespec_set.all()])
("frame_filter", db_task.data.frame_filter), ])
("segments", [
("segment", OrderedDict([
("id", str(db_segment.id)),
("start", str(db_segment.start_frame)),
("stop", str(db_segment.stop_frame)),
("url", "{}/?id={}".format(
host, db_segment.job_set.all()[0].id))]
)) for db_segment in db_segments
]),
("owner", OrderedDict([
("username", db_task.owner.username),
("email", db_task.owner.email)
]) if db_task.owner else ""),
("assignee", OrderedDict([
("username", db_task.assignee.username),
("email", db_task.assignee.email)
]) if db_task.assignee else ""),
])
if label_mapping is not None:
labels = []
for db_label in label_mapping.values():
label = OrderedDict([
("name", db_label.name),
("color", db_label.color),
("type", db_label.type),
("attributes", [
("attribute", OrderedDict([
("name", db_attr.name),
("mutable", str(db_attr.mutable)),
("input_type", db_attr.input_type),
("default_value", db_attr.default_value),
("values", db_attr.values)]))
for db_attr in db_label.attributespec_set.all()])
])
if db_label.parent:
label["parent"] = db_label.parent.name
if db_label.type == str(LabelType.SKELETON):
label["svg"] = db_label.skeleton.svg
for db_sublabel in list(db_label.sublabels.all()):
label["svg"] = label["svg"].replace(f'data-label-id="{db_sublabel.id}"', f'data-label-name="{db_sublabel.name}"')
labels.append(('label', label)) if db_label.parent:
label["parent"] = db_label.parent.name
meta['labels'] = labels if db_label.type == str(LabelType.SKELETON):
label["svg"] = db_label.skeleton.svg
for db_sublabel in list(db_label.sublabels.all()):
label["svg"] = label["svg"].replace(f'data-label-id="{db_sublabel.id}"', f'data-label-name="{db_sublabel.name}"')
if hasattr(db_task.data, "video"): labels.append(('label', label))
meta["original_size"] = OrderedDict([ return labels
("width", str(db_task.data.video.width)),
("height", str(db_task.data.video.height))
])
return meta
def _init_meta(self): def _init_meta(self):
self._meta = OrderedDict([ raise NotImplementedError()
("task", self.meta_for_task(self._db_task, self._host, self._label_mapping)),
("dumped", str(timezone.localtime(timezone.now())))
])
if hasattr(self._db_task.data, "video"):
# Add source to dumped file
self._meta["source"] = str(
osp.basename(self._db_task.data.video.path))
def _export_tracked_shape(self, shape): def _export_tracked_shape(self, shape):
return TaskData.TrackedShape( return CommonData.TrackedShape(
type=shape["type"], type=shape["type"],
frame=self.abs_frame_id(shape["frame"]), frame=self.abs_frame_id(shape["frame"]),
label=self._get_label_name(shape["label_id"]), label=self._get_label_name(shape["label_id"]),
@ -332,7 +297,7 @@ class TaskData(InstanceLabelData):
) )
def _export_labeled_shape(self, shape): def _export_labeled_shape(self, shape):
return TaskData.LabeledShape( return CommonData.LabeledShape(
type=shape["type"], type=shape["type"],
label=self._get_label_name(shape["label_id"]), label=self._get_label_name(shape["label_id"]),
frame=self.abs_frame_id(shape["frame"]), frame=self.abs_frame_id(shape["frame"]),
@ -348,13 +313,13 @@ class TaskData(InstanceLabelData):
) )
def _export_shape(self, shape): def _export_shape(self, shape):
return TaskData.Shape( return CommonData.Shape(
id=shape["id"], id=shape["id"],
label_id=shape["label_id"] label_id=shape["label_id"]
) )
def _export_tag(self, tag): def _export_tag(self, tag):
return TaskData.Tag( return CommonData.Tag(
frame=self.abs_frame_id(tag["frame"]), frame=self.abs_frame_id(tag["frame"]),
label=self._get_label_name(tag["label_id"]), label=self._get_label_name(tag["label_id"]),
group=tag.get("group", 0), group=tag.get("group", 0),
@ -363,9 +328,9 @@ class TaskData(InstanceLabelData):
) )
def _export_track(self, track, idx): def _export_track(self, track, idx):
track['shapes'] = list(filter(lambda x: x['frame'] not in self._deleted_frames, track['shapes'])) track['shapes'] = list(filter(lambda x: not self._is_frame_deleted(x['frame']), track['shapes']))
tracked_shapes = TrackManager.get_interpolated_shapes( tracked_shapes = TrackManager.get_interpolated_shapes(
track, 0, self._db_task.data.size) track, 0, len(self))
for tracked_shape in tracked_shapes: for tracked_shape in tracked_shapes:
tracked_shape["attributes"] += track["attributes"] tracked_shape["attributes"] += track["attributes"]
tracked_shape["track_id"] = idx tracked_shape["track_id"] = idx
@ -373,18 +338,18 @@ class TaskData(InstanceLabelData):
tracked_shape["source"] = track["source"] tracked_shape["source"] = track["source"]
tracked_shape["label_id"] = track["label_id"] tracked_shape["label_id"] = track["label_id"]
return TaskData.Track( return CommonData.Track(
label=self._get_label_name(track["label_id"]), label=self._get_label_name(track["label_id"]),
group=track["group"], group=track["group"],
source=track["source"], source=track["source"],
shapes=[self._export_tracked_shape(shape) shapes=[self._export_tracked_shape(shape)
for shape in tracked_shapes if shape["frame"] not in self._deleted_frames], for shape in tracked_shapes if not self._is_frame_deleted(shape["frame"])],
elements=[self._export_track(element, i) for i, element in enumerate(track.get("elements", []))] elements=[self._export_track(element, i) for i, element in enumerate(track.get("elements", []))]
) )
@staticmethod @staticmethod
def _export_label(label): def _export_label(label):
return TaskData.Labels( return CommonData.Labels(
id=label.id, id=label.id,
name=label.name, name=label.name,
color=label.color, color=label.color,
@ -397,11 +362,11 @@ class TaskData(InstanceLabelData):
frame_info = self._frame_info[idx] frame_info = self._frame_info[idx]
frame = self.abs_frame_id(idx) frame = self.abs_frame_id(idx)
if frame not in frames: if frame not in frames:
frames[frame] = TaskData.Frame( frames[frame] = CommonData.Frame(
idx=idx, idx=idx,
id=frame_info.get('id',0), id=frame_info.get("id", 0),
frame=frame, frame=frame,
name=frame_info['path'], name=frame_info["path"],
height=frame_info["height"], height=frame_info["height"],
width=frame_info["width"], width=frame_info["width"],
labeled_shapes=[], labeled_shapes=[],
@ -413,14 +378,14 @@ class TaskData(InstanceLabelData):
if include_empty: if include_empty:
for idx in self._frame_info: for idx in self._frame_info:
if idx not in self._deleted_frames: if not self._is_frame_deleted(idx):
get_frame(idx) get_frame(idx)
anno_manager = AnnotationManager(self._annotation_ir) anno_manager = AnnotationManager(self._annotation_ir)
shape_data = '' shape_data = ''
for shape in sorted(anno_manager.to_shapes(self._db_task.data.size), for shape in sorted(anno_manager.to_shapes(len(self)),
key=lambda shape: shape.get("z_order", 0)): key=lambda shape: shape.get("z_order", 0)):
if shape['frame'] not in self._frame_info or shape['frame'] in self._deleted_frames: if shape['frame'] not in self._frame_info or self._is_frame_deleted(shape['frame']):
# After interpolation there can be a finishing frame # After interpolation there can be a finishing frame
# outside of the task boundaries. Filter it out to avoid errors. # outside of the task boundaries. Filter it out to avoid errors.
# https://github.com/openvinotoolkit/cvat/issues/2827 # https://github.com/openvinotoolkit/cvat/issues/2827
@ -450,9 +415,12 @@ class TaskData(InstanceLabelData):
@property @property
def shapes(self): def shapes(self):
for shape in self._annotation_ir.shapes: for shape in self._annotation_ir.shapes:
if shape["frame"] not in self._deleted_frames: if not self._is_frame_deleted(shape["frame"]):
yield self._export_labeled_shape(shape) yield self._export_labeled_shape(shape)
def _is_frame_deleted(self, frame):
return frame in self._deleted_frames
@property @property
def tracks(self): def tracks(self):
for idx, track in enumerate(self._annotation_ir.tracks): for idx, track in enumerate(self._annotation_ir.tracks):
@ -577,8 +545,15 @@ class TaskData(InstanceLabelData):
return self._frame_step return self._frame_step
@property @property
def db_task(self): def db_instance(self):
return self._db_task raise NotImplementedError()
@property
def db_data(self):
return self._db_data
def __len__(self):
raise NotImplementedError()
@staticmethod @staticmethod
def _get_filename(path): def _get_filename(path):
@ -605,7 +580,176 @@ class TaskData(InstanceLabelData):
return v return v
return None return None
class JobData(CommonData):
META_FIELD = "job"
def __init__(self, annotation_ir, db_job, host='', create_callback=None):
self._db_job = db_job
self._db_task = db_job.segment.task
super().__init__(annotation_ir, self._db_task, host, create_callback)
def _init_meta(self):
db_segment = self._db_job.segment
self._meta = OrderedDict([
(JobData.META_FIELD, OrderedDict([
("id", str(self._db_job.id)),
("size", str(len(self))),
("mode", self._db_task.mode),
("overlap", str(self._db_task.overlap)),
("bugtracker", self._db_task.bug_tracker),
("created", str(timezone.localtime(self._db_task.created_date))),
("updated", str(timezone.localtime(self._db_job.updated_date))),
("subset", self._db_task.subset or dm.DEFAULT_SUBSET_NAME),
("start_frame", str(self._db_data.start_frame + db_segment.start_frame * self._frame_step)),
("stop_frame", str(self._db_data.start_frame + db_segment.stop_frame * self._frame_step)),
("frame_filter", self._db_data.frame_filter),
("segments", [
("segment", OrderedDict([
("id", str(db_segment.id)),
("start", str(db_segment.start_frame)),
("stop", str(db_segment.stop_frame)),
("url", "{}/api/jobs/{}".format(self._host, self._db_job.id))])),
]),
("owner", OrderedDict([
("username", self._db_task.owner.username),
("email", self._db_task.owner.email)
]) if self._db_task.owner else ""),
("assignee", OrderedDict([
("username", self._db_job.assignee.username),
("email", self._db_job.assignee.email)
]) if self._db_job.assignee else ""),
])),
("dumped", str(timezone.localtime(timezone.now()))),
])
if self._label_mapping is not None:
self._meta[JobData.META_FIELD]["labels"] = CommonData._convert_db_labels(self._label_mapping.values())
if hasattr(self._db_data, "video"):
self._meta["original_size"] = OrderedDict([
("width", str(self._db_data.video.width)),
("height", str(self._db_data.video.height))
])
def __len__(self):
segment = self._db_job.segment
return segment.stop_frame - segment.start_frame + 1
def _get_queryset(self):
return self._db_data.images.filter(frame__in=self.abs_range)
@property
def abs_range(self):
segment = self._db_job.segment
step = self._frame_step
start_frame = self._db_data.start_frame + segment.start_frame * step
stop_frame = self._db_data.start_frame + segment.stop_frame * step + 1
return range(start_frame, stop_frame, step)
@property
def rel_range(self):
segment = self._db_job.segment
return range(segment.start_frame, segment.stop_frame + 1)
@property
def start(self) -> int:
segment = self._db_job.segment
return segment.start_frame
@property
def stop(self) -> int:
segment = self._db_job.segment
return segment.stop_frame + 1
@property
def db_instance(self):
return self._db_job
class TaskData(CommonData):
META_FIELD = "task"
def __init__(self, annotation_ir, db_task, host='', create_callback=None):
self._db_task = db_task
super().__init__(annotation_ir, db_task, host, create_callback)
@staticmethod
def meta_for_task(db_task, host, label_mapping=None):
db_segments = db_task.segment_set.all().prefetch_related('job_set')
meta = OrderedDict([
("id", str(db_task.id)),
("name", db_task.name),
("size", str(db_task.data.size)),
("mode", db_task.mode),
("overlap", str(db_task.overlap)),
("bugtracker", db_task.bug_tracker),
("created", str(timezone.localtime(db_task.created_date))),
("updated", str(timezone.localtime(db_task.updated_date))),
("subset", db_task.subset or dm.DEFAULT_SUBSET_NAME),
("start_frame", str(db_task.data.start_frame)),
("stop_frame", str(db_task.data.stop_frame)),
("frame_filter", db_task.data.frame_filter),
("segments", [
("segment", OrderedDict([
("id", str(db_segment.id)),
("start", str(db_segment.start_frame)),
("stop", str(db_segment.stop_frame)),
("url", "{}/api/jobs/{}".format(
host, db_segment.job_set.all()[0].id))]
)) for db_segment in db_segments
]),
("owner", OrderedDict([
("username", db_task.owner.username),
("email", db_task.owner.email)
]) if db_task.owner else ""),
("assignee", OrderedDict([
("username", db_task.assignee.username),
("email", db_task.assignee.email)
]) if db_task.assignee else ""),
])
if label_mapping is not None:
meta['labels'] = CommonData._convert_db_labels(label_mapping.values())
if hasattr(db_task.data, "video"):
meta["original_size"] = OrderedDict([
("width", str(db_task.data.video.width)),
("height", str(db_task.data.video.height))
])
return meta
def _init_meta(self):
self._meta = OrderedDict([
(TaskData.META_FIELD, self.meta_for_task(self._db_task, self._host, self._label_mapping)),
("dumped", str(timezone.localtime(timezone.now())))
])
if hasattr(self._db_task.data, "video"):
# Add source to dumped file
self._meta["source"] = str(
osp.basename(self._db_task.data.video.path))
def __len__(self):
return self._db_data.size
@property
def rel_range(self):
return range(len(self))
@property
def db_instance(self):
return self._db_task
def _get_queryset(self):
return self._db_data.images.all()
class ProjectData(InstanceLabelData): class ProjectData(InstanceLabelData):
META_FIELD = 'project'
@attrs @attrs
class LabeledShape: class LabeledShape:
type: str = attrib() type: str = attrib()
@ -765,7 +909,7 @@ class ProjectData(InstanceLabelData):
def _init_meta(self): def _init_meta(self):
self._meta = OrderedDict([ self._meta = OrderedDict([
('project', OrderedDict([ (ProjectData.META_FIELD, OrderedDict([
('id', str(self._db_project.id)), ('id', str(self._db_project.id)),
('name', self._db_project.name), ('name', self._db_project.name),
("bugtracker", self._db_project.bug_tracker), ("bugtracker", self._db_project.bug_tracker),
@ -819,7 +963,7 @@ class ProjectData(InstanceLabelData):
labels.append(('label', label)) labels.append(('label', label))
self._meta['project']['labels'] = labels self._meta[ProjectData.META_FIELD]['labels'] = labels
def _export_tracked_shape(self, shape: dict, task_id: int): def _export_tracked_shape(self, shape: dict, task_id: int):
return ProjectData.TrackedShape( return ProjectData.TrackedShape(
@ -1040,7 +1184,7 @@ class ProjectData(InstanceLabelData):
for task_data in self.task_data: for task_data in self.task_data:
if task_data._db_task.id not in self.new_tasks: if task_data._db_task.id not in self.new_tasks:
continue continue
subset_dataset: dm.Dataset = dataset.subsets()[task_data.db_task.subset].as_dataset() subset_dataset: dm.Dataset = dataset.subsets()[task_data.db_instance.subset].as_dataset()
yield subset_dataset, task_data yield subset_dataset, task_data
def add_labels(self, labels: List[dict]): def add_labels(self, labels: List[dict]):
@ -1096,7 +1240,7 @@ class CVATDataExtractorMixin:
"updatedAt": meta['updated'] "updatedAt": meta['updated']
} }
def _read_cvat_anno(self, cvat_frame_anno: Union[ProjectData.Frame, TaskData.Frame], labels: list): def _read_cvat_anno(self, cvat_frame_anno: Union[ProjectData.Frame, CommonData.Frame], labels: list):
categories = self.categories() categories = self.categories()
label_cat = categories[dm.AnnotationType.label] label_cat = categories[dm.AnnotationType.label]
def map_label(name, parent=''): return label_cat.find(name, parent)[0] def map_label(name, parent=''): return label_cat.find(name, parent)[0]
@ -1108,16 +1252,17 @@ class CVATDataExtractorMixin:
return convert_cvat_anno_to_dm(cvat_frame_anno, label_attrs, map_label) return convert_cvat_anno_to_dm(cvat_frame_anno, label_attrs, map_label)
class CvatTaskDataExtractor(dm.SourceExtractor, CVATDataExtractorMixin): class CvatTaskOrJobDataExtractor(dm.SourceExtractor, CVATDataExtractorMixin):
def __init__(self, task_data, include_images=False, format_type=None, dimension=DimensionType.DIM_2D): def __init__(self, instance_data: CommonData, include_images=False, format_type=None, dimension=DimensionType.DIM_2D):
super().__init__(media_type=dm.Image if dimension == DimensionType.DIM_2D else PointCloud) super().__init__(media_type=dm.Image if dimension == DimensionType.DIM_2D else PointCloud)
self._categories = self._load_categories(task_data.meta['task']['labels']) instance_meta = instance_data.meta[instance_data.META_FIELD]
self._user = self._load_user_info(task_data.meta['task']) if dimension == DimensionType.DIM_3D else {} self._categories = self._load_categories(instance_meta['labels'])
self._user = self._load_user_info(instance_meta) if dimension == DimensionType.DIM_3D else {}
self._dimension = dimension self._dimension = dimension
self._format_type = format_type self._format_type = format_type
dm_items = [] dm_items = []
is_video = task_data.meta['task']['mode'] == 'interpolation' is_video = instance_meta['mode'] == 'interpolation'
ext = '' ext = ''
if is_video: if is_video:
ext = FrameProvider.VIDEO_FRAME_EXT ext = FrameProvider.VIDEO_FRAME_EXT
@ -1125,7 +1270,7 @@ class CvatTaskDataExtractor(dm.SourceExtractor, CVATDataExtractorMixin):
if dimension == DimensionType.DIM_3D: if dimension == DimensionType.DIM_3D:
def _make_image(image_id, **kwargs): def _make_image(image_id, **kwargs):
loader = osp.join( loader = osp.join(
task_data.db_task.data.get_upload_dirname(), kwargs['path']) instance_data.db_data.get_upload_dirname(), kwargs['path'])
related_images = [] related_images = []
image = Img.objects.get(id=image_id) image = Img.objects.get(id=image_id)
for i in image.related_files.all(): for i in image.related_files.all():
@ -1135,7 +1280,7 @@ class CvatTaskDataExtractor(dm.SourceExtractor, CVATDataExtractorMixin):
return loader, related_images return loader, related_images
elif include_images: elif include_images:
frame_provider = FrameProvider(task_data.db_task.data) frame_provider = FrameProvider(instance_data.db_data)
if is_video: if is_video:
# optimization for videos: use numpy arrays instead of bytes # optimization for videos: use numpy arrays instead of bytes
# some formats or transforms can require image data # some formats or transforms can require image data
@ -1152,7 +1297,7 @@ class CvatTaskDataExtractor(dm.SourceExtractor, CVATDataExtractorMixin):
out_type=frame_provider.Type.BUFFER)[0].getvalue() out_type=frame_provider.Type.BUFFER)[0].getvalue()
return dm.ByteImage(data=loader, **kwargs) return dm.ByteImage(data=loader, **kwargs)
for frame_data in task_data.group_by_frame(include_empty=True): for frame_data in instance_data.group_by_frame(include_empty=True):
image_args = { image_args = {
'path': frame_data.name + ext, 'path': frame_data.name + ext,
'size': (frame_data.height, frame_data.width), 'size': (frame_data.height, frame_data.width),
@ -1164,7 +1309,7 @@ class CvatTaskDataExtractor(dm.SourceExtractor, CVATDataExtractorMixin):
dm_image = _make_image(frame_data.idx, **image_args) dm_image = _make_image(frame_data.idx, **image_args)
else: else:
dm_image = dm.Image(**image_args) dm_image = dm.Image(**image_args)
dm_anno = self._read_cvat_anno(frame_data, task_data.meta['task']['labels']) dm_anno = self._read_cvat_anno(frame_data, instance_meta['labels'])
if dimension == DimensionType.DIM_2D: if dimension == DimensionType.DIM_2D:
dm_item = dm.DatasetItem( dm_item = dm.DatasetItem(
@ -1179,7 +1324,7 @@ class CvatTaskDataExtractor(dm.SourceExtractor, CVATDataExtractorMixin):
attributes["createdAt"] = self._user["createdAt"] attributes["createdAt"] = self._user["createdAt"]
attributes["updatedAt"] = self._user["updatedAt"] attributes["updatedAt"] = self._user["updatedAt"]
attributes["labels"] = [] attributes["labels"] = []
for (idx, (_, label)) in enumerate(task_data.meta['task']['labels']): for (idx, (_, label)) in enumerate(instance_meta['labels']):
attributes["labels"].append({"label_id": idx, "name": label["name"], "color": label["color"], "type": label["type"]}) attributes["labels"].append({"label_id": idx, "name": label["name"], "color": label["color"], "type": label["type"]})
attributes["track_id"] = -1 attributes["track_id"] = -1
@ -1193,7 +1338,7 @@ class CvatTaskDataExtractor(dm.SourceExtractor, CVATDataExtractorMixin):
self._items = dm_items self._items = dm_items
def _read_cvat_anno(self, cvat_frame_anno: TaskData.Frame, labels: list): def _read_cvat_anno(self, cvat_frame_anno: CommonData.Frame, labels: list):
categories = self.categories() categories = self.categories()
label_cat = categories[dm.AnnotationType.label] label_cat = categories[dm.AnnotationType.label]
def map_label(name, parent=''): return label_cat.find(name, parent)[0] def map_label(name, parent=''): return label_cat.find(name, parent)[0]
@ -1207,8 +1352,8 @@ class CvatTaskDataExtractor(dm.SourceExtractor, CVATDataExtractorMixin):
class CVATProjectDataExtractor(dm.Extractor, CVATDataExtractorMixin): class CVATProjectDataExtractor(dm.Extractor, CVATDataExtractorMixin):
def __init__(self, project_data: ProjectData, include_images: bool = False, format_type: str = None, dimension: DimensionType = DimensionType.DIM_2D): def __init__(self, project_data: ProjectData, include_images: bool = False, format_type: str = None, dimension: DimensionType = DimensionType.DIM_2D):
super().__init__(media_type=dm.Image if dimension == DimensionType.DIM_2D else PointCloud) super().__init__(media_type=dm.Image if dimension == DimensionType.DIM_2D else PointCloud)
self._categories = self._load_categories(project_data.meta['project']['labels']) self._categories = self._load_categories(project_data.meta[project_data.META_FIELD]['labels'])
self._user = self._load_user_info(project_data.meta['project']) if dimension == DimensionType.DIM_3D else {} self._user = self._load_user_info(project_data.meta[project_data.META_FIELD]) if dimension == DimensionType.DIM_3D else {}
self._dimension = dimension self._dimension = dimension
self._format_type = format_type self._format_type = format_type
@ -1271,7 +1416,7 @@ class CVATProjectDataExtractor(dm.Extractor, CVATDataExtractorMixin):
dm_image = image_maker_per_task[frame_data.task_id](frame_data.idx, **image_args) dm_image = image_maker_per_task[frame_data.task_id](frame_data.idx, **image_args)
else: else:
dm_image = dm.Image(**image_args) dm_image = dm.Image(**image_args)
dm_anno = self._read_cvat_anno(frame_data, project_data.meta['project']['labels']) dm_anno = self._read_cvat_anno(frame_data, project_data.meta[project_data.META_FIELD]['labels'])
if self._dimension == DimensionType.DIM_2D: if self._dimension == DimensionType.DIM_2D:
dm_item = dm.DatasetItem( dm_item = dm.DatasetItem(
id=osp.splitext(frame_data.name)[0], id=osp.splitext(frame_data.name)[0],
@ -1286,7 +1431,7 @@ class CVATProjectDataExtractor(dm.Extractor, CVATDataExtractorMixin):
attributes["createdAt"] = self._user["createdAt"] attributes["createdAt"] = self._user["createdAt"]
attributes["updatedAt"] = self._user["updatedAt"] attributes["updatedAt"] = self._user["updatedAt"]
attributes["labels"] = [] attributes["labels"] = []
for (idx, (_, label)) in enumerate(project_data.meta['project']['labels']): for (idx, (_, label)) in enumerate(project_data.meta[project_data.META_FIELD]['labels']):
attributes["labels"].append({"label_id": idx, "name": label["name"], "color": label["color"], "type": label["type"]}) attributes["labels"].append({"label_id": idx, "name": label["name"], "color": label["color"], "type": label["type"]})
attributes["track_id"] = -1 attributes["track_id"] = -1
@ -1309,11 +1454,16 @@ class CVATProjectDataExtractor(dm.Extractor, CVATDataExtractorMixin):
return len(self._items) return len(self._items)
def GetCVATDataExtractor(instance_data: Union[ProjectData, TaskData], include_images: bool = False, format_type: str = None, dimension: DimensionType = DimensionType.DIM_2D): def GetCVATDataExtractor(
instance_data: Union[ProjectData, CommonData],
include_images: bool = False,
format_type: str = None,
dimension: DimensionType = DimensionType.DIM_2D,
):
if isinstance(instance_data, ProjectData): if isinstance(instance_data, ProjectData):
return CVATProjectDataExtractor(instance_data, include_images, format_type, dimension) return CVATProjectDataExtractor(instance_data, include_images, format_type, dimension)
else: else:
return CvatTaskDataExtractor(instance_data, include_images, format_type, dimension) return CvatTaskOrJobDataExtractor(instance_data, include_images, format_type, dimension)
class CvatImportError(Exception): class CvatImportError(Exception):
pass pass
@ -1469,25 +1619,25 @@ def convert_cvat_anno_to_dm(cvat_frame_anno, label_attrs, map_label, format_name
return item_anno return item_anno
def match_dm_item(item, task_data, root_hint=None): def match_dm_item(item, instance_data, root_hint=None):
is_video = task_data.meta['task']['mode'] == 'interpolation' is_video = instance_data.meta[instance_data.META_FIELD]['mode'] == 'interpolation'
frame_number = None frame_number = None
if frame_number is None and item.has_image: if frame_number is None and item.has_image:
frame_number = task_data.match_frame(item.id + item.image.ext, root_hint) frame_number = instance_data.match_frame(item.id + item.image.ext, root_hint)
if frame_number is None: if frame_number is None:
frame_number = task_data.match_frame(item.id, root_hint, path_has_ext=False) frame_number = instance_data.match_frame(item.id, root_hint, path_has_ext=False)
if frame_number is None: if frame_number is None:
frame_number = dm.util.cast(item.attributes.get('frame', item.id), int) frame_number = dm.util.cast(item.attributes.get('frame', item.id), int)
if frame_number is None and is_video: if frame_number is None and is_video:
frame_number = dm.util.cast(osp.basename(item.id)[len('frame_'):], int) frame_number = dm.util.cast(osp.basename(item.id)[len('frame_'):], int)
if not frame_number in task_data.frame_info: if not frame_number in instance_data.frame_info:
raise CvatImportError("Could not match item id: " raise CvatImportError("Could not match item id: "
"'%s' with any task frame" % item.id) "'%s' with any task frame" % item.id)
return frame_number return frame_number
def find_dataset_root(dm_dataset, instance_data: Union[TaskData, ProjectData]): def find_dataset_root(dm_dataset, instance_data: Union[ProjectData, CommonData]):
longest_path = max(dm_dataset, key=lambda x: len(Path(x.id).parts), longest_path = max(dm_dataset, key=lambda x: len(Path(x.id).parts),
default=None) default=None)
if longest_path is None: if longest_path is None:
@ -1503,7 +1653,7 @@ def find_dataset_root(dm_dataset, instance_data: Union[TaskData, ProjectData]):
prefix = prefix[:-1] prefix = prefix[:-1]
return prefix return prefix
def import_dm_annotations(dm_dataset: dm.Dataset, instance_data: Union[TaskData, ProjectData]): def import_dm_annotations(dm_dataset: dm.Dataset, instance_data: Union[ProjectData, CommonData]):
if len(dm_dataset) == 0: if len(dm_dataset) == 0:
return return

@ -21,7 +21,7 @@ from datumaro.components.extractor import (DEFAULT_SUBSET_NAME, Extractor,
from datumaro.util.image import Image from datumaro.util.image import Image
from defusedxml import ElementTree from defusedxml import ElementTree
from cvat.apps.dataset_manager.bindings import (ProjectData, TaskData, from cvat.apps.dataset_manager.bindings import (ProjectData, CommonData,
get_defaulted_subset, get_defaulted_subset,
import_dm_annotations, import_dm_annotations,
match_dm_item) match_dm_item)
@ -984,11 +984,11 @@ def dump_as_cvat_interpolation(dumper, annotations):
counter += 1 counter += 1
for shape in annotations.shapes: for shape in annotations.shapes:
frame_step = annotations.frame_step if isinstance(annotations, TaskData) else annotations.frame_step[shape.task_id] frame_step = annotations.frame_step if not isinstance(annotations, ProjectData) else annotations.frame_step[shape.task_id]
if isinstance(annotations, TaskData): if not isinstance(annotations, ProjectData):
stop_frame = int(annotations.meta['task']['stop_frame']) stop_frame = int(annotations.meta[annotations.META_FIELD]['stop_frame'])
else: else:
task_meta = list(filter(lambda task: int(task[1]['id']) == shape.task_id, annotations.meta['project']['tasks']))[0][1] task_meta = list(filter(lambda task: int(task[1]['id']) == shape.task_id, annotations.meta[annotations.META_FIELD]['tasks']))[0][1]
stop_frame = int(task_meta['stop_frame']) stop_frame = int(task_meta['stop_frame'])
track = { track = {
'label': shape.label, 'label': shape.label,
@ -1102,7 +1102,7 @@ def load_anno(file_object, annotations):
attributes={'frame': el.attrib['id']}, attributes={'frame': el.attrib['id']},
image=el.attrib['name'] image=el.attrib['name']
), ),
task_data=annotations instance_data=annotations
)) ))
elif el.tag in supported_shapes and (track is not None or image_is_opened): elif el.tag in supported_shapes and (track is not None or image_is_opened):
if shape and shape['type'] == 'skeleton': if shape and shape['type'] == 'skeleton':
@ -1258,10 +1258,10 @@ def load_anno(file_object, annotations):
tag = None tag = None
el.clear() el.clear()
def dump_task_anno(dst_file, task_data, callback): def dump_task_or_job_anno(dst_file, instance_data, callback):
dumper = create_xml_dumper(dst_file) dumper = create_xml_dumper(dst_file)
dumper.open_document() dumper.open_document()
callback(dumper, task_data) callback(dumper, instance_data)
dumper.close_document() dumper.close_document()
def dump_project_anno(dst_file: BufferedWriter, project_data: ProjectData, callback: Callable): def dump_project_anno(dst_file: BufferedWriter, project_data: ProjectData, callback: Callable):
@ -1270,33 +1270,34 @@ def dump_project_anno(dst_file: BufferedWriter, project_data: ProjectData, callb
callback(dumper, project_data) callback(dumper, project_data)
dumper.close_document() dumper.close_document()
def dump_media_files(task_data: TaskData, img_dir: str, project_data: ProjectData = None): def dump_media_files(instance_data: CommonData, img_dir: str, project_data: ProjectData = None):
ext = '' ext = ''
if task_data.meta['task']['mode'] == 'interpolation': if instance_data.meta[instance_data.META_FIELD]['mode'] == 'interpolation':
ext = FrameProvider.VIDEO_FRAME_EXT ext = FrameProvider.VIDEO_FRAME_EXT
frame_provider = FrameProvider(task_data.db_task.data) frame_provider = FrameProvider(instance_data.db_data)
frames = frame_provider.get_frames( frames = frame_provider.get_frames(
instance_data.start, instance_data.stop,
frame_provider.Quality.ORIGINAL, frame_provider.Quality.ORIGINAL,
frame_provider.Type.BUFFER) frame_provider.Type.BUFFER)
for frame_id, (frame_data, _) in enumerate(frames): for frame_id, (frame_data, _) in zip(instance_data.rel_range, frames):
if (project_data is not None and (task_data.db_task.id, frame_id) in project_data.deleted_frames) \ if (project_data is not None and (instance_data.db_instance.id, frame_id) in project_data.deleted_frames) \
or frame_id in task_data.deleted_frames: or frame_id in instance_data.deleted_frames:
continue continue
frame_name = task_data.frame_info[frame_id]['path'] if project_data is None \ frame_name = instance_data.frame_info[frame_id]['path'] if project_data is None \
else project_data.frame_info[(task_data.db_task.id, frame_id)]['path'] else project_data.frame_info[(instance_data.db_instance.id, frame_id)]['path']
img_path = osp.join(img_dir, frame_name + ext) img_path = osp.join(img_dir, frame_name + ext)
os.makedirs(osp.dirname(img_path), exist_ok=True) os.makedirs(osp.dirname(img_path), exist_ok=True)
with open(img_path, 'wb') as f: with open(img_path, 'wb') as f:
f.write(frame_data.getvalue()) f.write(frame_data.getvalue())
def _export_task(dst_file, task_data, anno_callback, save_images=False): def _export_task_or_job(dst_file, instance_data, anno_callback, save_images=False):
with TemporaryDirectory() as temp_dir: with TemporaryDirectory() as temp_dir:
with open(osp.join(temp_dir, 'annotations.xml'), 'wb') as f: with open(osp.join(temp_dir, 'annotations.xml'), 'wb') as f:
dump_task_anno(f, task_data, anno_callback) dump_task_or_job_anno(f, instance_data, anno_callback)
if save_images: if save_images:
dump_media_files(task_data, osp.join(temp_dir, 'images')) dump_media_files(instance_data, osp.join(temp_dir, 'images'))
make_zip_archive(temp_dir, dst_file) make_zip_archive(temp_dir, dst_file)
@ -1307,7 +1308,7 @@ def _export_project(dst_file: str, project_data: ProjectData, anno_callback: Cal
if save_images: if save_images:
for task_data in project_data.task_data: for task_data in project_data.task_data:
subset = get_defaulted_subset(task_data.db_task.subset, project_data.subsets) subset = get_defaulted_subset(task_data.db_instance.subset, project_data.subsets)
subset_dir = osp.join(temp_dir, 'images', subset) subset_dir = osp.join(temp_dir, 'images', subset)
os.makedirs(subset_dir, exist_ok=True) os.makedirs(subset_dir, exist_ok=True)
dump_media_files(task_data, subset_dir, project_data) dump_media_files(task_data, subset_dir, project_data)
@ -1320,7 +1321,7 @@ def _export_video(dst_file, instance_data, save_images=False):
_export_project(dst_file, instance_data, _export_project(dst_file, instance_data,
anno_callback=dump_as_cvat_interpolation, save_images=save_images) anno_callback=dump_as_cvat_interpolation, save_images=save_images)
else: else:
_export_task(dst_file, instance_data, _export_task_or_job(dst_file, instance_data,
anno_callback=dump_as_cvat_interpolation, save_images=save_images) anno_callback=dump_as_cvat_interpolation, save_images=save_images)
@exporter(name='CVAT for images', ext='ZIP', version='1.1') @exporter(name='CVAT for images', ext='ZIP', version='1.1')
@ -1329,7 +1330,7 @@ def _export_images(dst_file, instance_data, save_images=False):
_export_project(dst_file, instance_data, _export_project(dst_file, instance_data,
anno_callback=dump_as_cvat_annotation, save_images=save_images) anno_callback=dump_as_cvat_annotation, save_images=save_images)
else: else:
_export_task(dst_file, instance_data, _export_task_or_job(dst_file, instance_data,
anno_callback=dump_as_cvat_annotation, save_images=save_images) anno_callback=dump_as_cvat_annotation, save_images=save_images)
@importer(name='CVAT', ext='XML, ZIP', version='1.1') @importer(name='CVAT', ext='XML, ZIP', version='1.1')

@ -1,4 +1,5 @@
# Copyright (C) 2021-2022 Intel Corporation # Copyright (C) 2021-2022 Intel Corporation
# Copyright (C) 2022 CVAT.ai Corporation
# #
# SPDX-License-Identifier: MIT # SPDX-License-Identifier: MIT
@ -9,8 +10,7 @@ from datumaro.components.dataset import Dataset
from datumaro.plugins.kitti_format.format import KittiPath, write_label_map from datumaro.plugins.kitti_format.format import KittiPath, write_label_map
from pyunpack import Archive from pyunpack import Archive
from cvat.apps.dataset_manager.bindings import (GetCVATDataExtractor, from cvat.apps.dataset_manager.bindings import (GetCVATDataExtractor, import_dm_annotations)
ProjectData, import_dm_annotations)
from cvat.apps.dataset_manager.util import make_zip_archive from cvat.apps.dataset_manager.util import make_zip_archive
from .transformations import RotatedBoxesToPolygons from .transformations import RotatedBoxesToPolygons
@ -45,8 +45,7 @@ def _import(src_file, instance_data):
write_label_map(color_map_path, color_map) write_label_map(color_map_path, color_map)
dataset = Dataset.import_from(tmp_dir, format='kitti', env=dm_env) dataset = Dataset.import_from(tmp_dir, format='kitti', env=dm_env)
labels_meta = instance_data.meta['project']['labels'] \ labels_meta = instance_data.meta[instance_data.META_FIELD]['labels']
if isinstance(instance_data, ProjectData) else instance_data.meta['task']['labels']
if 'background' not in [label['name'] for _, label in labels_meta]: if 'background' not in [label['name'] for _, label in labels_meta]:
dataset.filter('/item/annotation[label != "background"]', dataset.filter('/item/annotation[label != "background"]',
filter_annotations=True) filter_annotations=True)

@ -1,4 +1,5 @@
# Copyright (C) 2019-2022 Intel Corporation # Copyright (C) 2019-2022 Intel Corporation
# Copyright (C) 2022 CVAT.ai Corporation
# #
# SPDX-License-Identifier: MIT # SPDX-License-Identifier: MIT
@ -13,13 +14,15 @@ from cvat.apps.dataset_manager.util import make_zip_archive
from .registry import dm_env, exporter, importer from .registry import dm_env, exporter, importer
def _import_task(dataset, task_data): def _import_to_task(dataset, instance_data):
tracks = {} tracks = {}
label_cat = dataset.categories()[datumaro.AnnotationType.label] label_cat = dataset.categories()[datumaro.AnnotationType.label]
for item in dataset: for item in dataset:
frame_number = int(item.id) - 1 # NOTE: MOT frames start from 1 # NOTE: MOT frames start from 1
frame_number = task_data.abs_frame_id(frame_number) # job has an offset, for task offset is 0
frame_number = int(item.id) - 1 + instance_data.start
frame_number = instance_data.abs_frame_id(frame_number)
for ann in item.annotations: for ann in item.annotations:
if ann.type != datumaro.AnnotationType.bbox: if ann.type != datumaro.AnnotationType.bbox:
@ -28,7 +31,7 @@ def _import_task(dataset, task_data):
track_id = ann.attributes.get('track_id') track_id = ann.attributes.get('track_id')
if track_id is None: if track_id is None:
# Extension. Import regular boxes: # Extension. Import regular boxes:
task_data.add_shape(task_data.LabeledShape( instance_data.add_shape(instance_data.LabeledShape(
type='rectangle', type='rectangle',
label=label_cat.items[ann.label].name, label=label_cat.items[ann.label].name,
points=ann.points, points=ann.points,
@ -41,7 +44,7 @@ def _import_task(dataset, task_data):
)) ))
continue continue
shape = task_data.TrackedShape( shape = instance_data.TrackedShape(
type='rectangle', type='rectangle',
points=ann.points, points=ann.points,
occluded=ann.attributes.get('occluded') is True, occluded=ann.attributes.get('occluded') is True,
@ -55,7 +58,7 @@ def _import_task(dataset, task_data):
# build trajectories as lists of shapes in track dict # build trajectories as lists of shapes in track dict
if track_id not in tracks: if track_id not in tracks:
tracks[track_id] = task_data.Track( tracks[track_id] = instance_data.Track(
label_cat.items[ann.label].name, 0, 'manual', []) label_cat.items[ann.label].name, 0, 'manual', [])
tracks[track_id].shapes.append(shape) tracks[track_id].shapes.append(shape)
@ -67,10 +70,10 @@ def _import_task(dataset, task_data):
prev_shape_idx = 0 prev_shape_idx = 0
prev_shape = track.shapes[0] prev_shape = track.shapes[0]
for shape in track.shapes[1:]: for shape in track.shapes[1:]:
has_skip = task_data.frame_step < shape.frame - prev_shape.frame has_skip = instance_data.frame_step < shape.frame - prev_shape.frame
if has_skip and not prev_shape.outside: if has_skip and not prev_shape.outside:
prev_shape = prev_shape._replace(outside=True, prev_shape = prev_shape._replace(outside=True,
frame=prev_shape.frame + task_data.frame_step) frame=prev_shape.frame + instance_data.frame_step)
prev_shape_idx += 1 prev_shape_idx += 1
track.shapes.insert(prev_shape_idx, prev_shape) track.shapes.insert(prev_shape_idx, prev_shape)
prev_shape = shape prev_shape = shape
@ -78,12 +81,12 @@ def _import_task(dataset, task_data):
# Append a shape with outside=True to finish the track # Append a shape with outside=True to finish the track
last_shape = track.shapes[-1] last_shape = track.shapes[-1]
if last_shape.frame + task_data.frame_step <= \ if last_shape.frame + instance_data.frame_step <= \
int(task_data.meta['task']['stop_frame']): int(instance_data.meta[instance_data.META_FIELD]['stop_frame']):
track.shapes.append(last_shape._replace(outside=True, track.shapes.append(last_shape._replace(outside=True,
frame=last_shape.frame + task_data.frame_step) frame=last_shape.frame + instance_data.frame_step)
) )
task_data.add_track(track) instance_data.add_track(track)
@exporter(name='MOT', ext='ZIP', version='1.1') @exporter(name='MOT', ext='ZIP', version='1.1')
@ -107,7 +110,7 @@ def _import(src_file, instance_data, load_data_callback=None):
# Dirty way to determine instance type to avoid circular dependency # Dirty way to determine instance type to avoid circular dependency
if hasattr(instance_data, '_db_project'): if hasattr(instance_data, '_db_project'):
for sub_dataset, task_data in instance_data.split_dataset(dataset): for sub_dataset, task_data in instance_data.split_dataset(dataset):
_import_task(sub_dataset, task_data) _import_to_task(sub_dataset, task_data)
else: else:
_import_task(dataset, instance_data) _import_to_task(dataset, instance_data)

@ -1,4 +1,5 @@
# Copyright (C) 2019-2022 Intel Corporation # Copyright (C) 2019-2022 Intel Corporation
# Copyright (C) 2022 CVAT.ai Corporation
# #
# SPDX-License-Identifier: MIT # SPDX-License-Identifier: MIT
@ -22,16 +23,16 @@ class KeepTracks(ItemTransform):
return item.wrap(annotations=[a for a in item.annotations return item.wrap(annotations=[a for a in item.annotations
if 'track_id' in a.attributes]) if 'track_id' in a.attributes])
def _import_task(dataset, task_data): def _import_to_task(dataset, instance_data):
tracks = {} tracks = {}
label_cat = dataset.categories()[AnnotationType.label] label_cat = dataset.categories()[AnnotationType.label]
root_hint = find_dataset_root(dataset, task_data) root_hint = find_dataset_root(dataset, instance_data)
shift = 0 shift = 0
for item in dataset: for item in dataset:
frame_number = task_data.abs_frame_id( frame_number = instance_data.abs_frame_id(
match_dm_item(item, task_data, root_hint=root_hint)) match_dm_item(item, instance_data, root_hint=root_hint))
track_ids = set() track_ids = set()
@ -49,7 +50,7 @@ def _import_task(dataset, task_data):
else: else:
track_ids.add(track_id) track_ids.add(track_id)
shape = task_data.TrackedShape( shape = instance_data.TrackedShape(
type='polygon', type='polygon',
points=ann.points, points=ann.points,
occluded=ann.attributes.get('occluded') is True, occluded=ann.attributes.get('occluded') is True,
@ -64,7 +65,7 @@ def _import_task(dataset, task_data):
# build trajectories as lists of shapes in track dict # build trajectories as lists of shapes in track dict
if track_id not in tracks: if track_id not in tracks:
tracks[track_id] = task_data.Track( tracks[track_id] = instance_data.Track(
label_cat.items[ann.label].name, 0, 'manual', []) label_cat.items[ann.label].name, 0, 'manual', [])
tracks[track_id].shapes.append(shape) tracks[track_id].shapes.append(shape)
@ -75,10 +76,10 @@ def _import_task(dataset, task_data):
prev_shape_idx = 0 prev_shape_idx = 0
prev_shape = track.shapes[0] prev_shape = track.shapes[0]
for shape in track.shapes[1:]: for shape in track.shapes[1:]:
has_skip = task_data.frame_step < shape.frame - prev_shape.frame has_skip = instance_data.frame_step < shape.frame - prev_shape.frame
if has_skip and not prev_shape.outside: if has_skip and not prev_shape.outside:
prev_shape = prev_shape._replace(outside=True, prev_shape = prev_shape._replace(outside=True,
frame=prev_shape.frame + task_data.frame_step) frame=prev_shape.frame + instance_data.frame_step)
prev_shape_idx += 1 prev_shape_idx += 1
track.shapes.insert(prev_shape_idx, prev_shape) track.shapes.insert(prev_shape_idx, prev_shape)
prev_shape = shape prev_shape = shape
@ -86,12 +87,12 @@ def _import_task(dataset, task_data):
# Append a shape with outside=True to finish the track # Append a shape with outside=True to finish the track
last_shape = track.shapes[-1] last_shape = track.shapes[-1]
if last_shape.frame + task_data.frame_step <= \ if last_shape.frame + instance_data.frame_step <= \
int(task_data.meta['task']['stop_frame']): int(instance_data.meta[instance_data.META_FIELD]['stop_frame']):
track.shapes.append(last_shape._replace(outside=True, track.shapes.append(last_shape._replace(outside=True,
frame=last_shape.frame + task_data.frame_step) frame=last_shape.frame + instance_data.frame_step)
) )
task_data.add_track(track) instance_data.add_track(track)
@exporter(name='MOTS PNG', ext='ZIP', version='1.0') @exporter(name='MOTS PNG', ext='ZIP', version='1.0')
def _export(dst_file, instance_data, save_images=False): def _export(dst_file, instance_data, save_images=False):
@ -120,7 +121,7 @@ def _import(src_file, instance_data, load_data_callback=None):
# Dirty way to determine instance type to avoid circular dependency # Dirty way to determine instance type to avoid circular dependency
if hasattr(instance_data, '_db_project'): if hasattr(instance_data, '_db_project'):
for sub_dataset, task_data in instance_data.split_dataset(dataset): for sub_dataset, task_data in instance_data.split_dataset(dataset):
_import_task(sub_dataset, task_data) _import_to_task(sub_dataset, task_data)
else: else:
_import_task(dataset, instance_data) _import_to_task(dataset, instance_data)

@ -1,4 +1,5 @@
# Copyright (C) 2020-2022 Intel Corporation # Copyright (C) 2020-2022 Intel Corporation
# Copyright (C) 2022 CVAT.ai Corporation
# #
# SPDX-License-Identifier: MIT # SPDX-License-Identifier: MIT
@ -11,8 +12,7 @@ from tempfile import TemporaryDirectory
from datumaro.components.dataset import Dataset from datumaro.components.dataset import Dataset
from pyunpack import Archive from pyunpack import Archive
from cvat.apps.dataset_manager.bindings import (GetCVATDataExtractor, from cvat.apps.dataset_manager.bindings import (GetCVATDataExtractor, import_dm_annotations)
ProjectData, import_dm_annotations)
from cvat.apps.dataset_manager.util import make_zip_archive from cvat.apps.dataset_manager.util import make_zip_archive
from .registry import dm_env, exporter, importer from .registry import dm_env, exporter, importer
@ -36,8 +36,7 @@ def _import(src_file, instance_data, load_data_callback=None):
# put label map from the task if not present # put label map from the task if not present
labelmap_file = osp.join(tmp_dir, 'labelmap.txt') labelmap_file = osp.join(tmp_dir, 'labelmap.txt')
if not osp.isfile(labelmap_file): if not osp.isfile(labelmap_file):
labels_meta = instance_data.meta['project']['labels'] \ labels_meta = instance_data.meta[instance_data.META_FIELD]['labels']
if isinstance(instance_data, ProjectData) else instance_data.meta['task']['labels']
labels = (label['name'] + ':::' for _, label in labels_meta) labels = (label['name'] + ':::' for _, label in labels_meta)
with open(labelmap_file, 'w') as f: with open(labelmap_file, 'w') as f:
f.write('\n'.join(labels)) f.write('\n'.join(labels))

@ -49,8 +49,7 @@ def hex2rgb(color):
return tuple(int(color.lstrip('#')[i:i+2], 16) for i in (0, 2, 4)) return tuple(int(color.lstrip('#')[i:i+2], 16) for i in (0, 2, 4))
def make_colormap(instance_data): def make_colormap(instance_data):
instance_name = 'project' if 'project' in instance_data.meta.keys() else 'task' labels = [label for _, label in instance_data.meta[instance_data.META_FIELD]['labels']]
labels = [label for _, label in instance_data.meta[instance_name]['labels']]
label_names = [label['name'] for label in labels] label_names = [label['name'] for label in labels]
if 'background' not in label_names: if 'background' not in label_names:

@ -16,7 +16,7 @@ from cvat.apps.engine.plugins import plugin_decorator
from cvat.apps.profiler import silk_profile from cvat.apps.profiler import silk_profile
from .annotation import AnnotationIR, AnnotationManager from .annotation import AnnotationIR, AnnotationManager
from .bindings import TaskData from .bindings import TaskData, JobData
from .formats.registry import make_exporter, make_importer from .formats.registry import make_exporter, make_importer
from .util import bulk_create from .util import bulk_create
@ -553,24 +553,24 @@ class JobAnnotation:
return self.ir_data.data return self.ir_data.data
def export(self, dst_file, exporter, host='', **options): def export(self, dst_file, exporter, host='', **options):
task_data = TaskData( job_data = JobData(
annotation_ir=self.ir_data, annotation_ir=self.ir_data,
db_task=self.db_job.segment.task, db_job=self.db_job,
host=host, host=host,
) )
exporter(dst_file, task_data, **options) exporter(dst_file, job_data, **options)
def import_annotations(self, src_file, importer): def import_annotations(self, src_file, importer):
task_data = TaskData( job_data = JobData(
annotation_ir=AnnotationIR(), annotation_ir=AnnotationIR(),
db_task=self.db_job.segment.task, db_job=self.db_job,
create_callback=self.create, create_callback=self.create,
) )
self.delete() self.delete()
importer(src_file, task_data) importer(src_file, job_data)
self.create(task_data.data.slice(self.start_frame, self.stop_frame).serialize()) self.create(job_data.data.slice(self.start_frame, self.stop_frame).serialize())
class TaskAnnotation: class TaskAnnotation:
def __init__(self, pk): def __init__(self, pk):

@ -21,7 +21,7 @@ from rest_framework.test import APIClient, APITestCase
import cvat.apps.dataset_manager as dm import cvat.apps.dataset_manager as dm
from cvat.apps.dataset_manager.annotation import AnnotationIR from cvat.apps.dataset_manager.annotation import AnnotationIR
from cvat.apps.dataset_manager.bindings import (CvatTaskDataExtractor, from cvat.apps.dataset_manager.bindings import (CvatTaskOrJobDataExtractor,
TaskData, find_dataset_root) TaskData, find_dataset_root)
from cvat.apps.dataset_manager.task import TaskAnnotation from cvat.apps.dataset_manager.task import TaskAnnotation
from cvat.apps.dataset_manager.util import make_zip_archive from cvat.apps.dataset_manager.util import make_zip_archive
@ -417,7 +417,7 @@ class TaskExportTest(_DbTestBase):
task_ann.init_from_db() task_ann.init_from_db()
task_data = TaskData(task_ann.ir_data, Task.objects.get(pk=task["id"])) task_data = TaskData(task_ann.ir_data, Task.objects.get(pk=task["id"]))
extractor = CvatTaskDataExtractor(task_data) extractor = CvatTaskOrJobDataExtractor(task_data)
dm_dataset = datumaro.components.project.Dataset.from_extractors(extractor) dm_dataset = datumaro.components.project.Dataset.from_extractors(extractor)
self.assertEqual(4, len(dm_dataset.get("image_1").annotations)) self.assertEqual(4, len(dm_dataset.get("image_1").annotations))

@ -1,4 +1,5 @@
# Copyright (C) 2021-2022 Intel Corporation # Copyright (C) 2021-2022 Intel Corporation
# Copyright (C) 2022 CVAT.ai Corporation
# #
# SPDX-License-Identifier: MIT # SPDX-License-Identifier: MIT
@ -22,7 +23,7 @@ from rest_framework import status
from rest_framework.test import APIClient, APITestCase from rest_framework.test import APIClient, APITestCase
import cvat.apps.dataset_manager as dm import cvat.apps.dataset_manager as dm
from cvat.apps.dataset_manager.bindings import CvatTaskDataExtractor, TaskData from cvat.apps.dataset_manager.bindings import CvatTaskOrJobDataExtractor, TaskData
from cvat.apps.dataset_manager.task import TaskAnnotation from cvat.apps.dataset_manager.task import TaskAnnotation
from cvat.apps.engine.models import Task from cvat.apps.engine.models import Task
@ -185,7 +186,7 @@ class _DbTestBase(APITestCase):
task_ann = TaskAnnotation(task_id) task_ann = TaskAnnotation(task_id)
task_ann.init_from_db() task_ann.init_from_db()
task_data = TaskData(task_ann.ir_data, Task.objects.get(pk=task_id)) task_data = TaskData(task_ann.ir_data, Task.objects.get(pk=task_id))
extractor = CvatTaskDataExtractor(task_data, include_images=include_images) extractor = CvatTaskOrJobDataExtractor(task_data, include_images=include_images)
return Dataset.from_extractors(extractor) return Dataset.from_extractors(extractor)
def _get_request_with_data(self, path, data, user): def _get_request_with_data(self, path, data, user):

@ -1,4 +1,5 @@
# Copyright (C) 2020-2022 Intel Corporation # Copyright (C) 2020-2022 Intel Corporation
# Copyright (C) 2022 CVAT.ai Corporation
# #
# SPDX-License-Identifier: MIT # SPDX-License-Identifier: MIT
@ -182,6 +183,6 @@ class FrameProvider:
return (frame, self.VIDEO_FRAME_MIME) return (frame, self.VIDEO_FRAME_MIME)
return (frame, mimetypes.guess_type(frame_name)[0]) return (frame, mimetypes.guess_type(frame_name)[0])
def get_frames(self, quality=Quality.ORIGINAL, out_type=Type.BUFFER): def get_frames(self, start_frame, stop_frame, quality=Quality.ORIGINAL, out_type=Type.BUFFER):
for idx in range(self._db_data.size): for idx in range(start_frame, stop_frame):
yield self.get_frame(idx, quality=quality, out_type=out_type) yield self.get_frame(idx, quality=quality, out_type=out_type)

Loading…
Cancel
Save