You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
724 lines
23 KiB
Python
724 lines
23 KiB
Python
|
|
# Copyright (C) 2019 Intel Corporation
|
|
#
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
from collections import OrderedDict, defaultdict
|
|
import git
|
|
import importlib
|
|
from functools import reduce
|
|
import logging as log
|
|
import os
|
|
import os.path as osp
|
|
import sys
|
|
|
|
from datumaro.components.config import Config, DEFAULT_FORMAT
|
|
from datumaro.components.config_model import *
|
|
from datumaro.components.extractor import *
|
|
from datumaro.components.launcher import *
|
|
from datumaro.components.dataset_filter import XPathDatasetFilter
|
|
|
|
|
|
def import_foreign_module(name, path):
|
|
module = None
|
|
default_path = sys.path.copy()
|
|
try:
|
|
sys.path = [ osp.abspath(path), ] + default_path
|
|
sys.modules.pop(name, None) # remove from cache
|
|
module = importlib.import_module(name)
|
|
sys.modules.pop(name) # remove from cache
|
|
except ImportError as e:
|
|
log.warn("Failed to import module '%s': %s" % (name, e))
|
|
finally:
|
|
sys.path = default_path
|
|
return module
|
|
|
|
|
|
class Registry:
|
|
def __init__(self, config=None, item_type=None):
|
|
self.item_type = item_type
|
|
|
|
self.items = {}
|
|
|
|
if config is not None:
|
|
self.load(config)
|
|
|
|
def load(self, config):
|
|
pass
|
|
|
|
def register(self, name, value):
|
|
if self.item_type:
|
|
value = self.item_type(value)
|
|
self.items[name] = value
|
|
return value
|
|
|
|
def unregister(self, name):
|
|
return self.items.pop(name, None)
|
|
|
|
def get(self, key):
|
|
return self.items[key] # returns a class / ctor
|
|
|
|
|
|
class ModelRegistry(Registry):
|
|
def __init__(self, config=None):
|
|
super().__init__(config, item_type=Model)
|
|
|
|
def load(self, config):
|
|
# TODO: list default dir, insert values
|
|
if 'models' in config:
|
|
for name, model in config.models.items():
|
|
self.register(name, model)
|
|
|
|
|
|
class SourceRegistry(Registry):
|
|
def __init__(self, config=None):
|
|
super().__init__(config, item_type=Source)
|
|
|
|
def load(self, config):
|
|
# TODO: list default dir, insert values
|
|
if 'sources' in config:
|
|
for name, source in config.sources.items():
|
|
self.register(name, source)
|
|
|
|
|
|
class ModuleRegistry(Registry):
|
|
def __init__(self, config=None, builtin=None, local=None):
|
|
super().__init__(config)
|
|
|
|
if builtin is not None:
|
|
for k, v in builtin:
|
|
self.register(k, v)
|
|
if local is not None:
|
|
for k, v in local:
|
|
self.register(k, v)
|
|
|
|
|
|
class GitWrapper:
|
|
def __init__(self, config=None):
|
|
self.repo = None
|
|
|
|
if config is not None:
|
|
self.init(config.project_dir)
|
|
|
|
@staticmethod
|
|
def _git_dir(base_path):
|
|
return osp.join(base_path, '.git')
|
|
|
|
def init(self, path):
|
|
spawn = not osp.isdir(GitWrapper._git_dir(path))
|
|
self.repo = git.Repo.init(path=path)
|
|
if spawn:
|
|
author = git.Actor("Nobody", "nobody@example.com")
|
|
self.repo.index.commit('Initial commit', author=author)
|
|
return self.repo
|
|
|
|
def get_repo(self):
|
|
return self.repo
|
|
|
|
def is_initialized(self):
|
|
return self.repo is not None
|
|
|
|
def create_submodule(self, name, dst_dir, **kwargs):
|
|
self.repo.create_submodule(name, dst_dir, **kwargs)
|
|
|
|
def has_submodule(self, name):
|
|
return name in [submodule.name for submodule in self.repo.submodules]
|
|
|
|
def remove_submodule(self, name, **kwargs):
|
|
return self.repo.submodule(name).remove(**kwargs)
|
|
|
|
def load_project_as_dataset(url):
|
|
# symbol forward declaration
|
|
raise NotImplementedError()
|
|
|
|
class Environment:
|
|
PROJECT_EXTRACTOR_NAME = 'project'
|
|
|
|
def __init__(self, config=None):
|
|
config = Config(config,
|
|
fallback=PROJECT_DEFAULT_CONFIG, schema=PROJECT_SCHEMA)
|
|
|
|
env_dir = osp.join(config.project_dir, config.env_dir)
|
|
env_config_path = osp.join(env_dir, config.env_filename)
|
|
env_config = Config(fallback=ENV_DEFAULT_CONFIG, schema=ENV_SCHEMA)
|
|
if osp.isfile(env_config_path):
|
|
env_config.update(Config.parse(env_config_path))
|
|
|
|
self.config = env_config
|
|
|
|
self.models = ModelRegistry(env_config)
|
|
self.sources = SourceRegistry(config)
|
|
|
|
import datumaro.components.importers as builtin_importers
|
|
builtin_importers = builtin_importers.items
|
|
custom_importers = self._get_custom_module_items(
|
|
env_dir, env_config.importers_dir)
|
|
self.importers = ModuleRegistry(config,
|
|
builtin=builtin_importers, local=custom_importers)
|
|
|
|
import datumaro.components.extractors as builtin_extractors
|
|
builtin_extractors = builtin_extractors.items
|
|
custom_extractors = self._get_custom_module_items(
|
|
env_dir, env_config.extractors_dir)
|
|
self.extractors = ModuleRegistry(config,
|
|
builtin=builtin_extractors, local=custom_extractors)
|
|
self.extractors.register(self.PROJECT_EXTRACTOR_NAME,
|
|
load_project_as_dataset)
|
|
|
|
import datumaro.components.launchers as builtin_launchers
|
|
builtin_launchers = builtin_launchers.items
|
|
custom_launchers = self._get_custom_module_items(
|
|
env_dir, env_config.launchers_dir)
|
|
self.launchers = ModuleRegistry(config,
|
|
builtin=builtin_launchers, local=custom_launchers)
|
|
|
|
import datumaro.components.converters as builtin_converters
|
|
builtin_converters = builtin_converters.items
|
|
custom_converters = self._get_custom_module_items(
|
|
env_dir, env_config.converters_dir)
|
|
if custom_converters is not None:
|
|
custom_converters = custom_converters.items
|
|
self.converters = ModuleRegistry(config,
|
|
builtin=builtin_converters, local=custom_converters)
|
|
|
|
self.statistics = ModuleRegistry(config)
|
|
self.visualizers = ModuleRegistry(config)
|
|
self.git = GitWrapper(config)
|
|
|
|
def _get_custom_module_items(self, module_dir, module_name):
|
|
items = None
|
|
|
|
module = None
|
|
if osp.exists(osp.join(module_dir, module_name)):
|
|
module = import_foreign_module(module_name, module_dir)
|
|
if module is not None:
|
|
if hasattr(module, 'items'):
|
|
items = module.items
|
|
else:
|
|
items = self._find_custom_module_items(
|
|
osp.join(module_dir, module_name))
|
|
|
|
return items
|
|
|
|
@staticmethod
|
|
def _find_custom_module_items(module_dir):
|
|
files = [p for p in os.listdir(module_dir)
|
|
if p.endswith('.py') and p != '__init__.py']
|
|
|
|
all_items = []
|
|
for f in files:
|
|
name = osp.splitext(f)[0]
|
|
module = import_foreign_module(name, module_dir)
|
|
|
|
items = []
|
|
if hasattr(module, 'items'):
|
|
items = module.items
|
|
else:
|
|
if hasattr(module, name):
|
|
items = [ (name, getattr(module, name)) ]
|
|
else:
|
|
log.warn("Failed to import custom module '%s'."
|
|
" Custom module is expected to provide 'items' "
|
|
"list or have an item matching its file name."
|
|
" Skipping this module." % \
|
|
(module_dir + '.' + name))
|
|
|
|
all_items.extend(items)
|
|
|
|
return all_items
|
|
|
|
def save(self, path):
|
|
self.config.dump(path)
|
|
|
|
def make_extractor(self, name, *args, **kwargs):
|
|
return self.extractors.get(name)(*args, **kwargs)
|
|
|
|
def make_importer(self, name, *args, **kwargs):
|
|
return self.importers.get(name)(*args, **kwargs)
|
|
|
|
def make_launcher(self, name, *args, **kwargs):
|
|
return self.launchers.get(name)(*args, **kwargs)
|
|
|
|
def make_converter(self, name, *args, **kwargs):
|
|
return self.converters.get(name)(*args, **kwargs)
|
|
|
|
def register_model(self, name, model):
|
|
self.config.models[name] = model
|
|
self.models.register(name, model)
|
|
|
|
def unregister_model(self, name):
|
|
self.config.models.remove(name)
|
|
self.models.unregister(name)
|
|
|
|
|
|
class Subset(Extractor):
|
|
def __init__(self, parent):
|
|
self._parent = parent
|
|
self.items = OrderedDict()
|
|
|
|
def __iter__(self):
|
|
for item in self.items.values():
|
|
yield item
|
|
|
|
def __len__(self):
|
|
return len(self.items)
|
|
|
|
def categories(self):
|
|
return self._parent.categories()
|
|
|
|
class DatasetItemWrapper(DatasetItem):
|
|
def __init__(self, item, path, annotations, image=None):
|
|
self._item = item
|
|
self._path = path
|
|
self._annotations = annotations
|
|
self._image = image
|
|
|
|
@DatasetItem.id.getter
|
|
def id(self):
|
|
return self._item.id
|
|
|
|
@DatasetItem.subset.getter
|
|
def subset(self):
|
|
return self._item.subset
|
|
|
|
@DatasetItem.path.getter
|
|
def path(self):
|
|
return self._path
|
|
|
|
@DatasetItem.annotations.getter
|
|
def annotations(self):
|
|
return self._annotations
|
|
|
|
@DatasetItem.has_image.getter
|
|
def has_image(self):
|
|
if self._image is not None:
|
|
return True
|
|
return self._item.has_image
|
|
|
|
@DatasetItem.image.getter
|
|
def image(self):
|
|
if self._image is not None:
|
|
if callable(self._image):
|
|
return self._image()
|
|
return self._image
|
|
return self._item.image
|
|
|
|
class ProjectDataset(Extractor):
|
|
def __init__(self, project):
|
|
super().__init__()
|
|
|
|
self._project = project
|
|
config = self.config
|
|
env = self.env
|
|
|
|
dataset_filter = None
|
|
if config.filter:
|
|
dataset_filter = XPathDatasetFilter(config.filter)
|
|
self._filter = dataset_filter
|
|
|
|
sources = {}
|
|
for s_name, source in config.sources.items():
|
|
s_format = source.format
|
|
if not s_format:
|
|
s_format = env.PROJECT_EXTRACTOR_NAME
|
|
options = {}
|
|
options.update(source.options)
|
|
|
|
url = source.url
|
|
if not source.url:
|
|
url = osp.join(config.project_dir, config.sources_dir, s_name)
|
|
sources[s_name] = env.make_extractor(s_format,
|
|
url, **options)
|
|
self._sources = sources
|
|
|
|
own_source = None
|
|
own_source_dir = osp.join(config.project_dir, config.dataset_dir)
|
|
if osp.isdir(own_source_dir):
|
|
own_source = env.make_extractor(DEFAULT_FORMAT, own_source_dir)
|
|
|
|
# merge categories
|
|
# TODO: implement properly with merging and annotations remapping
|
|
categories = {}
|
|
for source in self._sources.values():
|
|
categories.update(source.categories())
|
|
for source in self._sources.values():
|
|
for cat_type, source_cat in source.categories().items():
|
|
assert categories[cat_type] == source_cat
|
|
if own_source is not None and len(own_source) != 0:
|
|
categories.update(own_source.categories())
|
|
self._categories = categories
|
|
|
|
# merge items
|
|
subsets = defaultdict(lambda: Subset(self))
|
|
for source_name, source in self._sources.items():
|
|
log.info("Loading '%s' source contents..." % source_name)
|
|
for item in source:
|
|
if dataset_filter and not dataset_filter(item):
|
|
continue
|
|
|
|
existing_item = subsets[item.subset].items.get(item.id)
|
|
if existing_item is not None:
|
|
image = None
|
|
if existing_item.has_image:
|
|
# TODO: think of image comparison
|
|
image = self._lazy_image(existing_item)
|
|
|
|
path = existing_item.path
|
|
if item.path != path:
|
|
path = None
|
|
item = DatasetItemWrapper(item=item, path=path,
|
|
image=image, annotations=self._merge_anno(
|
|
existing_item.annotations, item.annotations))
|
|
else:
|
|
s_config = config.sources[source_name]
|
|
if s_config and \
|
|
s_config.format != self.env.PROJECT_EXTRACTOR_NAME:
|
|
# NOTE: consider imported sources as our own dataset
|
|
path = None
|
|
else:
|
|
path = item.path
|
|
if path is None:
|
|
path = []
|
|
path = [source_name] + path
|
|
item = DatasetItemWrapper(item=item, path=path,
|
|
annotations=item.annotations)
|
|
|
|
subsets[item.subset].items[item.id] = item
|
|
|
|
# override with our items, fallback to existing images
|
|
if own_source is not None:
|
|
log.info("Loading own dataset...")
|
|
for item in own_source:
|
|
if dataset_filter and not dataset_filter(item):
|
|
continue
|
|
|
|
if not item.has_image:
|
|
existing_item = subsets[item.subset].items.get(item.id)
|
|
if existing_item is not None:
|
|
image = None
|
|
if existing_item.has_image:
|
|
# TODO: think of image comparison
|
|
image = self._lazy_image(existing_item)
|
|
item = DatasetItemWrapper(item=item, path=None,
|
|
annotations=item.annotations, image=image)
|
|
|
|
subsets[item.subset].items[item.id] = item
|
|
|
|
# TODO: implement subset remapping when needed
|
|
subsets_filter = config.subsets
|
|
if len(subsets_filter) != 0:
|
|
subsets = { k: v for k, v in subsets.items() if k in subsets_filter}
|
|
self._subsets = dict(subsets)
|
|
|
|
self._length = None
|
|
|
|
@staticmethod
|
|
def _lazy_image(item):
|
|
# NOTE: avoid https://docs.python.org/3/faq/programming.html#why-do-lambdas-defined-in-a-loop-with-different-values-all-return-the-same-result
|
|
return lambda: item.image
|
|
|
|
@staticmethod
|
|
def _merge_anno(a, b):
|
|
from itertools import chain
|
|
merged = []
|
|
for item in chain(a, b):
|
|
found = False
|
|
for elem in merged:
|
|
if elem == item:
|
|
found = True
|
|
break
|
|
if not found:
|
|
merged.append(item)
|
|
|
|
return merged
|
|
|
|
def iterate_own(self):
|
|
return self.select(lambda item: not item.path)
|
|
|
|
def __iter__(self):
|
|
for subset in self._subsets.values():
|
|
for item in subset:
|
|
if self._filter and not self._filter(item):
|
|
continue
|
|
yield item
|
|
|
|
def __len__(self):
|
|
if self._length is None:
|
|
self._length = reduce(lambda s, x: s + len(x),
|
|
self._subsets.values(), 0)
|
|
return self._length
|
|
|
|
def get_subset(self, name):
|
|
return self._subsets[name]
|
|
|
|
def subsets(self):
|
|
return list(self._subsets)
|
|
|
|
def categories(self):
|
|
return self._categories
|
|
|
|
def define_categories(self, categories):
|
|
assert not self._categories
|
|
self._categories = categories
|
|
|
|
def get(self, item_id, subset=None, path=None):
|
|
if path:
|
|
source = path[0]
|
|
rest_path = path[1:]
|
|
return self._sources[source].get(
|
|
item_id=item_id, subset=subset, path=rest_path)
|
|
return self._subsets[subset].items[item_id]
|
|
|
|
def put(self, item, item_id=None, subset=None, path=None):
|
|
if path is None:
|
|
path = item.path
|
|
if path:
|
|
source = path[0]
|
|
rest_path = path[1:]
|
|
# TODO: reverse remapping
|
|
self._sources[source].put(item,
|
|
item_id=item_id, subset=subset, path=rest_path)
|
|
|
|
if item_id is None:
|
|
item_id = item.id
|
|
if subset is None:
|
|
subset = item.subset
|
|
|
|
item = DatasetItemWrapper(item=item, path=path,
|
|
annotations=item.annotations)
|
|
if item.subset not in self._subsets:
|
|
self._subsets[item.subset] = Subset(self)
|
|
self._subsets[subset].items[item_id] = item
|
|
self._length = None
|
|
|
|
return item
|
|
|
|
def build(self, tasks=None):
|
|
pass
|
|
|
|
def docs(self):
|
|
pass
|
|
|
|
def transform(self, model_name, save_dir=None):
|
|
project = Project(self.config)
|
|
project.config.remove('sources')
|
|
|
|
if save_dir is None:
|
|
save_dir = self.config.project_dir
|
|
project.config.project_dir = save_dir
|
|
|
|
dataset = project.make_dataset()
|
|
launcher = self._project.make_executable_model(model_name)
|
|
inference = InferenceWrapper(self, launcher)
|
|
dataset.update(inference)
|
|
|
|
dataset.save(merge=True)
|
|
|
|
def export(self, save_dir, output_format,
|
|
filter_expr=None, **converter_kwargs):
|
|
save_dir = osp.abspath(save_dir)
|
|
os.makedirs(save_dir, exist_ok=True)
|
|
|
|
dataset = self
|
|
if filter_expr:
|
|
dataset_filter = XPathDatasetFilter(filter_expr)
|
|
dataset = dataset.select(dataset_filter)
|
|
|
|
converter = self.env.make_converter(output_format, **converter_kwargs)
|
|
converter(dataset, save_dir)
|
|
|
|
def extract(self, save_dir, filter_expr=None):
|
|
project = Project(self.config)
|
|
if filter_expr:
|
|
XPathDatasetFilter(filter_expr)
|
|
project.set_filter(filter_expr)
|
|
project.save(save_dir)
|
|
|
|
def update(self, items):
|
|
for item in items:
|
|
if self._filter and not self._filter(item):
|
|
continue
|
|
self.put(item)
|
|
return self
|
|
|
|
def save(self, save_dir=None, merge=False, recursive=True,
|
|
save_images=False):
|
|
if save_dir is None:
|
|
assert self.config.project_dir
|
|
save_dir = self.config.project_dir
|
|
project = self._project
|
|
else:
|
|
merge = True
|
|
|
|
if merge:
|
|
project = Project(Config(self.config))
|
|
project.config.remove('sources')
|
|
|
|
save_dir = osp.abspath(save_dir)
|
|
os.makedirs(save_dir, exist_ok=True)
|
|
|
|
dataset_save_dir = osp.join(save_dir, project.config.dataset_dir)
|
|
os.makedirs(dataset_save_dir, exist_ok=True)
|
|
|
|
converter_kwargs = {
|
|
'save_images': save_images,
|
|
}
|
|
|
|
if merge:
|
|
# merge and save the resulting dataset
|
|
converter = self.env.make_converter(
|
|
DEFAULT_FORMAT, **converter_kwargs)
|
|
converter(self, dataset_save_dir)
|
|
else:
|
|
if recursive:
|
|
# children items should already be updated
|
|
# so we just save them recursively
|
|
for source in self._sources.values():
|
|
if isinstance(source, ProjectDataset):
|
|
source.save(**converter_kwargs)
|
|
|
|
converter = self.env.make_converter(
|
|
DEFAULT_FORMAT, **converter_kwargs)
|
|
converter(self.iterate_own(), dataset_save_dir)
|
|
|
|
project.save(save_dir)
|
|
|
|
@property
|
|
def env(self):
|
|
return self._project.env
|
|
|
|
@property
|
|
def config(self):
|
|
return self._project.config
|
|
|
|
@property
|
|
def sources(self):
|
|
return self._sources
|
|
|
|
class Project:
|
|
@staticmethod
|
|
def load(path):
|
|
path = osp.abspath(path)
|
|
if osp.isdir(path):
|
|
path = osp.join(path, PROJECT_DEFAULT_CONFIG.project_filename)
|
|
config = Config.parse(path)
|
|
config.project_dir = osp.dirname(path)
|
|
config.project_filename = osp.basename(path)
|
|
return Project(config)
|
|
|
|
def save(self, save_dir=None):
|
|
config = self.config
|
|
if save_dir is None:
|
|
assert config.project_dir
|
|
save_dir = osp.abspath(config.project_dir)
|
|
config_path = osp.join(save_dir, config.project_filename)
|
|
|
|
env_dir = osp.join(save_dir, config.env_dir)
|
|
os.makedirs(env_dir, exist_ok=True)
|
|
self.env.save(osp.join(env_dir, config.env_filename))
|
|
|
|
config.dump(config_path)
|
|
|
|
@staticmethod
|
|
def generate(save_dir, config=None):
|
|
project = Project(config)
|
|
project.save(save_dir)
|
|
project.config.project_dir = save_dir
|
|
return project
|
|
|
|
@staticmethod
|
|
def import_from(path, dataset_format, env=None, **kwargs):
|
|
if env is None:
|
|
env = Environment()
|
|
importer = env.make_importer(dataset_format)
|
|
return importer(path, **kwargs)
|
|
|
|
def __init__(self, config=None):
|
|
self.config = Config(config,
|
|
fallback=PROJECT_DEFAULT_CONFIG, schema=PROJECT_SCHEMA)
|
|
self.env = Environment(self.config)
|
|
|
|
def make_dataset(self):
|
|
return ProjectDataset(self)
|
|
|
|
def add_source(self, name, value=Source()):
|
|
if isinstance(value, (dict, Config)):
|
|
value = Source(value)
|
|
self.config.sources[name] = value
|
|
self.env.sources.register(name, value)
|
|
|
|
def remove_source(self, name):
|
|
self.config.sources.remove(name)
|
|
self.env.sources.unregister(name)
|
|
|
|
def get_source(self, name):
|
|
try:
|
|
return self.config.sources[name]
|
|
except KeyError:
|
|
raise KeyError("Source '%s' is not found" % name)
|
|
|
|
def get_subsets(self):
|
|
return self.config.subsets
|
|
|
|
def set_subsets(self, value):
|
|
if not value:
|
|
self.config.remove('subsets')
|
|
else:
|
|
self.config.subsets = value
|
|
|
|
def add_model(self, name, value=Model()):
|
|
if isinstance(value, (dict, Config)):
|
|
value = Model(value)
|
|
self.env.register_model(name, value)
|
|
|
|
def get_model(self, name):
|
|
try:
|
|
return self.env.models.get(name)
|
|
except KeyError:
|
|
raise KeyError("Model '%s' is not found" % name)
|
|
|
|
def remove_model(self, name):
|
|
self.env.unregister_model(name)
|
|
|
|
def make_executable_model(self, name):
|
|
model = self.get_model(name)
|
|
model.model_dir = self.local_model_dir(name)
|
|
return self.env.make_launcher(model.launcher,
|
|
**model.options, model_dir=model.model_dir)
|
|
|
|
def make_source_project(self, name):
|
|
source = self.get_source(name)
|
|
|
|
config = Config(self.config)
|
|
config.remove('sources')
|
|
config.remove('subsets')
|
|
config.remove('filter')
|
|
project = Project(config)
|
|
project.add_source(name, source)
|
|
return project
|
|
|
|
def get_filter(self):
|
|
if 'filter' in self.config:
|
|
return self.config.filter
|
|
return ''
|
|
|
|
def set_filter(self, value=None):
|
|
if not value:
|
|
self.config.remove('filter')
|
|
else:
|
|
# check filter
|
|
XPathDatasetFilter(value)
|
|
self.config.filter = value
|
|
|
|
def local_model_dir(self, model_name):
|
|
return osp.join(
|
|
self.config.env_dir, self.env.config.models_dir, model_name)
|
|
|
|
def local_source_dir(self, source_name):
|
|
return osp.join(self.config.sources_dir, source_name)
|
|
|
|
# pylint: disable=function-redefined
|
|
def load_project_as_dataset(url):
|
|
# implement the function declared above
|
|
return Project.load(url).make_dataset()
|
|
# pylint: enable=function-redefined |