You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
111 lines
2.9 KiB
Python
111 lines
2.9 KiB
Python
|
|
# Copyright (C) 2019 Intel Corporation
|
|
#
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
import argparse
|
|
import cv2
|
|
from enum import Enum
|
|
|
|
from datumaro.components.project import Project
|
|
|
|
|
|
TargetKinds = Enum('TargetKinds',
|
|
['project', 'source', 'external_dataset', 'inference', 'image'])
|
|
|
|
def is_project_name(value, project):
|
|
return value == project.config.project_name
|
|
|
|
def is_project_path(value):
|
|
if value:
|
|
try:
|
|
Project.load(value)
|
|
return True
|
|
except Exception:
|
|
pass
|
|
return False
|
|
|
|
def is_project(value, project=None):
|
|
if is_project_path(value):
|
|
return True
|
|
elif project is not None:
|
|
return is_project_name(value, project)
|
|
|
|
return False
|
|
|
|
def is_source(value, project=None):
|
|
if project is not None:
|
|
try:
|
|
project.get_source(value)
|
|
return True
|
|
except KeyError:
|
|
pass
|
|
|
|
return False
|
|
|
|
def is_external_source(value):
|
|
return False
|
|
|
|
def is_inference_path(value):
|
|
return False
|
|
|
|
def is_image_path(value):
|
|
return cv2.imread(value) is not None
|
|
|
|
|
|
class Target:
|
|
def __init__(self, kind, test, is_default=False, name=None):
|
|
self.kind = kind
|
|
self.test = test
|
|
self.is_default = is_default
|
|
self.name = name
|
|
|
|
def _get_fields(self):
|
|
return [self.kind, self.test, self.is_default, self.name]
|
|
|
|
def __str__(self):
|
|
return self.name or str(self.kind)
|
|
|
|
def __len__(self):
|
|
return len(self._get_fields())
|
|
|
|
def __iter__(self):
|
|
return iter(self._get_fields())
|
|
|
|
def ProjectTarget(kind=TargetKinds.project, test=None,
|
|
is_default=False, name='project name or path',
|
|
project=None):
|
|
if test is None:
|
|
test = lambda v: is_project(v, project=project)
|
|
return Target(kind, test, is_default, name)
|
|
|
|
def SourceTarget(kind=TargetKinds.source, test=None,
|
|
is_default=False, name='source name',
|
|
project=None):
|
|
if test is None:
|
|
test = lambda v: is_source(v, project=project)
|
|
return Target(kind, test, is_default, name)
|
|
|
|
def ExternalDatasetTarget(kind=TargetKinds.external_dataset,
|
|
test=is_external_source,
|
|
is_default=False, name='external dataset path'):
|
|
return Target(kind, test, is_default, name)
|
|
|
|
def InferenceTarget(kind=TargetKinds.inference, test=is_inference_path,
|
|
is_default=False, name='inference path'):
|
|
return Target(kind, test, is_default, name)
|
|
|
|
def ImageTarget(kind=TargetKinds.image, test=is_image_path,
|
|
is_default=False, name='image path'):
|
|
return Target(kind, test, is_default, name)
|
|
|
|
|
|
def target_selector(*targets):
|
|
def selector(value):
|
|
for (kind, test, is_default, _) in targets:
|
|
if (is_default and (value == '' or value is None)) or test(value):
|
|
return (kind, value)
|
|
raise argparse.ArgumentTypeError('Value should be one of: %s' \
|
|
% (', '.join([str(t) for t in targets])))
|
|
return selector
|