|
|
|
|
@ -8,25 +8,22 @@ import cv2
|
|
|
|
|
import os
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
|
from cvat.apps.auto_annotation.inference_engine import make_plugin, make_network
|
|
|
|
|
from cvat.apps.auto_annotation.inference_engine import make_plugin_or_core, make_network
|
|
|
|
|
|
|
|
|
|
class ModelLoader():
|
|
|
|
|
def __init__(self, model, weights):
|
|
|
|
|
self._model = model
|
|
|
|
|
self._weights = weights
|
|
|
|
|
|
|
|
|
|
IE_PLUGINS_PATH = os.getenv("IE_PLUGINS_PATH")
|
|
|
|
|
if not IE_PLUGINS_PATH:
|
|
|
|
|
raise OSError("Inference engine plugin path env not found in the system.")
|
|
|
|
|
|
|
|
|
|
plugin = make_plugin()
|
|
|
|
|
core_or_plugin = make_plugin_or_core()
|
|
|
|
|
network = make_network(self._model, self._weights)
|
|
|
|
|
|
|
|
|
|
supported_layers = plugin.get_supported_layers(network)
|
|
|
|
|
not_supported_layers = [l for l in network.layers.keys() if l not in supported_layers]
|
|
|
|
|
if len(not_supported_layers) != 0:
|
|
|
|
|
raise Exception("Following layers are not supported by the plugin for specified device {}:\n {}".
|
|
|
|
|
format(plugin.device, ", ".join(not_supported_layers)))
|
|
|
|
|
if getattr(core_or_plugin, 'get_supported_layers', False):
|
|
|
|
|
supported_layers = core_or_plugin.get_supported_layers(network)
|
|
|
|
|
not_supported_layers = [l for l in network.layers.keys() if l not in supported_layers]
|
|
|
|
|
if len(not_supported_layers) != 0:
|
|
|
|
|
raise Exception("Following layers are not supported by the plugin for specified device {}:\n {}".
|
|
|
|
|
format(core_or_plugin.device, ", ".join(not_supported_layers)))
|
|
|
|
|
|
|
|
|
|
iter_inputs = iter(network.inputs)
|
|
|
|
|
self._input_blob_name = next(iter_inputs)
|
|
|
|
|
@ -45,7 +42,12 @@ class ModelLoader():
|
|
|
|
|
if self._input_blob_name in info_names:
|
|
|
|
|
self._input_blob_name = next(iter_inputs)
|
|
|
|
|
|
|
|
|
|
self._net = plugin.load(network=network, num_requests=2)
|
|
|
|
|
if getattr(core_or_plugin, 'load_network', False):
|
|
|
|
|
self._net = core_or_plugin.load_network(network,
|
|
|
|
|
"CPU",
|
|
|
|
|
num_requests=2)
|
|
|
|
|
else:
|
|
|
|
|
self._net = core_or_plugin.load(network=network, num_requests=2)
|
|
|
|
|
input_type = network.inputs[self._input_blob_name]
|
|
|
|
|
self._input_layout = input_type if isinstance(input_type, list) else input_type.shape
|
|
|
|
|
|
|
|
|
|
|