Refactorings (#1876)

main
zhiltsov-max 6 years ago committed by GitHub
parent c6b3c797ab
commit d7c9a0fd21
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -13,7 +13,8 @@ from datumaro.components.config import DEFAULT_FORMAT
from datumaro.components.project import Environment
from ...util import CliException, MultilineFormatter, add_subparser
from ...util.project import load_project, generate_next_dir_name
from ...util.project import load_project, \
generate_next_name, generate_next_file_name
def build_add_parser(parser_ctor=argparse.ArgumentParser):
@ -53,10 +54,8 @@ def add_command(args):
raise CliException("Model '%s' already exists "
"(pass --overwrite to overwrite)" % args.name)
else:
existing_ids = [int(n.split('-')[1]) for n in project.config.models
if re.match(r'model-\d+', n)]
max_idx = max(existing_ids, default=len(project.config.models))
args.name = 'model-%d' % (max_idx + 1)
args.name = generate_next_name(
project.config.models, 'model', '-', default=0)
assert args.name not in project.config.models, args.name
try:
@ -79,13 +78,11 @@ def add_command(args):
log.error("Can't copy: copying is not available for '%s' models" % \
args.launcher)
log.info("Adding the model")
log.info("Checking the model")
project.add_model(args.name, {
'launcher': args.launcher,
'options': model_args,
})
log.info("Checking the model")
project.make_executable_model(args.name)
project.save()
@ -138,8 +135,8 @@ def run_command(args):
raise CliException("Directory '%s' already exists "
"(pass --overwrite overwrite)" % dst_dir)
else:
dst_dir = generate_next_dir_name('%s-inference' % \
(project.config.project_name))
dst_dir = generate_next_file_name('%s-inference' % \
project.config.project_name)
project.make_dataset().apply_model(
save_dir=osp.abspath(dst_dir),

@ -5,6 +5,7 @@
import argparse
from enum import Enum
import json
import logging as log
import os
import os.path as osp
@ -20,7 +21,7 @@ from datumaro.components.operations import mean_std
from .diff import DiffVisualizer
from ...util import add_subparser, CliException, MultilineFormatter, \
make_file_name
from ...util.project import load_project, generate_next_dir_name
from ...util.project import load_project, generate_next_file_name
def build_create_parser(parser_ctor=argparse.ArgumentParser):
@ -329,7 +330,7 @@ def export_command(args):
raise CliException("Directory '%s' already exists "
"(pass --overwrite to force creation)" % dst_dir)
else:
dst_dir = generate_next_dir_name('%s-%s' % \
dst_dir = generate_next_file_name('%s-%s' % \
(project.config.project_name, make_file_name(args.format)))
dst_dir = osp.abspath(dst_dir)
@ -425,7 +426,7 @@ def extract_command(args):
raise CliException("Directory '%s' already exists "
"(pass --overwrite to force creation)" % dst_dir)
else:
dst_dir = generate_next_dir_name('%s-filter' % \
dst_dir = generate_next_file_name('%s-filter' % \
project.config.project_name)
dst_dir = osp.abspath(dst_dir)
@ -543,7 +544,7 @@ def diff_command(args):
raise CliException("Directory '%s' already exists "
"(pass --overwrite to force creation)" % dst_dir)
else:
dst_dir = generate_next_dir_name('%s-%s-diff' % (
dst_dir = generate_next_file_name('%s-%s-diff' % (
first_project.config.project_name,
second_project.config.project_name)
)
@ -603,7 +604,7 @@ def transform_command(args):
raise CliException("Directory '%s' already exists "
"(pass --overwrite to force creation)" % dst_dir)
else:
dst_dir = generate_next_dir_name('%s-%s' % \
dst_dir = generate_next_file_name('%s-%s' % \
(project.config.project_name, make_file_name(args.transform)))
dst_dir = osp.abspath(dst_dir)
@ -633,7 +634,8 @@ def transform_command(args):
def build_stats_parser(parser_ctor=argparse.ArgumentParser):
parser = parser_ctor(help="Get project statistics",
description="""
Outputs project statistics.
Outputs various project statistics like image mean and std,
annotations count etc.
""",
formatter_class=MultilineFormatter)

@ -4,31 +4,36 @@
# SPDX-License-Identifier: MIT
import os
import re
from datumaro.components.project import Project
from datumaro.util import cast
def load_project(project_dir):
return Project.load(project_dir)
def generate_next_dir_name(dirname, basedir='.', sep='.'):
def generate_next_file_name(basename, basedir='.', sep='.', ext=''):
"""
If basedir does not contain dirname, returns dirname itself,
else generates a dirname by appending separator to the dirname
If basedir does not contain basename, returns basename,
otherwise generates a name by appending sep to the basename
and the number, next to the last used number in the basedir for
files with dirname prefix.
files with basename prefix. Optionally, appends ext.
"""
def _to_int(s):
try:
return int(s)
except Exception:
return 0
sep_count = dirname.count(sep) + 2
files = [e for e in os.listdir(basedir) if e.startswith(dirname)]
if files:
files = [e.split(sep) for e in files]
files = [_to_int(e[-1]) for e in files if len(e) == sep_count]
dirname += '%s%s' % (sep, max(files, default=0) + 1)
return dirname
return generate_next_name(os.listdir(basedir), basename, sep, ext)
def generate_next_name(names, basename, sep='.', suffix='', default=None):
pattern = re.compile(r'%s(?:%s(\d+))?%s' % \
tuple(map(re.escape, [basename, sep, suffix])))
matches = [match for match in (pattern.match(n) for n in names) if match]
max_idx = max([cast(match[1], int, 0) for match in matches], default=None)
if max_idx is None:
if default is not None:
idx = sep + str(default)
else:
idx = ''
else:
idx = sep + str(max_idx + 1)
return basename + idx + suffix

@ -5,7 +5,8 @@
import numpy as np
from datumaro.components.extractor import Transform
from datumaro.components.extractor import (Transform, LabelCategories,
AnnotationType)
from datumaro.util import take_by
@ -33,6 +34,7 @@ class ModelTransform(Transform):
inference = self._launcher.launch(inputs)
for item, annotations in zip(batch, inference):
self._check_annotations(annotations)
yield self.wrap_item(item, annotations=annotations)
def get_subset(self, name):
@ -48,4 +50,18 @@ class ModelTransform(Transform):
def transform_item(self, item):
inputs = np.expand_dims(item.image, axis=0)
annotations = self._launcher.launch(inputs)[0]
return self.wrap_item(item, annotations=annotations)
return self.wrap_item(item, annotations=annotations)
def _check_annotations(self, annotations):
labels_count = len(self.categories().get(
AnnotationType.label, LabelCategories()).items)
for ann in annotations:
label = getattr(ann, 'label')
if label is None:
continue
if label not in range(labels_count):
raise Exception("Annotation has unexpected label id %s, "
"while there is only %s defined labels." % \
(label, labels_count))

@ -10,6 +10,7 @@ from datumaro.components.launcher import Launcher, ModelTransform
from datumaro.components.converter import Converter
from datumaro.components.extractor import (Extractor, DatasetItem,
Label, Mask, Points, Polygon, PolyLine, Bbox, Caption,
LabelCategories, AnnotationType
)
from datumaro.util.image import Image
from datumaro.components.config import Config, DefaultConfig, SchemaBuilder
@ -169,6 +170,12 @@ class ProjectTest(TestCase):
yield DatasetItem(id=i, image=np.ones([2, 2, 3]) * i,
annotations=[Label(i)])
def categories(self):
label_cat = LabelCategories()
label_cat.add('0')
label_cat.add('1')
return { AnnotationType.label: label_cat }
class TestLauncher(Launcher):
def launch(self, inputs):
for inp in inputs:

Loading…
Cancel
Save