@ -4,6 +4,9 @@
from tools . test import *
import os
from copy import copy
import jsonpickle
import numpy as np
class ModelHandler :
def __init__ ( self ) :
@ -11,7 +14,8 @@ class ModelHandler:
self . device = torch . device ( ' cuda ' if torch . cuda . is_available ( ) else ' cpu ' )
torch . backends . cudnn . benchmark = True
base_dir = " /opt/nuclio/SiamMask/experiments/siammask_sharp "
base_dir = os . environ . get ( " MODEL_PATH " ,
" /opt/nuclio/SiamMask/experiments/siammask_sharp " )
class configPath :
config = os . path . join ( base_dir , " config_davis.json " )
@ -21,18 +25,42 @@ class ModelHandler:
self . siammask = load_pretrain ( siammask , os . path . join ( base_dir , " SiamMask_DAVIS.pth " ) )
self . siammask . eval ( ) . to ( self . device )
def encode_state ( self , state ) :
state [ ' net.zf ' ] = state [ ' net ' ] . zf
state . pop ( ' net ' , None )
state . pop ( ' mask ' , None )
for k , v in state . items ( ) :
state [ k ] = jsonpickle . encode ( v )
return state
def decode_state ( self , state ) :
for k , v in state . items ( ) :
state [ k ] = jsonpickle . decode ( v )
state [ ' net ' ] = copy ( self . siammask )
state [ ' net ' ] . zf = state [ ' net.zf ' ]
del state [ ' net.zf ' ]
return state
def infer ( self , image , shape , state ) :
image = np . array ( image )
if state is None : # init tracking
x , y , w , h = shape
target_pos = np . array ( [ x + w / 2 , y + h / 2 ] )
target_sz = np . array ( [ w , h ] )
state = siamese_init ( image , target_pos , target_sz , self . siammask ,
xtl , ytl , xbr , ybr = shape
target_pos = np . array ( [ ( xtl + xbr ) / 2 , ( ytl + ybr ) / 2 ] )
target_sz = np . array ( [ xbr - xtl , ybr - ytl ] )
siammask = copy ( self . siammask ) # don't modify self.siammask
state = siamese_init ( image , target_pos , target_sz , siammask ,
self . config [ ' hp ' ] , device = self . device )
state = self . encode_state ( state )
else : # track
state = siamese_track ( state , image , mask_enable = True , refine_enable = True ,
device = self . device )
shape = state [ ' ploygon ' ] . flatten ( )
state = self . decode_state ( state )
state = siamese_track ( state , image , mask_enable = True ,
refine_enable = True , device = self . device )
shape = state [ ' ploygon ' ] . flatten ( ) . tolist ( )
state = self . encode_state ( state )
return { " shape " : shape , " state " : state }