You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
118 lines
3.6 KiB
Python
118 lines
3.6 KiB
Python
#!/usr/bin/env python
|
|
#
|
|
# Copyright (C) 2018 Intel Corporation
|
|
#
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
from __future__ import absolute_import, division, print_function
|
|
|
|
import argparse
|
|
import os
|
|
import glog as log
|
|
import numpy as np
|
|
import cv2
|
|
from lxml import etree
|
|
from tqdm import tqdm
|
|
|
|
|
|
def parse_args():
|
|
"""Parse arguments of command line"""
|
|
parser = argparse.ArgumentParser(
|
|
fromfile_prefix_chars='@',
|
|
description='Convert CVAT XML annotations to masks'
|
|
)
|
|
|
|
parser.add_argument(
|
|
'--cvat-xml', metavar='FILE', required=True,
|
|
help='input file with CVAT annotation in xml format'
|
|
)
|
|
|
|
parser.add_argument(
|
|
'--background-color', metavar='COLOR_BGR', default="0,0,0",
|
|
help='specify background color (by default: 0,0,0)'
|
|
)
|
|
|
|
parser.add_argument(
|
|
'--label-color', metavar='LABEL:COLOR_BGR', action='append',
|
|
default=[],
|
|
help="specify a label's color (e.g. 255 or 255,0,0). The color will " +
|
|
"be interpreted in accordance with the mask format."
|
|
)
|
|
|
|
parser.add_argument(
|
|
'--mask-bitness', type=int, choices=[8, 24], default=8,
|
|
help='choose bitness for masks'
|
|
)
|
|
|
|
parser.add_argument(
|
|
'--output-dir', metavar='DIRECTORY', required=True,
|
|
help='directory for output masks'
|
|
)
|
|
|
|
return parser.parse_args()
|
|
|
|
def parse_anno_file(cvat_xml):
|
|
root = etree.parse(cvat_xml).getroot()
|
|
anno = []
|
|
for image_tag in root.iter('image'):
|
|
image = {}
|
|
for key, value in image_tag.items():
|
|
image[key] = value
|
|
image['shapes'] = []
|
|
for poly_tag in image_tag.iter('polygon'):
|
|
polygon = {'type': 'polygon'}
|
|
for key, value in poly_tag.items():
|
|
polygon[key] = value
|
|
image['shapes'].append(polygon)
|
|
for box_tag in image_tag.iter('box'):
|
|
box = {'type': 'box'}
|
|
for key, value in box_tag.items():
|
|
box[key] = value
|
|
box['points'] = "{0},{1};{2},{1};{2},{3};{0},{3}".format(
|
|
box['xtl'], box['ytl'], box['xbr'], box['ybr'])
|
|
image['shapes'].append(box)
|
|
|
|
image['shapes'].sort(key=lambda x: int(x.get('z_order', 0)))
|
|
anno.append(image)
|
|
|
|
return anno
|
|
|
|
def create_mask_file(mask_path, width, height, bitness, color_map, background, shapes):
|
|
mask = np.zeros((height, width, bitness // 8), dtype=np.uint8)
|
|
for shape in shapes:
|
|
color = color_map.get(shape['label'], background)
|
|
points = [tuple(map(float, p.split(','))) for p in shape['points'].split(';')]
|
|
points = np.array([(int(p[0]), int(p[1])) for p in points])
|
|
|
|
mask = cv2.fillPoly(mask, [points], color=color)
|
|
cv2.imwrite(mask_path, mask)
|
|
|
|
def to_scalar(str, dim):
|
|
scalar = list(map(int, str.split(',')))
|
|
if len(scalar) < dim:
|
|
scalar.extend([scalar[-1]] * dim)
|
|
return tuple(scalar[0:dim])
|
|
|
|
def main():
|
|
args = parse_args()
|
|
anno = parse_anno_file(args.cvat_xml)
|
|
|
|
color_map = {}
|
|
dim = args.mask_bitness // 8
|
|
for item in args.label_color:
|
|
label, color = item.split(':')
|
|
color_map[label] = to_scalar(color, dim)
|
|
background = to_scalar(args.background_color, dim)
|
|
|
|
for image in tqdm(anno, desc='Generate masks'):
|
|
mask_path = os.path.join(args.output_dir, os.path.splitext(image['name'])[0] + '.png')
|
|
mask_dir = os.path.dirname(mask_path)
|
|
if mask_dir:
|
|
os.makedirs(mask_dir, exist_ok=True)
|
|
create_mask_file(mask_path, int(image['width']), int(image['height']),
|
|
args.mask_bitness, color_map, background, image['shapes'])
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|