|
|
|
|
@ -3,6 +3,7 @@
|
|
|
|
|
#
|
|
|
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
|
|
|
|
|
|
from enum import Enum
|
|
|
|
|
import logging as log
|
|
|
|
|
import os.path as osp
|
|
|
|
|
import random
|
|
|
|
|
@ -10,7 +11,9 @@ import random
|
|
|
|
|
import pycocotools.mask as mask_utils
|
|
|
|
|
|
|
|
|
|
from datumaro.components.extractor import (Transform, AnnotationType,
|
|
|
|
|
RleMask, Polygon, Bbox)
|
|
|
|
|
RleMask, Polygon, Bbox,
|
|
|
|
|
LabelCategories, MaskCategories, PointsCategories
|
|
|
|
|
)
|
|
|
|
|
from datumaro.components.cli_plugin import CliPlugin
|
|
|
|
|
import datumaro.util.mask_tools as mask_tools
|
|
|
|
|
from datumaro.util.annotation_tools import find_group_leader, find_instances
|
|
|
|
|
@ -46,7 +49,7 @@ class CropCoveredSegments(Transform, CliPlugin):
|
|
|
|
|
segments.append(s.points)
|
|
|
|
|
elif s.type == AnnotationType.mask:
|
|
|
|
|
if isinstance(s, RleMask):
|
|
|
|
|
rle = s._rle
|
|
|
|
|
rle = s.rle
|
|
|
|
|
else:
|
|
|
|
|
rle = mask_tools.mask_to_rle(s.image)
|
|
|
|
|
segments.append(rle)
|
|
|
|
|
@ -365,3 +368,116 @@ class IdFromImageName(Transform, CliPlugin):
|
|
|
|
|
if item.has_image and item.image.filename:
|
|
|
|
|
name = osp.splitext(item.image.filename)[0]
|
|
|
|
|
return self.wrap_item(item, id=name)
|
|
|
|
|
|
|
|
|
|
class RemapLabels(Transform, CliPlugin):
|
|
|
|
|
DefaultAction = Enum('DefaultAction', ['keep', 'delete'])
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def _split_arg(s):
|
|
|
|
|
parts = s.split(':')
|
|
|
|
|
if len(parts) != 2:
|
|
|
|
|
import argparse
|
|
|
|
|
raise argparse.ArgumentTypeError()
|
|
|
|
|
return (parts[0], parts[1])
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def build_cmdline_parser(cls, **kwargs):
|
|
|
|
|
parser = super().build_cmdline_parser(**kwargs)
|
|
|
|
|
parser.add_argument('-l', '--label', action='append',
|
|
|
|
|
type=cls._split_arg, dest='mapping',
|
|
|
|
|
help="Label in the form of: '<src>:<dst>' (repeatable)")
|
|
|
|
|
parser.add_argument('--default',
|
|
|
|
|
choices=[a.name for a in cls.DefaultAction],
|
|
|
|
|
default=cls.DefaultAction.keep.name,
|
|
|
|
|
help="Action for unspecified labels")
|
|
|
|
|
return parser
|
|
|
|
|
|
|
|
|
|
def __init__(self, extractor, mapping, default=None):
|
|
|
|
|
super().__init__(extractor)
|
|
|
|
|
|
|
|
|
|
assert isinstance(default, (str, self.DefaultAction))
|
|
|
|
|
if isinstance(default, str):
|
|
|
|
|
default = self.DefaultAction[default]
|
|
|
|
|
|
|
|
|
|
assert isinstance(mapping, (dict, list))
|
|
|
|
|
if isinstance(mapping, list):
|
|
|
|
|
mapping = dict(mapping)
|
|
|
|
|
|
|
|
|
|
self._categories = {}
|
|
|
|
|
|
|
|
|
|
src_label_cat = self._extractor.categories().get(AnnotationType.label)
|
|
|
|
|
if src_label_cat is not None:
|
|
|
|
|
self._make_label_id_map(src_label_cat, mapping, default)
|
|
|
|
|
|
|
|
|
|
src_mask_cat = self._extractor.categories().get(AnnotationType.mask)
|
|
|
|
|
if src_mask_cat is not None:
|
|
|
|
|
assert src_label_cat is not None
|
|
|
|
|
dst_mask_cat = MaskCategories(attributes=src_mask_cat.attributes)
|
|
|
|
|
dst_mask_cat.colormap = {
|
|
|
|
|
id: src_mask_cat.colormap[id]
|
|
|
|
|
for id, _ in enumerate(src_label_cat.items)
|
|
|
|
|
if self._map_id(id) or id == 0
|
|
|
|
|
}
|
|
|
|
|
self._categories[AnnotationType.mask] = dst_mask_cat
|
|
|
|
|
|
|
|
|
|
src_points_cat = self._extractor.categories().get(AnnotationType.points)
|
|
|
|
|
if src_points_cat is not None:
|
|
|
|
|
assert src_label_cat is not None
|
|
|
|
|
dst_points_cat = PointsCategories(attributes=src_points_cat.attributes)
|
|
|
|
|
dst_points_cat.items = {
|
|
|
|
|
id: src_points_cat.items[id]
|
|
|
|
|
for id, item in enumerate(src_label_cat.items)
|
|
|
|
|
if self._map_id(id) or id == 0
|
|
|
|
|
}
|
|
|
|
|
self._categories[AnnotationType.points] = dst_points_cat
|
|
|
|
|
|
|
|
|
|
def _make_label_id_map(self, src_label_cat, label_mapping, default_action):
|
|
|
|
|
dst_label_cat = LabelCategories(attributes=src_label_cat.attributes)
|
|
|
|
|
id_mapping = {}
|
|
|
|
|
for src_index, src_label in enumerate(src_label_cat.items):
|
|
|
|
|
dst_label = label_mapping.get(src_label.name)
|
|
|
|
|
if not dst_label and default_action == self.DefaultAction.keep:
|
|
|
|
|
dst_label = src_label.name # keep unspecified as is
|
|
|
|
|
if not dst_label:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
dst_index = dst_label_cat.find(dst_label)[0]
|
|
|
|
|
if dst_index is None:
|
|
|
|
|
dst_label_cat.add(dst_label,
|
|
|
|
|
src_label.parent, src_label.attributes)
|
|
|
|
|
dst_index = dst_label_cat.find(dst_label)[0]
|
|
|
|
|
id_mapping[src_index] = dst_index
|
|
|
|
|
|
|
|
|
|
if log.getLogger().isEnabledFor(log.DEBUG):
|
|
|
|
|
log.debug("Label mapping:")
|
|
|
|
|
for src_id, src_label in enumerate(src_label_cat.items):
|
|
|
|
|
if id_mapping.get(src_id):
|
|
|
|
|
log.debug("#%s '%s' -> #%s '%s'",
|
|
|
|
|
src_id, src_label.name, id_mapping[src_id],
|
|
|
|
|
dst_label_cat.items[id_mapping[src_id]].name
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
log.debug("#%s '%s' -> <deleted>", src_id, src_label.name)
|
|
|
|
|
|
|
|
|
|
self._map_id = lambda src_id: id_mapping.get(src_id, None)
|
|
|
|
|
self._categories[AnnotationType.label] = dst_label_cat
|
|
|
|
|
|
|
|
|
|
def categories(self):
|
|
|
|
|
return self._categories
|
|
|
|
|
|
|
|
|
|
def transform_item(self, item):
|
|
|
|
|
# TODO: provide non-inplace version
|
|
|
|
|
annotations = []
|
|
|
|
|
for ann in item.annotations:
|
|
|
|
|
if ann.type in { AnnotationType.label, AnnotationType.mask,
|
|
|
|
|
AnnotationType.points, AnnotationType.polygon,
|
|
|
|
|
AnnotationType.polyline, AnnotationType.bbox
|
|
|
|
|
} and ann.label is not None:
|
|
|
|
|
conv_label = self._map_id(ann.label)
|
|
|
|
|
if conv_label is not None:
|
|
|
|
|
ann._label = conv_label
|
|
|
|
|
annotations.append(ann)
|
|
|
|
|
else:
|
|
|
|
|
annotations.append(ann)
|
|
|
|
|
item._annotations = annotations
|
|
|
|
|
return item
|