You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
189 lines
6.7 KiB
Python
189 lines
6.7 KiB
Python
|
|
# Copyright (C) 2019 Intel Corporation
|
|
#
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
# pylint: disable=exec-used
|
|
|
|
import cv2
|
|
import logging as log
|
|
import numpy as np
|
|
import os.path as osp
|
|
import shutil
|
|
|
|
from openvino.inference_engine import IECore
|
|
|
|
from datumaro.components.cli_plugin import CliPlugin
|
|
from datumaro.components.launcher import Launcher
|
|
|
|
|
|
class OpenVinoImporter(CliPlugin):
|
|
@classmethod
|
|
def build_cmdline_parser(cls, **kwargs):
|
|
parser = super().build_cmdline_parser(**kwargs)
|
|
parser.add_argument('-d', '--description', required=True,
|
|
help="Path to the model description file (.xml)")
|
|
parser.add_argument('-w', '--weights', required=True,
|
|
help="Path to the model weights file (.bin)")
|
|
parser.add_argument('-i', '--interpreter', required=True,
|
|
help="Path to the network output interprter script (.py)")
|
|
parser.add_argument('--device', default='CPU',
|
|
help="Target device (default: %(default)s)")
|
|
return parser
|
|
|
|
@staticmethod
|
|
def copy_model(model_dir, model):
|
|
shutil.copy(model['description'],
|
|
osp.join(model_dir, osp.basename(model['description'])))
|
|
model['description'] = osp.basename(model['description'])
|
|
|
|
shutil.copy(model['weights'],
|
|
osp.join(model_dir, osp.basename(model['weights'])))
|
|
model['weights'] = osp.basename(model['weights'])
|
|
|
|
shutil.copy(model['interpreter'],
|
|
osp.join(model_dir, osp.basename(model['interpreter'])))
|
|
model['interpreter'] = osp.basename(model['interpreter'])
|
|
|
|
|
|
class InterpreterScript:
|
|
def __init__(self, path):
|
|
with open(path, 'r') as f:
|
|
script = f.read()
|
|
|
|
context = {}
|
|
exec(script, context, context)
|
|
|
|
process_outputs = context.get('process_outputs')
|
|
if not callable(process_outputs):
|
|
raise Exception("Can't find 'process_outputs' function in "
|
|
"the interpreter script")
|
|
self.__dict__['process_outputs'] = process_outputs
|
|
|
|
get_categories = context.get('get_categories')
|
|
assert get_categories is None or callable(get_categories)
|
|
if get_categories:
|
|
self.__dict__['get_categories'] = get_categories
|
|
|
|
@staticmethod
|
|
def get_categories():
|
|
return None
|
|
|
|
@staticmethod
|
|
def process_outputs(inputs, outputs):
|
|
raise NotImplementedError(
|
|
"Function should be implemented in the interpreter script")
|
|
|
|
|
|
class OpenVinoLauncher(Launcher):
|
|
cli_plugin = OpenVinoImporter
|
|
|
|
def __init__(self, description, weights, interpreter,
|
|
plugins_path=None, device=None, model_dir=None):
|
|
model_dir = model_dir or ''
|
|
if not osp.isfile(description):
|
|
description = osp.join(model_dir, description)
|
|
if not osp.isfile(description):
|
|
raise Exception('Failed to open model description file "%s"' % \
|
|
(description))
|
|
|
|
if not osp.isfile(weights):
|
|
weights = osp.join(model_dir, weights)
|
|
if not osp.isfile(weights):
|
|
raise Exception('Failed to open model weights file "%s"' % \
|
|
(weights))
|
|
|
|
if not osp.isfile(interpreter):
|
|
interpreter = osp.join(model_dir, interpreter)
|
|
if not osp.isfile(interpreter):
|
|
raise Exception('Failed to open model interpreter script file "%s"' % \
|
|
(interpreter))
|
|
|
|
self._interpreter = InterpreterScript(interpreter)
|
|
|
|
self._device = device or 'CPU'
|
|
|
|
self._ie = IECore()
|
|
if hasattr(self._ie, 'read_network'):
|
|
self._network = self._ie.read_network(description, weights)
|
|
else: # backward compatibility
|
|
from openvino.inference_engine import IENetwork
|
|
self._network = IENetwork.from_ir(description, weights)
|
|
self._check_model_support(self._network, self._device)
|
|
self._load_executable_net()
|
|
|
|
def _check_model_support(self, net, device):
|
|
supported_layers = set(self._ie.query_network(net, device))
|
|
not_supported_layers = set(net.layers) - supported_layers
|
|
if len(not_supported_layers) != 0:
|
|
log.error("The following layers are not supported " \
|
|
"by the plugin for device '%s': %s." % \
|
|
(device, ', '.join(not_supported_layers)))
|
|
raise NotImplementedError(
|
|
"Some layers are not supported on the device")
|
|
|
|
def _load_executable_net(self, batch_size=1):
|
|
network = self._network
|
|
|
|
iter_inputs = iter(network.inputs)
|
|
self._input_blob_name = next(iter_inputs)
|
|
self._output_blob_name = next(iter(network.outputs))
|
|
|
|
# NOTE: handling for the inclusion of `image_info` in OpenVino2019
|
|
self._require_image_info = 'image_info' in network.inputs
|
|
if self._input_blob_name == 'image_info':
|
|
self._input_blob_name = next(iter_inputs)
|
|
|
|
input_type = network.inputs[self._input_blob_name]
|
|
self._input_layout = input_type if isinstance(input_type, list) else input_type.shape
|
|
|
|
self._input_layout[0] = batch_size
|
|
network.reshape({self._input_blob_name: self._input_layout})
|
|
self._batch_size = batch_size
|
|
|
|
self._net = self._ie.load_network(network=network, num_requests=1,
|
|
device_name=self._device)
|
|
|
|
def infer(self, inputs):
|
|
assert len(inputs.shape) == 4, \
|
|
"Expected an input image in (N, H, W, C) format, got %s" % \
|
|
(inputs.shape)
|
|
assert inputs.shape[3] == 3, "Expected BGR input, got %s" % inputs.shape
|
|
|
|
n, c, h, w = self._input_layout
|
|
if inputs.shape[1:3] != (h, w):
|
|
resized_inputs = np.empty((n, h, w, c), dtype=inputs.dtype)
|
|
for inp, resized_input in zip(inputs, resized_inputs):
|
|
cv2.resize(inp, (w, h), resized_input)
|
|
inputs = resized_inputs
|
|
inputs = inputs.transpose((0, 3, 1, 2)) # NHWC to NCHW
|
|
inputs = {self._input_blob_name: inputs}
|
|
if self._require_image_info:
|
|
info = np.zeros([1, 3])
|
|
info[0, 0] = h
|
|
info[0, 1] = w
|
|
info[0, 2] = 1.0 # scale
|
|
inputs['image_info'] = info
|
|
|
|
results = self._net.infer(inputs)
|
|
if len(results) == 1:
|
|
return results[self._output_blob_name]
|
|
else:
|
|
return results
|
|
|
|
def launch(self, inputs):
|
|
batch_size = len(inputs)
|
|
if self._batch_size < batch_size:
|
|
self._load_executable_net(batch_size)
|
|
|
|
outputs = self.infer(inputs)
|
|
results = self.process_outputs(inputs, outputs)
|
|
return results
|
|
|
|
def categories(self):
|
|
return self._interpreter.get_categories()
|
|
|
|
def process_outputs(self, inputs, outputs):
|
|
return self._interpreter.process_outputs(inputs, outputs)
|
|
|