You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
224 lines
7.4 KiB
Python
224 lines
7.4 KiB
Python
#!/usr/bin/env python
|
|
#
|
|
# SPDX-License-Identifier: MIT
|
|
# coding: utf-8
|
|
# -*- coding: utf-8 -*-
|
|
"""
|
|
Given a CVAT XML and a directory with the image dataset, this script reads the
|
|
CVAT XML and writes the annotations in tfrecords into a given
|
|
directory.
|
|
|
|
This implementation supports annotated images only.
|
|
"""
|
|
from __future__ import unicode_literals
|
|
import xml.etree.ElementTree as ET
|
|
import tensorflow as tf
|
|
from object_detection.utils import dataset_util
|
|
from collections import Counter
|
|
import codecs
|
|
import hashlib
|
|
from pathlib import Path
|
|
import argparse
|
|
import os
|
|
import string
|
|
|
|
# we need it to filter out non-ASCII characters otherwise
|
|
# trainning will crash
|
|
printable = set(string.printable)
|
|
|
|
def parse_args():
|
|
"""Parse arguments of command line"""
|
|
parser = argparse.ArgumentParser(
|
|
description='Convert CVAT XML annotations to tfrecords format'
|
|
)
|
|
|
|
parser.add_argument(
|
|
'--cvat-xml', metavar='FILE', required=True,
|
|
help='input file with CVAT annotation in xml format'
|
|
)
|
|
|
|
parser.add_argument(
|
|
'--image-dir', metavar='DIRECTORY', required=True,
|
|
help='directory which contains original images'
|
|
)
|
|
|
|
parser.add_argument(
|
|
'--output-dir', metavar='DIRECTORY', required=True,
|
|
help='directory for output annotations in tfrecords format'
|
|
)
|
|
|
|
parser.add_argument(
|
|
'--train-percentage', metavar='PERCENTAGE', required=False, default=90, type=int,
|
|
help='the percentage of training data to total data (default: 90)'
|
|
)
|
|
|
|
parser.add_argument(
|
|
'--min-train', metavar='NUM', required=False, default=10, type=int,
|
|
help='The minimum number of images above which the label is considered (default: 10)'
|
|
)
|
|
|
|
parser.add_argument(
|
|
'--attribute', metavar='NAME', required=False, default="",
|
|
type=str,
|
|
help='The attribute name based on which the object can identified'
|
|
)
|
|
|
|
return parser.parse_args()
|
|
|
|
def process_cvat_xml(args):
|
|
"""Transforms a single XML in CVAT format to tfrecords.
|
|
"""
|
|
|
|
train_percentage = int(args.train_percentage)
|
|
assert (train_percentage<=100 and train_percentage>=0)
|
|
|
|
cvat_xml = ET.parse(args.cvat_xml).getroot()
|
|
|
|
output_dir = Path(args.output_dir)
|
|
if not output_dir.exists():
|
|
print("Creating the output directory because it doesn't exist")
|
|
output_dir.mkdir()
|
|
|
|
cvat_name, output_dir, min_train = \
|
|
args.attribute, output_dir.absolute(), args.min_train
|
|
|
|
# Open the tfrecord files for writing
|
|
writer_train = tf.python_io.TFRecordWriter(
|
|
os.path.join(output_dir.absolute(), 'train.tfrecord'))
|
|
writer_eval = tf.python_io.TFRecordWriter(
|
|
os.path.join(output_dir.absolute(), 'eval.tfrecord'))
|
|
|
|
# extract the object names
|
|
object_names = []
|
|
num_imgs = 0
|
|
for img in cvat_xml.findall('image'):
|
|
num_imgs += 1
|
|
for box in img:
|
|
if cvat_name == "" :
|
|
obj_name = ''.join(filter(lambda x: x in printable,
|
|
box.attrib['label']))
|
|
object_names.append(obj_name)
|
|
else :
|
|
for attribute in box :
|
|
if attribute.attrib['name'] == cvat_name :
|
|
obj_name = ''.join(filter(lambda x: x in printable,
|
|
attribute.text.lower()))
|
|
object_names.append(obj_name)
|
|
|
|
labels, values = zip(*Counter(object_names).items())
|
|
|
|
# Create the label map file
|
|
saved_dict = dict()
|
|
reverse_dict = dict()
|
|
with codecs.open(os.path.join(output_dir,'label_map.pbtxt'),
|
|
'w', encoding='utf8') as f:
|
|
counter = 1
|
|
for iii, label in enumerate(labels):
|
|
if values[iii] < min_train :
|
|
continue
|
|
saved_dict[label] = counter
|
|
reverse_dict[counter] = label
|
|
f.write(u'item {\n')
|
|
f.write(u'\tid: {}\n'.format(counter))
|
|
f.write(u"\tname: '{}'\n".format(label))
|
|
f.write(u'}\n\n')
|
|
counter+=1
|
|
|
|
num_iter = num_imgs
|
|
eval_num = num_iter * (100 - train_percentage) // 100
|
|
train_num = num_iter - eval_num
|
|
|
|
|
|
for counter,example in enumerate(cvat_xml.findall('image')):
|
|
tf_example = create_tf_example(example, args.attribute, saved_dict, args.image_dir)
|
|
if tf_example is None:
|
|
continue
|
|
if(counter < train_num):
|
|
writer_train.write(tf_example.SerializeToString())
|
|
else :
|
|
writer_eval.write(tf_example.SerializeToString())
|
|
|
|
writer_train.close()
|
|
writer_eval.close()
|
|
|
|
|
|
return saved_dict, num_imgs
|
|
|
|
|
|
# Defining the main conversion function
|
|
def create_tf_example(example, cvat_name, saved_dict, img_dir):
|
|
# Process one image data per run
|
|
height = int(example.attrib['height']) # Image height
|
|
width = int(example.attrib['width']) # Image width
|
|
filename = os.path.join(img_dir, example.attrib['name'])
|
|
_, ext = os.path.splitext(example.attrib['name'])
|
|
|
|
filename = filename.encode('utf8')
|
|
with tf.gfile.GFile(filename,'rb') as fid:
|
|
encoded_jpg = fid.read()
|
|
|
|
key = hashlib.sha256(encoded_jpg).hexdigest()
|
|
|
|
if ext.lower() in ['.jpg','.jpeg'] :
|
|
image_format = 'jpeg'.encode('utf8')
|
|
elif ext.lower() == '.png' :
|
|
image_format = 'png'.encode('utf8')
|
|
else:
|
|
print('File Format not supported, Skipping')
|
|
return None
|
|
|
|
xmins = [] # List of normalized left x coordinates in bounding box (1 per box)
|
|
xmaxs = [] # List of normalized right x coordinates in bounding box
|
|
# (1 per box)
|
|
ymins = [] # List of normalized top y coordinates in bounding box (1 per box)
|
|
ymaxs = [] # List of normalized bottom y coordinates in bounding box
|
|
# (1 per box)
|
|
classes_text = [] # List of string class name of bounding box (1 per box)
|
|
classes = [] # List of integer class id of bounding box (1 per box)
|
|
|
|
# Loop oer the boxes and fill the above fields
|
|
for box in example:
|
|
box_name = ''
|
|
if cvat_name == "" :
|
|
box_name = box.attrib['label']
|
|
else :
|
|
for attr in box:
|
|
if attr.attrib['name'] == cvat_name:
|
|
box_name = attr.text.lower()
|
|
|
|
# filter out non-ASCII characters
|
|
box_name = ''.join(filter(lambda x: x in printable, box_name))
|
|
|
|
if box_name in saved_dict.keys():
|
|
xmins.append(float(box.attrib['xtl']) / width)
|
|
xmaxs.append(float(box.attrib['xbr']) / width)
|
|
ymins.append(float(box.attrib['ytl']) / height)
|
|
ymaxs.append(float(box.attrib['ybr']) / height)
|
|
classes_text.append(box_name.encode('utf8'))
|
|
classes.append(saved_dict[box_name])
|
|
|
|
tf_example = tf.train.Example(features=tf.train.Features(feature={
|
|
'image/height': dataset_util.int64_feature(height),
|
|
'image/width': dataset_util.int64_feature(width),
|
|
'image/filename': dataset_util.bytes_feature(filename),
|
|
'image/source_id': dataset_util.bytes_feature(filename),
|
|
'image/key/sha256': dataset_util.bytes_feature(key.encode('utf8')),
|
|
'image/encoded': dataset_util.bytes_feature(encoded_jpg),
|
|
'image/format': dataset_util.bytes_feature(image_format),
|
|
'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),
|
|
'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),
|
|
'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),
|
|
'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),
|
|
'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
|
|
'image/object/class/label': dataset_util.int64_list_feature(classes),
|
|
}))
|
|
return tf_example
|
|
|
|
def main():
|
|
args = parse_args()
|
|
process_cvat_xml(args)
|
|
|
|
if __name__== '__main__' :
|
|
main()
|
|
|