From 2640e6c57965517b6d8b4d97141a3b62972c0de2 Mon Sep 17 00:00:00 2001 From: Ben Hoff Date: Tue, 15 Oct 2019 06:32:43 -0400 Subject: [PATCH] allow security segmentation models to be used in auto annotation (#759) --- cvat/apps/auto_annotation/model_loader.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/cvat/apps/auto_annotation/model_loader.py b/cvat/apps/auto_annotation/model_loader.py index 73d33d81..15a7c792 100644 --- a/cvat/apps/auto_annotation/model_loader.py +++ b/cvat/apps/auto_annotation/model_loader.py @@ -31,14 +31,19 @@ class ModelLoader(): iter_inputs = iter(network.inputs) self._input_blob_name = next(iter_inputs) + self._input_info_name = '' self._output_blob_name = next(iter(network.outputs)) self._require_image_info = False + info_names = ('image_info', 'im_info') + # NOTE: handeling for the inclusion of `image_info` in OpenVino2019 - if 'image_info' in network.inputs: + if any(s in network.inputs for s in info_names): self._require_image_info = True - if self._input_blob_name == 'image_info': + self._input_info_name = set(network.inputs).intersection(info_names) + self._input_info_name = self._input_info_name.pop() + if self._input_blob_name in info_names: self._input_blob_name = next(iter_inputs) self._net = plugin.load(network=network, num_requests=2) @@ -56,7 +61,7 @@ class ModelLoader(): info[0, 1] = w # frame number info[0, 2] = 1 - inputs['image_info'] = info + inputs[self._input_info_name] = info results = self._net.infer(inputs) if len(results) == 1: