|
|
|
@ -59,7 +59,7 @@ def build_parser(parser_ctor=argparse.ArgumentParser):
|
|
|
|
help="Confidence threshold for detections (default: include all)")
|
|
|
|
help="Confidence threshold for detections (default: include all)")
|
|
|
|
rise_parser.add_argument('-b', '--batch-size', default=1, type=int,
|
|
|
|
rise_parser.add_argument('-b', '--batch-size', default=1, type=int,
|
|
|
|
help="Inference batch size (default: %(default)s)")
|
|
|
|
help="Inference batch size (default: %(default)s)")
|
|
|
|
rise_parser.add_argument('--progressive', action='store_true',
|
|
|
|
rise_parser.add_argument('--display', action='store_true',
|
|
|
|
help="Visualize results during computations")
|
|
|
|
help="Visualize results during computations")
|
|
|
|
|
|
|
|
|
|
|
|
parser.add_argument('-p', '--project', dest='project_dir', default='.',
|
|
|
|
parser.add_argument('-p', '--project', dest='project_dir', default='.',
|
|
|
|
@ -108,16 +108,13 @@ def explain_command(args):
|
|
|
|
if args.target[0] == TargetKinds.image:
|
|
|
|
if args.target[0] == TargetKinds.image:
|
|
|
|
image_path = args.target[1]
|
|
|
|
image_path = args.target[1]
|
|
|
|
image = load_image(image_path)
|
|
|
|
image = load_image(image_path)
|
|
|
|
if model.preferred_input_size() is not None:
|
|
|
|
|
|
|
|
h, w = model.preferred_input_size()
|
|
|
|
|
|
|
|
image = cv2.resize(image, (w, h))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
log.info("Running inference explanation for '%s'" % image_path)
|
|
|
|
log.info("Running inference explanation for '%s'" % image_path)
|
|
|
|
heatmap_iter = rise.apply(image, progressive=args.progressive)
|
|
|
|
heatmap_iter = rise.apply(image, progressive=args.display)
|
|
|
|
|
|
|
|
|
|
|
|
image = image / 255.0
|
|
|
|
image = image / 255.0
|
|
|
|
file_name = osp.splitext(osp.basename(image_path))[0]
|
|
|
|
file_name = osp.splitext(osp.basename(image_path))[0]
|
|
|
|
if args.progressive:
|
|
|
|
if args.display:
|
|
|
|
for i, heatmaps in enumerate(heatmap_iter):
|
|
|
|
for i, heatmaps in enumerate(heatmap_iter):
|
|
|
|
for j, heatmap in enumerate(heatmaps):
|
|
|
|
for j, heatmap in enumerate(heatmaps):
|
|
|
|
hm_painted = cm.jet(heatmap)[:, :, 2::-1]
|
|
|
|
hm_painted = cm.jet(heatmap)[:, :, 2::-1]
|
|
|
|
@ -154,35 +151,31 @@ def explain_command(args):
|
|
|
|
log.info("Running inference explanation for '%s'" % project_name)
|
|
|
|
log.info("Running inference explanation for '%s'" % project_name)
|
|
|
|
|
|
|
|
|
|
|
|
for item in dataset:
|
|
|
|
for item in dataset:
|
|
|
|
image = item.image
|
|
|
|
image = item.image.data
|
|
|
|
if image is None:
|
|
|
|
if image is None:
|
|
|
|
log.warn(
|
|
|
|
log.warn(
|
|
|
|
"Dataset item %s does not have image data. Skipping." % \
|
|
|
|
"Dataset item %s does not have image data. Skipping." % \
|
|
|
|
(item.id))
|
|
|
|
(item.id))
|
|
|
|
continue
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
if model.preferred_input_size() is not None:
|
|
|
|
|
|
|
|
h, w = model.preferred_input_size()
|
|
|
|
|
|
|
|
image = cv2.resize(image, (w, h))
|
|
|
|
|
|
|
|
heatmap_iter = rise.apply(image)
|
|
|
|
heatmap_iter = rise.apply(image)
|
|
|
|
|
|
|
|
|
|
|
|
image = image / 255.0
|
|
|
|
image = image / 255.0
|
|
|
|
file_name = osp.splitext(osp.basename(image_path))[0]
|
|
|
|
|
|
|
|
heatmaps = next(heatmap_iter)
|
|
|
|
heatmaps = next(heatmap_iter)
|
|
|
|
|
|
|
|
|
|
|
|
if args.save_dir is not None:
|
|
|
|
if args.save_dir is not None:
|
|
|
|
log.info("Saving inference heatmaps at '%s'" % args.save_dir)
|
|
|
|
log.info("Saving inference heatmaps to '%s'" % args.save_dir)
|
|
|
|
os.makedirs(args.save_dir, exist_ok=True)
|
|
|
|
os.makedirs(args.save_dir, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
for j, heatmap in enumerate(heatmaps):
|
|
|
|
for j, heatmap in enumerate(heatmaps):
|
|
|
|
save_path = osp.join(args.save_dir,
|
|
|
|
save_image(osp.join(args.save_dir,
|
|
|
|
file_name + '-heatmap-%s.png' % j)
|
|
|
|
item.id + '-heatmap-%s.png' % j),
|
|
|
|
save_image(save_path, heatmap * 255.0)
|
|
|
|
heatmap * 255.0, create_dir=True)
|
|
|
|
|
|
|
|
|
|
|
|
if args.progressive:
|
|
|
|
if not args.save_dir or args.display:
|
|
|
|
for j, heatmap in enumerate(heatmaps):
|
|
|
|
for j, heatmap in enumerate(heatmaps):
|
|
|
|
disp = (image + cm.jet(heatmap)[:, :, 2::-1]) / 2
|
|
|
|
disp = (image + cm.jet(heatmap)[:, :, 2::-1]) / 2
|
|
|
|
cv2.imshow(file_name + '-heatmap-%s' % j, disp)
|
|
|
|
cv2.imshow(item.id + '-heatmap-%s' % j, disp)
|
|
|
|
cv2.waitKey(0)
|
|
|
|
cv2.waitKey(0)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
raise NotImplementedError()
|
|
|
|
raise NotImplementedError()
|
|
|
|
|