Add OpenCV MIL tracker tool (#4200)

main
Kirill Lakhov 4 years ago committed by GitHub
parent 237f98d5f8
commit 87be7bcf53
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -31,6 +31,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Add several flags to task creation CLI (<https://github.com/openvinotoolkit/cvat/pull/4119>)
- Add YOLOv5 serverless function for automatic annotation (<https://github.com/openvinotoolkit/cvat/pull/4178>)
- Basic page with jobs list, basic filtration to this list (<https://github.com/openvinotoolkit/cvat/pull/4258>)
- Added OpenCV.js TrackerMIL as tracking tool (<https://github.com/openvinotoolkit/cvat/pull/4200>)
### Changed
- Users don't have access to a task object anymore if they are assigneed only on some jobs of the task (<https://github.com/openvinotoolkit/cvat/pull/3788>)

@ -1,12 +1,12 @@
{
"name": "cvat-ui",
"version": "1.34.2",
"version": "1.35.0",
"lockfileVersion": 2,
"requires": true,
"packages": {
"": {
"name": "cvat-ui",
"version": "1.34.2",
"version": "1.35.0",
"license": "MIT",
"dependencies": {
"@ant-design/icons": "^4.6.3",

@ -1,6 +1,6 @@
{
"name": "cvat-ui",
"version": "1.34.2",
"version": "1.35.0",
"description": "CVAT single-page application",
"main": "src/index.tsx",
"scripts": {

@ -27,6 +27,7 @@ import {
Workspace,
} from 'reducers/interfaces';
import { updateJobAsync } from './tasks-actions';
import { switchToolsBlockerState } from './settings-actions';
interface AnnotationsParameters {
filters: string[];
@ -1503,12 +1504,13 @@ export function repeatDrawShapeAsync(): ThunkAction {
let activeControl = ActiveControl.CURSOR;
if (activeInteractor && canvasInstance instanceof Canvas) {
if (activeInteractor.type === 'tracker') {
if (activeInteractor.type.includes('tracker')) {
canvasInstance.interact({
enabled: true,
shapeType: 'rectangle',
});
dispatch(interactWithCanvas(activeInteractor, activeLabelID));
dispatch(switchToolsBlockerState({ buttonVisible: false }));
} else {
canvasInstance.interact({
enabled: true,

@ -1,4 +1,4 @@
// Copyright (C) 2021 Intel Corporation
// Copyright (C) 2021-2022 Intel Corporation
//
// SPDX-License-Identifier: MIT
@ -11,7 +11,9 @@ import Text from 'antd/lib/typography/Text';
import Tabs from 'antd/lib/tabs';
import Button from 'antd/lib/button';
import Progress from 'antd/lib/progress';
import Select from 'antd/lib/select';
import notification from 'antd/lib/notification';
import message from 'antd/lib/message';
import { OpenCVIcon } from 'icons';
import { Canvas, convertShapesForInteractor } from 'cvat-canvas-wrapper';
@ -27,13 +29,14 @@ import {
updateAnnotationsAsync,
createAnnotationsAsync,
changeFrameAsync,
switchNavigationBlocked as switchNavigationBlockedAction,
} from 'actions/annotation-actions';
import LabelSelector from 'components/label-selector/label-selector';
import CVATTooltip from 'components/common/cvat-tooltip';
import ApproximationAccuracy, {
thresholdFromAccuracy,
} from 'components/annotation-page/standard-workspace/controls-side-bar/approximation-accuracy';
import { ImageProcessing } from 'utils/opencv-wrapper/opencv-interfaces';
import { ImageProcessing, OpenCVTracker, TrackerModel } from 'utils/opencv-wrapper/opencv-interfaces';
import { switchToolsBlockerState } from 'actions/settings-actions';
import withVisibilityHandling from './handle-popover-visibility';
@ -58,6 +61,13 @@ interface DispatchToProps {
fetchAnnotations(): void;
changeFrame(toFrame: number, fillBuffer?: boolean, frameStep?: number, forceUpdate?: boolean):void;
onSwitchToolsBlockerState(toolsBlockerState: ToolsBlockerState):void;
switchNavigationBlocked(navigationBlocked: boolean): void;
}
interface TrackedShape {
clientID: number;
shapePoints: number[];
trackerModel: TrackerModel;
}
interface State {
@ -67,6 +77,10 @@ interface State {
activeLabelID: number;
approxPolyAccuracy: number;
activeImageModifiers: ImageModifier[];
mode: 'interaction' | 'tracking';
trackedShapes: TrackedShape[];
activeTracker: OpenCVTracker | null;
trackers: OpenCVTracker[]
}
interface ImageModifier {
@ -117,6 +131,7 @@ const mapDispatchToProps = {
createAnnotations: createAnnotationsAsync,
changeFrame: changeFrameAsync,
onSwitchToolsBlockerState: switchToolsBlockerState,
switchNavigationBlocked: switchNavigationBlockedAction,
};
class OpenCVControlComponent extends React.PureComponent<Props & DispatchToProps, State> {
@ -138,6 +153,10 @@ class OpenCVControlComponent extends React.PureComponent<Props & DispatchToProps
approxPolyAccuracy: defaultApproxPolyAccuracy,
activeLabelID: labels.length ? labels[0].id : null,
activeImageModifiers: [],
mode: 'interaction',
trackedShapes: [],
trackers: openCVWrapper.isInitialized ? Object.values(openCVWrapper.tracking) : [],
activeTracker: openCVWrapper.isInitialized ? Object.values(openCVWrapper.tracking)[0] : null,
};
}
@ -184,6 +203,7 @@ class OpenCVControlComponent extends React.PureComponent<Props & DispatchToProps
!!this.activeTool?.switchBlockMode) {
this.activeTool.switchBlockMode(toolsBlockerState.algorithmsLocked);
}
this.checkTrackedStates(prevProps);
}
public componentWillUnmount(): void {
@ -193,6 +213,18 @@ class OpenCVControlComponent extends React.PureComponent<Props & DispatchToProps
}
private interactionListener = async (e: Event): Promise<void> => {
const { mode } = this.state;
if (mode === 'interaction') {
await this.onInteraction(e);
}
if (mode === 'tracking') {
await this.onTracking(e);
}
};
private onInteraction = async (e: Event): Promise<void> => {
const { approxPolyAccuracy } = this.state;
const {
createAnnotations, isActivated, jobInstance, frame, labels, curZOrder, canvasInstance, toolsBlockerState,
@ -263,6 +295,76 @@ class OpenCVControlComponent extends React.PureComponent<Props & DispatchToProps
}
};
private onTracking = async (e: Event): Promise<void> => {
const {
isActivated, jobInstance, frame, curZOrder, fetchAnnotations,
} = this.props;
if (!isActivated) {
return;
}
const { activeLabelID, trackedShapes, activeTracker } = this.state;
const [label] = jobInstance.labels.filter((_label: any): boolean => _label.id === activeLabelID);
const { isDone, shapesUpdated } = (e as CustomEvent).detail;
if (!isDone || !shapesUpdated || !activeTracker) {
return;
}
try {
const { points } = (e as CustomEvent).detail.shapes[0];
const imageData = this.getCanvasImageData();
const trackerModel = activeTracker.model();
trackerModel.init(imageData, points);
const state = new core.classes.ObjectState({
shapeType: ShapeType.RECTANGLE,
objectType: ObjectType.TRACK,
zOrder: curZOrder,
label,
points,
frame,
occluded: false,
attributes: {},
descriptions: [`Trackable (${activeTracker.name})`],
});
const [clientID] = await jobInstance.annotations.put([state]);
this.setState({
trackedShapes: [
...trackedShapes,
{
clientID,
trackerModel,
shapePoints: points,
},
],
});
// update annotations on a canvas
fetchAnnotations();
} catch (err) {
notification.error({
description: err.toString(),
message: 'Tracking error occured',
});
}
};
private getCanvasImageData = ():ImageData => {
const canvas: HTMLCanvasElement | null = window.document.getElementById('cvat_canvas_background') as
| HTMLCanvasElement
| null;
if (!canvas) {
throw new Error('Element #cvat_canvas_background was not found');
}
const { width, height } = canvas;
const context = canvas.getContext('2d');
if (!context) {
throw new Error('Canvas context is empty');
}
return context.getImageData(0, 0, width, height);
};
private onChangeToolsBlockerState = (event:string):void => {
const {
isActivated, toolsBlockerState, onSwitchToolsBlockerState, canvasInstance,
@ -283,24 +385,14 @@ class OpenCVControlComponent extends React.PureComponent<Props & DispatchToProps
const {
frameData, states, curZOrder, canvasInstance, frame,
} = this.props;
try {
if (activeImageModifiers.length !== 0 && activeImageModifiers[0].modifier.currentProcessedImage !== frame) {
this.enableCanvasForceUpdate();
const canvas: HTMLCanvasElement | undefined = window.document.getElementById('cvat_canvas_background') as
| HTMLCanvasElement
| undefined;
if (!canvas) {
throw new Error('Element #cvat_canvas_background was not found');
}
const { width, height } = canvas;
const context = canvas.getContext('2d');
if (!context) {
throw new Error('Canvas context is empty');
}
const imageData = context.getImageData(0, 0, width, height);
const newImageData = activeImageModifiers.reduce((oldImageData, activeImageModifier) => (
activeImageModifier.modifier.processImage(oldImageData, frame)
), imageData);
const imageData = this.getCanvasImageData();
const newImageData = activeImageModifiers
.reduce((oldImageData, activeImageModifier) => activeImageModifier
.modifier.processImage(oldImageData, frame), imageData);
const imageBitmap = await createImageBitmap(newImageData);
frameData.imageData = imageBitmap;
canvasInstance.setup(frameData, states, curZOrder);
@ -316,6 +408,117 @@ class OpenCVControlComponent extends React.PureComponent<Props & DispatchToProps
}
};
private applyTracking = (imageData: ImageData, shape: TrackedShape,
objectState: any): Promise<void> => new Promise((resolve, reject) => {
setTimeout(() => {
try {
const stateIsRelevant =
objectState.points.length === shape.shapePoints.length &&
objectState.points.every(
(coord: number, index: number) => coord === shape.shapePoints[index],
);
if (!stateIsRelevant) {
shape.trackerModel.reinit(objectState.points);
shape.shapePoints = objectState.points;
}
const { updated, points } = shape.trackerModel.update(imageData);
if (updated) {
objectState.points = points;
objectState.save().then(() => {
shape.shapePoints = points;
}).catch((error) => {
reject(error);
});
}
resolve();
} catch (error) {
reject(error);
}
});
});
private setActiveTracker = (value: string): void => {
const { trackers } = this.state;
this.setState({
activeTracker: trackers.filter((tracker: OpenCVTracker) => tracker.name === value)[0],
});
};
private checkTrackedStates(prevProps: Props): void {
const {
frame,
states: objectStates,
fetchAnnotations,
switchNavigationBlocked,
} = this.props;
const { trackedShapes } = this.state;
if (prevProps.frame !== frame && trackedShapes.length) {
type AccumulatorType = {
[index: string]: TrackedShape[];
};
const trackingData = trackedShapes.reduce<AccumulatorType>(
(acc: AccumulatorType, trackedShape: TrackedShape): AccumulatorType => {
const [clientState] = objectStates.filter(
(_state: any): boolean => _state.clientID === trackedShape.clientID,
);
if (
!clientState ||
clientState.keyframes.prev !== frame - 1 ||
clientState.keyframes.last >= frame
) {
return acc;
}
const { name: trackerName } = trackedShape.trackerModel;
if (!acc[trackerName]) {
acc[trackerName] = [];
}
acc[trackerName].push(trackedShape);
return acc;
}, {},
);
if (Object.keys(trackingData).length === 0) {
return;
}
try {
switchNavigationBlocked(true);
for (const trackerID of Object.keys(trackingData)) {
const numOfObjects = trackingData[trackerID].length;
const hideMessage = message.loading(
`${trackerID}: ${numOfObjects} ${
numOfObjects > 1 ? 'objects are' : 'object is'
} being tracked..`,
0,
);
const imageData = this.getCanvasImageData();
for (const shape of trackingData[trackerID]) {
const [objectState] = objectStates.filter(
(_state: any): boolean => _state.clientID === shape.clientID,
);
this.applyTracking(imageData, shape, objectState)
.catch((error) => {
notification.error({
message: 'Tracking error',
description: error.toString(),
});
});
}
setTimeout(() => {
if (hideMessage) hideMessage();
});
}
} finally {
setTimeout(() => {
fetchAnnotations();
switchNavigationBlocked(false);
});
}
}
}
private async runCVAlgorithm(pressedPoints: number[], threshold: number): Promise<number[]> {
// Getting image data
const canvas: HTMLCanvasElement | undefined = window.document.getElementById('cvat_canvas_background') as
@ -407,6 +610,7 @@ class OpenCVControlComponent extends React.PureComponent<Props & DispatchToProps
<CVATTooltip title='Intelligent scissors' className='cvat-opencv-drawing-tool'>
<Button
onClick={() => {
this.setState({ mode: 'interaction' });
this.activeTool = openCVWrapper.segmentation
.intelligentScissorsFactory(this.onChangeToolsBlockerState);
canvasInstance.cancel();
@ -456,6 +660,91 @@ class OpenCVControlComponent extends React.PureComponent<Props & DispatchToProps
);
}
private renderTrackingContent(): JSX.Element {
const { activeLabelID, trackers, activeTracker } = this.state;
const {
labels, canvasInstance, onInteractionStart, frame, jobInstance,
} = this.props;
if (!trackers.length) {
return (
<Row justify='center' align='middle' className='cvat-opencv-tracker-content'>
<Col>
<Text type='warning' className='cvat-text-color'>
No available trackers found
</Text>
</Col>
</Row>
);
}
return (
<>
<Row justify='start'>
<Col>
<Text className='cvat-text-color'>Label</Text>
</Col>
</Row>
<Row justify='center'>
<Col span={24}>
<LabelSelector
className='cvat-opencv-tracker-select'
labels={labels}
value={activeLabelID}
onChange={(value: any) => this.setState({ activeLabelID: value.id })}
/>
</Col>
</Row>
<Row justify='start'>
<Col>
<Text className='cvat-text-color'>Tracker</Text>
</Col>
</Row>
<Row align='middle' justify='center'>
<Col span={24}>
<Select
className='cvat-opencv-tracker-select'
defaultValue={trackers[0].name}
onChange={this.setActiveTracker}
>
{trackers.map(
(tracker: OpenCVTracker): JSX.Element => (
<Select.Option value={tracker.name} title={tracker.description} key={tracker.name}>
{tracker.name}
</Select.Option>
),
)}
</Select>
</Col>
</Row>
<Row align='middle' justify='end'>
<Col>
<Button
type='primary'
className='cvat-tools-track-button'
disabled={!activeTracker || frame === jobInstance.stopFrame}
onClick={() => {
this.setState({ mode: 'tracking' });
if (activeTracker) {
canvasInstance.cancel();
canvasInstance.interact({
shapeType: 'rectangle',
enabled: true,
});
onInteractionStart(activeTracker as OpenCVTracker, activeLabelID);
const { onSwitchToolsBlockerState } = this.props;
onSwitchToolsBlockerState({ buttonVisible: false });
}
}}
>
Track
</Button>
</Col>
</Row>
</>
);
}
private renderContent(): JSX.Element {
const { libraryInitialized, initializationProgress, initializationError } = this.state;
@ -476,6 +765,9 @@ class OpenCVControlComponent extends React.PureComponent<Props & DispatchToProps
<Tabs.TabPane key='image' tab='Image' className='cvat-opencv-control-tabpane'>
{this.renderImageContent()}
</Tabs.TabPane>
<Tabs.TabPane key='tracking' tab='Tracking' className='cvat-opencv-control-tabpane'>
{this.renderTrackingContent()}
</Tabs.TabPane>
</Tabs>
) : (
<>
@ -493,7 +785,12 @@ class OpenCVControlComponent extends React.PureComponent<Props & DispatchToProps
await openCVWrapper.initialize((progress: number) => {
this.setState({ initializationProgress: progress });
});
this.setState({ libraryInitialized: true });
const trackers = Object.values(openCVWrapper.tracking);
this.setState({
libraryInitialized: true,
activeTracker: trackers[0],
trackers,
});
} catch (error) {
notification.error({
description: error.toString(),
@ -528,8 +825,8 @@ class OpenCVControlComponent extends React.PureComponent<Props & DispatchToProps
public render(): JSX.Element {
const { isActivated, canvasInstance, labels } = this.props;
const { libraryInitialized, approxPolyAccuracy } = this.state;
const dynamicPopoverProps = isActivated ?
const { libraryInitialized, approxPolyAccuracy, mode } = this.state;
const dynamcPopoverPros = isActivated ?
{
overlayStyle: {
display: 'none',
@ -553,7 +850,7 @@ class OpenCVControlComponent extends React.PureComponent<Props & DispatchToProps
) : (
<>
<CustomPopover
{...dynamicPopoverProps}
{...dynamcPopoverPros}
placement='right'
overlayClassName='cvat-opencv-control-popover'
content={this.renderContent()}
@ -567,7 +864,7 @@ class OpenCVControlComponent extends React.PureComponent<Props & DispatchToProps
>
<Icon {...dynamicIconProps} component={OpenCVIcon} />
</CustomPopover>
{isActivated ? (
{isActivated && mode !== 'tracking' ? (
<ApproximationAccuracy
approxPolyAccuracy={approxPolyAccuracy}
onChange={(value: number) => {

@ -256,6 +256,14 @@
border-color: #40a9ff;
}
.cvat-opencv-tracker-select {
width: 100%;
}
.cvat-opencv-tracker-content {
margin-top: $grid-unit-size;
}
.cvat-setup-tag-popover-content,
.cvat-draw-shape-popover-content {
padding: $grid-unit-size;

@ -6,6 +6,7 @@ import { Canvas3d } from 'cvat-canvas3d/src/typescript/canvas3d';
import { Canvas, RectDrawingMethod, CuboidDrawingMethod } from 'cvat-canvas-wrapper';
import { IntelligentScissors } from 'utils/opencv-wrapper/intelligent-scissors';
import { KeyMap } from 'utils/mousetrap-react';
import { OpenCVTracker } from 'utils/opencv-wrapper/opencv-interfaces';
export type StringObject = {
[index: string]: string;
@ -284,7 +285,7 @@ export interface Model {
};
}
export type OpenCVTool = IntelligentScissors;
export type OpenCVTool = IntelligentScissors | OpenCVTracker;
export interface ToolsBlockerState {
algorithmsLocked?: boolean;

@ -1,8 +1,27 @@
// Copyright (C) 2021 Intel Corporation
// Copyright (C) 2021-2022 Intel Corporation
//
// SPDX-License-Identifier: MIT
export interface ImageProcessing {
processImage: (src: ImageData, frameNumber: number) => ImageData;
currentProcessedImage: number|undefined
currentProcessedImage: number | undefined;
}
export interface TrackingResult {
updated: boolean;
points: number[];
}
export interface TrackerModel {
name: string;
init: (src: ImageData, points: number[]) => void;
reinit: (points: number[]) => void;
update: (src: ImageData) => TrackingResult;
}
export interface OpenCVTracker {
name: string;
description: string;
type: string;
model: (() => TrackerModel);
}

@ -1,12 +1,12 @@
// Copyright (C) 2020-2021 Intel Corporation
// Copyright (C) 2020-2022 Intel Corporation
//
// SPDX-License-Identifier: MIT
import getCore from 'cvat-core-wrapper';
import waitFor from '../wait-for';
import HistogramEqualizationImplementation, { HistogramEqualization } from './histogram-equalization';
import TrackerMImplementation from './tracker-mil';
import IntelligentScissorsImplementation, { IntelligentScissors } from './intelligent-scissors';
import { OpenCVTracker } from './opencv-interfaces';
const core = getCore();
const baseURL = core.config.backendAPI.slice(0, -7);
@ -23,6 +23,10 @@ export interface ImgProc {
hist: () => HistogramEqualization;
}
export interface Tracking {
trackerMIL: OpenCVTracker;
}
export class OpenCVWrapper {
private initialized: boolean;
private cv: any;
@ -73,13 +77,8 @@ export class OpenCVWrapper {
OpenCVConstructor();
const global = window as any;
await waitFor(
100,
() =>
typeof global.cv !== 'undefined' && typeof global.cv.segmentation_IntelligentScissorsMB !== 'undefined',
);
this.cv = global.cv;
this.cv = await global.cv;
this.initialized = true;
}
@ -126,8 +125,9 @@ export class OpenCVWrapper {
}
return {
intelligentScissorsFactory: (onChangeToolsBlockerState:(event:string)=>void) =>
new IntelligentScissorsImplementation(this.cv, onChangeToolsBlockerState),
intelligentScissorsFactory:
(onChangeToolsBlockerState:
(event:string)=>void) => new IntelligentScissorsImplementation(this.cv, onChangeToolsBlockerState),
};
}
@ -139,6 +139,20 @@ export class OpenCVWrapper {
hist: () => new HistogramEqualizationImplementation(this.cv),
};
}
public get tracking(): Tracking {
if (!this.initialized) {
throw new Error('Need to initialize OpenCV first');
}
return {
trackerMIL: {
model: () => new TrackerMImplementation(this.cv),
name: 'TrackerMIL',
description: 'Light client-side model useful to track simple objects',
type: 'opencv_tracker_mil',
},
};
}
}
export default new OpenCVWrapper();

@ -0,0 +1,68 @@
// Copyright (C) 2022 Intel Corporation
//
// SPDX-License-Identifier: MIT
import { clamp } from 'utils/math';
import { TrackerModel, TrackingResult } from './opencv-interfaces';
export type TrackerMIL = TrackerModel;
export default class TrackerMILImplementation implements TrackerMIL {
public name: string;
private imageData: ImageData | null;
private cv: any;
private trackerMIL: any;
constructor(cv: any) {
this.cv = cv;
this.trackerMIL = new cv.TrackerMIL();
this.imageData = null;
this.name = 'TrackerMil';
}
public init(src: ImageData, points: number[]): void {
if (points.length !== 4) {
throw Error(`TrackerMIL must be initialized with rectangle, but got ${points.length % 2} points.`);
}
this.imageData = src;
// cut bounding box if its out of image bounds
const x1 = clamp(points[0], 0, src.width);
const y1 = clamp(points[1], 0, src.height);
const x2 = clamp(points[2], 0, src.width);
const y2 = clamp(points[3], 0, src.height);
const [width, height] = [x2 - x1, y2 - y1];
if (width === 0 || height === 0) {
throw Error('TrackerMIL got rectangle out of image bounds');
}
let matImage = null;
try {
matImage = this.cv.matFromImageData(src);
const rect = new this.cv.Rect(x1, y1, width, height);
this.trackerMIL.init(matImage, rect);
} finally {
if (matImage) matImage.delete();
}
}
public reinit(points: number[]): void {
if (!this.imageData) {
throw Error('TrackerMIL needs to be initialized before re-initialization');
}
this.init(this.imageData, points);
}
public update(src: ImageData): TrackingResult {
this.imageData = src;
let matImage = null;
try {
matImage = this.cv.matFromImageData(src);
const [updated, rect] = this.trackerMIL.update(matImage);
return { updated, points: [rect.x, rect.y, rect.x + rect.width, rect.y + rect.height] };
} finally {
if (matImage) matImage.delete();
}
}
}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long
Loading…
Cancel
Save