@ -14,8 +14,10 @@ import InputNumber from 'antd/lib/input-number';
import Button from 'antd/lib/button' ;
import notification from 'antd/lib/notification' ;
import { Model , StringObject } from 'reducers/interfaces' ;
import { Model , ModelAttribute , StringObject } from 'reducers/interfaces' ;
import CVATTooltip from 'components/common/cvat-tooltip' ;
import { Label as LabelInterface } from 'components/labels-editor/common' ;
import { clamp } from 'utils/math' ;
import consts from 'consts' ;
import { DimensionType } from '../../reducers/interfaces' ;
@ -23,28 +25,40 @@ import { DimensionType } from '../../reducers/interfaces';
interface Props {
withCleanup : boolean ;
models : Model [ ] ;
labels : any [ ] ;
labels : LabelInterface [ ] ;
dimension : DimensionType ;
runInference ( model : Model , body : object ) : void ;
}
interface MappedLabel {
name : string ;
attributes : StringObject ;
}
type MappedLabelsList = Record < string , MappedLabel > ;
export interface DetectorRequestBody {
mapping : MappedLabelsList ;
cleanup : boolean ;
}
interface Match {
model : string | null ;
task : string | null ;
}
function DetectorRunner ( props : Props ) : JSX . Element {
const {
models , withCleanup , labels , dimension , runInference ,
} = props ;
const [ modelID , setModelID ] = useState < string | null > ( null ) ;
const [ mapping , setMapping ] = useState < StringObject > ( { } ) ;
const [ mapping , setMapping ] = useState < MappedLabelsLis t> ( { } ) ;
const [ threshold , setThreshold ] = useState < number > ( 0.5 ) ;
const [ distance , setDistance ] = useState < number > ( 50 ) ;
const [ cleanup , setCleanup ] = useState < boolean > ( false ) ;
const [ match , setMatch ] = useState < {
model : string | null ;
task : string | null ;
} > ( {
model : null ,
task : null ,
} ) ;
const [ match , setMatch ] = useState < Match > ( { model : null , task : null } ) ;
const [ attrMatches , setAttrMatch ] = useState < Record < string , Match > > ( { } ) ;
const model = models . filter ( ( _model ) : boolean = > _model . id === modelID ) [ 0 ] ;
const isDetector = model && model . type === 'detector' ;
@ -57,24 +71,47 @@ function DetectorRunner(props: Props): JSX.Element {
if ( model && model . type !== 'reid' && ! model . labels . length ) {
notification . warning ( {
message : 'The selected model does not include any lab l es',
message : 'The selected model does not include any lab el s',
} ) ;
}
function matchAttributes (
labelAttributes : LabelInterface [ 'attributes' ] ,
modelAttributes : ModelAttribute [ ] ,
) : StringObject {
if ( Array . isArray ( labelAttributes ) && Array . isArray ( modelAttributes ) ) {
return labelAttributes
. reduce ( ( attrAcc : StringObject , attr : any ) : StringObject = > {
if ( modelAttributes . some ( ( mAttr ) = > mAttr . name === attr . name ) ) {
attrAcc [ attr . name ] = attr . name ;
}
return attrAcc ;
} , { } ) ;
}
return { } ;
}
function updateMatch ( modelLabel : string | null , taskLabel : string | null ) : void {
if ( match . model && taskLabel ) {
const newmatch : { [ index : string ] : string } = { } ;
newmatch [ match . model ] = taskLabel ;
setMapping ( { . . . mapping , . . . newmatch } ) ;
function addMatch ( modelLbl : string , taskLbl : string ) : void {
const newMatch : MappedLabelsList = { } ;
const label = labels . find ( ( l ) = > l . name === taskLbl ) as LabelInterface ;
const currentModel = models . filter ( ( _model ) : boolean = > _model . id === modelID ) [ 0 ] ;
const attributes = matchAttributes ( label . attributes , currentModel . attributes [ modelLbl ] ) ;
newMatch [ modelLbl ] = { name : taskLbl , attributes } ;
setMapping ( { . . . mapping , . . . newMatch } ) ;
setMatch ( { model : null , task : null } ) ;
}
if ( match . model && taskLabel ) {
addMatch ( match . model , taskLabel ) ;
return ;
}
if ( match . task && modelLabel ) {
const newmatch : { [ index : string ] : string } = { } ;
newmatch [ modelLabel ] = match . task ;
setMapping ( { . . . mapping , . . . newmatch } ) ;
setMatch ( { model : null , task : null } ) ;
addMatch ( modelLabel , match . task ) ;
return ;
}
@ -84,14 +121,72 @@ function DetectorRunner(props: Props): JSX.Element {
} ) ;
}
function updateAttrMatch ( modelLabel : string , modelAttrLabel : string | null , taskAttrLabel : string | null ) : void {
function addAttributeMatch ( modelAttr : string , attrLabel : string ) : void {
const newMatch : StringObject = { } ;
newMatch [ modelAttr ] = attrLabel ;
mapping [ modelLabel ] . attributes = { . . . mapping [ modelLabel ] . attributes , . . . newMatch } ;
delete attrMatches [ modelLabel ] ;
setAttrMatch ( { . . . attrMatches } ) ;
}
const modelAttr = attrMatches [ modelLabel ] ? . model ;
if ( modelAttr && taskAttrLabel ) {
addAttributeMatch ( modelAttr , taskAttrLabel ) ;
return ;
}
const taskAttrModel = attrMatches [ modelLabel ] ? . task ;
if ( taskAttrModel && modelAttrLabel ) {
addAttributeMatch ( modelAttrLabel , taskAttrModel ) ;
return ;
}
attrMatches [ modelLabel ] = {
model : modelAttrLabel ,
task : taskAttrLabel ,
} ;
setAttrMatch ( { . . . attrMatches } ) ;
}
function renderMappingRow (
color : string ,
leftLabel : string ,
rightLabel : string ,
removalTitle : string ,
onClick : ( ) = > void ,
className = '' ,
) : JSX . Element {
return (
< Row key = { leftLabel } justify = 'start' align = 'middle' >
< Col span = { 10 } className = { className } >
< Tag color = { color } > { leftLabel } < / Tag >
< / Col >
< Col span = { 10 } offset = { 1 } className = { className } >
< Tag color = { color } > { rightLabel } < / Tag >
< / Col >
< Col offset = { 1 } >
< CVATTooltip title = { removalTitle } >
< DeleteOutlined
className = 'cvat-danger-circle-icon'
onClick = { onClick }
/ >
< / CVATTooltip >
< / Col >
< / Row >
) ;
}
function renderSelector (
value : string ,
tooltip : string ,
labelsToRender : string [ ] ,
onChange : ( label : string ) = > void ,
className = '' ,
) : JSX . Element {
return (
< CVATTooltip title = { tooltip } >
< CVATTooltip title = { tooltip } className = { className } >
< Select
value = { value }
onChange = { onChange }
@ -130,16 +225,24 @@ function DetectorRunner(props: Props): JSX.Element {
disabled = { dimension !== DimensionType . DIM_2D }
style = { { width : '100%' } }
onChange = { ( _modelID : string ) : void = > {
const newmodel = models . filter ( ( _model ) : boolean = > _model . id === _modelID ) [ 0 ] ;
const newmapping = labels . reduce ( ( acc : StringObject , label : any ) : StringObject = > {
if ( newmodel . labels . includes ( label . name ) ) {
acc [ label . name ] = label . name ;
}
return acc ;
} , { } ) ;
setMapping ( newmapping ) ;
const chosenModel = models . filter ( ( _model ) : boolean = > _model . id === _modelID ) [ 0 ] ;
const defaultMapping = labels . reduce (
( acc : MappedLabelsList , label : LabelInterface ) : MappedLabelsList = > {
if ( chosenModel . labels . includes ( label . name ) ) {
acc [ label . name ] = {
name : label.name ,
attributes : matchAttributes (
label . attributes , chosenModel . attributes [ label . name ] ,
) ,
} ;
}
return acc ;
} , { } ,
) ;
setMapping ( defaultMapping ) ;
setMatch ( { model : null , task : null } ) ;
setAttrMatch ( { } ) ;
setModelID ( _modelID ) ;
} }
>
@ -154,45 +257,92 @@ function DetectorRunner(props: Props): JSX.Element {
< / Col >
< / Row >
{ isDetector &&
! ! Object . keys ( mapping ) . length &&
Object . keys ( mapping ) . length ?
Object . keys ( mapping ) . map ( ( modelLabel : string ) = > {
const label = labels . filter ( ( _label : any ) : boolean = > _label . name === mapping [ modelLabel ] ) [ 0 ] ;
const label = labels
. find ( ( _label : LabelInterface ) : boolean = > (
_label . name === mapping [ modelLabel ] . name ) ) as LabelInterface ;
const color = label ? label.color : consts.NEW_LABEL_COLOR ;
const notMatchedModelAttributes = model . attributes [ modelLabel ]
. filter ( ( _attribute : ModelAttribute ) : boolean = > (
! ( _attribute . name in ( mapping [ modelLabel ] . attributes || { } ) )
) ) ;
const taskAttributes = label . attributes . map ( ( _attrLabel : any ) : string = > _attrLabel . name ) ;
return (
< Row key = { modelLabel } justify = 'start' align = 'middle' >
< Col span = { 10 } >
< Tag color = { color } > { modelLabel } < / Tag >
< / Col >
< Col span = { 10 } offset = { 1 } >
< Tag color = { color } > { mapping [ modelLabel ] } < / Tag >
< / Col >
< Col offset = { 1 } >
< CVATTooltip title = 'Remove the mapped values' >
< DeleteOutlined
className = 'cvat-danger-circle-icon'
onClick = { ( ) : void = > {
const newmapping = { . . . mapping } ;
delete newmapping [ modelLabel ] ;
setMapping ( newmapping ) ;
} }
/ >
< / CVATTooltip >
< / Col >
< / Row >
< React.Fragment key = { modelLabel } >
{
renderMappingRow ( color ,
modelLabel ,
label . name ,
'Remove the mapped label' ,
( ) : void = > {
const newMapping = { . . . mapping } ;
delete newMapping [ modelLabel ] ;
setMapping ( newMapping ) ;
const newAttrMatches = { . . . attrMatches } ;
delete newAttrMatches [ modelLabel ] ;
setAttrMatch ( { . . . newAttrMatches } ) ;
} )
}
{
Object . keys ( mapping [ modelLabel ] . attributes || { } )
. map ( ( mappedModelAttr : string ) = > (
renderMappingRow (
consts . NEW_LABEL_COLOR ,
mappedModelAttr ,
mapping [ modelLabel ] . attributes [ mappedModelAttr ] ,
'Remove the mapped attribute' ,
( ) : void = > {
const newMapping = { . . . mapping } ;
delete mapping [ modelLabel ] . attributes [ mappedModelAttr ] ;
setMapping ( newMapping ) ;
} ,
'cvat-run-model-label-attribute-block' ,
)
) )
}
{ notMatchedModelAttributes . length && taskAttributes . length ? (
< Row justify = 'start' align = 'middle' >
< Col span = { 10 } >
{ renderSelector (
attrMatches [ modelLabel ] ? . model || '' ,
'Model attr labels' , notMatchedModelAttributes . map ( ( l ) = > l . name ) ,
( modelAttrLabel : string ) = > updateAttrMatch (
modelLabel , modelAttrLabel , null ,
) ,
'cvat-run-model-label-attribute-block' ,
) }
< / Col >
< Col span = { 10 } offset = { 1 } >
{ renderSelector (
attrMatches [ modelLabel ] ? . task || '' ,
'Task attr labels' , taskAttributes ,
( taskAttrLabel : string ) = > updateAttrMatch (
modelLabel , null , taskAttrLabel ,
) ,
'cvat-run-model-label-attribute-block' ,
) }
< / Col >
< Col span = { 1 } offset = { 1 } >
< CVATTooltip title = 'Specify an attribute mapping between model label and task label attributes' >
< QuestionCircleOutlined className = 'cvat-info-circle-icon' / >
< / CVATTooltip >
< / Col >
< / Row >
) : null }
< / React.Fragment >
) ;
} ) }
{ isDetector && ! ! taskLabels . length && ! ! modelLabels . length && (
} ) : null }
{ isDetector && ! ! taskLabels . length && ! ! modelLabels . length ? (
< >
< Row justify = 'start' align = 'middle' >
< Col span = { 10 } >
{ renderSelector (
match . model || '' , 'Model labels' , modelLabels , ( modelLabel : string ) = > updateMatch ( modelLabel , null ) ,
) }
{ renderSelector ( match . model || '' , 'Model labels' , modelLabels , ( modelLabel : string ) = > updateMatch ( modelLabel , null ) ) }
< / Col >
< Col span = { 10 } offset = { 1 } >
{ renderSelector (
match . task || '' , 'Task labels' , taskLabels , ( taskLabel : string ) = > updateMatch ( null , taskLabel ) ,
) }
{ renderSelector ( match . task || '' , 'Task labels' , taskLabels , ( taskLabel : string ) = > updateMatch ( null , taskLabel ) ) }
< / Col >
< Col span = { 1 } offset = { 1 } >
< CVATTooltip title = 'Specify a label mapping between model labels and task labels' >
@ -201,8 +351,8 @@ function DetectorRunner(props: Props): JSX.Element {
< / Col >
< / Row >
< / >
) }
{ isDetector && withCleanup && (
) : null }
{ isDetector && withCleanup ? (
< div >
< Checkbox
checked = { cleanup }
@ -211,8 +361,8 @@ function DetectorRunner(props: Props): JSX.Element {
Clean old annotations
< / Checkbox >
< / div >
) }
{ isReId && (
) : null }
{ isReId ? (
< div >
< Row align = 'middle' justify = 'start' >
< Col >
@ -254,18 +404,25 @@ function DetectorRunner(props: Props): JSX.Element {
< / Col >
< / Row >
< / div >
) }
) : null }
< Row align = 'middle' justify = 'end' >
< Col >
< Button
disabled = { ! buttonEnabled }
type = 'primary'
onClick = { ( ) = > {
runInference ( model , model . type === 'detector' ?
{ mapping , cleanup } : {
const detectorRequestBody : DetectorRequestBody = {
mapping ,
cleanup ,
} ;
runInference (
model ,
model . type === 'detector' ? detectorRequestBody : {
threshold ,
max_distance : distance ,
} ) ;
} ,
) ;
} }
>
Annotate