[Datumaro] Refactor explain and models CLI (#1714)

* Update explain cli

* Update model cli

* Update config for models

* Remove input size hint for models
main
zhiltsov-max 6 years ago committed by GitHub
parent 0e003157d3
commit be30aa6354
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -59,7 +59,7 @@ def build_parser(parser_ctor=argparse.ArgumentParser):
help="Confidence threshold for detections (default: include all)") help="Confidence threshold for detections (default: include all)")
rise_parser.add_argument('-b', '--batch-size', default=1, type=int, rise_parser.add_argument('-b', '--batch-size', default=1, type=int,
help="Inference batch size (default: %(default)s)") help="Inference batch size (default: %(default)s)")
rise_parser.add_argument('--progressive', action='store_true', rise_parser.add_argument('--display', action='store_true',
help="Visualize results during computations") help="Visualize results during computations")
parser.add_argument('-p', '--project', dest='project_dir', default='.', parser.add_argument('-p', '--project', dest='project_dir', default='.',
@ -108,16 +108,13 @@ def explain_command(args):
if args.target[0] == TargetKinds.image: if args.target[0] == TargetKinds.image:
image_path = args.target[1] image_path = args.target[1]
image = load_image(image_path) image = load_image(image_path)
if model.preferred_input_size() is not None:
h, w = model.preferred_input_size()
image = cv2.resize(image, (w, h))
log.info("Running inference explanation for '%s'" % image_path) log.info("Running inference explanation for '%s'" % image_path)
heatmap_iter = rise.apply(image, progressive=args.progressive) heatmap_iter = rise.apply(image, progressive=args.display)
image = image / 255.0 image = image / 255.0
file_name = osp.splitext(osp.basename(image_path))[0] file_name = osp.splitext(osp.basename(image_path))[0]
if args.progressive: if args.display:
for i, heatmaps in enumerate(heatmap_iter): for i, heatmaps in enumerate(heatmap_iter):
for j, heatmap in enumerate(heatmaps): for j, heatmap in enumerate(heatmaps):
hm_painted = cm.jet(heatmap)[:, :, 2::-1] hm_painted = cm.jet(heatmap)[:, :, 2::-1]
@ -154,35 +151,31 @@ def explain_command(args):
log.info("Running inference explanation for '%s'" % project_name) log.info("Running inference explanation for '%s'" % project_name)
for item in dataset: for item in dataset:
image = item.image image = item.image.data
if image is None: if image is None:
log.warn( log.warn(
"Dataset item %s does not have image data. Skipping." % \ "Dataset item %s does not have image data. Skipping." % \
(item.id)) (item.id))
continue continue
if model.preferred_input_size() is not None:
h, w = model.preferred_input_size()
image = cv2.resize(image, (w, h))
heatmap_iter = rise.apply(image) heatmap_iter = rise.apply(image)
image = image / 255.0 image = image / 255.0
file_name = osp.splitext(osp.basename(image_path))[0]
heatmaps = next(heatmap_iter) heatmaps = next(heatmap_iter)
if args.save_dir is not None: if args.save_dir is not None:
log.info("Saving inference heatmaps at '%s'" % args.save_dir) log.info("Saving inference heatmaps to '%s'" % args.save_dir)
os.makedirs(args.save_dir, exist_ok=True) os.makedirs(args.save_dir, exist_ok=True)
for j, heatmap in enumerate(heatmaps): for j, heatmap in enumerate(heatmaps):
save_path = osp.join(args.save_dir, save_image(osp.join(args.save_dir,
file_name + '-heatmap-%s.png' % j) item.id + '-heatmap-%s.png' % j),
save_image(save_path, heatmap * 255.0) heatmap * 255.0, create_dir=True)
if args.progressive: if not args.save_dir or args.display:
for j, heatmap in enumerate(heatmaps): for j, heatmap in enumerate(heatmaps):
disp = (image + cm.jet(heatmap)[:, :, 2::-1]) / 2 disp = (image + cm.jet(heatmap)[:, :, 2::-1]) / 2
cv2.imshow(file_name + '-heatmap-%s' % j, disp) cv2.imshow(item.id + '-heatmap-%s' % j, disp)
cv2.waitKey(0) cv2.waitKey(0)
else: else:
raise NotImplementedError() raise NotImplementedError()

@ -11,8 +11,9 @@ import re
from datumaro.components.config import DEFAULT_FORMAT from datumaro.components.config import DEFAULT_FORMAT
from datumaro.components.project import Environment from datumaro.components.project import Environment
from ...util import add_subparser, MultilineFormatter
from ...util.project import load_project from ...util import CliException, MultilineFormatter, add_subparser
from ...util.project import load_project, generate_next_dir_name
def build_add_parser(parser_ctor=argparse.ArgumentParser): def build_add_parser(parser_ctor=argparse.ArgumentParser):
@ -63,19 +64,20 @@ def add_command(args):
except KeyError: except KeyError:
raise CliException("Launcher '%s' is not found" % args.launcher) raise CliException("Launcher '%s' is not found" % args.launcher)
cli_plugin = launcher.cli_plugin cli_plugin = getattr(launcher, 'cli_plugin', launcher)
model_args = cli_plugin.from_cmdline(args.extra_args) model_args = cli_plugin.from_cmdline(args.extra_args)
if args.copy: if args.copy:
try: log.info("Copying model data")
log.info("Copying model data")
model_dir = project.local_model_dir(args.name) model_dir = project.local_model_dir(args.name)
os.makedirs(model_dir, exist_ok=False) os.makedirs(model_dir, exist_ok=False)
try:
cli_plugin.copy_model(model_dir, model_args) cli_plugin.copy_model(model_dir, model_args)
except NotImplementedError: except (AttributeError, NotImplementedError):
log.error("Can't copy: copying is not available for '%s' models" % \ log.error("Can't copy: copying is not available for '%s' models" % \
(args.launcher)) args.launcher)
log.info("Adding the model") log.info("Adding the model")
project.add_model(args.name, { project.add_model(args.name, {
@ -115,12 +117,14 @@ def remove_command(args):
def build_run_parser(parser_ctor=argparse.ArgumentParser): def build_run_parser(parser_ctor=argparse.ArgumentParser):
parser = parser_ctor() parser = parser_ctor()
parser.add_argument('-o', '--output-dir', dest='dst_dir', required=True, parser.add_argument('-o', '--output-dir', dest='dst_dir',
help="Directory to save output") help="Directory to save output")
parser.add_argument('-m', '--model', dest='model_name', required=True, parser.add_argument('-m', '--model', dest='model_name', required=True,
help="Model to apply to the project") help="Model to apply to the project")
parser.add_argument('-p', '--project', dest='project_dir', default='.', parser.add_argument('-p', '--project', dest='project_dir', default='.',
help="Directory of the project to operate on (default: current dir)") help="Directory of the project to operate on (default: current dir)")
parser.add_argument('--overwrite', action='store_true',
help="Overwrite if exists")
parser.set_defaults(command=run_command) parser.set_defaults(command=run_command)
return parser return parser
@ -128,10 +132,17 @@ def build_run_parser(parser_ctor=argparse.ArgumentParser):
def run_command(args): def run_command(args):
project = load_project(args.project_dir) project = load_project(args.project_dir)
dst_dir = osp.abspath(args.dst_dir) dst_dir = args.dst_dir
os.makedirs(dst_dir, exist_ok=False) if dst_dir:
if not args.overwrite and osp.isdir(dst_dir) and os.listdir(dst_dir):
raise CliException("Directory '%s' already exists "
"(pass --overwrite overwrite)" % dst_dir)
else:
dst_dir = generate_next_dir_name('%s-inference' % \
(project.config.project_name))
project.make_dataset().apply_model( project.make_dataset().apply_model(
save_dir=dst_dir, save_dir=osp.abspath(dst_dir),
model=args.model_name) model=args.model_name)
log.info("Inference results have been saved to '%s'" % dst_dir) log.info("Inference results have been saved to '%s'" % dst_dir)

@ -130,7 +130,7 @@ class Config:
return len(self.items()) return len(self.items())
def __iter__(self): def __iter__(self):
return iter(zip(self.keys(), self.values())) return iter(self.keys())
def __getitem__(self, key): def __getitem__(self, key):
default = object() default = object()

@ -21,7 +21,6 @@ class Source(Config):
MODEL_SCHEMA = _SchemaBuilder() \ MODEL_SCHEMA = _SchemaBuilder() \
.add('launcher', str) \ .add('launcher', str) \
.add('model_dir', str, internal=True) \
.add('options', dict) \ .add('options', dict) \
.build() .build()

@ -17,9 +17,6 @@ class Launcher:
def launch(self, inputs): def launch(self, inputs):
raise NotImplementedError() raise NotImplementedError()
def preferred_input_size(self):
return None
def categories(self): def categories(self):
return None return None
# pylint: enable=no-self-use # pylint: enable=no-self-use

@ -817,9 +817,8 @@ class Project:
def make_executable_model(self, name): def make_executable_model(self, name):
model = self.get_model(name) model = self.get_model(name)
model.model_dir = self.local_model_dir(name)
return self.env.make_launcher(model.launcher, return self.env.make_launcher(model.launcher,
**model.options, model_dir=model.model_dir) **model.options, model_dir=self.local_model_dir(name))
def make_source_project(self, name): def make_source_project(self, name):
source = self.get_source(name) source = self.get_source(name)

@ -186,6 +186,3 @@ class OpenVinoLauncher(Launcher):
def process_outputs(self, inputs, outputs): def process_outputs(self, inputs, outputs):
return self._interpreter.process_outputs(inputs, outputs) return self._interpreter.process_outputs(inputs, outputs)
def preferred_input_size(self):
_, _, h, w = self._input_layout
return (h, w)

Loading…
Cancel
Save