Added SiamMask CUDA implementation (tracking), reworked tracking approach (#3571)

main
Boris Sekachev 4 years ago committed by GitHub
parent 02172ad55c
commit dbdcd4fd86
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -46,6 +46,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Support cloud storage status (<https://github.com/openvinotoolkit/cvat/pull/3386>)
- Support cloud storage preview (<https://github.com/openvinotoolkit/cvat/pull/3386>)
- cvat-core: support cloud storages (<https://github.com/openvinotoolkit/cvat/pull/3313>)
- Added GPU implementation for SiamMask, reworked tracking approach (<https://github.com/openvinotoolkit/cvat/pull/3571>)
### Changed
@ -53,6 +54,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- "Selected opacity" slider now defines opacity level for shapes being drawnSelected opacity (<https://github.com/openvinotoolkit/cvat/pull/3473>)
- Cloud storage creating and updating (<https://github.com/openvinotoolkit/cvat/pull/3386>)
- Way of working with cloud storage content (<https://github.com/openvinotoolkit/cvat/pull/3386>)
- UI tracking has been reworked (<https://github.com/openvinotoolkit/cvat/pull/3571>)
### Removed

@ -1,6 +1,6 @@
{
"name": "cvat-canvas",
"version": "2.7.0",
"version": "2.8.0",
"lockfileVersion": 1,
"requires": true,
"dependencies": {

@ -1,6 +1,6 @@
{
"name": "cvat-canvas",
"version": "2.7.0",
"version": "2.8.0",
"description": "Part of Computer Vision Annotation Tool which presents its canvas library",
"main": "src/canvas.ts",
"scripts": {

@ -46,6 +46,12 @@ polyline.cvat_shape_drawing_opacity {
pointer-events: none;
}
.cvat_canvas_text_description {
font-size: 14px;
fill: yellow;
font-style: oblique 40deg;
}
.cvat_canvas_crosshair {
stroke: red;
}

@ -1458,6 +1458,7 @@ export class CanvasViewImpl implements CanvasView, Listener {
shapeType: state.shapeType,
points: [...state.points],
attributes: { ...state.attributes },
descriptions: [...state.descriptions],
zOrder: state.zOrder,
pinned: state.pinned,
updated: state.updated,
@ -1544,7 +1545,14 @@ export class CanvasViewImpl implements CanvasView, Listener {
}
}
if (drawnState.label.id !== state.label.id) {
const stateDescriptions = state.descriptions;
const drawnStateDescriptions = drawnState.descriptions;
if (
drawnState.label.id !== state.label.id
|| drawnStateDescriptions.length !== stateDescriptions.length
|| drawnStateDescriptions.some((desc: string, id: number): boolean => desc !== stateDescriptions[id])
) {
// need to remove created text and create it again
if (text) {
text.remove();
@ -1967,7 +1975,7 @@ export class CanvasViewImpl implements CanvasView, Listener {
private addText(state: any): SVG.Text {
const { undefinedAttrValue } = this.configuration;
const {
label, clientID, attributes, source,
label, clientID, attributes, source, descriptions,
} = state;
const attrNames = label.attributes.reduce((acc: any, val: any): void => {
acc[val.id] = val.name;
@ -1977,13 +1985,25 @@ export class CanvasViewImpl implements CanvasView, Listener {
return this.adoptedText
.text((block): void => {
block.tspan(`${label.name} ${clientID} (${source})`).style('text-transform', 'uppercase');
for (const desc of descriptions) {
block
.tspan(`${desc}`)
.attr({
dy: '1em',
x: 0,
})
.addClass('cvat_canvas_text_description');
}
for (const attrID of Object.keys(attributes)) {
const value = attributes[attrID] === undefinedAttrValue ? '' : attributes[attrID];
block.tspan(`${attrNames[attrID]}: ${value}`).attr({
attrID,
dy: '1em',
x: 0,
});
block
.tspan(`${attrNames[attrID]}: ${value}`)
.attr({
attrID,
dy: '1em',
x: 0,
})
.addClass('cvat_canvas_text_attribute');
}
})
.move(0, 0)

@ -45,6 +45,7 @@ export interface DrawnState {
shapeType: string;
points?: number[];
attributes: Record<number, string>;
descriptions: string[];
zOrder?: number;
pinned?: boolean;
updated: number;

@ -722,6 +722,8 @@
checkObjectType('state occluded', state.occluded, 'boolean', null);
checkObjectType('state points', state.points, null, Array);
checkObjectType('state zOrder', state.zOrder, 'integer', null);
checkObjectType('state descriptions', state.descriptions, null, Array);
state.descriptions.forEach((desc) => checkObjectType('state description', desc, 'string'));
for (const coord of state.points) {
checkObjectType('point coordinate', coord, 'number', null);
@ -736,6 +738,7 @@
if (state.objectType === 'shape') {
constructed.shapes.push({
attributes,
descriptions: state.descriptions,
frame: state.frame,
group: 0,
label_id: state.label.id,
@ -748,6 +751,7 @@
} else if (state.objectType === 'track') {
constructed.tracks.push({
attributes: attributes.filter((attr) => !labelAttributes[attr.spec_id].mutable),
descriptions: state.descriptions,
frame: state.frame,
group: 0,
source: state.source,

@ -332,6 +332,14 @@
}
}
if (updated.descriptions) {
if (!Array.isArray(data.descriptions) || data.descriptions.some((desc) => typeof desc !== 'string')) {
throw new ArgumentError(
`Descriptions are expected to be an array of strings but got ${data.descriptions}`,
);
}
}
if (updated.points) {
checkObjectType('points', data.points, null, Array);
checkNumberOfPoints(this.shapeType, data.points);
@ -402,17 +410,7 @@
}
updateTimestamp(updated) {
const anyChanges = updated.label
|| updated.attributes
|| updated.points
|| updated.outside
|| updated.occluded
|| updated.keyframe
|| updated.zOrder
|| updated.hidden
|| updated.lock
|| updated.pinned;
const anyChanges = Object.keys(updated).some((key) => !!updated[key]);
if (anyChanges) {
this.updated = Date.now();
}
@ -446,11 +444,16 @@
constructor(data, clientID, color, injection) {
super(data, clientID, color, injection);
this.frameMeta = injection.frameMeta;
this.descriptions = data.descriptions || [];
this.hidden = false;
this.pinned = true;
this.shapeType = null;
}
_saveDescriptions(descriptions) {
this.descriptions = [...descriptions];
}
_savePinned(pinned, frame) {
const undoPinned = this.pinned;
const redoPinned = pinned;
@ -533,6 +536,7 @@
zOrder: this.zOrder,
points: [...this.points],
attributes: { ...this.attributes },
descriptions: [...this.descriptions],
label: this.label,
group: this.groupObject,
color: this.color,
@ -643,6 +647,10 @@
this._saveAttributes(data.attributes, frame);
}
if (updated.descriptions) {
this._saveDescriptions(data.descriptions);
}
if (updated.points && fittedPoints.length) {
this._savePoints(fittedPoints, frame);
}
@ -760,6 +768,7 @@
return {
...this.getPosition(frame, prev, next),
attributes: this.getAttributes(frame),
descriptions: [...this.descriptions],
group: this.groupObject,
objectType: ObjectType.TRACK,
shapeType: this.shapeType,
@ -1204,6 +1213,10 @@
this._saveAttributes(data.attributes, frame);
}
if (updated.descriptions) {
this._saveDescriptions(data.descriptions);
}
if (updated.keyframe) {
this._saveKeyframe(frame, data.keyframe);
}

@ -1,4 +1,4 @@
// Copyright (C) 2019-2020 Intel Corporation
// Copyright (C) 2019-2021 Intel Corporation
//
// SPDX-License-Identifier: MIT
@ -25,6 +25,7 @@ const { Source } = require('./enums');
const data = {
label: null,
attributes: {},
descriptions: [],
points: null,
outside: null,
@ -55,6 +56,7 @@ const { Source } = require('./enums');
value: function reset() {
this.label = false;
this.attributes = false;
this.descriptions = false;
this.points = false;
this.outside = false;
@ -70,6 +72,7 @@ const { Source } = require('./enums');
return reset;
},
writable: false,
enumerable: false,
});
Object.defineProperties(
@ -353,6 +356,30 @@ const { Source } = require('./enums');
}
},
},
descriptions: {
/**
* Additional text information displayed on canvas
* @name descripttions
* @type {string[]}
* @memberof module:API.cvat.classes.ObjectState
* @throws {module:API.cvat.exceptions.ArgumentError}
* @instance
*/
get: () => [...data.descriptions],
set: (descriptions) => {
if (
!Array.isArray(descriptions)
|| descriptions.some((description) => typeof description !== 'string')
) {
throw new ArgumentError(
`Descriptions are expected to be an array of strings but got ${data.descriptions}`,
);
}
data.updateFlags.descriptions = true;
data.descriptions = [...descriptions];
},
},
}),
);
@ -386,6 +413,12 @@ const { Source } = require('./enums');
if (Array.isArray(serialized.points)) {
this.points = serialized.points;
}
if (
Array.isArray(serialized.descriptions)
&& serialized.descriptions.every((desc) => typeof desc === 'string')
) {
this.descriptions = serialized.descriptions;
}
if (typeof serialized.attributes === 'object') {
this.attributes = serialized.attributes;
}
@ -429,7 +462,7 @@ const { Source } = require('./enums');
}
// Updates element in collection which contains it
ObjectState.prototype.save.implementation = async function () {
ObjectState.prototype.save.implementation = function () {
if (this.__internal && this.__internal.save) {
return this.__internal.save();
}
@ -438,7 +471,7 @@ const { Source } = require('./enums');
};
// Delete element from a collection which contains it
ObjectState.prototype.delete.implementation = async function (frame, force) {
ObjectState.prototype.delete.implementation = function (frame, force) {
if (this.__internal && this.__internal.delete) {
if (!Number.isInteger(+frame) || +frame < 0) {
throw new ArgumentError('Frame argument must be a non negative integer');

@ -2,7 +2,6 @@
//
// SPDX-License-Identifier: MIT
import { MutableRefObject } from 'react';
import {
ActionCreator, AnyAction, Dispatch, Store,
} from 'redux';
@ -183,7 +182,6 @@ export enum AnnotationActionTypes {
SAVE_LOGS_SUCCESS = 'SAVE_LOGS_SUCCESS',
SAVE_LOGS_FAILED = 'SAVE_LOGS_FAILED',
INTERACT_WITH_CANVAS = 'INTERACT_WITH_CANVAS',
SET_AI_TOOLS_REF = 'SET_AI_TOOLS_REF',
GET_DATA_FAILED = 'GET_DATA_FAILED',
SWITCH_REQUEST_REVIEW_DIALOG = 'SWITCH_REQUEST_REVIEW_DIALOG',
SWITCH_SUBMIT_REVIEW_DIALOG = 'SWITCH_SUBMIT_REVIEW_DIALOG',
@ -196,6 +194,7 @@ export enum AnnotationActionTypes {
GET_CONTEXT_IMAGE = 'GET_CONTEXT_IMAGE',
GET_CONTEXT_IMAGE_SUCCESS = 'GET_CONTEXT_IMAGE_SUCCESS',
GET_CONTEXT_IMAGE_FAILED = 'GET_CONTEXT_IMAGE_FAILED',
SWITCH_NAVIGATION_BLOCKED = 'SWITCH_NAVIGATION_BLOCKED',
}
export function saveLogsAsync(): ThunkAction {
@ -258,12 +257,14 @@ export function fetchAnnotationsAsync(): ThunkAction {
filters, frame, showAllInterpolationTracks, jobInstance,
} = receiveAnnotationsParameters();
const states = await jobInstance.annotations.get(frame, showAllInterpolationTracks, filters);
const history = await jobInstance.actions.get();
const [minZ, maxZ] = computeZRange(states);
dispatch({
type: AnnotationActionTypes.FETCH_ANNOTATIONS_SUCCESS,
payload: {
states,
history,
minZ,
maxZ,
},
@ -1460,15 +1461,6 @@ export function interactWithCanvas(activeInteractor: Model | OpenCVTool, activeL
};
}
export function setAIToolsRef(ref: MutableRefObject<any>): AnyAction {
return {
type: AnnotationActionTypes.SET_AI_TOOLS_REF,
payload: {
aiToolsRef: ref,
},
};
}
export function repeatDrawShapeAsync(): ThunkAction {
return async (dispatch: ActionCreator<Dispatch>): Promise<void> => {
const {
@ -1660,3 +1652,12 @@ export function getContextImageAsync(): ThunkAction {
}
};
}
export function switchNavigationBlocked(navigationBlocked: boolean): AnyAction {
return {
type: AnnotationActionTypes.SWITCH_NAVIGATION_BLOCKED,
payload: {
navigationBlocked,
},
};
}

@ -17,6 +17,7 @@ import {
changeFrameAsync,
updateAnnotationsAsync,
} from 'actions/annotation-actions';
import isAbleToChangeFrame from 'utils/is-able-to-change-frame';
import GlobalHotKeys, { KeyMap } from 'utils/mousetrap-react';
import { ThunkDispatch } from 'utils/redux';
import AppearanceBlock from 'components/annotation-page/appearance-block';
@ -266,7 +267,7 @@ function AttributeAnnotationSidebar(props: StateToProps & DispatchToProps): JSX.
if (activeObjectState && activeObjectState.objectType === ObjectType.TRACK) {
const frame =
typeof activeObjectState.keyframes.next === 'number' ? activeObjectState.keyframes.next : null;
if (frame !== null && canvasInstance.isAbleToChangeFrame()) {
if (frame !== null && isAbleToChangeFrame()) {
changeFrame(frame);
}
}
@ -276,7 +277,7 @@ function AttributeAnnotationSidebar(props: StateToProps & DispatchToProps): JSX.
if (activeObjectState && activeObjectState.objectType === ObjectType.TRACK) {
const frame =
typeof activeObjectState.keyframes.prev === 'number' ? activeObjectState.keyframes.prev : null;
if (frame !== null && canvasInstance.isAbleToChangeFrame()) {
if (frame !== null && isAbleToChangeFrame()) {
changeFrame(frame);
}
}

@ -2,9 +2,15 @@
//
// SPDX-License-Identifier: MIT
import React, { MutableRefObject } from 'react';
import React, { ReactPortal } from 'react';
import ReactDOM from 'react-dom';
import { connect } from 'react-redux';
import Icon, { LoadingOutlined, QuestionCircleOutlined } from '@ant-design/icons';
import Icon, {
EnvironmentFilled,
EnvironmentOutlined,
LoadingOutlined,
QuestionCircleOutlined,
} from '@ant-design/icons';
import Popover from 'antd/lib/popover';
import Select from 'antd/lib/select';
import Button from 'antd/lib/button';
@ -14,14 +20,11 @@ import Tabs from 'antd/lib/tabs';
import { Row, Col } from 'antd/lib/grid';
import notification from 'antd/lib/notification';
import message from 'antd/lib/message';
import Progress from 'antd/lib/progress';
import InputNumber from 'antd/lib/input-number';
import Dropdown from 'antd/lib/dropdown';
import lodash from 'lodash';
import { AIToolsIcon } from 'icons';
import { Canvas, convertShapesForInteractor } from 'cvat-canvas-wrapper';
import range from 'utils/range';
import getCore from 'cvat-core-wrapper';
import openCVWrapper from 'utils/opencv-wrapper/opencv-wrapper';
import {
@ -29,12 +32,15 @@ import {
} from 'reducers/interfaces';
import {
interactWithCanvas,
switchNavigationBlocked as switchNavigationBlockedAction,
fetchAnnotationsAsync,
updateAnnotationsAsync,
createAnnotationsAsync,
} from 'actions/annotation-actions';
import DetectorRunner from 'components/model-runner-modal/detector-runner';
import LabelSelector from 'components/label-selector/label-selector';
import CVATTooltip from 'components/common/cvat-tooltip';
import ApproximationAccuracy, {
thresholdFromAccuracy,
} from 'components/annotation-page/standard-workspace/controls-side-bar/approximation-accuracy';
@ -54,7 +60,6 @@ interface StateToProps {
detectors: Model[];
trackers: Model[];
curZOrder: number;
aiToolsRef: MutableRefObject<any>;
defaultApproxPolyAccuracy: number;
toolsBlockerState: ToolsBlockerState;
}
@ -64,7 +69,8 @@ interface DispatchToProps {
updateAnnotations(statesToUpdate: any[]): void;
createAnnotations(sessionInstance: any, frame: number, statesToCreate: any[]): void;
fetchAnnotations(): void;
onSwitchToolsBlockerState(toolsBlockerState: ToolsBlockerState):void;
onSwitchToolsBlockerState(toolsBlockerState: ToolsBlockerState): void;
switchNavigationBlocked(navigationBlocked: boolean): void;
}
const core = getCore();
@ -92,7 +98,6 @@ function mapStateToProps(state: CombinedState): StateToProps {
jobInstance,
frame,
curZOrder: annotation.annotations.zLayer.cur,
aiToolsRef: annotation.aiToolsRef,
defaultApproxPolyAccuracy: settings.workspace.defaultApproxPolyAccuracy,
toolsBlockerState,
};
@ -104,21 +109,81 @@ const mapDispatchToProps = {
createAnnotations: createAnnotationsAsync,
fetchAnnotations: fetchAnnotationsAsync,
onSwitchToolsBlockerState: switchToolsBlockerState,
switchNavigationBlocked: switchNavigationBlockedAction,
};
type Props = StateToProps & DispatchToProps;
interface TrackedShape {
clientID: number;
serverlessState: any;
shapePoints: number[];
trackerModel: Model;
}
interface State {
activeInteractor: Model | null;
activeLabelID: number;
activeTracker: Model | null;
trackingProgress: number | null;
trackingFrames: number;
trackedShapes: TrackedShape[];
fetching: boolean;
pointsRecieved: boolean;
approxPolyAccuracy: number;
mode: 'detection' | 'interaction' | 'tracking';
portals: React.ReactPortal[];
}
function trackedRectangleMapper(shape: number[]): number[] {
return shape.reduce(
(acc: number[], value: number, index: number): number[] => {
if (index % 2) {
// y
acc[1] = Math.min(acc[1], value);
acc[3] = Math.max(acc[3], value);
} else {
// x
acc[0] = Math.min(acc[0], value);
acc[2] = Math.max(acc[2], value);
}
return acc;
},
[Number.MAX_SAFE_INTEGER, Number.MAX_SAFE_INTEGER, Number.MIN_SAFE_INTEGER, Number.MIN_SAFE_INTEGER],
);
}
function registerPlugin(): (callback: null | (() => void)) => void {
let onTrigger: null | (() => void) = null;
const listener = {
name: 'Remove annotations listener',
description: 'Tracker needs to know when annotations is reset in the job',
cvat: {
classes: {
Job: {
prototype: {
annotations: {
clear: {
leave(self: any, result: any) {
if (typeof onTrigger === 'function') {
onTrigger();
}
return result;
},
},
},
},
},
},
},
};
core.plugins.register(listener);
return (callback: null | (() => void)) => {
onTrigger = callback;
};
}
const onRemoveAnnotations = registerPlugin();
export class ToolsControlComponent extends React.PureComponent<Props, State> {
private interaction: {
id: string | null;
@ -143,11 +208,11 @@ export class ToolsControlComponent extends React.PureComponent<Props, State> {
activeTracker: props.trackers.length ? props.trackers[0] : null,
activeLabelID: props.labels.length ? props.labels[0].id : null,
approxPolyAccuracy: props.defaultApproxPolyAccuracy,
trackingProgress: null,
trackingFrames: 10,
trackedShapes: [],
fetching: false,
pointsRecieved: false,
mode: 'interaction',
portals: [],
};
this.interaction = {
@ -161,15 +226,29 @@ export class ToolsControlComponent extends React.PureComponent<Props, State> {
}
public componentDidMount(): void {
const { canvasInstance, aiToolsRef } = this.props;
aiToolsRef.current = this;
const { canvasInstance } = this.props;
onRemoveAnnotations(() => {
this.setState({ trackedShapes: [] });
});
this.setState({
portals: this.collectTrackerPortals(),
});
canvasInstance.html().addEventListener('canvas.interacted', this.interactionListener);
canvasInstance.html().addEventListener('canvas.canceled', this.cancelListener);
}
public componentDidUpdate(prevProps: Props, prevState: State): void {
const { isActivated, defaultApproxPolyAccuracy, canvasInstance } = this.props;
const { approxPolyAccuracy, mode } = this.state;
const {
isActivated, defaultApproxPolyAccuracy, canvasInstance, states,
} = this.props;
const { approxPolyAccuracy, mode, activeTracker } = this.state;
if (prevProps.states !== states || prevState.activeTracker !== activeTracker) {
this.setState({
portals: this.collectTrackerPortals(),
});
}
if (prevProps.isActivated && !isActivated) {
window.removeEventListener('contextmenu', this.contextmenuDisabler);
@ -211,11 +290,13 @@ export class ToolsControlComponent extends React.PureComponent<Props, State> {
});
}
}
this.checkTrackedStates(prevProps);
}
public componentWillUnmount(): void {
const { canvasInstance, aiToolsRef } = this.props;
aiToolsRef.current = undefined;
const { canvasInstance } = this.props;
onRemoveAnnotations(null);
canvasInstance.html().removeEventListener('canvas.interacted', this.interactionListener);
canvasInstance.html().removeEventListener('canvas.canceled', this.cancelListener);
}
@ -339,6 +420,7 @@ export class ToolsControlComponent extends React.PureComponent<Props, State> {
};
private onTracking = async (e: Event): Promise<void> => {
const { trackedShapes, activeTracker } = this.state;
const {
isActivated, jobInstance, frame, curZOrder, fetchAnnotations,
} = this.props;
@ -365,18 +447,25 @@ export class ToolsControlComponent extends React.PureComponent<Props, State> {
points,
frame,
occluded: false,
source: 'auto',
attributes: {},
descriptions: [`Trackable (${activeTracker?.name})`],
});
const [clientID] = await jobInstance.annotations.put([state]);
this.setState({
trackedShapes: [
...trackedShapes,
{
clientID,
serverlessState: null,
shapePoints: points,
trackerModel: activeTracker as Model,
},
],
});
// update annotations on a canvas
fetchAnnotations();
const states = await jobInstance.annotations.get(frame);
const [objectState] = states.filter((_state: any): boolean => _state.clientID === clientID);
await this.trackState(objectState);
} catch (err) {
notification.error({
description: err.toString(),
@ -411,7 +500,7 @@ export class ToolsControlComponent extends React.PureComponent<Props, State> {
});
};
private onChangeToolsBlockerState = (event:string):void => {
private onChangeToolsBlockerState = (event: string): void => {
const { isActivated, onSwitchToolsBlockerState } = this.props;
if (isActivated && event === 'keydown') {
onSwitchToolsBlockerState({ algorithmsLocked: true });
@ -420,6 +509,275 @@ export class ToolsControlComponent extends React.PureComponent<Props, State> {
}
};
private collectTrackerPortals(): React.ReactPortal[] {
const { states, fetchAnnotations } = this.props;
const { trackedShapes, activeTracker } = this.state;
const trackedClientIDs = trackedShapes.map((trackedShape: TrackedShape) => trackedShape.clientID);
const portals = !activeTracker ?
[] :
states
.filter((objectState) => objectState.objectType === 'track' && objectState.shapeType === 'rectangle')
.map((objectState: any): React.ReactPortal | null => {
const { clientID } = objectState;
const selectorID = `#cvat-objects-sidebar-state-item-${clientID}`;
let targetElement = window.document.querySelector(
`${selectorID} .cvat-object-item-button-prev-keyframe`,
) as HTMLElement;
const isTracked = trackedClientIDs.includes(clientID);
if (targetElement) {
targetElement = targetElement.parentElement?.parentElement as HTMLElement;
return ReactDOM.createPortal(
<Col>
{isTracked ? (
<CVATTooltip overlay='Disable tracking'>
<EnvironmentFilled
onClick={() => {
const filteredStates = trackedShapes.filter(
(trackedShape: TrackedShape) =>
trackedShape.clientID !== clientID,
);
/* eslint no-param-reassign: ["error", { "props": false }] */
objectState.descriptions = [];
objectState.save().then(() => {
this.setState({
trackedShapes: filteredStates,
});
});
fetchAnnotations();
}}
/>
</CVATTooltip>
) : (
<CVATTooltip overlay={`Enable tracking using ${activeTracker.name}`}>
<EnvironmentOutlined
onClick={() => {
objectState.descriptions = [`Trackable (${activeTracker.name})`];
objectState.save().then(() => {
this.setState({
trackedShapes: [
...trackedShapes,
{
clientID,
serverlessState: null,
shapePoints: objectState.points,
trackerModel: activeTracker,
},
],
});
});
fetchAnnotations();
}}
/>
</CVATTooltip>
)}
</Col>,
targetElement,
);
}
return null;
})
.filter((portal: ReactPortal | null) => portal !== null);
return portals as React.ReactPortal[];
}
private async checkTrackedStates(prevProps: Props): Promise<void> {
const {
frame,
jobInstance,
states: objectStates,
trackers,
fetchAnnotations,
switchNavigationBlocked,
} = this.props;
const { trackedShapes } = this.state;
let withServerRequest = false;
type AccumulatorType = {
statefull: {
[index: string]: {
// tracker id
clientIDs: number[];
states: any[];
shapes: number[][];
};
};
stateless: {
[index: string]: {
// tracker id
clientIDs: number[];
shapes: number[][];
};
};
};
if (prevProps.frame !== frame && trackedShapes.length) {
// 1. find all trackable objects on the current frame
// 2. devide them into two groups: with relevant state, without relevant state
const trackingData = trackedShapes.reduce<AccumulatorType>(
(acc: AccumulatorType, trackedShape: TrackedShape): AccumulatorType => {
const {
serverlessState, shapePoints, clientID, trackerModel,
} = trackedShape;
const [clientState] = objectStates.filter((_state: any): boolean => _state.clientID === clientID);
if (
!clientState ||
clientState.keyframes.prev !== frame - 1 ||
clientState.keyframes.last >= frame
) {
return acc;
}
if (clientState && !clientState.outside) {
const { points } = clientState;
withServerRequest = true;
const stateIsRelevant =
serverlessState !== null &&
points.length === shapePoints.length &&
points.every((coord: number, i: number) => coord === shapePoints[i]);
if (stateIsRelevant) {
const container = acc.statefull[trackerModel.id] || {
clientIDs: [],
shapes: [],
states: [],
};
container.clientIDs.push(clientID);
container.shapes.push(points);
container.states.push(serverlessState);
acc.statefull[trackerModel.id] = container;
} else {
const container = acc.stateless[trackerModel.id] || {
clientIDs: [],
shapes: [],
};
container.clientIDs.push(clientID);
container.shapes.push(points);
acc.stateless[trackerModel.id] = container;
}
}
return acc;
},
{
statefull: {},
stateless: {},
},
);
try {
if (withServerRequest) {
switchNavigationBlocked(true);
}
// 3. get relevant state for the second group
for (const trackerID of Object.keys(trackingData.stateless)) {
let hideMessage = null;
try {
const [tracker] = trackers.filter((_tracker: Model) => _tracker.id === trackerID);
if (!tracker) {
throw new Error(`Suitable tracker with ID ${trackerID} not found in tracker list`);
}
const trackableObjects = trackingData.stateless[trackerID];
const numOfObjects = trackableObjects.clientIDs.length;
hideMessage = message.loading(
`${tracker.name}: states are being initialized for ${numOfObjects} ${
numOfObjects > 1 ? 'objects' : 'object'
} ..`,
0,
);
// eslint-disable-next-line no-await-in-loop
const response = await core.lambda.call(jobInstance.task, tracker, {
task: jobInstance.task,
frame: frame - 1,
shapes: trackableObjects.shapes,
});
const { states: serverlessStates } = response;
const statefullContainer = trackingData.statefull[trackerID] || {
clientIDs: [],
shapes: [],
states: [],
};
Array.prototype.push.apply(statefullContainer.clientIDs, trackableObjects.clientIDs);
Array.prototype.push.apply(statefullContainer.shapes, trackableObjects.shapes);
Array.prototype.push.apply(statefullContainer.states, serverlessStates);
trackingData.statefull[trackerID] = statefullContainer;
delete trackingData.stateless[trackerID];
} catch (error) {
notification.error({
message: 'Tracker initialization error',
description: error.toString(),
});
} finally {
if (hideMessage) hideMessage();
}
}
for (const trackerID of Object.keys(trackingData.statefull)) {
// 4. run tracking for all the objects
let hideMessage = null;
try {
const [tracker] = trackers.filter((_tracker: Model) => _tracker.id === trackerID);
if (!tracker) {
throw new Error(`Suitable tracker with ID ${trackerID} not found in tracker list`);
}
const trackableObjects = trackingData.statefull[trackerID];
const numOfObjects = trackableObjects.clientIDs.length;
hideMessage = message.loading(
`${tracker.name}: ${numOfObjects} ${
numOfObjects > 1 ? 'objects are' : 'object is'
} being tracked..`,
0,
);
// eslint-disable-next-line no-await-in-loop
const response = await core.lambda.call(jobInstance.task, tracker, {
task: jobInstance.task,
frame: frame - 1,
shapes: trackableObjects.shapes,
states: trackableObjects.states,
});
response.shapes = response.shapes.map(trackedRectangleMapper);
for (let i = 0; i < trackableObjects.clientIDs.length; i++) {
const clientID = trackableObjects.clientIDs[i];
const shape = response.shapes[i];
const state = response.states[i];
const [objectState] = objectStates.filter(
(_state: any): boolean => _state.clientID === clientID,
);
const [trackedShape] = trackedShapes.filter(
(_trackedShape: TrackedShape) => _trackedShape.clientID === clientID,
);
objectState.points = shape;
objectState.save().then(() => {
trackedShape.serverlessState = state;
trackedShape.shapePoints = shape;
});
}
} catch (error) {
notification.error({
message: 'Tracking error',
description: error.toString(),
});
} finally {
if (hideMessage) hideMessage();
fetchAnnotations();
}
}
} finally {
if (withServerRequest) {
switchNavigationBlocked(false);
}
}
}
}
private constructFromPoints(points: number[][]): void {
const {
frame, labels, curZOrder, jobInstance, activeLabelID, createAnnotations,
@ -457,70 +815,6 @@ export class ToolsControlComponent extends React.PureComponent<Props, State> {
return points;
}
public async trackState(state: any): Promise<void> {
const { jobInstance, frame, fetchAnnotations } = this.props;
const { activeTracker, trackingFrames } = this.state;
const { clientID, points } = state;
const tracker = activeTracker as Model;
try {
this.setState({ trackingProgress: 0, fetching: true });
let response = await core.lambda.call(jobInstance.task, tracker, {
task: jobInstance.task,
frame,
shape: points,
});
for (const offset of range(1, trackingFrames + 1)) {
/* eslint-disable no-await-in-loop */
const states = await jobInstance.annotations.get(frame + offset);
const [objectState] = states.filter((_state: any): boolean => _state.clientID === clientID);
response = await core.lambda.call(jobInstance.task, tracker, {
task: jobInstance.task,
frame: frame + offset,
shape: response.points,
state: response.state,
});
const reduced = response.shape.reduce(
(acc: number[], value: number, index: number): number[] => {
if (index % 2) {
// y
acc[1] = Math.min(acc[1], value);
acc[3] = Math.max(acc[3], value);
} else {
// x
acc[0] = Math.min(acc[0], value);
acc[2] = Math.max(acc[2], value);
}
return acc;
},
[
Number.MAX_SAFE_INTEGER,
Number.MAX_SAFE_INTEGER,
Number.MIN_SAFE_INTEGER,
Number.MIN_SAFE_INTEGER,
],
);
objectState.points = reduced;
await objectState.save();
this.setState({ trackingProgress: offset / trackingFrames });
}
} finally {
this.setState({ trackingProgress: null, fetching: false });
fetchAnnotations();
}
}
public trackingAvailable(): boolean {
const { activeTracker, trackingFrames } = this.state;
const { trackers } = this.props;
return !!trackingFrames && !!trackers.length && activeTracker !== null;
}
private renderLabelBlock(): JSX.Element {
const { labels } = this.props;
const { activeLabelID } = this.state;
@ -549,9 +843,7 @@ export class ToolsControlComponent extends React.PureComponent<Props, State> {
const {
trackers, canvasInstance, jobInstance, frame, onInteractionStart,
} = this.props;
const {
activeTracker, activeLabelID, fetching, trackingFrames,
} = this.state;
const { activeTracker, activeLabelID, fetching } = this.state;
if (!trackers.length) {
return (
@ -589,27 +881,6 @@ export class ToolsControlComponent extends React.PureComponent<Props, State> {
</Select>
</Col>
</Row>
<Row align='middle' justify='start' style={{ marginTop: '5px' }}>
<Col>
<Text>Tracking frames</Text>
</Col>
<Col offset={2}>
<InputNumber
value={trackingFrames}
step={1}
min={1}
precision={0}
max={jobInstance.stopFrame - frame}
onChange={(value: number | undefined | string | null): void => {
if (typeof value !== 'undefined' && value !== null) {
this.setState({
trackingFrames: +value,
});
}
}}
/>
</Col>
</Row>
<Row align='middle' justify='end'>
<Col>
<Button
@ -797,10 +1068,7 @@ export class ToolsControlComponent extends React.PureComponent<Props, State> {
</Text>
</Col>
</Row>
<Tabs
type='card'
tabBarGutter={8}
>
<Tabs type='card' tabBarGutter={8}>
<Tabs.TabPane key='interactors' tab='Interactors'>
{this.renderLabelBlock()}
{this.renderInteractorBlock()}
@ -822,7 +1090,7 @@ export class ToolsControlComponent extends React.PureComponent<Props, State> {
interactors, detectors, trackers, isActivated, canvasInstance, labels,
} = this.props;
const {
fetching, trackingProgress, approxPolyAccuracy, pointsRecieved, mode,
fetching, approxPolyAccuracy, pointsRecieved, mode, portals,
} = this.state;
if (![...interactors, ...detectors, ...trackers].length) return null;
@ -849,8 +1117,6 @@ export class ToolsControlComponent extends React.PureComponent<Props, State> {
const showAnyContent = !!labels.length;
const showInteractionContent = isActivated && mode === 'interaction' && pointsRecieved;
const showDetectionContent = fetching && mode === 'detection';
const showTrackingContent = fetching && mode === 'tracking' && trackingProgress !== null;
const formattedTrackingProgress = showTrackingContent ? +((trackingProgress as number) * 100).toFixed(0) : null;
const interactionContent: JSX.Element | null = showInteractionContent ? (
<>
@ -863,23 +1129,19 @@ export class ToolsControlComponent extends React.PureComponent<Props, State> {
</>
) : null;
const trackOrDetectModal: JSX.Element | null =
showDetectionContent || showTrackingContent ? (
<Modal
title='Making a server request'
zIndex={Number.MAX_SAFE_INTEGER}
visible
destroyOnClose
closable={false}
footer={[]}
>
<Text>Waiting for a server response..</Text>
<LoadingOutlined style={{ marginLeft: '10px' }} />
{showTrackingContent ? (
<Progress percent={formattedTrackingProgress as number} status='active' />
) : null}
</Modal>
) : null;
const detectionContent: JSX.Element | null = showDetectionContent ? (
<Modal
title='Making a server request'
zIndex={Number.MAX_SAFE_INTEGER}
visible
destroyOnClose
closable={false}
footer={[]}
>
<Text>Waiting for a server response..</Text>
<LoadingOutlined style={{ marginLeft: '10px' }} />
</Modal>
) : null;
return showAnyContent ? (
<>
@ -887,7 +1149,8 @@ export class ToolsControlComponent extends React.PureComponent<Props, State> {
<Icon {...dynamicIconProps} component={AIToolsIcon} />
</CustomPopover>
{interactionContent}
{trackOrDetectModal}
{detectionContent}
{portals}
</>
) : (
<Icon className=' cvat-tools-control cvat-disabled-canvas-control' component={AIToolsIcon} />

@ -43,7 +43,6 @@ interface Props {
toBackground(): void;
toForeground(): void;
resetCuboidPerspective(): void;
activateTracking(): void;
}
function ItemTopComponent(props: Props): JSX.Element {
@ -76,7 +75,6 @@ function ItemTopComponent(props: Props): JSX.Element {
toBackground,
toForeground,
resetCuboidPerspective,
activateTracking,
jobInstance,
} = props;
@ -152,7 +150,6 @@ function ItemTopComponent(props: Props): JSX.Element {
toForeground,
resetCuboidPerspective,
changeColorPickerVisible,
activateTracking,
})}
>
<MoreOutlined />

@ -7,12 +7,7 @@ import Menu from 'antd/lib/menu';
import Button from 'antd/lib/button';
import Modal from 'antd/lib/modal';
import Icon, {
LinkOutlined,
CopyOutlined,
BlockOutlined,
GatewayOutlined,
RetweetOutlined,
DeleteOutlined,
LinkOutlined, CopyOutlined, BlockOutlined, RetweetOutlined, DeleteOutlined,
} from '@ant-design/icons';
import {
@ -50,7 +45,6 @@ interface Props {
toForeground(): void;
resetCuboidPerspective(): void;
changeColorPickerVisible(visible: boolean): void;
activateTracking(): void;
jobInstance: any;
}
@ -98,20 +92,6 @@ function PropagateItem(props: ItemProps): JSX.Element {
);
}
function TrackingItem(props: ItemProps): JSX.Element {
const { toolProps, ...rest } = props;
const { activateTracking } = toolProps;
return (
<Menu.Item {...rest}>
<CVATTooltip title='Run tracking with the active tracker'>
<Button type='link' icon={<GatewayOutlined />} onClick={activateTracking}>
Track
</Button>
</CVATTooltip>
</Menu.Item>
);
}
function SwitchOrientationItem(props: ItemProps): JSX.Element {
const { toolProps, ...rest } = props;
const { switchOrientation } = toolProps;
@ -237,7 +217,6 @@ export default function ItemMenu(props: Props): JSX.Element {
CREATE_URL = 'create_url',
COPY = 'copy',
PROPAGATE = 'propagate',
TRACK = 'track',
SWITCH_ORIENTATION = 'switch_orientation',
RESET_PERSPECIVE = 'reset_perspective',
TO_BACKGROUND = 'to_background',
@ -253,9 +232,6 @@ export default function ItemMenu(props: Props): JSX.Element {
<CreateURLItem key={MenuKeys.CREATE_URL} toolProps={props} />
{!readonly && <MakeCopyItem key={MenuKeys.COPY} toolProps={props} />}
{!readonly && <PropagateItem key={MenuKeys.PROPAGATE} toolProps={props} />}
{is2D && !readonly && objectType === ObjectType.TRACK && shapeType === ShapeType.RECTANGLE && (
<TrackingItem key={MenuKeys.TRACK} toolProps={props} />
)}
{is2D && !readonly && [ShapeType.POLYGON, ShapeType.POLYLINE, ShapeType.CUBOID].includes(shapeType) && (
<SwitchOrientationItem key={MenuKeys.SWITCH_ORIENTATION} toolProps={props} />
)}

@ -39,7 +39,6 @@ interface Props {
changeColor(color: string): void;
collapse(): void;
resetCuboidPerspective(): void;
activateTracking(): void;
}
function objectItemsAreEqual(prevProps: Props, nextProps: Props): boolean {
@ -92,7 +91,6 @@ function ObjectItemComponent(props: Props): JSX.Element {
changeColor,
collapse,
resetCuboidPerspective,
activateTracking,
jobInstance,
} = props;
@ -144,7 +142,6 @@ function ObjectItemComponent(props: Props): JSX.Element {
toBackground={toBackground}
toForeground={toForeground}
resetCuboidPerspective={resetCuboidPerspective}
activateTracking={activateTracking}
/>
<ObjectButtonsContainer readonly={readonly} clientID={clientID} />
{!!attributes.length && (

@ -26,6 +26,7 @@ import { CombinedState, ObjectType } from 'reducers/interfaces';
import { adjustContextImagePosition } from 'components/annotation-page/standard-workspace/context-image/context-image';
import LabelSelector from 'components/label-selector/label-selector';
import getCore from 'cvat-core-wrapper';
import isAbleToChangeFrame from 'utils/is-able-to-change-frame';
import GlobalHotKeys, { KeyMap } from 'utils/mousetrap-react';
import ShortcutsSelect from './shortcuts-select';
@ -168,7 +169,7 @@ function TagAnnotationSidebar(props: StateToProps & DispatchToProps): JSX.Elemen
const onChangeFrame = (): void => {
const frame = Math.min(jobInstance.stopFrame, frameNumber + 1);
if (canvasInstance.isAbleToChangeFrame()) {
if (isAbleToChangeFrame()) {
changeFrame(frame);
}
};

@ -1,4 +1,4 @@
// Copyright (C) 2020 Intel Corporation
// Copyright (C) 2020-2021 Intel Corporation
//
// SPDX-License-Identifier: MIT
@ -6,7 +6,7 @@ import React from 'react';
import { connect } from 'react-redux';
import { LogType } from 'cvat-logger';
import { Canvas } from 'cvat-canvas-wrapper';
import isAbleToChangeFrame from 'utils/is-able-to-change-frame';
import { ThunkDispatch } from 'utils/redux';
import { updateAnnotationsAsync, changeFrameAsync } from 'actions/annotation-actions';
import { CombinedState } from 'reducers/interfaces';
@ -25,7 +25,6 @@ interface StateToProps {
jobInstance: any;
frameNumber: number;
normalizedKeyMap: Record<string, string>;
canvasInstance: Canvas;
outsideDisabled: boolean;
hiddenDisabled: boolean;
keyframeDisabled: boolean;
@ -44,7 +43,6 @@ function mapStateToProps(state: CombinedState, own: OwnProps): StateToProps {
player: {
frame: { number: frameNumber },
},
canvas: { instance: canvasInstance },
},
shortcuts: { normalizedKeyMap },
} = state;
@ -59,7 +57,6 @@ function mapStateToProps(state: CombinedState, own: OwnProps): StateToProps {
normalizedKeyMap,
frameNumber,
jobInstance,
canvasInstance,
outsideDisabled: typeof outsideDisabled === 'undefined' ? false : outsideDisabled,
hiddenDisabled: typeof hiddenDisabled === 'undefined' ? false : hiddenDisabled,
keyframeDisabled: typeof keyframeDisabled === 'undefined' ? false : keyframeDisabled,
@ -217,8 +214,8 @@ class ItemButtonsWrapper extends React.PureComponent<StateToProps & DispatchToPr
}
private changeFrame(frame: number): void {
const { changeFrame, canvasInstance } = this.props;
if (canvasInstance.isAbleToChangeFrame()) {
const { changeFrame } = this.props;
if (isAbleToChangeFrame()) {
changeFrame(frame);
}
}

@ -2,7 +2,7 @@
//
// SPDX-License-Identifier: MIT
import React, { MutableRefObject } from 'react';
import React from 'react';
import copy from 'copy-to-clipboard';
import { connect } from 'react-redux';
@ -22,7 +22,6 @@ import {
ActiveControl, CombinedState, ColorBy, ShapeType,
} from 'reducers/interfaces';
import ObjectStateItemComponent from 'components/annotation-page/standard-workspace/objects-side-bar/object-item';
import { ToolsControlComponent } from 'components/annotation-page/standard-workspace/controls-side-bar/tools-control';
import { shift } from 'utils/math';
import { Canvas } from 'cvat-canvas-wrapper';
import { Canvas3d } from 'cvat-canvas3d-wrapper';
@ -48,7 +47,6 @@ interface StateToProps {
minZLayer: number;
maxZLayer: number;
normalizedKeyMap: Record<string, string>;
aiToolsRef: MutableRefObject<ToolsControlComponent>;
canvasInstance: Canvas | Canvas3d;
}
@ -76,7 +74,6 @@ function mapStateToProps(state: CombinedState, own: OwnProps): StateToProps {
frame: { number: frameNumber },
},
canvas: { instance: canvasInstance, ready, activeControl },
aiToolsRef,
},
settings: {
shapes: { colorBy },
@ -105,7 +102,6 @@ function mapStateToProps(state: CombinedState, own: OwnProps): StateToProps {
minZLayer,
maxZLayer,
normalizedKeyMap,
aiToolsRef,
canvasInstance,
};
}
@ -243,13 +239,6 @@ class ObjectItemContainer extends React.PureComponent<Props> {
collapseOrExpand([objectState], !collapsed);
};
private activateTracking = (): void => {
const { objectState, readonly, aiToolsRef } = this.props;
if (!readonly && aiToolsRef.current && aiToolsRef.current.trackingAvailable()) {
aiToolsRef.current.trackState(objectState);
}
};
private changeColor = (color: string): void => {
const { objectState, colorBy, changeGroupColor } = this.props;
@ -392,7 +381,6 @@ class ObjectItemContainer extends React.PureComponent<Props> {
changeLabel={this.changeLabel}
changeAttribute={this.changeAttribute}
collapse={this.collapse}
activateTracking={this.activateTracking}
resetCuboidPerspective={() => this.resetCuboidPerspective()}
/>
);

@ -16,8 +16,7 @@ import {
copyShape as copyShapeAction,
propagateObject as propagateObjectAction,
} from 'actions/annotation-actions';
import { Canvas } from 'cvat-canvas-wrapper';
import { Canvas3d } from 'cvat-canvas3d-wrapper';
import isAbleToChangeFrame from 'utils/is-able-to-change-frame';
import {
CombinedState, StatesOrdering, ObjectType, ColorBy,
} from 'reducers/interfaces';
@ -42,7 +41,6 @@ interface StateToProps {
maxZLayer: number;
keyMap: KeyMap;
normalizedKeyMap: Record<string, string>;
canvasInstance: Canvas | Canvas3d;
}
interface DispatchToProps {
@ -70,7 +68,6 @@ function mapStateToProps(state: CombinedState): StateToProps {
player: {
frame: { number: frameNumber },
},
canvas: { instance: canvasInstance },
colors,
},
settings: {
@ -108,7 +105,6 @@ function mapStateToProps(state: CombinedState): StateToProps {
maxZLayer,
keyMap,
normalizedKeyMap,
canvasInstance,
};
}
@ -257,7 +253,6 @@ class ObjectsListContainer extends React.PureComponent<Props, State> {
minZLayer,
keyMap,
normalizedKeyMap,
canvasInstance,
colors,
colorBy,
readonly,
@ -437,7 +432,7 @@ class ObjectsListContainer extends React.PureComponent<Props, State> {
const state = activatedStated();
if (state && state.objectType === ObjectType.TRACK) {
const frame = typeof state.keyframes.next === 'number' ? state.keyframes.next : null;
if (frame !== null && canvasInstance.isAbleToChangeFrame()) {
if (frame !== null && isAbleToChangeFrame()) {
changeFrame(frame);
}
}
@ -447,7 +442,7 @@ class ObjectsListContainer extends React.PureComponent<Props, State> {
const state = activatedStated();
if (state && state.objectType === ObjectType.TRACK) {
const frame = typeof state.keyframes.prev === 'number' ? state.keyframes.prev : null;
if (frame !== null && canvasInstance.isAbleToChangeFrame()) {
if (frame !== null && isAbleToChangeFrame()) {
changeFrame(frame);
}
}

@ -30,8 +30,15 @@ import AnnotationTopBarComponent from 'components/annotation-page/top-bar/top-ba
import { Canvas } from 'cvat-canvas-wrapper';
import { Canvas3d } from 'cvat-canvas3d-wrapper';
import {
CombinedState, FrameSpeed, Workspace, PredictorState, DimensionType, ActiveControl, ToolsBlockerState,
CombinedState,
FrameSpeed,
Workspace,
PredictorState,
DimensionType,
ActiveControl,
ToolsBlockerState,
} from 'reducers/interfaces';
import isAbleToChangeFrame from 'utils/is-able-to-change-frame';
import GlobalHotKeys, { KeyMap } from 'utils/mousetrap-react';
import { switchToolsBlockerState } from 'actions/settings-actions';
@ -171,7 +178,7 @@ function mapDispatchToProps(dispatch: any): DispatchToProps {
dispatch(getPredictionsAsync());
}
},
onSwitchToolsBlockerState(toolsBlockerState: ToolsBlockerState):void{
onSwitchToolsBlockerState(toolsBlockerState: ToolsBlockerState): void {
dispatch(switchToolsBlockerState(toolsBlockerState));
},
};
@ -245,21 +252,17 @@ class AnnotationTopBarContainer extends React.PureComponent<Props, State> {
}
private undo = (): void => {
const {
undo, jobInstance, frameNumber, canvasInstance,
} = this.props;
const { undo, jobInstance, frameNumber } = this.props;
if (canvasInstance.isAbleToChangeFrame()) {
if (isAbleToChangeFrame()) {
undo(jobInstance, frameNumber);
}
};
private redo = (): void => {
const {
redo, jobInstance, frameNumber, canvasInstance,
} = this.props;
const { redo, jobInstance, frameNumber } = this.props;
if (canvasInstance.isAbleToChangeFrame()) {
if (isAbleToChangeFrame()) {
redo(jobInstance, frameNumber);
}
};
@ -484,7 +487,6 @@ class AnnotationTopBarContainer extends React.PureComponent<Props, State> {
frameDelay,
playing,
canvasIsReady,
canvasInstance,
onSwitchPlay,
onChangeFrame,
} = this.props;
@ -502,7 +504,7 @@ class AnnotationTopBarContainer extends React.PureComponent<Props, State> {
setTimeout(() => {
const { playing: stillPlaying } = this.props;
if (stillPlaying) {
if (canvasInstance.isAbleToChangeFrame()) {
if (isAbleToChangeFrame()) {
onChangeFrame(frameNumber + 1 + framesSkipped, stillPlaying, framesSkipped + 1);
} else if (jobInstance.task.dimension === DimensionType.DIM_2D) {
onSwitchPlay(false);
@ -526,22 +528,22 @@ class AnnotationTopBarContainer extends React.PureComponent<Props, State> {
}
private changeFrame(frame: number): void {
const { onChangeFrame, canvasInstance } = this.props;
if (canvasInstance.isAbleToChangeFrame()) {
const { onChangeFrame } = this.props;
if (isAbleToChangeFrame()) {
onChangeFrame(frame);
}
}
private searchAnnotations(start: number, stop: number): void {
const { canvasInstance, jobInstance, searchAnnotations } = this.props;
if (canvasInstance.isAbleToChangeFrame()) {
const { jobInstance, searchAnnotations } = this.props;
if (isAbleToChangeFrame()) {
searchAnnotations(jobInstance, start, stop);
}
}
private searchEmptyFrame(start: number, stop: number): void {
const { canvasInstance, jobInstance, searchEmptyFrame } = this.props;
if (canvasInstance.isAbleToChangeFrame()) {
const { jobInstance, searchEmptyFrame } = this.props;
if (isAbleToChangeFrame()) {
searchEmptyFrame(jobInstance, start, stop);
}
}
@ -562,7 +564,6 @@ class AnnotationTopBarContainer extends React.PureComponent<Props, State> {
canvasIsReady,
keyMap,
normalizedKeyMap,
canvasInstance,
predictor,
isTrainingActive,
activeControl,
@ -637,13 +638,13 @@ class AnnotationTopBarContainer extends React.PureComponent<Props, State> {
},
SEARCH_FORWARD: (event: KeyboardEvent | undefined) => {
preventDefault(event);
if (frameNumber + 1 <= stopFrame && canvasIsReady && canvasInstance.isAbleToChangeFrame()) {
if (frameNumber + 1 <= stopFrame && canvasIsReady && isAbleToChangeFrame()) {
searchAnnotations(jobInstance, frameNumber + 1, stopFrame);
}
},
SEARCH_BACKWARD: (event: KeyboardEvent | undefined) => {
preventDefault(event);
if (frameNumber - 1 >= startFrame && canvasIsReady && canvasInstance.isAbleToChangeFrame()) {
if (frameNumber - 1 >= startFrame && canvasIsReady && isAbleToChangeFrame()) {
searchAnnotations(jobInstance, frameNumber - 1, startFrame);
}
},

@ -2,7 +2,6 @@
//
// SPDX-License-Identifier: MIT
import React from 'react';
import { AnyAction } from 'redux';
import { AnnotationActionTypes } from 'actions/annotation-actions';
import { AuthActionTypes } from 'actions/auth-actions';
@ -64,6 +63,7 @@ const defaultState: AnnotationState = {
},
playing: false,
frameAngles: [],
navigationBlocked: false,
contextImage: {
fetching: false,
data: null,
@ -107,7 +107,6 @@ const defaultState: AnnotationState = {
collecting: false,
data: null,
},
aiToolsRef: React.createRef(),
colors: [],
sidebarCollapsed: false,
appearanceCollapsed: false,
@ -976,7 +975,9 @@ export default (state = defaultState, action: AnyAction): AnnotationState => {
}
case AnnotationActionTypes.FETCH_ANNOTATIONS_SUCCESS: {
const { activatedStateID } = state.annotations;
const { states, minZ, maxZ } = action.payload;
const {
states, history, minZ, maxZ,
} = action.payload;
return {
...state,
@ -984,6 +985,7 @@ export default (state = defaultState, action: AnyAction): AnnotationState => {
...state.annotations,
activatedStateID: updateActivatedStateID(states, activatedStateID),
states,
history,
zLayer: {
min: minZ,
max: maxZ,
@ -1205,6 +1207,15 @@ export default (state = defaultState, action: AnyAction): AnnotationState => {
},
};
}
case AnnotationActionTypes.SWITCH_NAVIGATION_BLOCKED: {
return {
...state,
player: {
...state.player,
navigationBlocked: action.payload.navigationBlocked,
},
};
}
case AnnotationActionTypes.CLOSE_JOB:
case AuthActionTypes.LOGOUT_SUCCESS: {
return { ...defaultState };

@ -2,7 +2,6 @@
//
// SPDX-License-Identifier: MIT
import { MutableRefObject } from 'react';
import { Canvas3d } from 'cvat-canvas3d/src/typescript/canvas3d';
import { Canvas, RectDrawingMethod, CuboidDrawingMethod } from 'cvat-canvas-wrapper';
import { IntelligentScissors } from 'utils/opencv-wrapper/intelligent-scissors';
@ -512,6 +511,7 @@ export interface AnnotationState {
delay: number;
changeTime: number | null;
};
navigationBlocked: boolean;
playing: boolean;
frameAngles: number[];
contextImage: {
@ -570,7 +570,6 @@ export interface AnnotationState {
appearanceCollapsed: boolean;
workspace: Workspace;
predictor: PredictorState;
aiToolsRef: MutableRefObject<any>;
}
export enum Workspace {

@ -0,0 +1,13 @@
// Copyright (C) 2021 Intel Corporation
//
// SPDX-License-Identifier: MIT
import { getCVATStore } from 'cvat-store';
import { CombinedState } from 'reducers/interfaces';
export default function isAbleToChangeFrame(): boolean {
const store = getCVATStore();
const state: CombinedState = store.getState();
return state.annotation.canvas.instance.isAbleToChangeFrame() && !state.annotation.player.navigationBlocked;
}

@ -1,27 +0,0 @@
// Copyright (C) 2020 Intel Corporation
//
// SPDX-License-Identifier: MIT
export default function range(x: number, y?: number): number[] {
if (typeof x !== 'undefined' && typeof y !== 'undefined') {
if (typeof x !== 'number' && typeof y !== 'number') {
throw new Error(`Range() expects number arguments. Got ${typeof x}, ${typeof y}`);
}
if (x >= y) {
throw new Error(`Range() expects the first argument less or equal than the second. Got ${x}, ${y}`);
}
return Array.from(Array(y - x), (_: number, i: number) => i + x);
}
if (typeof x !== 'undefined') {
if (typeof x !== 'number') {
throw new Error(`Range() expects number arguments. Got ${typeof x}`);
}
return [...Array(x).keys()];
}
return [];
}

@ -191,8 +191,8 @@ class LambdaFunction:
elif self.kind == LambdaType.TRACKER:
payload.update({
"image": self._get_image(db_task, data["frame"], quality),
"shape": data.get("shape", None),
"state": data.get("state", None)
"shapes": data.get("shapes", []),
"states": data.get("states", [])
})
else:
raise ValidationError(

@ -0,0 +1,73 @@
metadata:
name: pth-foolwood-siammask
namespace: cvat
annotations:
name: SiamMask
type: tracker
spec:
framework: pytorch
spec:
description: Fast Online Object Tracking and Segmentation
runtime: 'python:3.6'
handler: main:handler
eventTimeout: 30s
env:
- name: PYTHONPATH
value: /opt/nuclio/SiamMask:/opt/nuclio/SiamMask/experiments/siammask_sharp
build:
image: cvat/pth.foolwood.siammask
baseImage: nvidia/cuda:11.1-devel-ubuntu20.04
directives:
preCopy:
- kind: ENV
value: PATH="/root/miniconda3/bin:${PATH}"
- kind: ARG
value: PATH="/root/miniconda3/bin:${PATH}"
- kind: RUN
value: apt update && apt install -y --no-install-recommends wget git ca-certificates libglib2.0-0 libsm6 libxrender1 libxext6 && rm -rf /var/lib/apt/lists/*
- kind: RUN
value: wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh &&
chmod +x Miniconda3-latest-Linux-x86_64.sh && ./Miniconda3-latest-Linux-x86_64.sh -b &&
rm -f Miniconda3-latest-Linux-x86_64.sh
- kind: WORKDIR
value: /opt/nuclio
- kind: RUN
value: conda create -y -n siammask python=3.7
- kind: SHELL
value: '["conda", "run", "-n", "siammask", "/bin/bash", "-c"]'
- kind: RUN
value: git clone https://github.com/foolwood/SiamMask.git
- kind: RUN
value: pip install -r SiamMask/requirements.txt jsonpickle
- kind: RUN
value: pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html
- kind: RUN
value: conda install -y gcc_linux-64
- kind: RUN
value: cd SiamMask && bash make.sh && cd -
- kind: RUN
value: wget -P SiamMask/experiments/siammask_sharp http://www.robots.ox.ac.uk/~qwang/SiamMask_DAVIS.pth
- kind: ENTRYPOINT
value: '["conda", "run", "-n", "siammask"]'
triggers:
myHttpTrigger:
maxWorkers: 2
kind: 'http'
workerAvailabilityTimeoutMilliseconds: 10000
attributes:
maxRequestBodySize: 33554432 # 32MB
resources:
limits:
nvidia.com/gpu: 1
platform:
attributes:
restartPolicy:
name: always
maximumRetryCount: 3
mountMode: volume

@ -18,10 +18,20 @@ spec:
build:
image: cvat/pth.foolwood.siammask
baseImage: continuumio/miniconda3
baseImage: ubuntu:20.04
directives:
preCopy:
- kind: ENV
value: PATH="/root/miniconda3/bin:${PATH}"
- kind: ARG
value: PATH="/root/miniconda3/bin:${PATH}"
- kind: RUN
value: apt update && apt install -y --no-install-recommends wget git ca-certificates libglib2.0-0 libsm6 libxrender1 libxext6 && rm -rf /var/lib/apt/lists/*
- kind: RUN
value: wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh &&
chmod +x Miniconda3-latest-Linux-x86_64.sh && ./Miniconda3-latest-Linux-x86_64.sh -b &&
rm -f Miniconda3-latest-Linux-x86_64.sh
- kind: WORKDIR
value: /opt/nuclio
- kind: RUN

@ -17,11 +17,18 @@ def handler(context, event):
context.logger.info("Run SiamMask model")
data = event.body
buf = io.BytesIO(base64.b64decode(data["image"]))
shape = data.get("shape")
state = data.get("state")
shapes = data.get("shapes")
states = data.get("states")
image = Image.open(buf)
results = context.user_data.model.infer(image, shape, state)
results = {
'shapes': [],
'states': []
}
for i, shape in enumerate(shapes):
shape, state = context.user_data.model.infer(image, shape, states[i] if i < len(states) else None)
results['shapes'].append(shape)
results['states'].append(state)
return context.Response(body=json.dumps(results), headers={},
content_type='application/json', status_code=200)

@ -62,5 +62,5 @@ class ModelHandler:
shape = state['ploygon'].flatten().tolist()
state = self.encode_state(state)
return {"shape": shape, "state": state}
return shape, state

Loading…
Cancel
Save