You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

96 lines
3.3 KiB
Python

# Copyright (C) 2019-2020 Intel Corporation
#
# SPDX-License-Identifier: MIT
from collections import defaultdict
from glob import glob
import logging as log
import os.path as osp
from datumaro.components.extractor import Importer
from datumaro.util.log_utils import logging_disabled
from .format import CocoTask
class CocoImporter(Importer):
_COCO_EXTRACTORS = {
CocoTask.instances: 'coco_instances',
CocoTask.person_keypoints: 'coco_person_keypoints',
CocoTask.captions: 'coco_captions',
CocoTask.labels: 'coco_labels',
CocoTask.image_info: 'coco_image_info',
}
@classmethod
def detect(cls, path):
with logging_disabled(log.WARN):
return len(cls.find_subsets(path)) != 0
def __call__(self, path, **extra_params):
from datumaro.components.project import Project # cyclic import
project = Project()
subsets = self.find_subsets(path)
if len(subsets) == 0:
raise Exception("Failed to find 'coco' dataset at '%s'" % path)
# TODO: should be removed when proper label merging is implemented
conflicting_types = {CocoTask.instances,
CocoTask.person_keypoints, CocoTask.labels}
ann_types = set(t for s in subsets.values() for t in s) \
& conflicting_types
if 1 <= len(ann_types):
selected_ann_type = sorted(ann_types, key=lambda x: x.name)[0]
if 1 < len(ann_types):
log.warning("Not implemented: "
"Found potentially conflicting source types with labels: %s. "
"Only one type will be used: %s" \
% (", ".join(t.name for t in ann_types), selected_ann_type.name))
for ann_files in subsets.values():
for ann_type, ann_file in ann_files.items():
if ann_type in conflicting_types:
if ann_type is not selected_ann_type:
log.warning("Not implemented: "
"conflicting source '%s' is skipped." % ann_file)
continue
log.info("Found a dataset at '%s'" % ann_file)
source_name = osp.splitext(osp.basename(ann_file))[0]
project.add_source(source_name, {
'url': ann_file,
'format': self._COCO_EXTRACTORS[ann_type],
'options': dict(extra_params),
})
return project
@staticmethod
def find_subsets(path):
if path.endswith('.json') and osp.isfile(path):
subset_paths = [path]
else:
subset_paths = glob(osp.join(path, '**', '*_*.json'),
recursive=True)
subsets = defaultdict(dict)
for subset_path in subset_paths:
name_parts = osp.splitext(osp.basename(subset_path))[0] \
.rsplit('_', maxsplit=1)
ann_type = name_parts[0]
try:
ann_type = CocoTask[ann_type]
except KeyError:
log.warn("Skipping '%s': unknown subset "
"type '%s', the only known are: %s" % \
(subset_path, ann_type,
', '.join([e.name for e in CocoTask])))
continue
subset_name = name_parts[1]
subsets[subset_name][ann_type] = subset_path
return dict(subsets)