[Datumaro] Refactor explain and models CLI (#1714)

* Update explain cli

* Update model cli

* Update config for models

* Remove input size hint for models
main
zhiltsov-max 6 years ago committed by GitHub
parent 0e003157d3
commit be30aa6354
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -59,7 +59,7 @@ def build_parser(parser_ctor=argparse.ArgumentParser):
help="Confidence threshold for detections (default: include all)")
rise_parser.add_argument('-b', '--batch-size', default=1, type=int,
help="Inference batch size (default: %(default)s)")
rise_parser.add_argument('--progressive', action='store_true',
rise_parser.add_argument('--display', action='store_true',
help="Visualize results during computations")
parser.add_argument('-p', '--project', dest='project_dir', default='.',
@ -108,16 +108,13 @@ def explain_command(args):
if args.target[0] == TargetKinds.image:
image_path = args.target[1]
image = load_image(image_path)
if model.preferred_input_size() is not None:
h, w = model.preferred_input_size()
image = cv2.resize(image, (w, h))
log.info("Running inference explanation for '%s'" % image_path)
heatmap_iter = rise.apply(image, progressive=args.progressive)
heatmap_iter = rise.apply(image, progressive=args.display)
image = image / 255.0
file_name = osp.splitext(osp.basename(image_path))[0]
if args.progressive:
if args.display:
for i, heatmaps in enumerate(heatmap_iter):
for j, heatmap in enumerate(heatmaps):
hm_painted = cm.jet(heatmap)[:, :, 2::-1]
@ -154,35 +151,31 @@ def explain_command(args):
log.info("Running inference explanation for '%s'" % project_name)
for item in dataset:
image = item.image
image = item.image.data
if image is None:
log.warn(
"Dataset item %s does not have image data. Skipping." % \
(item.id))
continue
if model.preferred_input_size() is not None:
h, w = model.preferred_input_size()
image = cv2.resize(image, (w, h))
heatmap_iter = rise.apply(image)
image = image / 255.0
file_name = osp.splitext(osp.basename(image_path))[0]
heatmaps = next(heatmap_iter)
if args.save_dir is not None:
log.info("Saving inference heatmaps at '%s'" % args.save_dir)
log.info("Saving inference heatmaps to '%s'" % args.save_dir)
os.makedirs(args.save_dir, exist_ok=True)
for j, heatmap in enumerate(heatmaps):
save_path = osp.join(args.save_dir,
file_name + '-heatmap-%s.png' % j)
save_image(save_path, heatmap * 255.0)
save_image(osp.join(args.save_dir,
item.id + '-heatmap-%s.png' % j),
heatmap * 255.0, create_dir=True)
if args.progressive:
if not args.save_dir or args.display:
for j, heatmap in enumerate(heatmaps):
disp = (image + cm.jet(heatmap)[:, :, 2::-1]) / 2
cv2.imshow(file_name + '-heatmap-%s' % j, disp)
cv2.imshow(item.id + '-heatmap-%s' % j, disp)
cv2.waitKey(0)
else:
raise NotImplementedError()

@ -11,8 +11,9 @@ import re
from datumaro.components.config import DEFAULT_FORMAT
from datumaro.components.project import Environment
from ...util import add_subparser, MultilineFormatter
from ...util.project import load_project
from ...util import CliException, MultilineFormatter, add_subparser
from ...util.project import load_project, generate_next_dir_name
def build_add_parser(parser_ctor=argparse.ArgumentParser):
@ -63,19 +64,20 @@ def add_command(args):
except KeyError:
raise CliException("Launcher '%s' is not found" % args.launcher)
cli_plugin = launcher.cli_plugin
cli_plugin = getattr(launcher, 'cli_plugin', launcher)
model_args = cli_plugin.from_cmdline(args.extra_args)
if args.copy:
try:
log.info("Copying model data")
model_dir = project.local_model_dir(args.name)
os.makedirs(model_dir, exist_ok=False)
try:
cli_plugin.copy_model(model_dir, model_args)
except NotImplementedError:
except (AttributeError, NotImplementedError):
log.error("Can't copy: copying is not available for '%s' models" % \
(args.launcher))
args.launcher)
log.info("Adding the model")
project.add_model(args.name, {
@ -115,12 +117,14 @@ def remove_command(args):
def build_run_parser(parser_ctor=argparse.ArgumentParser):
parser = parser_ctor()
parser.add_argument('-o', '--output-dir', dest='dst_dir', required=True,
parser.add_argument('-o', '--output-dir', dest='dst_dir',
help="Directory to save output")
parser.add_argument('-m', '--model', dest='model_name', required=True,
help="Model to apply to the project")
parser.add_argument('-p', '--project', dest='project_dir', default='.',
help="Directory of the project to operate on (default: current dir)")
parser.add_argument('--overwrite', action='store_true',
help="Overwrite if exists")
parser.set_defaults(command=run_command)
return parser
@ -128,10 +132,17 @@ def build_run_parser(parser_ctor=argparse.ArgumentParser):
def run_command(args):
project = load_project(args.project_dir)
dst_dir = osp.abspath(args.dst_dir)
os.makedirs(dst_dir, exist_ok=False)
dst_dir = args.dst_dir
if dst_dir:
if not args.overwrite and osp.isdir(dst_dir) and os.listdir(dst_dir):
raise CliException("Directory '%s' already exists "
"(pass --overwrite overwrite)" % dst_dir)
else:
dst_dir = generate_next_dir_name('%s-inference' % \
(project.config.project_name))
project.make_dataset().apply_model(
save_dir=dst_dir,
save_dir=osp.abspath(dst_dir),
model=args.model_name)
log.info("Inference results have been saved to '%s'" % dst_dir)

@ -130,7 +130,7 @@ class Config:
return len(self.items())
def __iter__(self):
return iter(zip(self.keys(), self.values()))
return iter(self.keys())
def __getitem__(self, key):
default = object()

@ -21,7 +21,6 @@ class Source(Config):
MODEL_SCHEMA = _SchemaBuilder() \
.add('launcher', str) \
.add('model_dir', str, internal=True) \
.add('options', dict) \
.build()

@ -17,9 +17,6 @@ class Launcher:
def launch(self, inputs):
raise NotImplementedError()
def preferred_input_size(self):
return None
def categories(self):
return None
# pylint: enable=no-self-use

@ -817,9 +817,8 @@ class Project:
def make_executable_model(self, name):
model = self.get_model(name)
model.model_dir = self.local_model_dir(name)
return self.env.make_launcher(model.launcher,
**model.options, model_dir=model.model_dir)
**model.options, model_dir=self.local_model_dir(name))
def make_source_project(self, name):
source = self.get_source(name)

@ -186,6 +186,3 @@ class OpenVinoLauncher(Launcher):
def process_outputs(self, inputs, outputs):
return self._interpreter.process_outputs(inputs, outputs)
def preferred_input_size(self):
_, _, h, w = self._input_layout
return (h, w)

Loading…
Cancel
Save