[Datumaro] Label remapping transform (#1233)

* Add label remapping transform

* Apply transforms before project saving

* Refactor voc converter
main
Nikita Manovich 6 years ago committed by GitHub
parent 78dad73de9
commit be5577d013
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -634,6 +634,8 @@ class ProjectDataset(Dataset):
return self._sources return self._sources
def _save_branch_project(self, extractor, save_dir=None): def _save_branch_project(self, extractor, save_dir=None):
extractor = Dataset.from_extractors(extractor) # apply lazy transforms
# NOTE: probably this function should be in the ViewModel layer # NOTE: probably this function should be in the ViewModel layer
save_dir = osp.abspath(save_dir) save_dir = osp.abspath(save_dir)
if save_dir: if save_dir:

@ -3,6 +3,7 @@
# #
# SPDX-License-Identifier: MIT # SPDX-License-Identifier: MIT
from enum import Enum
import logging as log import logging as log
import os.path as osp import os.path as osp
import random import random
@ -10,7 +11,9 @@ import random
import pycocotools.mask as mask_utils import pycocotools.mask as mask_utils
from datumaro.components.extractor import (Transform, AnnotationType, from datumaro.components.extractor import (Transform, AnnotationType,
RleMask, Polygon, Bbox) RleMask, Polygon, Bbox,
LabelCategories, MaskCategories, PointsCategories
)
from datumaro.components.cli_plugin import CliPlugin from datumaro.components.cli_plugin import CliPlugin
import datumaro.util.mask_tools as mask_tools import datumaro.util.mask_tools as mask_tools
from datumaro.util.annotation_tools import find_group_leader, find_instances from datumaro.util.annotation_tools import find_group_leader, find_instances
@ -46,7 +49,7 @@ class CropCoveredSegments(Transform, CliPlugin):
segments.append(s.points) segments.append(s.points)
elif s.type == AnnotationType.mask: elif s.type == AnnotationType.mask:
if isinstance(s, RleMask): if isinstance(s, RleMask):
rle = s._rle rle = s.rle
else: else:
rle = mask_tools.mask_to_rle(s.image) rle = mask_tools.mask_to_rle(s.image)
segments.append(rle) segments.append(rle)
@ -365,3 +368,116 @@ class IdFromImageName(Transform, CliPlugin):
if item.has_image and item.image.filename: if item.has_image and item.image.filename:
name = osp.splitext(item.image.filename)[0] name = osp.splitext(item.image.filename)[0]
return self.wrap_item(item, id=name) return self.wrap_item(item, id=name)
class RemapLabels(Transform, CliPlugin):
DefaultAction = Enum('DefaultAction', ['keep', 'delete'])
@staticmethod
def _split_arg(s):
parts = s.split(':')
if len(parts) != 2:
import argparse
raise argparse.ArgumentTypeError()
return (parts[0], parts[1])
@classmethod
def build_cmdline_parser(cls, **kwargs):
parser = super().build_cmdline_parser(**kwargs)
parser.add_argument('-l', '--label', action='append',
type=cls._split_arg, dest='mapping',
help="Label in the form of: '<src>:<dst>' (repeatable)")
parser.add_argument('--default',
choices=[a.name for a in cls.DefaultAction],
default=cls.DefaultAction.keep.name,
help="Action for unspecified labels")
return parser
def __init__(self, extractor, mapping, default=None):
super().__init__(extractor)
assert isinstance(default, (str, self.DefaultAction))
if isinstance(default, str):
default = self.DefaultAction[default]
assert isinstance(mapping, (dict, list))
if isinstance(mapping, list):
mapping = dict(mapping)
self._categories = {}
src_label_cat = self._extractor.categories().get(AnnotationType.label)
if src_label_cat is not None:
self._make_label_id_map(src_label_cat, mapping, default)
src_mask_cat = self._extractor.categories().get(AnnotationType.mask)
if src_mask_cat is not None:
assert src_label_cat is not None
dst_mask_cat = MaskCategories(attributes=src_mask_cat.attributes)
dst_mask_cat.colormap = {
id: src_mask_cat.colormap[id]
for id, _ in enumerate(src_label_cat.items)
if self._map_id(id) or id == 0
}
self._categories[AnnotationType.mask] = dst_mask_cat
src_points_cat = self._extractor.categories().get(AnnotationType.points)
if src_points_cat is not None:
assert src_label_cat is not None
dst_points_cat = PointsCategories(attributes=src_points_cat.attributes)
dst_points_cat.items = {
id: src_points_cat.items[id]
for id, item in enumerate(src_label_cat.items)
if self._map_id(id) or id == 0
}
self._categories[AnnotationType.points] = dst_points_cat
def _make_label_id_map(self, src_label_cat, label_mapping, default_action):
dst_label_cat = LabelCategories(attributes=src_label_cat.attributes)
id_mapping = {}
for src_index, src_label in enumerate(src_label_cat.items):
dst_label = label_mapping.get(src_label.name)
if not dst_label and default_action == self.DefaultAction.keep:
dst_label = src_label.name # keep unspecified as is
if not dst_label:
continue
dst_index = dst_label_cat.find(dst_label)[0]
if dst_index is None:
dst_label_cat.add(dst_label,
src_label.parent, src_label.attributes)
dst_index = dst_label_cat.find(dst_label)[0]
id_mapping[src_index] = dst_index
if log.getLogger().isEnabledFor(log.DEBUG):
log.debug("Label mapping:")
for src_id, src_label in enumerate(src_label_cat.items):
if id_mapping.get(src_id):
log.debug("#%s '%s' -> #%s '%s'",
src_id, src_label.name, id_mapping[src_id],
dst_label_cat.items[id_mapping[src_id]].name
)
else:
log.debug("#%s '%s' -> <deleted>", src_id, src_label.name)
self._map_id = lambda src_id: id_mapping.get(src_id, None)
self._categories[AnnotationType.label] = dst_label_cat
def categories(self):
return self._categories
def transform_item(self, item):
# TODO: provide non-inplace version
annotations = []
for ann in item.annotations:
if ann.type in { AnnotationType.label, AnnotationType.mask,
AnnotationType.points, AnnotationType.polygon,
AnnotationType.polyline, AnnotationType.bbox
} and ann.label is not None:
conv_label = self._map_id(ann.label)
if conv_label is not None:
ann._label = conv_label
annotations.append(ann)
else:
annotations.append(ann)
item._annotations = annotations
return item

@ -53,14 +53,13 @@ LabelmapType = Enum('LabelmapType', ['voc', 'source', 'guess'])
class _Converter: class _Converter:
def __init__(self, extractor, save_dir, def __init__(self, extractor, save_dir,
tasks=None, apply_colormap=True, save_images=False, label_map=None): tasks=None, apply_colormap=True, save_images=False, label_map=None):
assert tasks is None or isinstance(tasks, (VocTask, list)) assert tasks is None or isinstance(tasks, (VocTask, list, set))
if tasks is None: if tasks is None:
tasks = list(VocTask) tasks = set(VocTask)
elif isinstance(tasks, VocTask): elif isinstance(tasks, VocTask):
tasks = [tasks] tasks = {tasks}
else: else:
tasks = [t if t in VocTask else VocTask[t] for t in tasks] tasks = set(t if t in VocTask else VocTask[t] for t in tasks)
self._tasks = tasks self._tasks = tasks
self._extractor = extractor self._extractor = extractor
@ -259,10 +258,10 @@ class _Converter:
if len(actions_elem) != 0: if len(actions_elem) != 0:
obj_elem.append(actions_elem) obj_elem.append(actions_elem)
if set(self._tasks) & set([None, if self._tasks & {None,
VocTask.detection, VocTask.detection,
VocTask.person_layout, VocTask.person_layout,
VocTask.action_classification]): VocTask.action_classification}:
with open(osp.join(self._ann_dir, item.id + '.xml'), 'w') as f: with open(osp.join(self._ann_dir, item.id + '.xml'), 'w') as f:
f.write(ET.tostring(root_elem, f.write(ET.tostring(root_elem,
encoding='unicode', pretty_print=True)) encoding='unicode', pretty_print=True))
@ -302,19 +301,19 @@ class _Converter:
action_list[item.id] = None action_list[item.id] = None
segm_list[item.id] = None segm_list[item.id] = None
if set(self._tasks) & set([None, if self._tasks & {None,
VocTask.classification, VocTask.classification,
VocTask.detection, VocTask.detection,
VocTask.action_classification, VocTask.action_classification,
VocTask.person_layout]): VocTask.person_layout}:
self.save_clsdet_lists(subset_name, clsdet_list) self.save_clsdet_lists(subset_name, clsdet_list)
if set(self._tasks) & set([None, VocTask.classification]): if self._tasks & {None, VocTask.classification}:
self.save_class_lists(subset_name, class_lists) self.save_class_lists(subset_name, class_lists)
if set(self._tasks) & set([None, VocTask.action_classification]): if self._tasks & {None, VocTask.action_classification}:
self.save_action_lists(subset_name, action_list) self.save_action_lists(subset_name, action_list)
if set(self._tasks) & set([None, VocTask.person_layout]): if self._tasks & {None, VocTask.person_layout}:
self.save_layout_lists(subset_name, layout_list) self.save_layout_lists(subset_name, layout_list)
if set(self._tasks) & set([None, VocTask.segmentation]): if self._tasks & {None, VocTask.segmentation}:
self.save_segm_lists(subset_name, segm_list) self.save_segm_lists(subset_name, segm_list)
def save_action_lists(self, subset_name, action_list): def save_action_lists(self, subset_name, action_list):

@ -3,10 +3,12 @@ import numpy as np
from unittest import TestCase from unittest import TestCase
from datumaro.components.extractor import (Extractor, DatasetItem, from datumaro.components.extractor import (Extractor, DatasetItem,
Mask, Polygon, PolyLine, Points, Bbox Mask, Polygon, PolyLine, Points, Bbox, Label,
LabelCategories, MaskCategories, AnnotationType
) )
from datumaro.util.test_utils import compare_datasets import datumaro.util.mask_tools as mask_tools
import datumaro.plugins.transforms as transforms import datumaro.plugins.transforms as transforms
from datumaro.util.test_utils import compare_datasets
class TransformsTest(TestCase): class TransformsTest(TestCase):
@ -361,3 +363,95 @@ class TransformsTest(TestCase):
('train', -0.5), ('train', -0.5),
('test', 1.5), ('test', 1.5),
]) ])
def test_remap_labels(self):
class SrcExtractor(Extractor):
def __iter__(self):
return iter([
DatasetItem(id=1, annotations=[
# Should be remapped
Label(1),
Bbox(1, 2, 3, 4, label=2),
Mask(image=np.array([1]), label=3),
# Should be kept
Polygon([1, 1, 2, 2, 3, 4], label=4),
PolyLine([1, 3, 4, 2, 5, 6], label=None)
]),
])
def categories(self):
label_cat = LabelCategories()
label_cat.add('label0')
label_cat.add('label1')
label_cat.add('label2')
label_cat.add('label3')
label_cat.add('label4')
mask_cat = MaskCategories(
colormap=mask_tools.generate_colormap(5))
return {
AnnotationType.label: label_cat,
AnnotationType.mask: mask_cat,
}
class DstExtractor(Extractor):
def __iter__(self):
return iter([
DatasetItem(id=1, annotations=[
Label(1),
Bbox(1, 2, 3, 4, label=0),
Mask(image=np.array([1]), label=1),
Polygon([1, 1, 2, 2, 3, 4], label=2),
PolyLine([1, 3, 4, 2, 5, 6], label=None)
]),
])
def categories(self):
label_cat = LabelCategories()
label_cat.add('label0')
label_cat.add('label9')
label_cat.add('label4')
mask_cat = MaskCategories(colormap={
k: v for k, v in mask_tools.generate_colormap(5).items()
if k in { 0, 1, 3, 4 }
})
return {
AnnotationType.label: label_cat,
AnnotationType.mask: mask_cat,
}
actual = transforms.RemapLabels(SrcExtractor(), mapping={
'label1': 'label9',
'label2': 'label0',
'label3': 'label9',
}, default='keep')
compare_datasets(self, DstExtractor(), actual)
def test_remap_labels_delete_unspecified(self):
class SrcExtractor(Extractor):
def __iter__(self):
return iter([ DatasetItem(id=1, annotations=[ Label(0) ]) ])
def categories(self):
label_cat = LabelCategories()
label_cat.add('label0')
return { AnnotationType.label: label_cat }
class DstExtractor(Extractor):
def __iter__(self):
return iter([ DatasetItem(id=1, annotations=[]) ])
def categories(self):
return { AnnotationType.label: LabelCategories() }
actual = transforms.RemapLabels(SrcExtractor(),
mapping={}, default='delete')
compare_datasets(self, DstExtractor(), actual)

Loading…
Cancel
Save