You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

169 lines
6.1 KiB
Python

import os
import sys
import json
import argparse
import random
import logging
import numpy as np
import cv2
work_dir = os.path.dirname(os.path.abspath(__file__))
cvat_dir = os.path.join(work_dir, '..', '..')
sys.path.insert(0, cvat_dir)
from cvat.apps.auto_annotation.inference import run_inference_engine_annotation
def _get_kwargs():
parser = argparse.ArgumentParser()
parser.add_argument('--py', required=True, help='Path to the python interpt file')
parser.add_argument('--xml', required=True, help='Path to the xml file')
parser.add_argument('--bin', required=True, help='Path to the bin file')
parser.add_argument('--json', required=True, help='Path to the JSON mapping file')
parser.add_argument('--restricted', dest='restricted', action='store_true')
parser.add_argument('--unrestricted', dest='restricted', action='store_false')
parser.add_argument('--image-files', nargs='*', help='Paths to image files you want to test')
parser.add_argument('--show-images', action='store_true', help='Show the results of the annotation in a window')
parser.add_argument('--show-image-delay', default=0, type=int, help='Displays the images for a set duration in milliseconds, default is until a key is pressed')
parser.add_argument('--serialize', default=False, action='store_true', help='Try to serialize the result')
return vars(parser.parse_args())
def random_color():
rgbl=[255,0,0]
random.shuffle(rgbl)
return tuple(rgbl)
def pairwise(iterable):
result = []
for i in range(0, len(iterable) - 1, 2):
result.append((iterable[i], iterable[i+1]))
return np.array(result, dtype=np.int32)
def main():
kwargs = _get_kwargs()
py_file = kwargs['py']
bin_file = kwargs['bin']
mapping_file = kwargs['json']
xml_file = kwargs['xml']
if not os.path.isfile(py_file):
logging.critical('Py file not found! Check the path')
return
if not os.path.isfile(bin_file):
logging.critical('Bin file is not found! Check path!')
return
if not os.path.isfile(xml_file):
logging.critical('XML File not found! Check path!')
return
if not os.path.isfile(mapping_file):
logging.critical('JSON file is not found! Check path!')
return
with open(mapping_file) as json_file:
try:
mapping = json.load(json_file)
except json.decoder.JSONDecodeError:
logging.critical('JSON file not able to be parsed! Check file')
return
try:
mapping = mapping['label_map']
except KeyError:
logging.critical("JSON Mapping file must contain key `label_map`!")
logging.critical("Exiting")
return
mapping = {int(k): v for k, v in mapping.items()}
restricted = kwargs['restricted']
image_files = kwargs.get('image_files')
if image_files:
image_data = [cv2.imread(f) for f in image_files]
else:
test_image = np.ones((1024, 1980, 3), np.uint8) * 255
image_data = [test_image,]
attribute_spec = {}
results = run_inference_engine_annotation(image_data,
xml_file,
bin_file,
mapping,
attribute_spec,
py_file,
restricted=restricted)
if kwargs['serialize']:
os.environ['DJANGO_SETTINGS_MODULE'] = 'cvat.settings.production'
import django
django.setup()
from cvat.apps.engine.serializers import LabeledDataSerializer
# NOTE: We're actually using `run_inference_engine_annotation`
# incorrectly here. The `mapping` dict is supposed to be a mapping
# of integers -> integers and represents the transition from model
# integers to the labels in the database. We're using a mapping of
# integers -> strings. For testing purposes, this shortcut is fine.
# We just want to make sure everything works. Until, that is....
# we want to test using the label serializer. Then we have to transition
# back to integers, otherwise the serializer complains about have a string
# where an integer is expected. We'll just brute force that.
for shape in results['shapes']:
# Change the english label to an integer for serialization validation
shape['label_id'] = 1
serializer = LabeledDataSerializer(data=results)
if not serializer.is_valid():
logging.critical('Data unable to be serialized correctly!')
serializer.is_valid(raise_exception=True)
logging.warning('Program didn\'t have any errors.')
show_images = kwargs.get('show_images', False)
if show_images:
if image_files is None:
logging.critical("Warning, no images provided!")
logging.critical('Exiting without presenting results')
return
if not results['shapes']:
logging.warning(str(results))
logging.critical("No objects detected!")
return
show_image_delay = kwargs['show_image_delay']
for index, data in enumerate(image_data):
for detection in results['shapes']:
if not detection['frame'] == index:
continue
points = detection['points']
# Cv2 doesn't like floats for drawing
points = [int(p) for p in points]
color = random_color()
if detection['type'] == 'rectangle':
cv2.rectangle(data, (points[0], points[1]), (points[2], points[3]), color, 3)
elif detection['type'] in ('polygon', 'polyline'):
# polylines is picky about datatypes
points = pairwise(points)
cv2.polylines(data, [points], 1, color)
cv2.imshow(str(index), data)
cv2.waitKey(show_image_delay)
cv2.destroyWindow(str(index))
if __name__ == '__main__':
main()