From d7c9a0fd21ae01d71c31c89ea60afd891d07f697 Mon Sep 17 00:00:00 2001 From: zhiltsov-max Date: Mon, 13 Jul 2020 07:20:24 +0300 Subject: [PATCH] Refactorings (#1876) --- .../datumaro/cli/contexts/model/__init__.py | 17 ++++---- .../datumaro/cli/contexts/project/__init__.py | 14 ++++--- datumaro/datumaro/cli/util/project.py | 39 +++++++++++-------- datumaro/datumaro/components/launcher.py | 20 +++++++++- datumaro/tests/test_project.py | 7 ++++ 5 files changed, 62 insertions(+), 35 deletions(-) diff --git a/datumaro/datumaro/cli/contexts/model/__init__.py b/datumaro/datumaro/cli/contexts/model/__init__.py index 30cc4da8..a10f5c0d 100644 --- a/datumaro/datumaro/cli/contexts/model/__init__.py +++ b/datumaro/datumaro/cli/contexts/model/__init__.py @@ -13,7 +13,8 @@ from datumaro.components.config import DEFAULT_FORMAT from datumaro.components.project import Environment from ...util import CliException, MultilineFormatter, add_subparser -from ...util.project import load_project, generate_next_dir_name +from ...util.project import load_project, \ + generate_next_name, generate_next_file_name def build_add_parser(parser_ctor=argparse.ArgumentParser): @@ -53,10 +54,8 @@ def add_command(args): raise CliException("Model '%s' already exists " "(pass --overwrite to overwrite)" % args.name) else: - existing_ids = [int(n.split('-')[1]) for n in project.config.models - if re.match(r'model-\d+', n)] - max_idx = max(existing_ids, default=len(project.config.models)) - args.name = 'model-%d' % (max_idx + 1) + args.name = generate_next_name( + project.config.models, 'model', '-', default=0) assert args.name not in project.config.models, args.name try: @@ -79,13 +78,11 @@ def add_command(args): log.error("Can't copy: copying is not available for '%s' models" % \ args.launcher) - log.info("Adding the model") + log.info("Checking the model") project.add_model(args.name, { 'launcher': args.launcher, 'options': model_args, }) - - log.info("Checking the model") project.make_executable_model(args.name) project.save() @@ -138,8 +135,8 @@ def run_command(args): raise CliException("Directory '%s' already exists " "(pass --overwrite overwrite)" % dst_dir) else: - dst_dir = generate_next_dir_name('%s-inference' % \ - (project.config.project_name)) + dst_dir = generate_next_file_name('%s-inference' % \ + project.config.project_name) project.make_dataset().apply_model( save_dir=osp.abspath(dst_dir), diff --git a/datumaro/datumaro/cli/contexts/project/__init__.py b/datumaro/datumaro/cli/contexts/project/__init__.py index f116c797..15b49eb0 100644 --- a/datumaro/datumaro/cli/contexts/project/__init__.py +++ b/datumaro/datumaro/cli/contexts/project/__init__.py @@ -5,6 +5,7 @@ import argparse from enum import Enum +import json import logging as log import os import os.path as osp @@ -20,7 +21,7 @@ from datumaro.components.operations import mean_std from .diff import DiffVisualizer from ...util import add_subparser, CliException, MultilineFormatter, \ make_file_name -from ...util.project import load_project, generate_next_dir_name +from ...util.project import load_project, generate_next_file_name def build_create_parser(parser_ctor=argparse.ArgumentParser): @@ -329,7 +330,7 @@ def export_command(args): raise CliException("Directory '%s' already exists " "(pass --overwrite to force creation)" % dst_dir) else: - dst_dir = generate_next_dir_name('%s-%s' % \ + dst_dir = generate_next_file_name('%s-%s' % \ (project.config.project_name, make_file_name(args.format))) dst_dir = osp.abspath(dst_dir) @@ -425,7 +426,7 @@ def extract_command(args): raise CliException("Directory '%s' already exists " "(pass --overwrite to force creation)" % dst_dir) else: - dst_dir = generate_next_dir_name('%s-filter' % \ + dst_dir = generate_next_file_name('%s-filter' % \ project.config.project_name) dst_dir = osp.abspath(dst_dir) @@ -543,7 +544,7 @@ def diff_command(args): raise CliException("Directory '%s' already exists " "(pass --overwrite to force creation)" % dst_dir) else: - dst_dir = generate_next_dir_name('%s-%s-diff' % ( + dst_dir = generate_next_file_name('%s-%s-diff' % ( first_project.config.project_name, second_project.config.project_name) ) @@ -603,7 +604,7 @@ def transform_command(args): raise CliException("Directory '%s' already exists " "(pass --overwrite to force creation)" % dst_dir) else: - dst_dir = generate_next_dir_name('%s-%s' % \ + dst_dir = generate_next_file_name('%s-%s' % \ (project.config.project_name, make_file_name(args.transform))) dst_dir = osp.abspath(dst_dir) @@ -633,7 +634,8 @@ def transform_command(args): def build_stats_parser(parser_ctor=argparse.ArgumentParser): parser = parser_ctor(help="Get project statistics", description=""" - Outputs project statistics. + Outputs various project statistics like image mean and std, + annotations count etc. """, formatter_class=MultilineFormatter) diff --git a/datumaro/datumaro/cli/util/project.py b/datumaro/datumaro/cli/util/project.py index af92458b..75013053 100644 --- a/datumaro/datumaro/cli/util/project.py +++ b/datumaro/datumaro/cli/util/project.py @@ -4,31 +4,36 @@ # SPDX-License-Identifier: MIT import os +import re from datumaro.components.project import Project +from datumaro.util import cast def load_project(project_dir): return Project.load(project_dir) -def generate_next_dir_name(dirname, basedir='.', sep='.'): +def generate_next_file_name(basename, basedir='.', sep='.', ext=''): """ - If basedir does not contain dirname, returns dirname itself, - else generates a dirname by appending separator to the dirname + If basedir does not contain basename, returns basename, + otherwise generates a name by appending sep to the basename and the number, next to the last used number in the basedir for - files with dirname prefix. + files with basename prefix. Optionally, appends ext. """ - def _to_int(s): - try: - return int(s) - except Exception: - return 0 - sep_count = dirname.count(sep) + 2 - - files = [e for e in os.listdir(basedir) if e.startswith(dirname)] - if files: - files = [e.split(sep) for e in files] - files = [_to_int(e[-1]) for e in files if len(e) == sep_count] - dirname += '%s%s' % (sep, max(files, default=0) + 1) - return dirname \ No newline at end of file + return generate_next_name(os.listdir(basedir), basename, sep, ext) + +def generate_next_name(names, basename, sep='.', suffix='', default=None): + pattern = re.compile(r'%s(?:%s(\d+))?%s' % \ + tuple(map(re.escape, [basename, sep, suffix]))) + matches = [match for match in (pattern.match(n) for n in names) if match] + + max_idx = max([cast(match[1], int, 0) for match in matches], default=None) + if max_idx is None: + if default is not None: + idx = sep + str(default) + else: + idx = '' + else: + idx = sep + str(max_idx + 1) + return basename + idx + suffix \ No newline at end of file diff --git a/datumaro/datumaro/components/launcher.py b/datumaro/datumaro/components/launcher.py index 1a60cead..b66bf237 100644 --- a/datumaro/datumaro/components/launcher.py +++ b/datumaro/datumaro/components/launcher.py @@ -5,7 +5,8 @@ import numpy as np -from datumaro.components.extractor import Transform +from datumaro.components.extractor import (Transform, LabelCategories, + AnnotationType) from datumaro.util import take_by @@ -33,6 +34,7 @@ class ModelTransform(Transform): inference = self._launcher.launch(inputs) for item, annotations in zip(batch, inference): + self._check_annotations(annotations) yield self.wrap_item(item, annotations=annotations) def get_subset(self, name): @@ -48,4 +50,18 @@ class ModelTransform(Transform): def transform_item(self, item): inputs = np.expand_dims(item.image, axis=0) annotations = self._launcher.launch(inputs)[0] - return self.wrap_item(item, annotations=annotations) \ No newline at end of file + return self.wrap_item(item, annotations=annotations) + + def _check_annotations(self, annotations): + labels_count = len(self.categories().get( + AnnotationType.label, LabelCategories()).items) + + for ann in annotations: + label = getattr(ann, 'label') + if label is None: + continue + + if label not in range(labels_count): + raise Exception("Annotation has unexpected label id %s, " + "while there is only %s defined labels." % \ + (label, labels_count)) \ No newline at end of file diff --git a/datumaro/tests/test_project.py b/datumaro/tests/test_project.py index 6ab66799..267ef378 100644 --- a/datumaro/tests/test_project.py +++ b/datumaro/tests/test_project.py @@ -10,6 +10,7 @@ from datumaro.components.launcher import Launcher, ModelTransform from datumaro.components.converter import Converter from datumaro.components.extractor import (Extractor, DatasetItem, Label, Mask, Points, Polygon, PolyLine, Bbox, Caption, + LabelCategories, AnnotationType ) from datumaro.util.image import Image from datumaro.components.config import Config, DefaultConfig, SchemaBuilder @@ -169,6 +170,12 @@ class ProjectTest(TestCase): yield DatasetItem(id=i, image=np.ones([2, 2, 3]) * i, annotations=[Label(i)]) + def categories(self): + label_cat = LabelCategories() + label_cat.add('0') + label_cat.add('1') + return { AnnotationType.label: label_cat } + class TestLauncher(Launcher): def launch(self, inputs): for inp in inputs: