You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
176 lines
5.8 KiB
Python
176 lines
5.8 KiB
Python
from collections import OrderedDict
|
|
import os
|
|
import os.path as osp
|
|
|
|
from django.db import transaction
|
|
|
|
from cvat.apps.annotation.annotation import Annotation
|
|
from cvat.apps.engine.annotation import TaskAnnotation
|
|
from cvat.apps.engine.models import Task, ShapeType
|
|
|
|
import datumaro.components.extractor as datumaro
|
|
from datumaro.util.image import lazy_image
|
|
|
|
|
|
class CvatImagesDirExtractor(datumaro.Extractor):
|
|
_SUPPORTED_FORMATS = ['.png', '.jpg']
|
|
|
|
def __init__(self, url):
|
|
super().__init__()
|
|
|
|
items = []
|
|
for (dirpath, _, filenames) in os.walk(url):
|
|
for name in filenames:
|
|
path = osp.join(dirpath, name)
|
|
if self._is_image(path):
|
|
item_id = Task.get_image_frame(path)
|
|
item = datumaro.DatasetItem(
|
|
id=item_id, image=lazy_image(path))
|
|
items.append((item.id, item))
|
|
|
|
items = sorted(items, key=lambda e: e[0])
|
|
items = OrderedDict(items)
|
|
self._items = items
|
|
|
|
self._subsets = None
|
|
|
|
def __iter__(self):
|
|
for item in self._items.values():
|
|
yield item
|
|
|
|
def __len__(self):
|
|
return len(self._items)
|
|
|
|
def subsets(self):
|
|
return self._subsets
|
|
|
|
def get(self, item_id, subset=None, path=None):
|
|
if path or subset:
|
|
raise KeyError()
|
|
return self._items[item_id]
|
|
|
|
def _is_image(self, path):
|
|
for ext in self._SUPPORTED_FORMATS:
|
|
if osp.isfile(path) and path.endswith(ext):
|
|
return True
|
|
return False
|
|
|
|
|
|
class CvatTaskExtractor(datumaro.Extractor):
|
|
def __init__(self, url, db_task, user):
|
|
self._db_task = db_task
|
|
self._categories = self._load_categories()
|
|
|
|
cvat_annotations = TaskAnnotation(db_task.id, user)
|
|
with transaction.atomic():
|
|
cvat_annotations.init_from_db()
|
|
cvat_annotations = Annotation(cvat_annotations.ir_data, db_task)
|
|
|
|
dm_annotations = []
|
|
|
|
for cvat_anno in cvat_annotations.group_by_frame():
|
|
dm_anno = self._read_cvat_anno(cvat_anno)
|
|
dm_item = datumaro.DatasetItem(
|
|
id=cvat_anno.frame, annotations=dm_anno)
|
|
dm_annotations.append((dm_item.id, dm_item))
|
|
|
|
dm_annotations = sorted(dm_annotations, key=lambda e: e[0])
|
|
self._items = OrderedDict(dm_annotations)
|
|
|
|
self._subsets = None
|
|
|
|
def __iter__(self):
|
|
for item in self._items.values():
|
|
yield item
|
|
|
|
def __len__(self):
|
|
return len(self._items)
|
|
|
|
def subsets(self):
|
|
return self._subsets
|
|
|
|
def get(self, item_id, subset=None, path=None):
|
|
if path or subset:
|
|
raise KeyError()
|
|
return self._items[item_id]
|
|
|
|
def _load_categories(self):
|
|
categories = {}
|
|
label_categories = datumaro.LabelCategories()
|
|
|
|
db_labels = self._db_task.label_set.all()
|
|
for db_label in db_labels:
|
|
db_attributes = db_label.attributespec_set.all()
|
|
label_categories.add(db_label.name)
|
|
|
|
for db_attr in db_attributes:
|
|
label_categories.attributes.add(db_attr.name)
|
|
|
|
categories[datumaro.AnnotationType.label] = label_categories
|
|
|
|
return categories
|
|
|
|
def categories(self):
|
|
return self._categories
|
|
|
|
def _read_cvat_anno(self, cvat_anno):
|
|
item_anno = []
|
|
|
|
categories = self.categories()
|
|
label_cat = categories[datumaro.AnnotationType.label]
|
|
|
|
label_map = {}
|
|
label_attrs = {}
|
|
db_labels = self._db_task.label_set.all()
|
|
for db_label in db_labels:
|
|
label_map[db_label.name] = label_cat.find(db_label.name)[0]
|
|
|
|
attrs = {}
|
|
db_attributes = db_label.attributespec_set.all()
|
|
for db_attr in db_attributes:
|
|
attrs[db_attr.name] = db_attr.default_value
|
|
label_attrs[db_label.name] = attrs
|
|
map_label = lambda label_db_name: label_map[label_db_name]
|
|
|
|
for tag_obj in cvat_anno.tags:
|
|
anno_group = tag_obj.group
|
|
if isinstance(anno_group, int):
|
|
anno_group = anno_group
|
|
anno_label = map_label(tag_obj.label)
|
|
anno_attr = dict(label_attrs[tag_obj.label])
|
|
for attr in tag_obj.attributes:
|
|
anno_attr[attr.name] = attr.value
|
|
|
|
anno = datumaro.LabelObject(label=anno_label,
|
|
attributes=anno_attr, group=anno_group)
|
|
item_anno.append(anno)
|
|
|
|
for shape_obj in cvat_anno.labeled_shapes:
|
|
anno_group = shape_obj.group
|
|
if isinstance(anno_group, int):
|
|
anno_group = anno_group
|
|
anno_label = map_label(shape_obj.label)
|
|
anno_attr = dict(label_attrs[shape_obj.label])
|
|
for attr in shape_obj.attributes:
|
|
anno_attr[attr.name] = attr.value
|
|
|
|
anno_points = shape_obj.points
|
|
if shape_obj.type == ShapeType.POINTS:
|
|
anno = datumaro.PointsObject(anno_points,
|
|
label=anno_label, attributes=anno_attr, group=anno_group)
|
|
elif shape_obj.type == ShapeType.POLYLINE:
|
|
anno = datumaro.PolyLineObject(anno_points,
|
|
label=anno_label, attributes=anno_attr, group=anno_group)
|
|
elif shape_obj.type == ShapeType.POLYGON:
|
|
anno = datumaro.PolygonObject(anno_points,
|
|
label=anno_label, attributes=anno_attr, group=anno_group)
|
|
elif shape_obj.type == ShapeType.RECTANGLE:
|
|
x0, y0, x1, y1 = anno_points
|
|
anno = datumaro.BboxObject(x0, y0, x1 - x0, y1 - y0,
|
|
label=anno_label, attributes=anno_attr, group=anno_group)
|
|
else:
|
|
raise Exception("Unknown shape type '%s'" % (shape_obj.type))
|
|
|
|
item_anno.append(anno)
|
|
|
|
return item_anno |