You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
65 lines
1.9 KiB
Python
65 lines
1.9 KiB
Python
|
|
# Copyright (C) 2019 Intel Corporation
|
|
#
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
import numpy as np
|
|
|
|
from datumaro.components.extractor import Transform
|
|
|
|
|
|
# pylint: disable=no-self-use
|
|
class Launcher:
|
|
def __init__(self, model_dir=None):
|
|
pass
|
|
|
|
def launch(self, inputs):
|
|
raise NotImplementedError()
|
|
|
|
def preferred_input_size(self):
|
|
return None
|
|
|
|
def get_categories(self):
|
|
return None
|
|
# pylint: enable=no-self-use
|
|
|
|
class InferenceWrapper(Transform):
|
|
def __init__(self, extractor, launcher, batch_size=1):
|
|
super().__init__(extractor)
|
|
self._launcher = launcher
|
|
self._batch_size = batch_size
|
|
|
|
def __iter__(self):
|
|
stop = False
|
|
data_iter = iter(self._extractor)
|
|
while not stop:
|
|
batch_items = []
|
|
try:
|
|
for _ in range(self._batch_size):
|
|
item = next(data_iter)
|
|
batch_items.append(item)
|
|
except StopIteration:
|
|
stop = True
|
|
if len(batch_items) == 0:
|
|
break
|
|
|
|
inputs = np.array([item.image.data for item in batch_items])
|
|
inference = self._launcher.launch(inputs)
|
|
|
|
for item, annotations in zip(batch_items, inference):
|
|
yield self.wrap_item(item, annotations=annotations)
|
|
|
|
def get_subset(self, name):
|
|
subset = self._extractor.get_subset(name)
|
|
return InferenceWrapper(subset, self._launcher, self._batch_size)
|
|
|
|
def categories(self):
|
|
launcher_override = self._launcher.get_categories()
|
|
if launcher_override is not None:
|
|
return launcher_override
|
|
return self._extractor.categories()
|
|
|
|
def transform_item(self, item):
|
|
inputs = np.expand_dims(item.image, axis=0)
|
|
annotations = self._launcher.launch(inputs)[0]
|
|
return self.wrap_item(item, annotations=annotations) |