@ -8,6 +8,7 @@ import itertools
import fnmatch
import os
import sys
from typing import Any , Dict , Iterator , List , NamedTuple , Optional , Union
from rest_framework . serializers import ValidationError
import rq
import re
@ -60,6 +61,17 @@ def rq_handler(job, exc_type, exc_value, traceback):
############################# Internal implementation for server API
JobFileMapping = List [ List [ str ] ]
class SegmentParams ( NamedTuple ) :
start_frame : int
stop_frame : int
class SegmentsParams ( NamedTuple ) :
segments : Iterator [ SegmentParams ]
segment_size : int
overlap : int
def _copy_data_from_source ( server_files , upload_dir , server_dir = None ) :
job = rq . get_current_job ( )
job . meta [ ' status ' ] = ' Data are being copied from source.. '
@ -79,7 +91,32 @@ def _copy_data_from_source(server_files, upload_dir, server_dir=None):
os . makedirs ( target_dir )
shutil . copyfile ( source_path , target_path )
def _get_task_segment_data ( db_task , data_size ) :
def _get_task_segment_data (
db_task : models . Task ,
* ,
data_size : Optional [ int ] = None ,
job_file_mapping : Optional [ JobFileMapping ] = None ,
) - > SegmentsParams :
if job_file_mapping is not None :
def _segments ( ) :
# It is assumed here that files are already saved ordered in the task
# Here we just need to create segments by the job sizes
start_frame = 0
for jf in job_file_mapping :
segment_size = len ( jf )
stop_frame = start_frame + segment_size - 1
yield SegmentParams ( start_frame , stop_frame )
start_frame = stop_frame + 1
segments = _segments ( )
segment_size = 0
overlap = 0
else :
# The segments have equal parameters
if data_size is None :
data_size = db_task . data . size
segment_size = db_task . segment_size
segment_step = segment_size
if segment_size == 0 or segment_size > data_size :
@ -94,22 +131,28 @@ def _get_task_segment_data(db_task, data_size):
overlap = min ( db_task . overlap , segment_size / / 2 )
segment_step - = overlap
return segment_step , segment_size , overlap
def _save_task_to_db ( db_task , extractor ) :
segments = (
SegmentParams ( start_frame , min ( start_frame + segment_size - 1 , data_size - 1 ) )
for start_frame in range ( 0 , data_size , segment_step )
)
return SegmentsParams ( segments , segment_size , overlap )
def _save_task_to_db ( db_task : models . Task , * , job_file_mapping : Optional [ JobFileMapping ] = None ) :
job = rq . get_current_job ( )
job . meta [ ' status ' ] = ' Task is being saved in database '
job . save_meta ( )
segment_step , segment_size , overlap = _get_task_segment_data ( db_task , db_task . data . size )
segments , segment_size , overlap = _get_task_segment_data (
db_task = db_task , job_file_mapping = job_file_mapping
)
db_task . segment_size = segment_size
db_task . overlap = overlap
for start_frame in range ( 0 , db_task . data . size , segment_step ) :
stop_frame = min ( start_frame + segment_size - 1 , db_task . data . size - 1 )
slogger . glob . info ( " New segment for task # {} : start_frame = {} , \
stop_frame = { } " .format(db_task.id, start_frame, stop_frame))
for segment_idx , ( start_frame , stop_frame ) in enumerate ( segments ) :
slogger . glob . info ( " New segment for task # {} : idx = {} , start_frame = {} , \
stop_frame = { } " .format(db_task.id, segment_idx, start_frame, stop_frame))
db_segment = models . Segment ( )
db_segment . task = db_task
@ -214,6 +257,41 @@ def _validate_data(counter, manifest_files=None):
return counter , task_modes [ 0 ]
def _validate_job_file_mapping (
db_task : models . Task , data : Dict [ str , Any ]
) - > Optional [ JobFileMapping ] :
job_file_mapping = data . get ( ' job_file_mapping ' , None )
if job_file_mapping is None :
return None
elif not list ( itertools . chain . from_iterable ( job_file_mapping ) ) :
raise ValidationError ( " job_file_mapping cannot be empty " )
if db_task . segment_size :
raise ValidationError ( " job_file_mapping cannot be used with segment_size " )
if ( data . get ( ' sorting_method ' , db_task . data . sorting_method )
!= models . SortingMethod . LEXICOGRAPHICAL
) :
raise ValidationError ( " job_file_mapping cannot be used with sorting_method " )
if data . get ( ' start_frame ' , db_task . data . start_frame ) :
raise ValidationError ( " job_file_mapping cannot be used with start_frame " )
if data . get ( ' stop_frame ' , db_task . data . stop_frame ) :
raise ValidationError ( " job_file_mapping cannot be used with stop_frame " )
if data . get ( ' frame_filter ' , db_task . data . frame_filter ) :
raise ValidationError ( " job_file_mapping cannot be used with frame_filter " )
if db_task . data . get_frame_step ( ) != 1 :
raise ValidationError ( " job_file_mapping cannot be used with frame step " )
if data . get ( ' filename_pattern ' ) :
raise ValidationError ( " job_file_mapping cannot be used with filename_pattern " )
return job_file_mapping
def _validate_manifest ( manifests , root_dir , is_in_cloud , db_cloud_storage , data_storage_method ) :
if manifests :
if len ( manifests ) != 1 :
@ -325,12 +403,20 @@ def _create_task_manifest_based_on_cloud_storage_manifest(
manifest . create ( sorted_content )
@transaction.atomic
def _create_thread ( db_task , data , isBackupRestore = False , isDatasetImport = False ) :
def _create_thread (
db_task : Union [ int , models . Task ] ,
data : Dict [ str , Any ] ,
* ,
isBackupRestore : bool = False ,
isDatasetImport : bool = False ,
) - > None :
if isinstance ( db_task , int ) :
db_task = models . Task . objects . select_for_update ( ) . get ( pk = db_task )
slogger . glob . info ( " create task # {} " . format ( db_task . id ) )
job_file_mapping = _validate_job_file_mapping ( db_task , data )
db_data = db_task . data
upload_dir = db_data . get_upload_dirname ( ) if db_data . storage != models . StorageChoice . SHARE else settings . SHARE_ROOT
is_data_in_cloud = db_data . storage == models . StorageChoice . CLOUD_STORAGE
@ -387,10 +473,16 @@ def _create_thread(db_task, data, isBackupRestore=False, isDatasetImport=False):
media = _count_files ( data )
media , task_mode = _validate_data ( media , manifest_files )
if job_file_mapping is not None and task_mode != ' annotation ' :
raise ValidationError ( " job_file_mapping can ' t be used with sequence-based data like videos " )
if data [ ' server_files ' ] :
if db_data . storage == models . StorageChoice . LOCAL :
_copy_data_from_source ( data [ ' server_files ' ] , upload_dir , data . get ( ' server_files_path ' ) )
elif is_data_in_cloud :
if job_file_mapping is not None :
sorted_media = list ( itertools . chain . from_iterable ( job_file_mapping ) )
else :
sorted_media = sort ( media [ ' image ' ] , data [ ' sorting_method ' ] )
# Define task manifest content based on cloud storage manifest content and uploaded files
@ -486,24 +578,44 @@ def _create_thread(db_task, data, isBackupRestore=False, isDatasetImport=False):
extractor . filter ( lambda x : not re . search ( r ' (^| {0} )related_images {0} ' . format ( os . sep ) , x ) )
related_images = detect_related_images ( extractor . absolute_source_paths , upload_dir )
if isBackupRestore and not isinstance ( extractor , MEDIA_TYPES [ ' video ' ] [ ' extractor ' ] ) and db_data . storage_method == models . StorageMethodChoice . CACHE and \
db_data . sorting_method in { models . SortingMethod . RANDOM , models . SortingMethod . PREDEFINED } and validate_dimension . dimension != models . DimensionType . DIM_3D :
# Sort the files
if ( isBackupRestore and (
not isinstance ( extractor , MEDIA_TYPES [ ' video ' ] [ ' extractor ' ] )
and db_data . storage_method == models . StorageMethodChoice . CACHE
and db_data . sorting_method in { models . SortingMethod . RANDOM , models . SortingMethod . PREDEFINED }
and validate_dimension . dimension != models . DimensionType . DIM_3D
) or job_file_mapping
) :
sorted_media_files = [ ]
if job_file_mapping :
sorted_media_files . extend ( itertools . chain . from_iterable ( job_file_mapping ) )
else :
# we should sort media_files according to the manifest content sequence
# and we should do this in general after validation step for 3D data and after filtering from related_images
manifest = ImageManifestManager ( db_data . get_manifest_path ( ) )
manifest . set_index ( )
sorted_media_files = [ ]
for idx in range ( len ( extractor . absolute_source_paths ) ) :
properties = manifest [ idx ]
image_name = properties . get ( ' name ' , None )
image_extension = properties . get ( ' extension ' , None )
full_image_path = os . path . join ( upload_dir , f " { image_name } { image_extension } " ) if image_name and image_extension else None
if full_image_path and full_image_path in extractor :
full_image_path = f " { image_name } { image_extension } " if image_name and image_extension else None
if full_image_path :
sorted_media_files . append ( full_image_path )
sorted_media_files = [ os . path . join ( upload_dir , fn ) for fn in sorted_media_files ]
for file_path in sorted_media_files :
if not file_path in extractor :
raise ValidationError (
f " Can ' t find file ' { os . path . basename ( file_path ) } ' in the input files "
)
media_files = sorted_media_files . copy ( )
del sorted_media_files
data [ ' sorting_method ' ] = models . SortingMethod . PREDEFINED
extractor . reconcile (
source_files = media_files ,
@ -720,4 +832,4 @@ def _create_thread(db_task, data, isBackupRestore=False, isDatasetImport=False):
db_data . start_frame + ( db_data . size - 1 ) * db_data . get_frame_step ( ) )
slogger . glob . info ( " Found frames {} for Data # {} " . format ( db_data . size , db_data . id ) )
_save_task_to_db ( db_task , extractor )
_save_task_to_db ( db_task , job_file_mapping= job_file_mapping )