You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

460 lines
17 KiB
Python

import csv
import os
import re
import sys
import shlex
import logging
import shutil
import tempfile
from io import StringIO
from PIL import Image
from traceback import print_exception
import mimetypes
_SCRIPT_DIR = os.path.realpath(os.path.dirname(__file__))
_MEDIA_MIMETYPES_FILE = os.path.join(_SCRIPT_DIR, "media.mimetypes")
mimetypes.init(files=[_MEDIA_MIMETYPES_FILE])
import django_rq
from django.conf import settings
from django.db import transaction
from ffmpy import FFmpeg
from pyunpack import Archive
from distutils.dir_util import copy_tree
from . import models
from .logging import task_logger, job_logger
global_logger = logging.getLogger(__name__)
############################# Low Level server API
@transaction.atomic
def create_empty(params):
"""Create empty directory structure for a new task, add it to DB."""
db_task = models.Task()
db_task.name = params['task_name']
db_task.bug_tracker = params['bug_tracker_link']
db_task.path = ""
db_task.size = 0
db_task.owner = params['owner']
db_task.save()
task_path = os.path.join(settings.DATA_ROOT, str(db_task.id))
db_task.set_task_dirname(task_path)
task_path = db_task.get_task_dirname()
if os.path.isdir(task_path):
shutil.rmtree(task_path)
os.mkdir(task_path)
upload_dir = db_task.get_upload_dirname()
os.makedirs(upload_dir)
output_dir = db_task.get_data_dirname()
os.makedirs(output_dir)
return db_task
def create(tid, params):
"""Schedule the task"""
q = django_rq.get_queue('default')
q.enqueue_call(func=_create_thread, args=(tid, params),
job_id="task.create/{}".format(tid))
def check(tid):
"""Check status of the scheduled task"""
response = {}
queue = django_rq.get_queue('default')
job = queue.fetch_job("task.create/{}".format(tid))
if job is None:
response = {"state": "unknown"}
elif job.is_failed:
response = {"state": "error", "stderr": "Could not create the task. " + job.exc_info }
elif job.is_finished:
response = {"state": "created"}
else:
response = {"state": "started"}
return response
@transaction.atomic
def delete(tid):
"""Delete the task"""
db_task = models.Task.objects.select_for_update().get(pk=tid)
if db_task:
db_task.delete()
shutil.rmtree(db_task.get_task_dirname(), ignore_errors=True)
else:
raise Exception("The task doesn't exist")
@transaction.atomic
def update(tid, labels):
"""Update labels for the task"""
db_task = models.Task.objects.select_for_update().get(pk=tid)
db_labels = list(db_task.label_set.prefetch_related('attributespec_set').all())
new_labels = _parse_labels(labels)
old_labels = _parse_db_labels(db_labels)
for label_name in new_labels:
if label_name in old_labels:
db_label = [l for l in db_labels if l.name == label_name][0]
for attr_name in new_labels[label_name]:
if attr_name in old_labels[label_name]:
db_attr = [attr for attr in db_label.attributespec_set.all()
if attr.get_name() == attr_name][0]
new_attr = new_labels[label_name][attr_name]
old_attr = old_labels[label_name][attr_name]
if new_attr['prefix'] != old_attr['prefix']:
raise Exception("new_attr['prefix'] != old_attr['prefix']")
if new_attr['type'] != old_attr['type']:
raise Exception("new_attr['type'] != old_attr['type']")
if set(old_attr['values']) - set(new_attr['values']):
raise Exception("set(old_attr['values']) - set(new_attr['values'])")
db_attr.text = "{}{}={}:{}".format(new_attr['prefix'],
new_attr['type'], attr_name, ",".join(new_attr['values']))
db_attr.save()
else:
db_attr = models.AttributeSpec()
attr = new_labels[label_name][attr_name]
db_attr.text = "{}{}={}:{}".format(attr['prefix'],
attr['type'], attr_name, ",".join(attr['values']))
db_attr.label = db_label
db_attr.save()
else:
db_label = models.Label()
db_label.name = label_name
db_label.task = db_task
db_label.save()
for attr_name in new_labels[label_name]:
db_attr = models.AttributeSpec()
attr = new_labels[label_name][attr_name]
db_attr.text = "{}{}={}:{}".format(attr['prefix'],
attr['type'], attr_name, ",".join(attr['values']))
db_attr.label = db_label
db_attr.save()
def get_frame_path(tid, frame):
"""Read corresponding frame for the task"""
db_task = models.Task.objects.get(pk=tid)
path = _get_frame_path(frame, db_task.get_data_dirname())
return path
def get(tid):
"""Get the task as dictionary of attributes"""
db_task = models.Task.objects.get(pk=tid)
if db_task:
db_labels = db_task.label_set.prefetch_related('attributespec_set').all()
attributes = {}
for db_label in db_labels:
attributes[db_label.id] = {}
for db_attrspec in db_label.attributespec_set.all():
attributes[db_label.id][db_attrspec.id] = db_attrspec.text
db_segments = list(db_task.segment_set.prefetch_related('job_set').all())
segment_length = max(db_segments[0].stop_frame - db_segments[0].start_frame + 1, 1)
job_indexes = [segment.job_set.first().id for segment in db_segments]
response = {
"status": db_task.status.capitalize(),
"spec": {
"labels": { db_label.id:db_label.name for db_label in db_labels },
"attributes": attributes
},
"size": db_task.size,
"blowradius": 0,
"taskid": db_task.id,
"name": db_task.name,
"mode": db_task.mode,
"segment_length": segment_length,
"jobs": job_indexes,
"overlap": db_task.overlap
}
else:
raise Exception("Cannot find the task: {}".format(tid))
return response
def get_job(jid):
"""Get the job as dictionary of attributes"""
db_job = models.Job.objects.select_related("segment__task").get(id=jid)
if db_job:
db_segment = db_job.segment
db_task = db_segment.task
db_labels = db_task.label_set.prefetch_related('attributespec_set').all()
attributes = {}
for db_label in db_labels:
attributes[db_label.id] = {}
for db_attrspec in db_label.attributespec_set.all():
attributes[db_label.id][db_attrspec.id] = db_attrspec.text
response = {
"status": db_task.status.capitalize(),
"labels": { db_label.id:db_label.name for db_label in db_labels },
"stop": db_segment.stop_frame,
"blowradius": 0,
"taskid": db_task.id,
"slug": db_task.name,
"jobid": jid,
"start": db_segment.start_frame,
"mode": db_task.mode,
"overlap": db_task.overlap,
"attributes": attributes,
}
else:
raise Exception("Cannot find the job: {}".format(jid))
return response
def is_task_owner(user, tid):
try:
return user == models.Task.objects.get(pk=tid).owner or \
user.groups.filter(name='admin').exists()
except:
return False
@transaction.atomic
def rq_handler(job, exc_type, exc_value, traceback):
tid = job.id.split('/')[1]
db_task = models.Task.objects.select_for_update().get(pk=tid)
with open(db_task.get_log_path(), "wt") as log_file:
print_exception(exc_type, exc_value, traceback, file=log_file)
db_task.delete()
return False
############################# Internal implementation for server API
class _FrameExtractor:
def __init__(self, source_path, compress_quality, flip_flag=False):
# translate inversed range 1:95 to 2:32
translated_quality = 96 - compress_quality
translated_quality = round((((translated_quality - 1) * (31 - 2)) / (95 - 1)) + 2)
self.output = tempfile.mkdtemp(prefix='cvat-', suffix='.data')
target_path = os.path.join(self.output, '%d.jpg')
output_opts = '-start_number 0 -b:v 10000k -vsync 0 -an -y -q:v ' + str(translated_quality)
if flip_flag:
output_opts += ' -vf "transpose=2,transpose=2"'
ff = FFmpeg(
inputs = {source_path: None},
outputs = {target_path: output_opts})
ff.run()
def getframepath(self, k):
return "{0}/{1}.jpg".format(self.output, k)
def __del__(self):
if self.output:
shutil.rmtree(self.output)
def __getitem__(self, k):
return self.getframepath(k)
def __iter__(self):
i = 0
while os.path.exists(self.getframepath(i)):
yield self[i]
i += 1
def _get_mime(name):
mime = mimetypes.guess_type(name)
mime_type = mime[0]
encoding = mime[1]
# zip, rar, tar, tar.gz, tar.bz2, 7z, cpio
supportedArchives = ['application/zip', 'application/x-rar-compressed',
'application/x-tar', 'application/x-7z-compressed', 'application/x-cpio',
'gzip', 'bzip2']
if mime_type is not None:
if mime_type.startswith('video'):
return 'video'
elif mime_type in supportedArchives or encoding in supportedArchives:
return 'archive'
elif mime_type.startswith('image'):
return 'image'
else:
return 'empty'
else:
if os.path.isdir(name):
return 'directory'
else:
return 'empty'
def _get_frame_path(frame, base_dir):
d1 = str(frame // 10000)
d2 = str(frame // 100)
path = os.path.join(d1, d2, str(frame) + '.jpg')
if base_dir:
path = os.path.join(base_dir, path)
return path
def _parse_labels(labels):
parsed_labels = {}
last_label = ""
for token in shlex.split(labels):
if token[0] != "~" and token[0] != "@":
parsed_labels[token] = {}
last_label = token
else:
match = re.match(r'^([~@])(\w+)=(\w+):(.+)$', token)
prefix = match.group(1)
atype = match.group(2)
aname = match.group(3)
values = list(csv.reader(StringIO(match.group(4)), quotechar="'"))[0]
parsed_labels[last_label][aname] = {'prefix':prefix, 'type':atype, 'values':values}
return parsed_labels
def _parse_db_labels(db_labels):
result = []
for db_label in db_labels:
result += [db_label.name]
result += [attr.text for attr in db_label.attributespec_set.all()]
return _parse_labels(" ".join(result))
@transaction.atomic
def _create_thread(tid, params):
# TODO: Improve a function logic. Need filter paths from a share storage before their copy to the server
db_task = db_task = models.Task.objects.select_for_update().get(pk=tid)
upload_dir = db_task.get_upload_dirname()
output_dir = db_task.get_data_dirname()
with open(db_task.get_log_path(), 'w') as log_file:
storage = params['storage']
mode = 'annotation'
for source_path, target_path in zip(params['SOURCE_PATHS'], params['TARGET_PATHS']):
filepath = target_path if storage == 'local' else source_path
mime = _get_mime(filepath)
if mime == 'empty':
continue
if mime == 'video':
mode = 'interpolation'
else:
mode = 'annotation'
if len(params['TARGET_PATHS']) > 1 and (mime == 'video' or mime == 'archive'):
for tmp_path in params['SOURCE_PATHS']:
if tmp_path not in filepath:
raise Exception('Only images can be loaded in plural quantity. {} was found'.format(mime.capitalize()))
if storage == 'share' and not os.path.exists(target_path):
if mime == 'directory':
copy_tree(source_path, os.path.join(upload_dir, os.path.basename(source_path)))
else:
dirname = os.path.dirname(target_path)
if not os.path.exists(dirname):
os.makedirs(dirname)
shutil.copyfile(source_path, target_path)
if mime == 'archive':
Archive(target_path).extractall(upload_dir)
os.remove(target_path)
flip_flag = params['flip_flag'].lower() == 'true'
compress_quality = int(params.get('compress_quality', 50))
if mode == 'interpolation':
# Last element in params['TARGET_PATHS'] must contain video due to a sort by path len above
# Early elements (if exist) contain parent dirs for video
extractor = _FrameExtractor(params['TARGET_PATHS'][-1], compress_quality, flip_flag)
for frame, image_orig_path in enumerate(extractor):
image_dest_path = _get_frame_path(frame, output_dir)
db_task.size += 1
dirname = os.path.dirname(image_dest_path)
if not os.path.exists(dirname):
os.makedirs(dirname)
shutil.copyfile(image_orig_path, image_dest_path)
else:
extensions = ['.jpg', '.png', '.bmp', '.jpeg']
filenames = []
for root, _, files in os.walk(upload_dir):
fullnames = map(lambda f: os.path.join(root, f), files)
filtnames = filter(lambda f: os.path.splitext(f)[1].lower() \
in extensions, fullnames)
filenames.extend(filtnames)
filenames.sort()
# Compress input images
compressed_names = []
for name in filenames:
compressed_name = os.path.splitext(name)[0] + '.jpg'
image = Image.open(name)
image = image.convert('RGB')
image.save(compressed_name, quality=compress_quality, optimize=True)
compressed_names.append(compressed_name)
if compressed_name != name:
os.remove(name)
filenames = compressed_names
if not filenames:
raise Exception("No files ending with {}".format(extensions))
for frame, image_orig_path in enumerate(filenames):
image_dest_path = _get_frame_path(frame, output_dir)
image_orig_path = os.path.abspath(image_orig_path)
if flip_flag:
image = Image.open(image_orig_path)
image = image.transpose(Image.ROTATE_180)
image.save(image_orig_path)
db_task.size += 1
dirname = os.path.dirname(image_dest_path)
if not os.path.exists(dirname):
os.makedirs(dirname)
os.symlink(image_orig_path, image_dest_path)
log_file.write("Formatted {0} images\n".format(len(filenames)))
default_segment_length = sys.maxsize # greather then any task size. Default split by segments disabled.
segment_length = int(params.get('segment_size', default_segment_length))
global_logger.info("segment length for task #{} is {}".format(db_task.id, segment_length))
if mode == 'interpolation':
default_overlap = 5
else:
default_overlap = 0
overlap = min(int(params.get('overlap_size', default_overlap)), segment_length - 1)
db_task.overlap = min(db_task.size, overlap)
global_logger.info("segment overlap for task #{} is {}".format(db_task.id, db_task.overlap))
segment_step = segment_length - db_task.overlap
for x in range(0, db_task.size, segment_step):
start_frame = x
stop_frame = min(x + segment_length - 1, db_task.size - 1)
global_logger.info("new segment for task #{}: start_frame = {}, stop_frame = {}".format(db_task.id, start_frame, stop_frame))
db_segment = models.Segment()
db_segment.task = db_task
db_segment.start_frame = start_frame
db_segment.stop_frame = stop_frame
db_segment.save()
db_job = models.Job()
db_job.segment = db_segment
db_job.save()
labels = params['labels']
global_logger.info("labels with attributes for task #{} is {}".format(db_task.id, labels))
db_label = None
for token in shlex.split(labels):
if token[0] != "~" and token[0] != "@":
db_label = models.Label()
db_label.task = db_task
db_label.name = token
db_label.save()
elif db_label != None:
db_attrspec = models.AttributeSpec()
db_attrspec.label = db_label
db_attrspec.text = token
db_attrspec.save()
else:
raise ValueError("Invalid labels format {}".format(labels))
db_task.mode = mode
db_task.save()