Refactorings (#1876)

main
zhiltsov-max 6 years ago committed by GitHub
parent c6b3c797ab
commit d7c9a0fd21
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -13,7 +13,8 @@ from datumaro.components.config import DEFAULT_FORMAT
from datumaro.components.project import Environment from datumaro.components.project import Environment
from ...util import CliException, MultilineFormatter, add_subparser from ...util import CliException, MultilineFormatter, add_subparser
from ...util.project import load_project, generate_next_dir_name from ...util.project import load_project, \
generate_next_name, generate_next_file_name
def build_add_parser(parser_ctor=argparse.ArgumentParser): def build_add_parser(parser_ctor=argparse.ArgumentParser):
@ -53,10 +54,8 @@ def add_command(args):
raise CliException("Model '%s' already exists " raise CliException("Model '%s' already exists "
"(pass --overwrite to overwrite)" % args.name) "(pass --overwrite to overwrite)" % args.name)
else: else:
existing_ids = [int(n.split('-')[1]) for n in project.config.models args.name = generate_next_name(
if re.match(r'model-\d+', n)] project.config.models, 'model', '-', default=0)
max_idx = max(existing_ids, default=len(project.config.models))
args.name = 'model-%d' % (max_idx + 1)
assert args.name not in project.config.models, args.name assert args.name not in project.config.models, args.name
try: try:
@ -79,13 +78,11 @@ def add_command(args):
log.error("Can't copy: copying is not available for '%s' models" % \ log.error("Can't copy: copying is not available for '%s' models" % \
args.launcher) args.launcher)
log.info("Adding the model") log.info("Checking the model")
project.add_model(args.name, { project.add_model(args.name, {
'launcher': args.launcher, 'launcher': args.launcher,
'options': model_args, 'options': model_args,
}) })
log.info("Checking the model")
project.make_executable_model(args.name) project.make_executable_model(args.name)
project.save() project.save()
@ -138,8 +135,8 @@ def run_command(args):
raise CliException("Directory '%s' already exists " raise CliException("Directory '%s' already exists "
"(pass --overwrite overwrite)" % dst_dir) "(pass --overwrite overwrite)" % dst_dir)
else: else:
dst_dir = generate_next_dir_name('%s-inference' % \ dst_dir = generate_next_file_name('%s-inference' % \
(project.config.project_name)) project.config.project_name)
project.make_dataset().apply_model( project.make_dataset().apply_model(
save_dir=osp.abspath(dst_dir), save_dir=osp.abspath(dst_dir),

@ -5,6 +5,7 @@
import argparse import argparse
from enum import Enum from enum import Enum
import json
import logging as log import logging as log
import os import os
import os.path as osp import os.path as osp
@ -20,7 +21,7 @@ from datumaro.components.operations import mean_std
from .diff import DiffVisualizer from .diff import DiffVisualizer
from ...util import add_subparser, CliException, MultilineFormatter, \ from ...util import add_subparser, CliException, MultilineFormatter, \
make_file_name make_file_name
from ...util.project import load_project, generate_next_dir_name from ...util.project import load_project, generate_next_file_name
def build_create_parser(parser_ctor=argparse.ArgumentParser): def build_create_parser(parser_ctor=argparse.ArgumentParser):
@ -329,7 +330,7 @@ def export_command(args):
raise CliException("Directory '%s' already exists " raise CliException("Directory '%s' already exists "
"(pass --overwrite to force creation)" % dst_dir) "(pass --overwrite to force creation)" % dst_dir)
else: else:
dst_dir = generate_next_dir_name('%s-%s' % \ dst_dir = generate_next_file_name('%s-%s' % \
(project.config.project_name, make_file_name(args.format))) (project.config.project_name, make_file_name(args.format)))
dst_dir = osp.abspath(dst_dir) dst_dir = osp.abspath(dst_dir)
@ -425,7 +426,7 @@ def extract_command(args):
raise CliException("Directory '%s' already exists " raise CliException("Directory '%s' already exists "
"(pass --overwrite to force creation)" % dst_dir) "(pass --overwrite to force creation)" % dst_dir)
else: else:
dst_dir = generate_next_dir_name('%s-filter' % \ dst_dir = generate_next_file_name('%s-filter' % \
project.config.project_name) project.config.project_name)
dst_dir = osp.abspath(dst_dir) dst_dir = osp.abspath(dst_dir)
@ -543,7 +544,7 @@ def diff_command(args):
raise CliException("Directory '%s' already exists " raise CliException("Directory '%s' already exists "
"(pass --overwrite to force creation)" % dst_dir) "(pass --overwrite to force creation)" % dst_dir)
else: else:
dst_dir = generate_next_dir_name('%s-%s-diff' % ( dst_dir = generate_next_file_name('%s-%s-diff' % (
first_project.config.project_name, first_project.config.project_name,
second_project.config.project_name) second_project.config.project_name)
) )
@ -603,7 +604,7 @@ def transform_command(args):
raise CliException("Directory '%s' already exists " raise CliException("Directory '%s' already exists "
"(pass --overwrite to force creation)" % dst_dir) "(pass --overwrite to force creation)" % dst_dir)
else: else:
dst_dir = generate_next_dir_name('%s-%s' % \ dst_dir = generate_next_file_name('%s-%s' % \
(project.config.project_name, make_file_name(args.transform))) (project.config.project_name, make_file_name(args.transform)))
dst_dir = osp.abspath(dst_dir) dst_dir = osp.abspath(dst_dir)
@ -633,7 +634,8 @@ def transform_command(args):
def build_stats_parser(parser_ctor=argparse.ArgumentParser): def build_stats_parser(parser_ctor=argparse.ArgumentParser):
parser = parser_ctor(help="Get project statistics", parser = parser_ctor(help="Get project statistics",
description=""" description="""
Outputs project statistics. Outputs various project statistics like image mean and std,
annotations count etc.
""", """,
formatter_class=MultilineFormatter) formatter_class=MultilineFormatter)

@ -4,31 +4,36 @@
# SPDX-License-Identifier: MIT # SPDX-License-Identifier: MIT
import os import os
import re
from datumaro.components.project import Project from datumaro.components.project import Project
from datumaro.util import cast
def load_project(project_dir): def load_project(project_dir):
return Project.load(project_dir) return Project.load(project_dir)
def generate_next_dir_name(dirname, basedir='.', sep='.'): def generate_next_file_name(basename, basedir='.', sep='.', ext=''):
""" """
If basedir does not contain dirname, returns dirname itself, If basedir does not contain basename, returns basename,
else generates a dirname by appending separator to the dirname otherwise generates a name by appending sep to the basename
and the number, next to the last used number in the basedir for and the number, next to the last used number in the basedir for
files with dirname prefix. files with basename prefix. Optionally, appends ext.
""" """
def _to_int(s): return generate_next_name(os.listdir(basedir), basename, sep, ext)
try:
return int(s) def generate_next_name(names, basename, sep='.', suffix='', default=None):
except Exception: pattern = re.compile(r'%s(?:%s(\d+))?%s' % \
return 0 tuple(map(re.escape, [basename, sep, suffix])))
sep_count = dirname.count(sep) + 2 matches = [match for match in (pattern.match(n) for n in names) if match]
files = [e for e in os.listdir(basedir) if e.startswith(dirname)] max_idx = max([cast(match[1], int, 0) for match in matches], default=None)
if files: if max_idx is None:
files = [e.split(sep) for e in files] if default is not None:
files = [_to_int(e[-1]) for e in files if len(e) == sep_count] idx = sep + str(default)
dirname += '%s%s' % (sep, max(files, default=0) + 1) else:
return dirname idx = ''
else:
idx = sep + str(max_idx + 1)
return basename + idx + suffix

@ -5,7 +5,8 @@
import numpy as np import numpy as np
from datumaro.components.extractor import Transform from datumaro.components.extractor import (Transform, LabelCategories,
AnnotationType)
from datumaro.util import take_by from datumaro.util import take_by
@ -33,6 +34,7 @@ class ModelTransform(Transform):
inference = self._launcher.launch(inputs) inference = self._launcher.launch(inputs)
for item, annotations in zip(batch, inference): for item, annotations in zip(batch, inference):
self._check_annotations(annotations)
yield self.wrap_item(item, annotations=annotations) yield self.wrap_item(item, annotations=annotations)
def get_subset(self, name): def get_subset(self, name):
@ -48,4 +50,18 @@ class ModelTransform(Transform):
def transform_item(self, item): def transform_item(self, item):
inputs = np.expand_dims(item.image, axis=0) inputs = np.expand_dims(item.image, axis=0)
annotations = self._launcher.launch(inputs)[0] annotations = self._launcher.launch(inputs)[0]
return self.wrap_item(item, annotations=annotations) return self.wrap_item(item, annotations=annotations)
def _check_annotations(self, annotations):
labels_count = len(self.categories().get(
AnnotationType.label, LabelCategories()).items)
for ann in annotations:
label = getattr(ann, 'label')
if label is None:
continue
if label not in range(labels_count):
raise Exception("Annotation has unexpected label id %s, "
"while there is only %s defined labels." % \
(label, labels_count))

@ -10,6 +10,7 @@ from datumaro.components.launcher import Launcher, ModelTransform
from datumaro.components.converter import Converter from datumaro.components.converter import Converter
from datumaro.components.extractor import (Extractor, DatasetItem, from datumaro.components.extractor import (Extractor, DatasetItem,
Label, Mask, Points, Polygon, PolyLine, Bbox, Caption, Label, Mask, Points, Polygon, PolyLine, Bbox, Caption,
LabelCategories, AnnotationType
) )
from datumaro.util.image import Image from datumaro.util.image import Image
from datumaro.components.config import Config, DefaultConfig, SchemaBuilder from datumaro.components.config import Config, DefaultConfig, SchemaBuilder
@ -169,6 +170,12 @@ class ProjectTest(TestCase):
yield DatasetItem(id=i, image=np.ones([2, 2, 3]) * i, yield DatasetItem(id=i, image=np.ones([2, 2, 3]) * i,
annotations=[Label(i)]) annotations=[Label(i)])
def categories(self):
label_cat = LabelCategories()
label_cat.add('0')
label_cat.add('1')
return { AnnotationType.label: label_cat }
class TestLauncher(Launcher): class TestLauncher(Launcher):
def launch(self, inputs): def launch(self, inputs):
for inp in inputs: for inp in inputs:

Loading…
Cancel
Save