You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
1250 lines
49 KiB
TypeScript
1250 lines
49 KiB
TypeScript
// Copyright (C) 2020-2022 Intel Corporation
|
|
//
|
|
// SPDX-License-Identifier: MIT
|
|
|
|
import React, { ReactPortal } from 'react';
|
|
import ReactDOM from 'react-dom';
|
|
import { connect } from 'react-redux';
|
|
import Icon, {
|
|
EnvironmentFilled,
|
|
EnvironmentOutlined,
|
|
LoadingOutlined,
|
|
QuestionCircleOutlined,
|
|
} from '@ant-design/icons';
|
|
import Popover from 'antd/lib/popover';
|
|
import Select from 'antd/lib/select';
|
|
import Button from 'antd/lib/button';
|
|
import Modal from 'antd/lib/modal';
|
|
import Text from 'antd/lib/typography/Text';
|
|
import Tabs from 'antd/lib/tabs';
|
|
import { Row, Col } from 'antd/lib/grid';
|
|
import notification from 'antd/lib/notification';
|
|
import message from 'antd/lib/message';
|
|
import Dropdown from 'antd/lib/dropdown';
|
|
import lodash from 'lodash';
|
|
|
|
import { AIToolsIcon } from 'icons';
|
|
import { Canvas, convertShapesForInteractor } from 'cvat-canvas-wrapper';
|
|
import getCore from 'cvat-core-wrapper';
|
|
import openCVWrapper from 'utils/opencv-wrapper/opencv-wrapper';
|
|
import {
|
|
CombinedState, ActiveControl, Model, ObjectType, ShapeType, ToolsBlockerState, ModelAttribute,
|
|
} from 'reducers/interfaces';
|
|
import {
|
|
interactWithCanvas,
|
|
switchNavigationBlocked as switchNavigationBlockedAction,
|
|
fetchAnnotationsAsync,
|
|
updateAnnotationsAsync,
|
|
createAnnotationsAsync,
|
|
} from 'actions/annotation-actions';
|
|
import DetectorRunner, { DetectorRequestBody } from 'components/model-runner-modal/detector-runner';
|
|
import LabelSelector from 'components/label-selector/label-selector';
|
|
import CVATTooltip from 'components/common/cvat-tooltip';
|
|
import { Attribute, Label } from 'components/labels-editor/common';
|
|
|
|
import ApproximationAccuracy, {
|
|
thresholdFromAccuracy,
|
|
} from 'components/annotation-page/standard-workspace/controls-side-bar/approximation-accuracy';
|
|
import { switchToolsBlockerState } from 'actions/settings-actions';
|
|
import withVisibilityHandling from './handle-popover-visibility';
|
|
import ToolsTooltips from './interactor-tooltips';
|
|
|
|
interface StateToProps {
|
|
canvasInstance: Canvas;
|
|
labels: any[];
|
|
states: any[];
|
|
activeLabelID: number;
|
|
jobInstance: any;
|
|
isActivated: boolean;
|
|
frame: number;
|
|
interactors: Model[];
|
|
detectors: Model[];
|
|
trackers: Model[];
|
|
curZOrder: number;
|
|
defaultApproxPolyAccuracy: number;
|
|
toolsBlockerState: ToolsBlockerState;
|
|
frameIsDeleted: boolean;
|
|
}
|
|
|
|
interface DispatchToProps {
|
|
onInteractionStart(activeInteractor: Model, activeLabelID: number): void;
|
|
updateAnnotations(statesToUpdate: any[]): void;
|
|
createAnnotations(sessionInstance: any, frame: number, statesToCreate: any[]): void;
|
|
fetchAnnotations(): void;
|
|
onSwitchToolsBlockerState(toolsBlockerState: ToolsBlockerState): void;
|
|
switchNavigationBlocked(navigationBlocked: boolean): void;
|
|
}
|
|
|
|
const core = getCore();
|
|
const CustomPopover = withVisibilityHandling(Popover, 'tools-control');
|
|
|
|
function mapStateToProps(state: CombinedState): StateToProps {
|
|
const {
|
|
annotation: {
|
|
job: { instance: jobInstance, labels },
|
|
canvas: { instance: canvasInstance, activeControl },
|
|
player: {
|
|
frame: { number: frame, data: { deleted: frameIsDeleted } },
|
|
},
|
|
annotations: {
|
|
zLayer: { cur: curZOrder },
|
|
states,
|
|
},
|
|
drawing: { activeLabelID },
|
|
},
|
|
models: {
|
|
interactors, detectors, trackers,
|
|
},
|
|
settings: {
|
|
workspace: { toolsBlockerState, defaultApproxPolyAccuracy },
|
|
},
|
|
} = state;
|
|
|
|
return {
|
|
interactors,
|
|
detectors,
|
|
trackers,
|
|
isActivated: activeControl === ActiveControl.AI_TOOLS,
|
|
activeLabelID,
|
|
labels,
|
|
states,
|
|
canvasInstance: canvasInstance as Canvas,
|
|
jobInstance,
|
|
frame,
|
|
curZOrder,
|
|
defaultApproxPolyAccuracy,
|
|
toolsBlockerState,
|
|
frameIsDeleted,
|
|
};
|
|
}
|
|
|
|
const mapDispatchToProps = {
|
|
onInteractionStart: interactWithCanvas,
|
|
updateAnnotations: updateAnnotationsAsync,
|
|
createAnnotations: createAnnotationsAsync,
|
|
fetchAnnotations: fetchAnnotationsAsync,
|
|
onSwitchToolsBlockerState: switchToolsBlockerState,
|
|
switchNavigationBlocked: switchNavigationBlockedAction,
|
|
};
|
|
|
|
type Props = StateToProps & DispatchToProps;
|
|
interface TrackedShape {
|
|
clientID: number;
|
|
serverlessState: any;
|
|
shapePoints: number[];
|
|
trackerModel: Model;
|
|
}
|
|
|
|
interface State {
|
|
activeInteractor: Model | null;
|
|
activeLabelID: number;
|
|
activeTracker: Model | null;
|
|
trackedShapes: TrackedShape[];
|
|
fetching: boolean;
|
|
pointsRecieved: boolean;
|
|
approxPolyAccuracy: number;
|
|
mode: 'detection' | 'interaction' | 'tracking';
|
|
portals: React.ReactPortal[];
|
|
}
|
|
|
|
function trackedRectangleMapper(shape: number[]): number[] {
|
|
return shape.reduce(
|
|
(acc: number[], value: number, index: number): number[] => {
|
|
if (index % 2) {
|
|
// y
|
|
acc[1] = Math.min(acc[1], value);
|
|
acc[3] = Math.max(acc[3], value);
|
|
} else {
|
|
// x
|
|
acc[0] = Math.min(acc[0], value);
|
|
acc[2] = Math.max(acc[2], value);
|
|
}
|
|
return acc;
|
|
},
|
|
[Number.MAX_SAFE_INTEGER, Number.MAX_SAFE_INTEGER, Number.MIN_SAFE_INTEGER, Number.MIN_SAFE_INTEGER],
|
|
);
|
|
}
|
|
|
|
function registerPlugin(): (callback: null | (() => void)) => void {
|
|
let onTrigger: null | (() => void) = null;
|
|
const listener = {
|
|
name: 'Remove annotations listener',
|
|
description: 'Tracker needs to know when annotations is reset in the job',
|
|
cvat: {
|
|
classes: {
|
|
Job: {
|
|
prototype: {
|
|
annotations: {
|
|
clear: {
|
|
leave(self: any, result: any) {
|
|
if (typeof onTrigger === 'function') {
|
|
onTrigger();
|
|
}
|
|
return result;
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
};
|
|
|
|
core.plugins.register(listener);
|
|
|
|
return (callback: null | (() => void)) => {
|
|
onTrigger = callback;
|
|
};
|
|
}
|
|
|
|
const onRemoveAnnotations = registerPlugin();
|
|
|
|
export class ToolsControlComponent extends React.PureComponent<Props, State> {
|
|
private interaction: {
|
|
id: string | null;
|
|
isAborted: boolean;
|
|
latestResponse: number[][];
|
|
latestResult: number[][];
|
|
latestRequest: null | {
|
|
interactor: Model;
|
|
data: {
|
|
frame: number;
|
|
neg_points: number[][];
|
|
pos_points: number[][];
|
|
};
|
|
} | null;
|
|
hideMessage: (() => void) | null;
|
|
};
|
|
|
|
public constructor(props: Props) {
|
|
super(props);
|
|
this.state = {
|
|
activeInteractor: props.interactors.length ? props.interactors[0] : null,
|
|
activeTracker: props.trackers.length ? props.trackers[0] : null,
|
|
activeLabelID: props.labels.length ? props.labels[0].id : null,
|
|
approxPolyAccuracy: props.defaultApproxPolyAccuracy,
|
|
trackedShapes: [],
|
|
fetching: false,
|
|
pointsRecieved: false,
|
|
mode: 'interaction',
|
|
portals: [],
|
|
};
|
|
|
|
this.interaction = {
|
|
id: null,
|
|
isAborted: false,
|
|
latestResponse: [],
|
|
latestResult: [],
|
|
latestRequest: null,
|
|
hideMessage: null,
|
|
};
|
|
}
|
|
|
|
public componentDidMount(): void {
|
|
const { canvasInstance } = this.props;
|
|
onRemoveAnnotations(() => {
|
|
this.setState({ trackedShapes: [] });
|
|
});
|
|
|
|
this.setState({
|
|
portals: this.collectTrackerPortals(),
|
|
});
|
|
canvasInstance.html().addEventListener('canvas.interacted', this.interactionListener);
|
|
canvasInstance.html().addEventListener('canvas.canceled', this.cancelListener);
|
|
}
|
|
|
|
public componentDidUpdate(prevProps: Props, prevState: State): void {
|
|
const {
|
|
isActivated, defaultApproxPolyAccuracy, canvasInstance, states,
|
|
} = this.props;
|
|
const { approxPolyAccuracy, mode, activeTracker } = this.state;
|
|
|
|
if (prevProps.states !== states || prevState.activeTracker !== activeTracker) {
|
|
this.setState({
|
|
portals: this.collectTrackerPortals(),
|
|
});
|
|
}
|
|
|
|
if (prevProps.isActivated && !isActivated) {
|
|
window.removeEventListener('contextmenu', this.contextmenuDisabler);
|
|
// hide interaction message if exists
|
|
if (this.interaction.hideMessage) {
|
|
this.interaction.hideMessage();
|
|
this.interaction.hideMessage = null;
|
|
}
|
|
} else if (!prevProps.isActivated && isActivated) {
|
|
// reset flags when start interaction/tracking
|
|
this.interaction = {
|
|
id: null,
|
|
isAborted: false,
|
|
latestResponse: [],
|
|
latestResult: [],
|
|
latestRequest: null,
|
|
hideMessage: null,
|
|
};
|
|
|
|
this.setState({
|
|
approxPolyAccuracy: defaultApproxPolyAccuracy,
|
|
pointsRecieved: false,
|
|
});
|
|
window.addEventListener('contextmenu', this.contextmenuDisabler);
|
|
}
|
|
|
|
if (prevState.approxPolyAccuracy !== approxPolyAccuracy) {
|
|
if (isActivated && mode === 'interaction' && this.interaction.latestResponse.length) {
|
|
this.approximateResponsePoints(this.interaction.latestResponse).then((points: number[][]) => {
|
|
this.interaction.latestResult = points;
|
|
canvasInstance.interact({
|
|
enabled: true,
|
|
intermediateShape: {
|
|
shapeType: ShapeType.POLYGON,
|
|
points: this.interaction.latestResult.flat(),
|
|
},
|
|
onChangeToolsBlockerState: this.onChangeToolsBlockerState,
|
|
});
|
|
});
|
|
}
|
|
}
|
|
|
|
this.checkTrackedStates(prevProps);
|
|
}
|
|
|
|
public componentWillUnmount(): void {
|
|
const { canvasInstance } = this.props;
|
|
onRemoveAnnotations(null);
|
|
canvasInstance.html().removeEventListener('canvas.interacted', this.interactionListener);
|
|
canvasInstance.html().removeEventListener('canvas.canceled', this.cancelListener);
|
|
}
|
|
|
|
private contextmenuDisabler = (e: MouseEvent): void => {
|
|
if (
|
|
e.target &&
|
|
(e.target as Element).classList &&
|
|
(e.target as Element).classList.toString().includes('ant-modal')
|
|
) {
|
|
e.preventDefault();
|
|
}
|
|
};
|
|
|
|
private cancelListener = async (): Promise<void> => {
|
|
const { fetching } = this.state;
|
|
if (fetching) {
|
|
// user pressed ESC
|
|
this.setState({ fetching: false });
|
|
this.interaction.isAborted = true;
|
|
}
|
|
};
|
|
|
|
private runInteractionRequest = async (interactionId: string): Promise<void> => {
|
|
const { jobInstance, canvasInstance } = this.props;
|
|
const { activeInteractor, fetching } = this.state;
|
|
|
|
const { id, latestRequest } = this.interaction;
|
|
if (id !== interactionId || !latestRequest || fetching) {
|
|
// current interaction request is not relevant (new interaction session has started)
|
|
// or a user didn't add more points
|
|
// or one server request is on processing
|
|
return;
|
|
}
|
|
|
|
const { interactor, data } = latestRequest;
|
|
this.interaction.latestRequest = null;
|
|
|
|
try {
|
|
this.interaction.hideMessage = message.loading(`Waiting a response from ${activeInteractor?.name}..`, 0);
|
|
try {
|
|
// run server request
|
|
this.setState({ fetching: true });
|
|
const response = await core.lambda.call(jobInstance.taskId, interactor, data);
|
|
// approximation with cv.approxPolyDP
|
|
const approximated = await this.approximateResponsePoints(response);
|
|
|
|
if (this.interaction.id !== interactionId || this.interaction.isAborted) {
|
|
// new interaction session or the session is aborted
|
|
return;
|
|
}
|
|
|
|
this.interaction.latestResponse = response;
|
|
this.interaction.latestResult = approximated;
|
|
|
|
this.setState({ pointsRecieved: !!response.length });
|
|
} finally {
|
|
if (this.interaction.id === interactionId && this.interaction.hideMessage) {
|
|
this.interaction.hideMessage();
|
|
this.interaction.hideMessage = null;
|
|
}
|
|
|
|
this.setState({ fetching: false });
|
|
}
|
|
|
|
if (this.interaction.latestResult.length) {
|
|
canvasInstance.interact({
|
|
enabled: true,
|
|
intermediateShape: {
|
|
shapeType: ShapeType.POLYGON,
|
|
points: this.interaction.latestResult.flat(),
|
|
},
|
|
onChangeToolsBlockerState: this.onChangeToolsBlockerState,
|
|
});
|
|
}
|
|
|
|
setTimeout(() => this.runInteractionRequest(interactionId));
|
|
} catch (err: any) {
|
|
notification.error({
|
|
description: err.toString(),
|
|
message: 'Interaction error occured',
|
|
});
|
|
}
|
|
};
|
|
|
|
private onInteraction = (e: Event): void => {
|
|
const { frame, isActivated } = this.props;
|
|
const { activeInteractor } = this.state;
|
|
|
|
if (!isActivated) {
|
|
return;
|
|
}
|
|
|
|
if (!this.interaction.id) {
|
|
this.interaction.id = lodash.uniqueId('interaction_');
|
|
}
|
|
|
|
const { shapesUpdated, isDone, shapes } = (e as CustomEvent).detail;
|
|
if (isDone) {
|
|
// make an object from current result
|
|
// do not make one more request
|
|
// prevent future requests if possible
|
|
this.interaction.isAborted = true;
|
|
this.interaction.latestRequest = null;
|
|
if (this.interaction.latestResult.length) {
|
|
this.constructFromPoints(this.interaction.latestResult);
|
|
}
|
|
} else if (shapesUpdated) {
|
|
const interactor = activeInteractor as Model;
|
|
this.interaction.latestRequest = {
|
|
interactor,
|
|
data: {
|
|
frame,
|
|
pos_points: convertShapesForInteractor(shapes, 0),
|
|
neg_points: convertShapesForInteractor(shapes, 2),
|
|
},
|
|
};
|
|
|
|
this.runInteractionRequest(this.interaction.id);
|
|
}
|
|
};
|
|
|
|
private onTracking = async (e: Event): Promise<void> => {
|
|
const { trackedShapes, activeTracker } = this.state;
|
|
const {
|
|
isActivated, jobInstance, frame, curZOrder, fetchAnnotations,
|
|
} = this.props;
|
|
|
|
if (!isActivated) {
|
|
return;
|
|
}
|
|
|
|
const { activeLabelID } = this.state;
|
|
const [label] = jobInstance.labels.filter((_label: any): boolean => _label.id === activeLabelID);
|
|
|
|
const { isDone, shapesUpdated } = (e as CustomEvent).detail;
|
|
if (!isDone || !shapesUpdated) {
|
|
return;
|
|
}
|
|
|
|
try {
|
|
const { points } = (e as CustomEvent).detail.shapes[0];
|
|
const state = new core.classes.ObjectState({
|
|
shapeType: ShapeType.RECTANGLE,
|
|
objectType: ObjectType.TRACK,
|
|
zOrder: curZOrder,
|
|
label,
|
|
points,
|
|
frame,
|
|
occluded: false,
|
|
attributes: {},
|
|
descriptions: [`Trackable (${activeTracker?.name})`],
|
|
});
|
|
|
|
const [clientID] = await jobInstance.annotations.put([state]);
|
|
this.setState({
|
|
trackedShapes: [
|
|
...trackedShapes,
|
|
{
|
|
clientID,
|
|
serverlessState: null,
|
|
shapePoints: points,
|
|
trackerModel: activeTracker as Model,
|
|
},
|
|
],
|
|
});
|
|
|
|
// update annotations on a canvas
|
|
fetchAnnotations();
|
|
} catch (err: any) {
|
|
notification.error({
|
|
description: err.toString(),
|
|
message: 'Tracking error occured',
|
|
});
|
|
}
|
|
};
|
|
|
|
private interactionListener = async (e: Event): Promise<void> => {
|
|
const { mode } = this.state;
|
|
|
|
if (mode === 'interaction') {
|
|
await this.onInteraction(e);
|
|
}
|
|
|
|
if (mode === 'tracking') {
|
|
await this.onTracking(e);
|
|
}
|
|
};
|
|
|
|
private setActiveInteractor = (value: string): void => {
|
|
const { interactors } = this.props;
|
|
this.setState({
|
|
activeInteractor: interactors.filter((interactor: Model) => interactor.id === value)[0],
|
|
});
|
|
};
|
|
|
|
private setActiveTracker = (value: string): void => {
|
|
const { trackers } = this.props;
|
|
this.setState({
|
|
activeTracker: trackers.filter((tracker: Model) => tracker.id === value)[0],
|
|
});
|
|
};
|
|
|
|
private onChangeToolsBlockerState = (event: string): void => {
|
|
const { isActivated, onSwitchToolsBlockerState } = this.props;
|
|
if (isActivated && event === 'keydown') {
|
|
onSwitchToolsBlockerState({ algorithmsLocked: true });
|
|
} else if (isActivated && event === 'keyup') {
|
|
onSwitchToolsBlockerState({ algorithmsLocked: false });
|
|
}
|
|
};
|
|
|
|
private collectTrackerPortals(): React.ReactPortal[] {
|
|
const { states, fetchAnnotations } = this.props;
|
|
const { trackedShapes, activeTracker } = this.state;
|
|
|
|
const trackedClientIDs = trackedShapes.map((trackedShape: TrackedShape) => trackedShape.clientID);
|
|
const portals = !activeTracker ?
|
|
[] :
|
|
states
|
|
.filter((objectState) => objectState.objectType === 'track' && objectState.shapeType === 'rectangle')
|
|
.map((objectState: any): React.ReactPortal | null => {
|
|
const { clientID } = objectState;
|
|
const selectorID = `#cvat-objects-sidebar-state-item-${clientID}`;
|
|
let targetElement = window.document.querySelector(
|
|
`${selectorID} .cvat-object-item-button-prev-keyframe`,
|
|
) as HTMLElement;
|
|
|
|
const isTracked = trackedClientIDs.includes(clientID);
|
|
if (targetElement) {
|
|
targetElement = targetElement.parentElement?.parentElement as HTMLElement;
|
|
return ReactDOM.createPortal(
|
|
<Col>
|
|
{isTracked ? (
|
|
<CVATTooltip overlay='Disable tracking'>
|
|
<EnvironmentFilled
|
|
onClick={() => {
|
|
const filteredStates = trackedShapes.filter(
|
|
(trackedShape: TrackedShape) => trackedShape.clientID !== clientID,
|
|
);
|
|
/* eslint no-param-reassign: ["error", { "props": false }] */
|
|
objectState.descriptions = [];
|
|
objectState.save().then(() => {
|
|
this.setState({
|
|
trackedShapes: filteredStates,
|
|
});
|
|
});
|
|
fetchAnnotations();
|
|
}}
|
|
/>
|
|
</CVATTooltip>
|
|
) : (
|
|
<CVATTooltip overlay={`Enable tracking using ${activeTracker.name}`}>
|
|
<EnvironmentOutlined
|
|
onClick={() => {
|
|
objectState.descriptions = [`Trackable (${activeTracker.name})`];
|
|
objectState.save().then(() => {
|
|
this.setState({
|
|
trackedShapes: [
|
|
...trackedShapes,
|
|
{
|
|
clientID,
|
|
serverlessState: null,
|
|
shapePoints: objectState.points,
|
|
trackerModel: activeTracker,
|
|
},
|
|
],
|
|
});
|
|
});
|
|
fetchAnnotations();
|
|
}}
|
|
/>
|
|
</CVATTooltip>
|
|
)}
|
|
</Col>,
|
|
targetElement,
|
|
);
|
|
}
|
|
|
|
return null;
|
|
})
|
|
.filter((portal: ReactPortal | null) => portal !== null);
|
|
|
|
return portals as React.ReactPortal[];
|
|
}
|
|
|
|
private async checkTrackedStates(prevProps: Props): Promise<void> {
|
|
const {
|
|
frame,
|
|
jobInstance,
|
|
states: objectStates,
|
|
trackers,
|
|
fetchAnnotations,
|
|
switchNavigationBlocked,
|
|
} = this.props;
|
|
const { trackedShapes } = this.state;
|
|
let withServerRequest = false;
|
|
|
|
type AccumulatorType = {
|
|
statefull: {
|
|
[index: string]: {
|
|
// tracker id
|
|
clientIDs: number[];
|
|
states: any[];
|
|
shapes: number[][];
|
|
};
|
|
};
|
|
stateless: {
|
|
[index: string]: {
|
|
// tracker id
|
|
clientIDs: number[];
|
|
shapes: number[][];
|
|
};
|
|
};
|
|
};
|
|
|
|
if (prevProps.frame !== frame && trackedShapes.length) {
|
|
// 1. find all trackable objects on the current frame
|
|
// 2. devide them into two groups: with relevant state, without relevant state
|
|
const trackingData = trackedShapes.reduce<AccumulatorType>(
|
|
(acc: AccumulatorType, trackedShape: TrackedShape): AccumulatorType => {
|
|
const {
|
|
serverlessState, shapePoints, clientID, trackerModel,
|
|
} = trackedShape;
|
|
const [clientState] = objectStates.filter((_state: any): boolean => _state.clientID === clientID);
|
|
|
|
if (
|
|
!clientState ||
|
|
clientState.keyframes.prev !== frame - 1 ||
|
|
clientState.keyframes.last >= frame
|
|
) {
|
|
return acc;
|
|
}
|
|
|
|
if (clientState && !clientState.outside) {
|
|
const { points } = clientState;
|
|
withServerRequest = true;
|
|
const stateIsRelevant =
|
|
serverlessState !== null &&
|
|
points.length === shapePoints.length &&
|
|
points.every((coord: number, i: number) => coord === shapePoints[i]);
|
|
if (stateIsRelevant) {
|
|
const container = acc.statefull[trackerModel.id] || {
|
|
clientIDs: [],
|
|
shapes: [],
|
|
states: [],
|
|
};
|
|
container.clientIDs.push(clientID);
|
|
container.shapes.push(points);
|
|
container.states.push(serverlessState);
|
|
acc.statefull[trackerModel.id] = container;
|
|
} else {
|
|
const container = acc.stateless[trackerModel.id] || {
|
|
clientIDs: [],
|
|
shapes: [],
|
|
};
|
|
container.clientIDs.push(clientID);
|
|
container.shapes.push(points);
|
|
acc.stateless[trackerModel.id] = container;
|
|
}
|
|
}
|
|
|
|
return acc;
|
|
},
|
|
{
|
|
statefull: {},
|
|
stateless: {},
|
|
},
|
|
);
|
|
|
|
try {
|
|
if (withServerRequest) {
|
|
switchNavigationBlocked(true);
|
|
}
|
|
// 3. get relevant state for the second group
|
|
for (const trackerID of Object.keys(trackingData.stateless)) {
|
|
let hideMessage = null;
|
|
try {
|
|
const [tracker] = trackers.filter((_tracker: Model) => _tracker.id === trackerID);
|
|
if (!tracker) {
|
|
throw new Error(`Suitable tracker with ID ${trackerID} not found in tracker list`);
|
|
}
|
|
|
|
const trackableObjects = trackingData.stateless[trackerID];
|
|
const numOfObjects = trackableObjects.clientIDs.length;
|
|
hideMessage = message.loading(
|
|
`${tracker.name}: states are being initialized for ${numOfObjects} ${
|
|
numOfObjects > 1 ? 'objects' : 'object'
|
|
} ..`,
|
|
0,
|
|
);
|
|
// eslint-disable-next-line no-await-in-loop
|
|
const response = await core.lambda.call(jobInstance.taskId, tracker, {
|
|
frame: frame - 1,
|
|
shapes: trackableObjects.shapes,
|
|
});
|
|
|
|
const { states: serverlessStates } = response;
|
|
const statefullContainer = trackingData.statefull[trackerID] || {
|
|
clientIDs: [],
|
|
shapes: [],
|
|
states: [],
|
|
};
|
|
|
|
Array.prototype.push.apply(statefullContainer.clientIDs, trackableObjects.clientIDs);
|
|
Array.prototype.push.apply(statefullContainer.shapes, trackableObjects.shapes);
|
|
Array.prototype.push.apply(statefullContainer.states, serverlessStates);
|
|
trackingData.statefull[trackerID] = statefullContainer;
|
|
delete trackingData.stateless[trackerID];
|
|
} catch (error: any) {
|
|
notification.error({
|
|
message: 'Tracker initialization error',
|
|
description: error.toString(),
|
|
});
|
|
} finally {
|
|
if (hideMessage) hideMessage();
|
|
}
|
|
}
|
|
|
|
for (const trackerID of Object.keys(trackingData.statefull)) {
|
|
// 4. run tracking for all the objects
|
|
let hideMessage = null;
|
|
try {
|
|
const [tracker] = trackers.filter((_tracker: Model) => _tracker.id === trackerID);
|
|
if (!tracker) {
|
|
throw new Error(`Suitable tracker with ID ${trackerID} not found in tracker list`);
|
|
}
|
|
|
|
const trackableObjects = trackingData.statefull[trackerID];
|
|
const numOfObjects = trackableObjects.clientIDs.length;
|
|
hideMessage = message.loading(
|
|
`${tracker.name}: ${numOfObjects} ${
|
|
numOfObjects > 1 ? 'objects are' : 'object is'
|
|
} being tracked..`,
|
|
0,
|
|
);
|
|
// eslint-disable-next-line no-await-in-loop
|
|
const response = await core.lambda.call(jobInstance.taskId, tracker, {
|
|
frame: frame - 1,
|
|
shapes: trackableObjects.shapes,
|
|
states: trackableObjects.states,
|
|
});
|
|
|
|
response.shapes = response.shapes.map(trackedRectangleMapper);
|
|
for (let i = 0; i < trackableObjects.clientIDs.length; i++) {
|
|
const clientID = trackableObjects.clientIDs[i];
|
|
const shape = response.shapes[i];
|
|
const state = response.states[i];
|
|
const [objectState] = objectStates.filter(
|
|
(_state: any): boolean => _state.clientID === clientID,
|
|
);
|
|
const [trackedShape] = trackedShapes.filter(
|
|
(_trackedShape: TrackedShape) => _trackedShape.clientID === clientID,
|
|
);
|
|
objectState.points = shape;
|
|
objectState.save().then(() => {
|
|
trackedShape.serverlessState = state;
|
|
trackedShape.shapePoints = shape;
|
|
});
|
|
}
|
|
} catch (error: any) {
|
|
notification.error({
|
|
message: 'Tracking error',
|
|
description: error.toString(),
|
|
});
|
|
} finally {
|
|
if (hideMessage) hideMessage();
|
|
fetchAnnotations();
|
|
}
|
|
}
|
|
} finally {
|
|
if (withServerRequest) {
|
|
switchNavigationBlocked(false);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
private constructFromPoints(points: number[][]): void {
|
|
const {
|
|
frame, labels, curZOrder, jobInstance, activeLabelID, createAnnotations,
|
|
} = this.props;
|
|
|
|
const object = new core.classes.ObjectState({
|
|
frame,
|
|
objectType: ObjectType.SHAPE,
|
|
label: labels.length ? labels.filter((label: any) => label.id === activeLabelID)[0] : null,
|
|
shapeType: ShapeType.POLYGON,
|
|
points: points.flat(),
|
|
occluded: false,
|
|
zOrder: curZOrder,
|
|
});
|
|
|
|
createAnnotations(jobInstance, frame, [object]);
|
|
}
|
|
|
|
private async approximateResponsePoints(points: number[][]): Promise<number[][]> {
|
|
const { approxPolyAccuracy } = this.state;
|
|
if (points.length > 3) {
|
|
if (!openCVWrapper.isInitialized) {
|
|
const hide = message.loading('OpenCV.js initialization..', 0);
|
|
try {
|
|
await openCVWrapper.initialize(() => {});
|
|
} finally {
|
|
hide();
|
|
}
|
|
}
|
|
|
|
const threshold = thresholdFromAccuracy(approxPolyAccuracy);
|
|
return openCVWrapper.contours.approxPoly(points, threshold);
|
|
}
|
|
|
|
return points;
|
|
}
|
|
|
|
private renderLabelBlock(): JSX.Element {
|
|
const { labels } = this.props;
|
|
const { activeLabelID } = this.state;
|
|
return (
|
|
<>
|
|
<Row justify='start'>
|
|
<Col>
|
|
<Text className='cvat-text-color'>Label</Text>
|
|
</Col>
|
|
</Row>
|
|
<Row justify='center'>
|
|
<Col span={24}>
|
|
<LabelSelector
|
|
style={{ width: '100%' }}
|
|
labels={labels}
|
|
value={activeLabelID}
|
|
onChange={(value: any) => this.setState({ activeLabelID: value.id })}
|
|
/>
|
|
</Col>
|
|
</Row>
|
|
</>
|
|
);
|
|
}
|
|
|
|
private renderTrackerBlock(): JSX.Element {
|
|
const {
|
|
trackers, canvasInstance, jobInstance, frame, onInteractionStart,
|
|
} = this.props;
|
|
const { activeTracker, activeLabelID, fetching } = this.state;
|
|
|
|
if (!trackers.length) {
|
|
return (
|
|
<Row justify='center' align='middle' style={{ marginTop: '5px' }}>
|
|
<Col>
|
|
<Text type='warning' className='cvat-text-color'>
|
|
No available trackers found
|
|
</Text>
|
|
</Col>
|
|
</Row>
|
|
);
|
|
}
|
|
|
|
return (
|
|
<>
|
|
<Row justify='start'>
|
|
<Col>
|
|
<Text className='cvat-text-color'>Tracker</Text>
|
|
</Col>
|
|
</Row>
|
|
<Row align='middle' justify='center'>
|
|
<Col span={24}>
|
|
<Select
|
|
style={{ width: '100%' }}
|
|
defaultValue={trackers[0].name}
|
|
onChange={this.setActiveTracker}
|
|
>
|
|
{trackers.map(
|
|
(tracker: Model): JSX.Element => (
|
|
<Select.Option value={tracker.id} title={tracker.description} key={tracker.id}>
|
|
{tracker.name}
|
|
</Select.Option>
|
|
),
|
|
)}
|
|
</Select>
|
|
</Col>
|
|
</Row>
|
|
<Row align='middle' justify='end'>
|
|
<Col>
|
|
<Button
|
|
type='primary'
|
|
loading={fetching}
|
|
className='cvat-tools-track-button'
|
|
disabled={!activeTracker || fetching || frame === jobInstance.stopFrame}
|
|
onClick={() => {
|
|
this.setState({ mode: 'tracking' });
|
|
|
|
if (activeTracker) {
|
|
canvasInstance.cancel();
|
|
canvasInstance.interact({
|
|
shapeType: 'rectangle',
|
|
enabled: true,
|
|
});
|
|
|
|
onInteractionStart(activeTracker, activeLabelID);
|
|
const { onSwitchToolsBlockerState } = this.props;
|
|
onSwitchToolsBlockerState({ buttonVisible: false });
|
|
}
|
|
}}
|
|
>
|
|
Track
|
|
</Button>
|
|
</Col>
|
|
</Row>
|
|
</>
|
|
);
|
|
}
|
|
|
|
private renderInteractorBlock(): JSX.Element {
|
|
const { interactors, canvasInstance, onInteractionStart } = this.props;
|
|
const { activeInteractor, activeLabelID, fetching } = this.state;
|
|
|
|
if (!interactors.length) {
|
|
return (
|
|
<Row justify='center' align='middle' style={{ marginTop: '5px' }}>
|
|
<Col>
|
|
<Text type='warning' className='cvat-text-color'>
|
|
No available interactors found
|
|
</Text>
|
|
</Col>
|
|
</Row>
|
|
);
|
|
}
|
|
|
|
const minNegVertices = activeInteractor ? (activeInteractor.params.canvas.minNegVertices as number) : -1;
|
|
|
|
return (
|
|
<>
|
|
<Row justify='start'>
|
|
<Col>
|
|
<Text className='cvat-text-color'>Interactor</Text>
|
|
</Col>
|
|
</Row>
|
|
<Row align='middle' justify='space-between'>
|
|
<Col span={22}>
|
|
<Select
|
|
style={{ width: '100%' }}
|
|
defaultValue={interactors[0].name}
|
|
onChange={this.setActiveInteractor}
|
|
>
|
|
{interactors.map(
|
|
(interactor: Model): JSX.Element => (
|
|
<Select.Option
|
|
value={interactor.id}
|
|
title={interactor.description}
|
|
key={interactor.id}
|
|
>
|
|
{interactor.name}
|
|
</Select.Option>
|
|
),
|
|
)}
|
|
</Select>
|
|
</Col>
|
|
<Col span={2} className='cvat-interactors-tips-icon-container'>
|
|
<Dropdown
|
|
overlay={(
|
|
<ToolsTooltips
|
|
name={activeInteractor?.name}
|
|
withNegativePoints={minNegVertices >= 0}
|
|
{...(activeInteractor?.tip || {})}
|
|
/>
|
|
)}
|
|
>
|
|
<QuestionCircleOutlined />
|
|
</Dropdown>
|
|
</Col>
|
|
</Row>
|
|
<Row align='middle' justify='end'>
|
|
<Col>
|
|
<Button
|
|
type='primary'
|
|
loading={fetching}
|
|
className='cvat-tools-interact-button'
|
|
disabled={!activeInteractor || fetching}
|
|
onClick={() => {
|
|
this.setState({ mode: 'interaction' });
|
|
|
|
if (activeInteractor) {
|
|
canvasInstance.cancel();
|
|
activeInteractor.onChangeToolsBlockerState = this.onChangeToolsBlockerState;
|
|
canvasInstance.interact({
|
|
shapeType: 'points',
|
|
enabled: true,
|
|
...activeInteractor.params.canvas,
|
|
});
|
|
onInteractionStart(activeInteractor, activeLabelID);
|
|
}
|
|
}}
|
|
>
|
|
Interact
|
|
</Button>
|
|
</Col>
|
|
</Row>
|
|
</>
|
|
);
|
|
}
|
|
|
|
private renderDetectorBlock(): JSX.Element {
|
|
const {
|
|
jobInstance, detectors, curZOrder, frame, createAnnotations,
|
|
} = this.props;
|
|
|
|
if (!detectors.length) {
|
|
return (
|
|
<Row justify='center' align='middle' style={{ marginTop: '5px' }}>
|
|
<Col>
|
|
<Text type='warning' className='cvat-text-color'>
|
|
No available detectors found
|
|
</Text>
|
|
</Col>
|
|
</Row>
|
|
);
|
|
}
|
|
const attrsMap: Record<string, Record<string, number>> = {};
|
|
jobInstance.labels.forEach((label: any) => {
|
|
attrsMap[label.name] = {};
|
|
label.attributes.forEach((attr: any) => {
|
|
attrsMap[label.name][attr.name] = attr.id;
|
|
});
|
|
});
|
|
|
|
function checkAttributesCompatibility(
|
|
functionAttribute: ModelAttribute | undefined,
|
|
dbAttribute: Attribute | undefined,
|
|
value: string,
|
|
): boolean {
|
|
if (!dbAttribute || !functionAttribute) {
|
|
return false;
|
|
}
|
|
|
|
const { inputType } = (dbAttribute as any as { inputType: string });
|
|
if (functionAttribute.input_type === inputType) {
|
|
if (functionAttribute.input_type === 'number') {
|
|
const [min, max, step] = dbAttribute.values;
|
|
return !Number.isNaN(+value) && +value >= +min && +value <= +max && !(+value % +step);
|
|
}
|
|
|
|
if (functionAttribute.input_type === 'checkbox') {
|
|
return ['true', 'false'].includes(value.toLowerCase());
|
|
}
|
|
|
|
if (['select', 'radio'].includes(functionAttribute.input_type)) {
|
|
return dbAttribute.values.includes(value);
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
switch (functionAttribute.input_type) {
|
|
case 'number':
|
|
return dbAttribute.values.includes(value) || inputType === 'text';
|
|
case 'text':
|
|
return ['select', 'radio'].includes(dbAttribute.input_type) && dbAttribute.values.includes(value);
|
|
case 'select':
|
|
return (inputType === 'radio' && dbAttribute.values.includes(value)) || inputType === 'text';
|
|
case 'radio':
|
|
return (inputType === 'select' && dbAttribute.values.includes(value)) || inputType === 'text';
|
|
case 'checkbox':
|
|
return dbAttribute.values.includes(value) || inputType === 'text';
|
|
default:
|
|
return false;
|
|
}
|
|
}
|
|
|
|
return (
|
|
<DetectorRunner
|
|
withCleanup={false}
|
|
models={detectors}
|
|
labels={jobInstance.labels}
|
|
dimension={jobInstance.dimension}
|
|
runInference={async (model: Model, body: DetectorRequestBody) => {
|
|
try {
|
|
this.setState({ mode: 'detection', fetching: true });
|
|
const result = await core.lambda.call(jobInstance.taskId, model, { ...body, frame });
|
|
const states = result.map(
|
|
(data: any): any => {
|
|
const jobLabel = (jobInstance.labels as Label[])
|
|
.find((jLabel: Label): boolean => jLabel.name === data.label);
|
|
const [modelLabel] = Object.entries(body.mapping)
|
|
.find(([, { name }]) => name === data.label) || [];
|
|
|
|
if (!jobLabel || !modelLabel) return null;
|
|
|
|
return new core.classes.ObjectState({
|
|
shapeType: data.type,
|
|
label: jobLabel,
|
|
points: data.points,
|
|
objectType: ObjectType.SHAPE,
|
|
frame,
|
|
occluded: false,
|
|
source: 'auto',
|
|
attributes: (data.attributes as { name: string, value: string }[])
|
|
.reduce((acc, attr) => {
|
|
const [modelAttr] = Object.entries(body.mapping[modelLabel].attributes)
|
|
.find((value: string[]) => value[1] === attr.name) || [];
|
|
const areCompatible = checkAttributesCompatibility(
|
|
model.attributes[modelLabel].find((mAttr) => mAttr.name === modelAttr),
|
|
jobLabel.attributes.find((jobAttr: Attribute) => (
|
|
jobAttr.name === attr.name
|
|
)),
|
|
attr.value,
|
|
);
|
|
|
|
if (areCompatible) {
|
|
acc[attrsMap[data.label][attr.name]] = attr.value;
|
|
}
|
|
|
|
return acc;
|
|
}, {} as Record<number, string>),
|
|
zOrder: curZOrder,
|
|
});
|
|
},
|
|
).filter((state: any) => state);
|
|
|
|
createAnnotations(jobInstance, frame, states);
|
|
const { onSwitchToolsBlockerState } = this.props;
|
|
onSwitchToolsBlockerState({ buttonVisible: false });
|
|
} catch (error: any) {
|
|
notification.error({
|
|
description: error.toString(),
|
|
message: 'Detection error occurred',
|
|
});
|
|
} finally {
|
|
this.setState({ fetching: false });
|
|
}
|
|
}}
|
|
/>
|
|
);
|
|
}
|
|
|
|
private renderPopoverContent(): JSX.Element {
|
|
return (
|
|
<div className='cvat-tools-control-popover-content'>
|
|
<Row justify='start'>
|
|
<Col>
|
|
<Text className='cvat-text-color' strong>
|
|
AI Tools
|
|
</Text>
|
|
</Col>
|
|
</Row>
|
|
<Tabs type='card' tabBarGutter={8}>
|
|
<Tabs.TabPane key='interactors' tab='Interactors'>
|
|
{this.renderLabelBlock()}
|
|
{this.renderInteractorBlock()}
|
|
</Tabs.TabPane>
|
|
<Tabs.TabPane key='detectors' tab='Detectors'>
|
|
{this.renderDetectorBlock()}
|
|
</Tabs.TabPane>
|
|
<Tabs.TabPane key='trackers' tab='Trackers'>
|
|
{this.renderLabelBlock()}
|
|
{this.renderTrackerBlock()}
|
|
</Tabs.TabPane>
|
|
</Tabs>
|
|
</div>
|
|
);
|
|
}
|
|
|
|
public render(): JSX.Element | null {
|
|
const {
|
|
interactors, detectors, trackers, isActivated, canvasInstance, labels, frameIsDeleted,
|
|
} = this.props;
|
|
const {
|
|
fetching, approxPolyAccuracy, pointsRecieved, mode, portals,
|
|
} = this.state;
|
|
|
|
if (![...interactors, ...detectors, ...trackers].length) return null;
|
|
|
|
const dynamicPopoverProps = isActivated ?
|
|
{
|
|
overlayStyle: {
|
|
display: 'none',
|
|
},
|
|
} :
|
|
{};
|
|
|
|
const dynamicIconProps = isActivated ?
|
|
{
|
|
className: 'cvat-tools-control cvat-active-canvas-control',
|
|
onClick: (): void => {
|
|
canvasInstance.interact({ enabled: false });
|
|
},
|
|
} :
|
|
{
|
|
className: 'cvat-tools-control',
|
|
};
|
|
|
|
const showAnyContent = labels.length && !frameIsDeleted;
|
|
const showInteractionContent = isActivated && mode === 'interaction' && pointsRecieved;
|
|
const showDetectionContent = fetching && mode === 'detection';
|
|
|
|
const interactionContent: JSX.Element | null = showInteractionContent ? (
|
|
<>
|
|
<ApproximationAccuracy
|
|
approxPolyAccuracy={approxPolyAccuracy}
|
|
onChange={(value: number) => {
|
|
this.setState({ approxPolyAccuracy: value });
|
|
}}
|
|
/>
|
|
</>
|
|
) : null;
|
|
|
|
const detectionContent: JSX.Element | null = showDetectionContent ? (
|
|
<Modal
|
|
title='Making a server request'
|
|
zIndex={Number.MAX_SAFE_INTEGER}
|
|
visible
|
|
destroyOnClose
|
|
closable={false}
|
|
footer={[]}
|
|
>
|
|
<Text>Waiting for a server response..</Text>
|
|
<LoadingOutlined style={{ marginLeft: '10px' }} />
|
|
</Modal>
|
|
) : null;
|
|
|
|
return showAnyContent ? (
|
|
<>
|
|
<CustomPopover {...dynamicPopoverProps} placement='right' content={this.renderPopoverContent()}>
|
|
<Icon {...dynamicIconProps} component={AIToolsIcon} />
|
|
</CustomPopover>
|
|
{interactionContent}
|
|
{detectionContent}
|
|
{portals}
|
|
</>
|
|
) : (
|
|
<Icon className=' cvat-tools-control cvat-disabled-canvas-control' component={AIToolsIcon} />
|
|
);
|
|
}
|
|
}
|
|
|
|
export default connect(mapStateToProps, mapDispatchToProps)(ToolsControlComponent);
|