Integration with an internal training server (#2785)

Co-authored-by: Boris Sekachev <boris.sekachev@intel.com>
Co-authored-by: Nikita Manovich <nikita.manovich@intel.com>
main
Dmitry Agapov 5 years ago committed by GitHub
parent babf1a3f54
commit d2a1d12fba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,5 +1,5 @@
{
"python.pythonPath": ".env/bin/python",
"eslint.enable": true,
"eslint.probe": [
"javascript",
"typescript",

@ -573,7 +573,7 @@ function build() {
* @param {module:API.cvat.classes.Task} task task to be annotated
* @param {module:API.cvat.classes.MLModel} model model used to get annotation
* @param {object} [args] extra arguments
* @returns {string} requestID
* @returns {object[]} annotations
* @throws {module:API.cvat.exceptions.ServerError}
* @throws {module:API.cvat.exceptions.PluginError}
* @throws {module:API.cvat.exceptions.ArgumentError}

@ -33,6 +33,7 @@
created_date: undefined,
updated_date: undefined,
task_subsets: undefined,
training_project: undefined,
};
for (const property in data) {
@ -64,6 +65,9 @@
}
data.task_subsets = Array.from(subsetsSet);
}
if (initialData.training_project) {
data.training_project = JSON.parse(JSON.stringify(initialData.training_project));
}
Object.defineProperties(
this,
@ -94,6 +98,7 @@
data.name = value;
},
},
/**
* @name status
* @type {module:API.cvat.enums.TaskStatus}
@ -217,9 +222,21 @@
subsets: {
get: () => [...data.task_subsets],
},
_internalData: {
get: () => data,
},
training_project: {
get: () => data.training_project,
set: (training) => {
if (training) {
data.training_project = JSON.parse(JSON.stringify(training));
} else {
data.training_project = training;
}
},
},
}),
);
}
@ -261,12 +278,17 @@
};
Project.prototype.save.implementation = async function () {
let trainingProject;
if (this.training_project) {
trainingProject = JSON.parse(JSON.stringify(this.training_project));
}
if (typeof this.id !== 'undefined') {
const projectData = {
name: this.name,
assignee_id: this.assignee ? this.assignee.id : null,
bug_tracker: this.bugTracker,
labels: [...this._internalData.labels.map((el) => el.toJSON())],
training_project: trainingProject,
};
await serverProxy.projects.save(this.id, projectData);
@ -276,6 +298,7 @@
const projectSpec = {
name: this.name,
labels: [...this.labels.map((el) => el.toJSON())],
training_project: trainingProject,
};
if (this.bugTracker) {

@ -9,6 +9,31 @@
const config = require('./config');
const DownloadWorker = require('./download.worker');
function waitFor(frequencyHz, predicate) {
return new Promise((resolve, reject) => {
if (typeof predicate !== 'function') {
reject(new Error(`Predicate must be a function, got ${typeof predicate}`));
}
const internalWait = () => {
let result = false;
try {
result = predicate();
} catch (error) {
reject(error);
}
if (result) {
resolve();
} else {
setTimeout(internalWait, 1000 / frequencyHz);
}
};
setTimeout(internalWait);
});
}
function generateError(errorData) {
if (errorData.response) {
const message = `${errorData.message}. ${JSON.stringify(errorData.response.data) || ''}.`;
@ -993,6 +1018,96 @@
}
}
function predictorStatus(projectId) {
const { backendAPI } = config;
return new Promise((resolve, reject) => {
async function request() {
try {
const response = await Axios.get(`${backendAPI}/predict/status?project=${projectId}`);
return response.data;
} catch (errorData) {
throw generateError(errorData);
}
}
const timeoutCallback = async () => {
let data = null;
try {
data = await request();
if (data.status === 'queued') {
setTimeout(timeoutCallback, 1000);
} else if (data.status === 'done') {
resolve(data);
} else {
throw new Error(`Unknown status was received "${data.status}"`);
}
} catch (error) {
reject(error);
}
};
setTimeout(timeoutCallback);
});
}
function predictAnnotations(taskId, frame) {
return new Promise((resolve, reject) => {
const { backendAPI } = config;
async function request() {
try {
const response = await Axios.get(
`${backendAPI}/predict/frame?task=${taskId}&frame=${frame}`,
);
return response.data;
} catch (errorData) {
throw generateError(errorData);
}
}
const timeoutCallback = async () => {
let data = null;
try {
data = await request();
if (data.status === 'queued') {
setTimeout(timeoutCallback, 1000);
} else if (data.status === 'done') {
predictAnnotations.latestRequest.fetching = false;
resolve(data.annotation);
} else {
throw new Error(`Unknown status was received "${data.status}"`);
}
} catch (error) {
predictAnnotations.latestRequest.fetching = false;
reject(error);
}
};
const closureId = Date.now();
predictAnnotations.latestRequest.id = closureId;
const predicate = () => !predictAnnotations.latestRequest.fetching || predictAnnotations.latestRequest.id !== closureId;
if (predictAnnotations.latestRequest.fetching) {
waitFor(5, predicate).then(() => {
if (predictAnnotations.latestRequest.id !== closureId) {
resolve(null);
} else {
predictAnnotations.latestRequest.fetching = true;
setTimeout(timeoutCallback);
}
});
} else {
predictAnnotations.latestRequest.fetching = true;
setTimeout(timeoutCallback);
}
});
}
predictAnnotations.latestRequest = {
fetching: false,
id: null,
};
async function installedApps() {
const { backendAPI } = config;
try {
@ -1123,6 +1238,14 @@
}),
writable: false,
},
predictor: {
value: Object.freeze({
status: predictorStatus,
predict: predictAnnotations,
}),
writable: false,
},
}),
);
}

@ -10,7 +10,7 @@
const {
getFrame, getRanges, getPreview, clear: clearFrames, getContextImage,
} = require('./frames');
const { ArgumentError } = require('./exceptions');
const { ArgumentError, DataError } = require('./exceptions');
const { TaskStatus } = require('./enums');
const { Label } = require('./labels');
const User = require('./user');
@ -258,6 +258,19 @@
},
writable: true,
}),
predictor: Object.freeze({
value: {
async status() {
const result = await PluginRegistry.apiWrapper.call(this, prototype.predictor.status);
return result;
},
async predict(frame) {
const result = await PluginRegistry.apiWrapper.call(this, prototype.predictor.predict, frame);
return result;
},
},
writable: true,
}),
});
}
@ -665,6 +678,40 @@
* @instance
* @async
*/
/**
* @typedef {Object} PredictorStatus
* @property {string} message - message for a user to be displayed somewhere
* @property {number} projectScore - model accuracy
* @global
*/
/**
* Namespace is used for an interaction with events
* @namespace predictor
* @memberof Session
*/
/**
* Subscribe to updates of a ML model binded to the project
* @method status
* @memberof Session.predictor
* @throws {module:API.cvat.exceptions.PluginError}
* @throws {module:API.cvat.exceptions.ServerError}
* @returns {PredictorStatus}
* @instance
* @async
*/
/**
* Get predictions from a ML model binded to the project
* @method predict
* @memberof Session.predictor
* @param {number} frame - number of frame to inference
* @throws {module:API.cvat.exceptions.PluginError}
* @throws {module:API.cvat.exceptions.ArgumentError}
* @throws {module:API.cvat.exceptions.ServerError}
* @throws {module:API.cvat.exceptions.DataError}
* @returns {object[] | null} annotations
* @instance
* @async
*/
}
}
@ -865,6 +912,11 @@
this.logger = {
log: Object.getPrototypeOf(this).logger.log.bind(this),
};
this.predictor = {
status: Object.getPrototypeOf(this).predictor.status.bind(this),
predict: Object.getPrototypeOf(this).predictor.predict.bind(this),
};
}
/**
@ -1554,6 +1606,11 @@
this.logger = {
log: Object.getPrototypeOf(this).logger.log.bind(this),
};
this.predictor = {
status: Object.getPrototypeOf(this).predictor.status.bind(this),
predict: Object.getPrototypeOf(this).predictor.predict.bind(this),
};
}
/**
@ -1741,6 +1798,11 @@
return rangesData;
};
Job.prototype.frames.preview.implementation = async function () {
const frameData = await getPreview(this.task.id);
return frameData;
};
// TODO: Check filter for annotations
Job.prototype.annotations.get.implementation = async function (frame, allTracks, filters) {
if (!Array.isArray(filters)) {
@ -1897,6 +1959,16 @@
return result;
};
Job.prototype.predictor.status.implementation = async function () {
const result = await this.task.predictor.status();
return result;
};
Job.prototype.predictor.predict.implementation = async function (frame) {
const result = await this.task.predictor.predict(frame);
return result;
};
Task.prototype.close.implementation = function closeTask() {
clearFrames(this.id);
for (const job of this.jobs) {
@ -2028,11 +2100,6 @@
return result;
};
Job.prototype.frames.preview.implementation = async function () {
const frameData = await getPreview(this.task.id);
return frameData;
};
Task.prototype.frames.ranges.implementation = async function () {
const rangesData = await getRanges(this.id);
return rangesData;
@ -2199,6 +2266,39 @@
return result;
};
Task.prototype.predictor.status.implementation = async function () {
if (!Number.isInteger(this.projectId)) {
throw new DataError('The task must belong to a project to use the feature');
}
const result = await serverProxy.predictor.status(this.projectId);
return {
message: result.message,
progress: result.progress,
projectScore: result.score,
timeRemaining: result.time_remaining,
mediaAmount: result.media_amount,
annotationAmount: result.annotation_amount,
};
};
Task.prototype.predictor.predict.implementation = async function (frame) {
if (!Number.isInteger(frame) || frame < 0) {
throw new ArgumentError(`Frame must be a positive integer. Got: "${frame}"`);
}
if (frame >= this.size) {
throw new ArgumentError(`The frame with number ${frame} is out of the task`);
}
if (!Number.isInteger(this.projectId)) {
throw new DataError('The task must belong to a project to use the feature');
}
const result = await serverProxy.predictor.predict(this.id, frame);
return result;
};
Job.prototype.frames.contextImage.implementation = async function (taskId, frameId) {
const result = await getContextImage(taskId, frameId);
return result;

@ -28953,6 +28953,11 @@
"resolved": "https://registry.npmjs.org/react-is/-/react-is-16.11.0.tgz",
"integrity": "sha512-gbBVYR2p8mnriqAwWx9LbuUrShnAuSCNnuPGyc7GJrMVQtPDAh8iLpv7FRuMPFb56KkaVZIYSz1PrjI9q0QPCw=="
},
"react-moment": {
"version": "1.1.1",
"resolved": "https://registry.npmjs.org/react-moment/-/react-moment-1.1.1.tgz",
"integrity": "sha512-WjwvxBSnmLMRcU33do0KixDB+9vP3e84eCse+rd+HNklAMNWyRgZTDEQlay/qK6lcXFPRuEIASJTpEt6pyK7Ww=="
},
"react-redux": {
"version": "7.2.2",
"resolved": "https://registry.npmjs.org/react-redux/-/react-redux-7.2.2.tgz",

@ -77,6 +77,7 @@
"react-color": "^2.19.3",
"react-cookie": "^4.0.3",
"react-dom": "^16.14.0",
"react-moment": "^1.1.1",
"react-redux": "^7.2.2",
"react-resizable": "^1.11.1",
"@types/react-resizable": "^1.7.2",

@ -1,4 +1,4 @@
// Copyright (C) 2021 Intel Corporation
// Copyright (C) 2020-2021 Intel Corporation
//
// SPDX-License-Identifier: MIT
@ -190,6 +190,10 @@ export enum AnnotationActionTypes {
SWITCH_REQUEST_REVIEW_DIALOG = 'SWITCH_REQUEST_REVIEW_DIALOG',
SWITCH_SUBMIT_REVIEW_DIALOG = 'SWITCH_SUBMIT_REVIEW_DIALOG',
SET_FORCE_EXIT_ANNOTATION_PAGE_FLAG = 'SET_FORCE_EXIT_ANNOTATION_PAGE_FLAG',
UPDATE_PREDICTOR_STATE = 'UPDATE_PREDICTOR_STATE',
GET_PREDICTIONS = 'GET_PREDICTIONS',
GET_PREDICTIONS_FAILED = 'GET_PREDICTIONS_FAILED',
GET_PREDICTIONS_SUCCESS = 'GET_PREDICTIONS_SUCCESS',
HIDE_SHOW_CONTEXT_IMAGE = 'HIDE_SHOW_CONTEXT_IMAGE',
GET_CONTEXT_IMAGE = 'GET_CONTEXT_IMAGE',
}
@ -612,6 +616,87 @@ export function switchPlay(playing: boolean): AnyAction {
};
}
export function getPredictionsAsync(): ThunkAction {
return async (dispatch: ActionCreator<Dispatch>): Promise<void> => {
const {
annotations: {
states: currentStates,
zLayer: { cur: curZOrder },
},
predictor: { enabled, annotatedFrames },
} = getStore().getState().annotation;
const {
filters, frame, showAllInterpolationTracks, jobInstance: job,
} = receiveAnnotationsParameters();
if (!enabled || currentStates.length || annotatedFrames.includes(frame)) return;
dispatch({
type: AnnotationActionTypes.GET_PREDICTIONS,
payload: {},
});
let annotations = [];
try {
annotations = await job.predictor.predict(frame);
// current frame could be changed during a request above, need to fetch it from store again
const { number: currentFrame } = getStore().getState().annotation.player.frame;
if (frame !== currentFrame || annotations === null) {
// another request has already been sent or user went to another frame
// we do not need dispatch predictions success action
return;
}
annotations = annotations.map(
(data: any): any =>
new cvat.classes.ObjectState({
shapeType: data.type,
label: job.task.labels.filter((label: any): boolean => label.id === data.label)[0],
points: data.points,
objectType: ObjectType.SHAPE,
frame,
occluded: false,
source: 'auto',
attributes: {},
zOrder: curZOrder,
}),
);
dispatch({
type: AnnotationActionTypes.GET_PREDICTIONS_SUCCESS,
payload: { frame },
});
} catch (error) {
dispatch({
type: AnnotationActionTypes.GET_PREDICTIONS_FAILED,
payload: {
error,
},
});
}
try {
await job.annotations.put(annotations);
const states = await job.annotations.get(frame, showAllInterpolationTracks, filters);
const history = await job.actions.get();
dispatch({
type: AnnotationActionTypes.CREATE_ANNOTATIONS_SUCCESS,
payload: {
states,
history,
},
});
} catch (error) {
dispatch({
type: AnnotationActionTypes.CREATE_ANNOTATIONS_FAILED,
payload: {
error,
},
});
}
};
}
export function changeFrameAsync(toFrame: number, fillBuffer?: boolean, frameStep?: number): ThunkAction {
return async (dispatch: ActionCreator<Dispatch>): Promise<void> => {
const state: CombinedState = getStore().getState();
@ -689,6 +774,7 @@ export function changeFrameAsync(toFrame: number, fillBuffer?: boolean, frameSte
delay,
},
});
dispatch(getPredictionsAsync());
} catch (error) {
if (error !== 'not needed') {
dispatch({
@ -934,9 +1020,11 @@ export function getJobAsync(tid: number, jid: number, initialFrame: number, init
loadJobEvent.close(await jobInfoGenerator(job));
const openTime = Date.now();
dispatch({
type: AnnotationActionTypes.GET_JOB_SUCCESS,
payload: {
openTime,
job,
issues,
reviews,
@ -950,10 +1038,38 @@ export function getJobAsync(tid: number, jid: number, initialFrame: number, init
maxZ,
},
});
if (job.task.dimension === DimensionType.DIM_3D) {
const workspace = Workspace.STANDARD3D;
dispatch(changeWorkspace(workspace));
}
const updatePredictorStatus = async (): Promise<void> => {
// get current job
const currentState: CombinedState = getStore().getState();
const { openTime: currentOpenTime, instance: currentJob } = currentState.annotation.job;
if (currentJob === null || currentJob.id !== job.id || currentOpenTime !== openTime) {
// the job was closed, changed or reopened
return;
}
try {
const status = await job.predictor.status();
dispatch({
type: AnnotationActionTypes.UPDATE_PREDICTOR_STATE,
payload: status,
});
setTimeout(updatePredictorStatus, 60 * 1000);
} catch (error) {
dispatch({
type: AnnotationActionTypes.UPDATE_PREDICTOR_STATE,
payload: { error },
});
setTimeout(updatePredictorStatus, 20 * 1000);
}
};
updatePredictorStatus();
dispatch(changeFrameAsync(frameNumber, false));
} catch (error) {
dispatch({
@ -1516,6 +1632,14 @@ export function setForceExitAnnotationFlag(forceExit: boolean): AnyAction {
};
}
export function switchPredictor(predictorEnabled: boolean): AnyAction {
return {
type: AnnotationActionTypes.UPDATE_PREDICTOR_STATE,
payload: {
enabled: predictorEnabled,
},
};
}
export function hideShowContextImage(hidden: boolean): AnyAction {
return {
type: AnnotationActionTypes.HIDE_SHOW_CONTEXT_IMAGE,

@ -1,8 +1,10 @@
// Copyright (C) 2020 Intel Corporation
// Copyright (C) 2020-2021 Intel Corporation
//
// SPDX-License-Identifier: MIT
import { ActionUnion, createAction, ThunkAction, ThunkDispatch } from 'utils/redux';
import {
ActionUnion, createAction, ThunkAction, ThunkDispatch,
} from 'utils/redux';
import getCore from 'cvat-core-wrapper';
import { LogType } from 'cvat-logger';
import { computeZRange } from './annotation-actions';

@ -0,0 +1,56 @@
<?xml version="1.0" encoding="iso-8859-1"?>
<!-- The icon received from: https://www.svgrepo.com/svg/25187/brain -->
<!-- License: CC0 Creative Commons License -->
<svg version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" x="0px" y="0px"
viewBox="0 0 463 463" width="40px" height="40px" style="enable-background:new 0 0 463 463;" xml:space="preserve">
<g>
<path d="M151.245,222.446C148.054,237.039,135.036,248,119.5,248c-4.142,0-7.5,3.357-7.5,7.5s3.358,7.5,7.5,7.5
c23.774,0,43.522-17.557,46.966-40.386c14.556-1.574,27.993-8.06,38.395-18.677c2.899-2.959,2.85-7.708-0.109-10.606
c-2.958-2.897-7.707-2.851-10.606,0.108C184.947,202.829,172.643,208,159.5,208c-26.743,0-48.5-21.757-48.5-48.5
c0-4.143-3.358-7.5-7.5-7.5s-7.5,3.357-7.5,7.5C96,191.715,120.119,218.384,151.245,222.446z"/>
<path d="M183,287.5c0-4.143-3.358-7.5-7.5-7.5c-35.014,0-63.5,28.486-63.5,63.5c0,0.362,0.013,0.725,0.019,1.088
C109.23,344.212,106.39,344,103.5,344c-4.142,0-7.5,3.357-7.5,7.5s3.358,7.5,7.5,7.5c26.743,0,48.5,21.757,48.5,48.5
c0,4.143,3.358,7.5,7.5,7.5s7.5-3.357,7.5-7.5c0-26.611-16.462-49.437-39.731-58.867c-0.178-1.699-0.269-3.418-0.269-5.133
c0-26.743,21.757-48.5,48.5-48.5C179.642,295,183,291.643,183,287.5z"/>
<path d="M439,223.5c0-17.075-6.82-33.256-18.875-45.156c1.909-6.108,2.875-12.426,2.875-18.844
c0-30.874-22.152-56.659-51.394-62.329C373.841,91.6,375,85.628,375,79.5c0-19.557-11.883-36.387-28.806-43.661
C317.999,13.383,287.162,0,263.5,0c-13.153,0-24.817,6.468-32,16.384C224.317,6.468,212.653,0,199.5,0
c-23.662,0-54.499,13.383-82.694,35.839C99.883,43.113,88,59.943,88,79.5c0,6.128,1.159,12.1,3.394,17.671
C62.152,102.841,40,128.626,40,159.5c0,6.418,0.965,12.735,2.875,18.844C30.82,190.244,24,206.425,24,223.5
c0,13.348,4.149,25.741,11.213,35.975C27.872,270.087,24,282.466,24,295.5c0,23.088,12.587,44.242,32.516,55.396
C56.173,353.748,56,356.626,56,359.5c0,31.144,20.315,58.679,49.79,68.063C118.611,449.505,141.965,463,167.5,463
c27.995,0,52.269-16.181,64-39.674c11.731,23.493,36.005,39.674,64,39.674c25.535,0,48.889-13.495,61.71-35.437
c29.475-9.385,49.79-36.92,49.79-68.063c0-2.874-0.173-5.752-0.516-8.604C426.413,339.742,439,318.588,439,295.5
c0-13.034-3.872-25.413-11.213-36.025C434.851,249.241,439,236.848,439,223.5z M167.5,448c-21.029,0-40.191-11.594-50.009-30.256
c-0.973-1.849-2.671-3.208-4.688-3.751C88.19,407.369,71,384.961,71,359.5c0-3.81,0.384-7.626,1.141-11.344
c0.702-3.447-1.087-6.92-4.302-8.35C50.32,332.018,39,314.626,39,295.5c0-8.699,2.256-17.014,6.561-24.379
C56.757,280.992,71.436,287,87.5,287c4.142,0,7.5-3.357,7.5-7.5s-3.358-7.5-7.5-7.5C60.757,272,39,250.243,39,223.5
c0-14.396,6.352-27.964,17.428-37.221c2.5-2.09,3.365-5.555,2.14-8.574C56.2,171.869,55,165.744,55,159.5
c0-26.743,21.757-48.5,48.5-48.5s48.5,21.757,48.5,48.5c0,4.143,3.358,7.5,7.5,7.5s7.5-3.357,7.5-7.5
c0-33.642-26.302-61.243-59.421-63.355C104.577,91.127,103,85.421,103,79.5c0-13.369,8.116-24.875,19.678-29.859
c0.447-0.133,0.885-0.307,1.308-0.527C127.568,47.752,131.447,47,135.5,47c12.557,0,23.767,7.021,29.256,18.325
c1.81,3.727,6.298,5.281,10.023,3.47c3.726-1.809,5.28-6.296,3.47-10.022c-6.266-12.903-18.125-22.177-31.782-25.462
C165.609,21.631,184.454,15,199.5,15c13.509,0,24.5,10.99,24.5,24.5v97.051c-6.739-5.346-15.25-8.551-24.5-8.551
c-4.142,0-7.5,3.357-7.5,7.5s3.358,7.5,7.5,7.5c13.509,0,24.5,10.99,24.5,24.5v180.279c-9.325-12.031-22.471-21.111-37.935-25.266
c-3.999-1.071-8.114,1.297-9.189,5.297c-1.075,4.001,1.297,8.115,5.297,9.189C206.8,343.616,224,366.027,224,391.5
C224,422.654,198.654,448,167.5,448z M395.161,339.807c-3.215,1.43-5.004,4.902-4.302,8.35c0.757,3.718,1.141,7.534,1.141,11.344
c0,25.461-17.19,47.869-41.803,54.493c-2.017,0.543-3.716,1.902-4.688,3.751C335.691,436.406,316.529,448,295.5,448
c-31.154,0-56.5-25.346-56.5-56.5c0-2.109-0.098-4.2-0.281-6.271c0.178-0.641,0.281-1.314,0.281-2.012V135.5
c0-13.51,10.991-24.5,24.5-24.5c4.142,0,7.5-3.357,7.5-7.5s-3.358-7.5-7.5-7.5c-9.25,0-17.761,3.205-24.5,8.551V39.5
c0-13.51,10.991-24.5,24.5-24.5c15.046,0,33.891,6.631,53.033,18.311c-13.657,3.284-25.516,12.559-31.782,25.462
c-1.81,3.727-0.256,8.214,3.47,10.022c3.726,1.81,8.213,0.257,10.023-3.47C303.733,54.021,314.943,47,327.5,47
c4.053,0,7.933,0.752,11.514,2.114c0.422,0.22,0.86,0.393,1.305,0.526C351.883,54.624,360,66.13,360,79.5
c0,5.921-1.577,11.627-4.579,16.645C322.302,98.257,296,125.858,296,159.5c0,4.143,3.358,7.5,7.5,7.5s7.5-3.357,7.5-7.5
c0-26.743,21.757-48.5,48.5-48.5s48.5,21.757,48.5,48.5c0,6.244-1.2,12.369-3.567,18.205c-1.225,3.02-0.36,6.484,2.14,8.574
C417.648,195.536,424,209.104,424,223.5c0,26.743-21.757,48.5-48.5,48.5c-4.142,0-7.5,3.357-7.5,7.5s3.358,7.5,7.5,7.5
c16.064,0,30.743-6.008,41.939-15.879c4.306,7.365,6.561,15.68,6.561,24.379C424,314.626,412.68,332.018,395.161,339.807z"/>
<path d="M359.5,240c-15.536,0-28.554-10.961-31.745-25.554C358.881,210.384,383,183.715,383,151.5c0-4.143-3.358-7.5-7.5-7.5
s-7.5,3.357-7.5,7.5c0,26.743-21.757,48.5-48.5,48.5c-13.143,0-25.447-5.171-34.646-14.561c-2.898-2.958-7.647-3.007-10.606-0.108
s-3.008,7.647-0.109,10.606c10.402,10.617,23.839,17.103,38.395,18.677C315.978,237.443,335.726,255,359.5,255
c4.142,0,7.5-3.357,7.5-7.5S363.642,240,359.5,240z"/>
<path d="M335.5,328c-2.89,0-5.73,0.212-8.519,0.588c0.006-0.363,0.019-0.726,0.019-1.088c0-35.014-28.486-63.5-63.5-63.5
c-4.142,0-7.5,3.357-7.5,7.5s3.358,7.5,7.5,7.5c26.743,0,48.5,21.757,48.5,48.5c0,1.714-0.091,3.434-0.269,5.133
C288.462,342.063,272,364.889,272,391.5c0,4.143,3.358,7.5,7.5,7.5s7.5-3.357,7.5-7.5c0-26.743,21.757-48.5,48.5-48.5
c4.142,0,7.5-3.357,7.5-7.5S339.642,328,335.5,328z"/>
</g>
</svg>

After

Width:  |  Height:  |  Size: 5.5 KiB

@ -1,4 +1,4 @@
// Copyright (C) 2020 Intel Corporation
// Copyright (C) 2020-2021 Intel Corporation
//
// SPDX-License-Identifier: MIT
@ -10,6 +10,7 @@ import Radio, { RadioChangeEvent } from 'antd/lib/radio';
import Slider from 'antd/lib/slider';
import Checkbox, { CheckboxChangeEvent } from 'antd/lib/checkbox';
import Collapse from 'antd/lib/collapse';
import Button from 'antd/lib/button';
import ColorPicker from 'components/annotation-page/standard-workspace/objects-side-bar/color-picker';
import { ColorizeIcon } from 'icons';
@ -26,7 +27,6 @@ import {
changeShowBitmap as changeShowBitmapAction,
changeShowProjections as changeShowProjectionsAction,
} from 'actions/settings-actions';
import Button from 'antd/lib/button';
interface StateToProps {
appearanceCollapsed: boolean;
@ -152,7 +152,14 @@ function AppearanceBlock(props: Props): JSX.Element {
activeKey={appearanceCollapsed ? [] : ['appearance']}
className='cvat-objects-appearance-collapse'
>
<Collapse.Panel header={<Text strong className='cvat-objects-appearance-collapse-header'>Appearance</Text>} key='appearance'>
<Collapse.Panel
header={(
<Text strong className='cvat-objects-appearance-collapse-header'>
Appearance
</Text>
)}
key='appearance'
>
<div className='cvat-objects-appearance-content'>
<Text type='secondary'>Color by</Text>
<Radio.Group

@ -3,13 +3,14 @@
// SPDX-License-Identifier: MIT
import React from 'react';
import GlobalHotKeys, { KeyMap } from 'utils/mousetrap-react';
import Text from 'antd/lib/typography/Text';
import Checkbox, { CheckboxChangeEvent } from 'antd/lib/checkbox';
import Select, { SelectValue } from 'antd/lib/select';
import Radio, { RadioChangeEvent } from 'antd/lib/radio';
import Input from 'antd/lib/input';
import GlobalHotKeys, { KeyMap } from 'utils/mousetrap-react';
import consts from 'consts';
interface InputElementParameters {

@ -3,9 +3,9 @@
// SPDX-License-Identifier: MIT
import React from 'react';
import GlobalHotKeys, { KeyMap } from 'utils/mousetrap-react';
import Layout from 'antd/lib/layout';
import GlobalHotKeys, { KeyMap } from 'utils/mousetrap-react';
import { ActiveControl, Rotation } from 'reducers/interfaces';
import { Canvas } from 'cvat-canvas-wrapper';

@ -4,9 +4,9 @@
import React from 'react';
import Layout from 'antd/lib/layout';
import GlobalHotKeys, { KeyMap } from 'utils/mousetrap-react';
import { ActiveControl, Rotation } from 'reducers/interfaces';
import GlobalHotKeys, { KeyMap } from 'utils/mousetrap-react';
import { Canvas } from 'cvat-canvas-wrapper';
import ControlVisibilityObserver, { ExtraControlsControl } from './control-visibility-observer';

@ -81,6 +81,52 @@
}
}
button.cvat-predictor-button {
&.cvat-predictor-inprogress {
> span {
> svg {
fill: $inprogress-progress-color;
}
}
}
&.cvat-predictor-fetching {
> span {
> svg {
animation-duration: 500ms;
animation-name: predictorBlinking;
animation-iteration-count: infinite;
@keyframes predictorBlinking {
0% {
fill: $inprogress-progress-color;
}
50% {
fill: $completed-progress-color;
}
100% {
fill: $inprogress-progress-color;
}
}
}
}
}
&.cvat-predictor-disabled {
opacity: 0.5;
&:active {
pointer-events: none;
}
> span[role='img'] {
transform: scale(0.8) !important;
}
}
}
.cvat-annotation-disabled-header-button {
@extend .cvat-annotation-header-button;

@ -4,12 +4,12 @@
import React, { useState, useEffect } from 'react';
import { useSelector } from 'react-redux';
import GlobalHotKeys, { KeyMap } from 'utils/mousetrap-react';
import { Row, Col } from 'antd/lib/grid';
import Text from 'antd/lib/typography/Text';
import Select from 'antd/lib/select';
import { CombinedState } from 'reducers/interfaces';
import GlobalHotKeys, { KeyMap } from 'utils/mousetrap-react';
import { shift } from 'utils/math';
interface ShortcutLabelMap {

@ -20,11 +20,11 @@ import {
changeFrameAsync,
rememberObject,
} from 'actions/annotation-actions';
import GlobalHotKeys, { KeyMap } from 'utils/mousetrap-react';
import { Canvas } from 'cvat-canvas-wrapper';
import { CombinedState, ObjectType } from 'reducers/interfaces';
import LabelSelector from 'components/label-selector/label-selector';
import getCore from 'cvat-core-wrapper';
import GlobalHotKeys, { KeyMap } from 'utils/mousetrap-react';
import ShortcutsSelect from './shortcuts-select';
const cvat = getCore();

@ -1,4 +1,4 @@
// Copyright (C) 2020 Intel Corporation
// Copyright (C) 2020-2021 Intel Corporation
//
// SPDX-License-Identifier: MIT
@ -11,7 +11,9 @@ import Timeline from 'antd/lib/timeline';
import Dropdown from 'antd/lib/dropdown';
import AnnotationMenuContainer from 'containers/annotation-page/top-bar/annotation-menu';
import { MainMenuIcon, SaveIcon, UndoIcon, RedoIcon } from 'icons';
import {
MainMenuIcon, SaveIcon, UndoIcon, RedoIcon,
} from 'icons';
interface Props {
saving: boolean;

@ -7,28 +7,141 @@ import { Col } from 'antd/lib/grid';
import Icon from '@ant-design/icons';
import Select from 'antd/lib/select';
import Button from 'antd/lib/button';
import Text from 'antd/lib/typography/Text';
import Tooltip from 'antd/lib/tooltip';
import Moment from 'react-moment';
import moment from 'moment';
import { useSelector } from 'react-redux';
import { FilterIcon, FullscreenIcon, InfoIcon } from 'icons';
import { CombinedState, DimensionType, Workspace } from 'reducers/interfaces';
import {
FilterIcon, FullscreenIcon, InfoIcon, BrainIcon,
} from 'icons';
import {
CombinedState, DimensionType, Workspace, PredictorState,
} from 'reducers/interfaces';
interface Props {
workspace: Workspace;
predictor: PredictorState;
isTrainingActive: boolean;
showStatistics(): void;
switchPredictor(predictorEnabled: boolean): void;
showFilters(): void;
changeWorkspace(workspace: Workspace): void;
jobInstance: any;
}
function RightGroup(props: Props): JSX.Element {
const {
showFilters, showStatistics, changeWorkspace, workspace, jobInstance,
showStatistics,
changeWorkspace,
switchPredictor,
workspace,
predictor,
jobInstance,
isTrainingActive,
showFilters,
} = props;
predictor.annotationAmount = predictor.annotationAmount ? predictor.annotationAmount : 0;
predictor.mediaAmount = predictor.mediaAmount ? predictor.mediaAmount : 0;
const formattedScore = `${(predictor.projectScore * 100).toFixed(0)}%`;
const predictorTooltip = (
<div className='cvat-predictor-tooltip'>
<span>Adaptive auto annotation is</span>
{predictor.enabled ? (
<Text type='success' strong>
{' active'}
</Text>
) : (
<Text type='warning' strong>
{' inactive'}
</Text>
)}
<br />
<span>
Annotations amount:
{predictor.annotationAmount}
</span>
<br />
<span>
Media amount:
{predictor.mediaAmount}
</span>
<br />
{predictor.annotationAmount > 0 ? (
<span>
Model mAP is
{' '}
{formattedScore}
<br />
</span>
) : null}
{predictor.error ? (
<Text type='danger'>
{predictor.error.toString()}
<br />
</Text>
) : null}
{predictor.message ? (
<span>
Status:
{' '}
{predictor.message}
<br />
</span>
) : null}
{predictor.timeRemaining > 0 ? (
<span>
Time Remaining:
{' '}
<Moment date={moment().add(-predictor.timeRemaining, 's')} format='hh:mm:ss' trim durationFromNow />
<br />
</span>
) : null}
{predictor.progress > 0 ? (
<span>
Progress:
{predictor.progress.toFixed(1)}
{' '}
%
</span>
) : null}
</div>
);
let predictorClassName = 'cvat-annotation-header-button cvat-predictor-button';
if (!!predictor.error || !predictor.projectScore) {
predictorClassName += ' cvat-predictor-disabled';
} else if (predictor.enabled) {
if (predictor.fetching) {
predictorClassName += ' cvat-predictor-fetching';
}
predictorClassName += ' cvat-predictor-inprogress';
}
const filters = useSelector((state: CombinedState) => state.annotation.annotations.filters);
return (
<Col className='cvat-annotation-header-right-group'>
{isTrainingActive && (
<Button
type='link'
className={predictorClassName}
onClick={() => {
switchPredictor(!predictor.enabled);
}}
>
<Tooltip title={predictorTooltip}>
<Icon component={BrainIcon} />
</Tooltip>
{predictor.annotationAmount ? `mAP ${formattedScore}` : 'not trained'}
</Button>
)}
<Button
type='link'
className='cvat-annotation-header-button'

@ -1,4 +1,4 @@
// Copyright (C) 2021 Intel Corporation
// Copyright (C) 2020-2021 Intel Corporation
//
// SPDX-License-Identifier: MIT
@ -6,7 +6,7 @@ import React from 'react';
import Input from 'antd/lib/input';
import { Col, Row } from 'antd/lib/grid';
import { Workspace } from 'reducers/interfaces';
import { PredictorState, Workspace } from 'reducers/interfaces';
import LeftGroup from './left-group';
import PlayerButtons from './player-buttons';
import PlayerNavigation from './player-navigation';
@ -35,7 +35,10 @@ interface Props {
prevButtonType: string;
nextButtonType: string;
focusFrameInputShortcut: string;
predictor: PredictorState;
isTrainingActive: boolean;
changeWorkspace(workspace: Workspace): void;
switchPredictor(predictorEnabled: boolean): void;
showStatistics(): void;
showFilters(): void;
onSwitchPlay(): void;
@ -80,8 +83,10 @@ export default function AnnotationTopBarComponent(props: Props): JSX.Element {
backwardShortcut,
prevButtonType,
nextButtonType,
predictor,
focusFrameInputShortcut,
showStatistics,
switchPredictor,
showFilters,
changeWorkspace,
onSwitchPlay,
@ -100,6 +105,7 @@ export default function AnnotationTopBarComponent(props: Props): JSX.Element {
onUndoClick,
onRedoClick,
jobInstance,
isTrainingActive,
} = props;
return (
@ -151,10 +157,13 @@ export default function AnnotationTopBarComponent(props: Props): JSX.Element {
</Row>
</Col>
<RightGroup
jobInstance={jobInstance}
predictor={predictor}
workspace={workspace}
switchPredictor={switchPredictor}
jobInstance={jobInstance}
changeWorkspace={changeWorkspace}
showStatistics={showStatistics}
isTrainingActive={isTrainingActive}
showFilters={showFilters}
/>
</Row>

@ -1,4 +1,4 @@
// Copyright (C) 2020 Intel Corporation
// Copyright (C) 2020-2021 Intel Corporation
//
// SPDX-License-Identifier: MIT
@ -48,7 +48,8 @@ function ChangePasswordFormComponent({ fetching, onSubmit }: Props): JSX.Element
{
required: true,
message: 'Please input new password!',
}, validatePassword,
},
validatePassword,
]}
>
<Input.Password
@ -66,7 +67,8 @@ function ChangePasswordFormComponent({ fetching, onSubmit }: Props): JSX.Element
{
required: true,
message: 'Please confirm your new password!',
}, validateConfirmation('newPassword1'),
},
validateConfirmation('newPassword1'),
]}
>
<Input.Password

@ -1,4 +1,4 @@
// Copyright (C) 2020 Intel Corporation
// Copyright (C) 2020-2021 Intel Corporation
//
// SPDX-License-Identifier: MIT
@ -43,7 +43,9 @@ function mapDispatchToProps(dispatch: any): DispatchToProps {
}
function ChangePasswordComponent(props: ChangePasswordPageComponentProps): JSX.Element {
const { fetching, onChangePassword, visible, onClose } = props;
const {
fetching, onChangePassword, visible, onClose,
} = props;
return (
<Modal

@ -1,12 +1,13 @@
// Copyright (C) 2020 Intel Corporation
// Copyright (C) 2020-2021 Intel Corporation
//
// SPDX-License-Identifier: MIT
import React, {
useState, useRef, useEffect, RefObject,
RefObject, useContext, useEffect, useRef, useState,
} from 'react';
import { useDispatch, useSelector } from 'react-redux';
import { useHistory } from 'react-router';
import { Switch, Select } from 'antd';
import { Col, Row } from 'antd/lib/grid';
import Text from 'antd/lib/typography/Text';
import Form, { FormInstance } from 'antd/lib/form';
@ -18,6 +19,9 @@ import patterns from 'utils/validation-patterns';
import { CombinedState } from 'reducers/interfaces';
import LabelsEditor from 'components/labels-editor/labels-editor';
import { createProjectAsync } from 'actions/projects-actions';
import CreateProjectContext from './create-project.context';
const { Option } = Select;
function NameConfigurationForm({ formRef }: { formRef: RefObject<FormInstance> }): JSX.Element {
return (
@ -39,6 +43,59 @@ function NameConfigurationForm({ formRef }: { formRef: RefObject<FormInstance> }
);
}
function AdaptiveAutoAnnotationForm({ formRef }: { formRef: RefObject<FormInstance> }): JSX.Element {
const { projectClass, trainingEnabled } = useContext(CreateProjectContext);
const projectClassesForTraining = ['OD'];
return (
<Form layout='vertical' ref={formRef}>
<Form.Item name='project_class' hasFeedback label='Class'>
<Select value={projectClass.value} onChange={(v) => projectClass.set(v)}>
<Option value=''>--Not Selected--</Option>
<Option value='OD'>Detection</Option>
</Select>
</Form.Item>
<Form.Item name='enabled' label='Adaptive auto annotation' initialValue={false}>
<Switch
disabled={!projectClassesForTraining.includes(projectClass.value)}
checked={trainingEnabled.value}
onClick={() => trainingEnabled.set(!trainingEnabled.value)}
/>
</Form.Item>
<Form.Item
name='host'
label='Host'
rules={[
{
validator: (_, value, callback): void => {
if (value && !patterns.validateURL.pattern.test(value)) {
callback('Training server host must be url.');
} else {
callback();
}
},
},
]}
>
<Input placeholder='https://example.host' disabled={!trainingEnabled.value} />
</Form.Item>
<Row gutter={16}>
<Col span={12}>
<Form.Item name='username' label='Username'>
<Input placeholder='UserName' disabled={!trainingEnabled.value} />
</Form.Item>
</Col>
<Col span={12}>
<Form.Item name='password' label='Password'>
<Input.Password placeholder='Pa$$w0rd' disabled={!trainingEnabled.value} />
</Form.Item>
</Col>
</Row>
</Form>
);
}
function AdvanvedConfigurationForm({ formRef }: { formRef: RefObject<FormInstance> }): JSX.Element {
return (
<Form layout='vertical' ref={formRef}>
@ -69,12 +126,15 @@ export default function CreateProjectContent(): JSX.Element {
const [projectLabels, setProjectLabels] = useState<any[]>([]);
const shouldShowNotification = useRef(false);
const nameFormRef = useRef<FormInstance>(null);
const adaptiveAutoAnnotationFormRef = useRef<FormInstance>(null);
const advancedFormRef = useRef<FormInstance>(null);
const dispatch = useDispatch();
const history = useHistory();
const newProjectId = useSelector((state: CombinedState) => state.projects.activities.creates.id);
const { isTrainingActive } = useContext(CreateProjectContext);
useEffect(() => {
if (Number.isInteger(newProjectId) && shouldShowNotification.current) {
const btn = <Button onClick={() => history.push(`/projects/${newProjectId}`)}>Open project</Button>;
@ -102,7 +162,16 @@ export default function CreateProjectContent(): JSX.Element {
if (nameFormRef.current && advancedFormRef.current) {
const basicValues = await nameFormRef.current.validateFields();
const advancedValues = await advancedFormRef.current.validateFields();
const adaptiveAutoAnnotationValues = await adaptiveAutoAnnotationFormRef.current?.validateFields();
projectData.name = basicValues.name;
projectData.training_project = null;
if (adaptiveAutoAnnotationValues) {
projectData.training_project = {};
for (const [field, value] of Object.entries(adaptiveAutoAnnotationValues)) {
projectData.training_project[field] = value;
}
}
for (const [field, value] of Object.entries(advancedValues)) {
projectData[field] = value;
}
@ -120,6 +189,11 @@ export default function CreateProjectContent(): JSX.Element {
<Col span={24}>
<NameConfigurationForm formRef={nameFormRef} />
</Col>
{isTrainingActive.value && (
<Col span={24}>
<AdaptiveAutoAnnotationForm formRef={adaptiveAutoAnnotationFormRef} />
</Col>
)}
<Col span={24}>
<Text className='cvat-text-color'>Labels:</Text>
<LabelsEditor

@ -1,21 +1,56 @@
// Copyright (C) 2020 Intel Corporation
// Copyright (C) 2020-2021 Intel Corporation
//
// SPDX-License-Identifier: MIT
import './styles.scss';
import React from 'react';
import React, { useState } from 'react';
import { Row, Col } from 'antd/lib/grid';
import Text from 'antd/lib/typography/Text';
import { connect } from 'react-redux';
import CreateProjectContent from './create-project-content';
import { CombinedState } from '../../reducers/interfaces';
import CreateProjectContext, { ICreateProjectContext } from './create-project.context';
export default function CreateProjectPageComponent(): JSX.Element {
function CreateProjectPageComponent(props: StateToProps): JSX.Element {
const { isTrainingActive } = props;
const [projectClass, setProjectClass] = useState('');
const [trainingEnabled, setTrainingEnabled] = useState(false);
const [isTrainingActiveState] = useState(isTrainingActive);
const defaultContext: ICreateProjectContext = {
projectClass: {
value: projectClass,
set: setProjectClass,
},
trainingEnabled: {
value: trainingEnabled,
set: setTrainingEnabled,
},
isTrainingActive: {
value: isTrainingActiveState,
},
};
return (
<Row justify='center' align='top' className='cvat-create-task-form-wrapper'>
<Col md={20} lg={16} xl={14} xxl={9}>
<Text className='cvat-title'>Create a new project</Text>
<CreateProjectContent />
</Col>
</Row>
<CreateProjectContext.Provider value={defaultContext}>
<Row justify='center' align='top' className='cvat-create-task-form-wrapper'>
<Col md={20} lg={16} xl={14} xxl={9}>
<Text className='cvat-title'>Create a new project</Text>
<CreateProjectContent />
</Col>
</Row>
</CreateProjectContext.Provider>
);
}
interface StateToProps {
isTrainingActive: boolean;
}
function mapStateToProps(state: CombinedState): StateToProps {
return {
isTrainingActive: state.plugins.list.PREDICT,
};
}
export default connect(mapStateToProps)(CreateProjectPageComponent);

@ -0,0 +1,31 @@
// Copyright (C) 2020-2021 Intel Corporation
//
// SPDX-License-Identifier: MIT
import { createContext, Dispatch, SetStateAction } from 'react';
export interface IState<T> {
value: T;
set?: Dispatch<SetStateAction<T>>;
}
export function getDefaultState<T>(v: T): IState<T> {
return {
value: v,
// eslint-disable-next-line @typescript-eslint/no-unused-vars
set: (value: SetStateAction<T>): void => {},
};
}
export interface ICreateProjectContext {
projectClass: IState<string>;
trainingEnabled: IState<boolean>;
isTrainingActive: IState<boolean>;
}
export const defaultState: ICreateProjectContext = {
projectClass: getDefaultState<string>(''),
trainingEnabled: getDefaultState<boolean>(false),
isTrainingActive: getDefaultState<boolean>(false),
};
export default createContext<ICreateProjectContext>(defaultState);

@ -1,4 +1,4 @@
// Copyright (C) 2020 Intel Corporation
// Copyright (C) 2020-2021 Intel Corporation
//
// SPDX-License-Identifier: MIT

@ -3,8 +3,8 @@
// SPDX-License-Identifier: MIT
import { connect } from 'react-redux';
import { KeyMap } from 'utils/mousetrap-react';
import CanvasWrapperComponent from 'components/annotation-page/canvas/canvas-wrapper';
import {
confirmCanvasReady,

@ -2,7 +2,6 @@
//
// SPDX-License-Identifier: MIT
import { KeyMap } from 'utils/mousetrap-react';
import { connect } from 'react-redux';
import { Canvas } from 'cvat-canvas-wrapper';
@ -19,6 +18,7 @@ import {
} from 'actions/annotation-actions';
import ControlsSideBarComponent from 'components/annotation-page/review-workspace/controls-side-bar/controls-side-bar';
import { ActiveControl, CombinedState, Rotation } from 'reducers/interfaces';
import { KeyMap } from 'utils/mousetrap-react';
interface StateToProps {
canvasInstance: Canvas;

@ -1,4 +1,4 @@
// Copyright (C) 2020 Intel Corporation
// Copyright (C) 2020-2021 Intel Corporation
//
// SPDX-License-Identifier: MIT
@ -68,7 +68,9 @@ function mapDispatchToProps(dispatch: any): DispatchToProps {
type Props = StateToProps & DispatchToProps;
class PropagateConfirmContainer extends React.PureComponent<Props> {
private propagateObject = (): void => {
const { propagateObject, objectState, propagateFrames, frameNumber, stopFrame, jobInstance } = this.props;
const {
propagateObject, objectState, propagateFrames, frameNumber, stopFrame, jobInstance,
} = this.props;
const propagateUpToFrame = Math.min(frameNumber + propagateFrames, stopFrame);
propagateObject(jobInstance, objectState, frameNumber + 1, propagateUpToFrame);
@ -87,7 +89,9 @@ class PropagateConfirmContainer extends React.PureComponent<Props> {
};
public render(): JSX.Element {
const { frameNumber, stopFrame, propagateFrames, cancel, objectState } = this.props;
const {
frameNumber, stopFrame, propagateFrames, cancel, objectState,
} = this.props;
const propagateUpToFrame = Math.min(frameNumber + propagateFrames, stopFrame);

@ -18,6 +18,8 @@ import {
searchAnnotationsAsync,
searchEmptyFrameAsync,
setForceExitAnnotationFlag as setForceExitAnnotationFlagAction,
switchPredictor as switchPredictorAction,
getPredictionsAsync,
showFilters as showFiltersAction,
showStatistics as showStatisticsAction,
switchPlay,
@ -25,7 +27,9 @@ import {
} from 'actions/annotation-actions';
import AnnotationTopBarComponent from 'components/annotation-page/top-bar/top-bar';
import { Canvas } from 'cvat-canvas-wrapper';
import { CombinedState, FrameSpeed, Workspace } from 'reducers/interfaces';
import {
CombinedState, FrameSpeed, Workspace, PredictorState,
} from 'reducers/interfaces';
import GlobalHotKeys, { KeyMap } from 'utils/mousetrap-react';
interface StateToProps {
@ -48,6 +52,8 @@ interface StateToProps {
normalizedKeyMap: Record<string, string>;
canvasInstance: Canvas;
forceExit: boolean;
predictor: PredictorState;
isTrainingActive: boolean;
}
interface DispatchToProps {
@ -62,6 +68,7 @@ interface DispatchToProps {
searchEmptyFrame(sessionInstance: any, frameFrom: number, frameTo: number): void;
setForceExitAnnotationFlag(forceExit: boolean): void;
changeWorkspace(workspace: Workspace): void;
switchPredictor(predictorEnabled: boolean): void;
}
function mapStateToProps(state: CombinedState): StateToProps {
@ -78,12 +85,14 @@ function mapStateToProps(state: CombinedState): StateToProps {
job: { instance: jobInstance },
canvas: { ready: canvasIsReady, instance: canvasInstance },
workspace,
predictor,
},
settings: {
player: { frameSpeed, frameStep },
workspace: { autoSave, autoSaveInterval },
},
shortcuts: { keyMap, normalizedKeyMap },
plugins: { list },
} = state;
return {
@ -106,6 +115,8 @@ function mapStateToProps(state: CombinedState): StateToProps {
normalizedKeyMap,
canvasInstance,
forceExit,
predictor,
isTrainingActive: list.PREDICT,
};
}
@ -146,6 +157,12 @@ function mapDispatchToProps(dispatch: any): DispatchToProps {
setForceExitAnnotationFlag(forceExit: boolean): void {
dispatch(setForceExitAnnotationFlagAction(forceExit));
},
switchPredictor(predictorEnabled: boolean): void {
dispatch(switchPredictorAction(predictorEnabled));
if (predictorEnabled) {
dispatch(getPredictionsAsync());
}
},
};
}
@ -497,11 +514,14 @@ class AnnotationTopBarContainer extends React.PureComponent<Props, State> {
redoAction,
workspace,
canvasIsReady,
searchAnnotations,
changeWorkspace,
keyMap,
normalizedKeyMap,
canvasInstance,
predictor,
searchAnnotations,
changeWorkspace,
switchPredictor,
isTrainingActive,
} = this.props;
const preventDefault = (event: KeyboardEvent | undefined): void => {
@ -611,6 +631,8 @@ class AnnotationTopBarContainer extends React.PureComponent<Props, State> {
onInputChange={this.onChangePlayerInputValue}
onURLIconClick={this.onURLIconClick}
changeWorkspace={changeWorkspace}
switchPredictor={switchPredictor}
predictor={predictor}
workspace={workspace}
playing={playing}
saving={saving}
@ -636,6 +658,7 @@ class AnnotationTopBarContainer extends React.PureComponent<Props, State> {
onUndoClick={this.undo}
onRedoClick={this.redo}
jobInstance={jobInstance}
isTrainingActive={isTrainingActive}
/>
</>
);

@ -1,4 +1,4 @@
// Copyright (C) 2020 Intel Corporation
// Copyright (C) 2020-2021 Intel Corporation
//
// SPDX-License-Identifier: MIT
@ -69,7 +69,9 @@ export class FileManagerContainer extends React.PureComponent<Props> {
}
public render(): JSX.Element {
const { treeData, getTreeData, withRemote, onChangeActiveKey } = this.props;
const {
treeData, getTreeData, withRemote, onChangeActiveKey,
} = this.props;
return (
<FileManagerComponent

@ -1,4 +1,4 @@
// Copyright (C) 2020 Intel Corporation
// Copyright (C) 2020-2021 Intel Corporation
//
// SPDX-License-Identifier: MIT
@ -16,7 +16,9 @@ interface StateToProps {
function mapStateToProps(state: CombinedState): StateToProps {
const { models } = state;
const { interactors, detectors, trackers, reid } = models;
const {
interactors, detectors, trackers, reid,
} = models;
return {
interactors,

@ -1,4 +1,4 @@
// Copyright (C) 2020 Intel Corporation
// Copyright (C) 2020-2021 Intel Corporation
//
// SPDX-License-Identifier: MIT
@ -31,9 +31,9 @@ function mapStateToProps(state: CombinedState): StateToProps {
gettingQuery: tasks.gettingQuery,
numberOfTasks: state.tasks.count,
numberOfVisibleTasks: state.tasks.current.length,
numberOfHiddenTasks: tasks.hideEmpty
? tasks.current.filter((task: Task): boolean => !task.instance.jobs.length).length
: 0,
numberOfHiddenTasks: tasks.hideEmpty ?
tasks.current.filter((task: Task): boolean => !task.instance.jobs.length).length :
0,
};
}

@ -47,6 +47,7 @@ import SVGCubeIcon from './assets/cube-icon.svg';
import SVGResetPerspectiveIcon from './assets/reset-perspective.svg';
import SVGColorizeIcon from './assets/colorize-icon.svg';
import SVGAITools from './assets/ai-tools-icon.svg';
import SVGBrain from './assets/brain.svg';
import SVGOpenCV from './assets/opencv.svg';
import SVGFilterIcon from './assets/object-filter-icon.svg';
@ -93,5 +94,6 @@ export const CubeIcon = React.memo((): JSX.Element => <SVGCubeIcon />);
export const ResetPerspectiveIcon = React.memo((): JSX.Element => <SVGResetPerspectiveIcon />);
export const AIToolsIcon = React.memo((): JSX.Element => <SVGAITools />);
export const ColorizeIcon = React.memo((): JSX.Element => <SVGColorizeIcon />);
export const BrainIcon = React.memo((): JSX.Element => <SVGBrain />);
export const OpenCVIcon = React.memo((): JSX.Element => <SVGOpenCV />);
export const FilterIcon = React.memo((): JSX.Element => <SVGFilterIcon />);

@ -1,4 +1,4 @@
// Copyright (C) 2021 Intel Corporation
// Copyright (C) 2020-2021 Intel Corporation
//
// SPDX-License-Identifier: MIT
@ -38,6 +38,7 @@ const defaultState: AnnotationState = {
activeControl: ActiveControl.CURSOR,
},
job: {
openTime: null,
labels: [],
requestedId: null,
instance: null,
@ -108,6 +109,14 @@ const defaultState: AnnotationState = {
requestReviewDialogVisible: false,
submitReviewDialogVisible: false,
tabContentHeight: 0,
predictor: {
enabled: false,
error: null,
message: '',
projectScore: 0,
fetching: false,
annotatedFrames: [],
},
workspace: Workspace.STANDARD,
};
@ -129,6 +138,7 @@ export default (state = defaultState, action: AnyAction): AnnotationState => {
const {
job,
states,
openTime,
frameNumber: number,
frameFilename: filename,
colors,
@ -148,6 +158,7 @@ export default (state = defaultState, action: AnyAction): AnnotationState => {
...state,
job: {
...state.job,
openTime,
fetching: false,
instance: job,
labels: job.task.labels,
@ -1093,6 +1104,47 @@ export default (state = defaultState, action: AnyAction): AnnotationState => {
workspace,
};
}
case AnnotationActionTypes.UPDATE_PREDICTOR_STATE: {
const { payload } = action;
return {
...state,
predictor: {
...state.predictor,
...payload,
},
};
}
case AnnotationActionTypes.GET_PREDICTIONS: {
return {
...state,
predictor: {
...state.predictor,
fetching: true,
},
};
}
case AnnotationActionTypes.GET_PREDICTIONS_SUCCESS: {
const { frame } = action.payload;
const annotatedFrames = [...state.predictor.annotatedFrames, frame];
return {
...state,
predictor: {
...state.predictor,
fetching: false,
annotatedFrames,
},
};
}
case AnnotationActionTypes.GET_PREDICTIONS_FAILED: {
return {
...state,
predictor: {
...state.predictor,
fetching: false,
},
};
}
case AnnotationActionTypes.RESET_CANVAS: {
return {
...state,

@ -111,6 +111,7 @@ export enum SupportedPlugins {
GIT_INTEGRATION = 'GIT_INTEGRATION',
ANALYTICS = 'ANALYTICS',
MODELS = 'MODELS',
PREDICT = 'PREDICT',
}
export type PluginsList = {
@ -301,6 +302,9 @@ export interface NotificationsState {
commentingIssue: null | ErrorState;
submittingReview: null | ErrorState;
};
predictor: {
prediction: null | ErrorState;
};
};
messages: {
tasks: {
@ -367,6 +371,18 @@ export enum Rotation {
CLOCKWISE90,
}
export interface PredictorState {
timeRemaining: number;
progress: number;
projectScore: number;
message: string;
error: Error | null;
enabled: boolean;
fetching: boolean;
annotationAmount: number;
mediaAmount: number;
}
export interface AnnotationState {
activities: {
loads: {
@ -388,6 +404,7 @@ export interface AnnotationState {
activeControl: ActiveControl;
};
job: {
openTime: null | number;
labels: any[];
requestedId: number | null;
instance: any | null | undefined;
@ -462,6 +479,7 @@ export interface AnnotationState {
appearanceCollapsed: boolean;
tabContentHeight: number;
workspace: Workspace;
predictor: PredictorState;
aiToolsRef: MutableRefObject<any>;
}

@ -102,6 +102,9 @@ const defaultState: NotificationsState = {
resolvingIssue: null,
submittingReview: null,
},
predictor: {
prediction: null,
},
},
messages: {
tasks: {
@ -1104,6 +1107,21 @@ export default function (state = defaultState, action: AnyAction): Notifications
},
};
}
case AnnotationActionTypes.GET_PREDICTIONS_FAILED: {
return {
...state,
errors: {
...state.errors,
predictor: {
...state.errors.predictor,
prediction: {
message: 'Could not fetch prediction data',
reason: action.payload.error,
},
},
},
};
}
case BoundariesActionTypes.RESET_AFTER_ERROR:
case AuthActionTypes.LOGOUT_SUCCESS: {
return { ...defaultState };

@ -1,4 +1,4 @@
// Copyright (C) 2020 Intel Corporation
// Copyright (C) 2020-2021 Intel Corporation
//
// SPDX-License-Identifier: MIT
@ -13,6 +13,7 @@ const defaultState: PluginsState = {
GIT_INTEGRATION: false,
ANALYTICS: false,
MODELS: false,
PREDICT: false,
},
};

@ -0,0 +1,48 @@
# Generated by Django 3.1.7 on 2021-04-02 13:17
from django.db import migrations, models
import django.db.models.deletion
class Migration(migrations.Migration):
dependencies = [
('engine', '0038_manifest'),
]
operations = [
migrations.CreateModel(
name='TrainingProject',
fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('host', models.CharField(max_length=256)),
('username', models.CharField(max_length=256)),
('password', models.CharField(max_length=256)),
('training_id', models.CharField(max_length=64)),
('enabled', models.BooleanField(null=True)),
('project_class', models.CharField(blank=True, choices=[('OD', 'Object Detection')], max_length=2, null=True)),
],
),
migrations.CreateModel(
name='TrainingProjectLabel',
fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('training_label_id', models.CharField(max_length=64)),
('cvat_label', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='training_project_label', to='engine.label')),
],
),
migrations.CreateModel(
name='TrainingProjectImage',
fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('idx', models.PositiveIntegerField()),
('training_image_id', models.CharField(max_length=64)),
('task', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='engine.task')),
],
),
migrations.AddField(
model_name='project',
name='training_project',
field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, to='engine.trainingproject'),
),
]

@ -2,15 +2,16 @@
#
# SPDX-License-Identifier: MIT
from enum import Enum
import re
import os
import re
from enum import Enum
from django.db import models
from django.conf import settings
from django.contrib.auth.models import User
from django.core.files.storage import FileSystemStorage
from django.db import models
from django.utils.translation import gettext_lazy as _
class SafeCharField(models.CharField):
def get_prep_value(self, value):
@ -19,6 +20,7 @@ class SafeCharField(models.CharField):
return value[:self.max_length]
return value
class DimensionType(str, Enum):
DIM_3D = '3d'
DIM_2D = '2d'
@ -152,6 +154,7 @@ class Video(models.Model):
class Meta:
default_permissions = ()
class Image(models.Model):
data = models.ForeignKey(Data, on_delete=models.CASCADE, related_name="images", null=True)
path = models.CharField(max_length=1024, default='')
@ -162,17 +165,32 @@ class Image(models.Model):
class Meta:
default_permissions = ()
class TrainingProject(models.Model):
class ProjectClass(models.TextChoices):
DETECTION = 'OD', _('Object Detection')
host = models.CharField(max_length=256)
username = models.CharField(max_length=256)
password = models.CharField(max_length=256)
training_id = models.CharField(max_length=64)
enabled = models.BooleanField(null=True)
project_class = models.CharField(max_length=2, choices=ProjectClass.choices, null=True, blank=True)
class Project(models.Model):
name = SafeCharField(max_length=256)
owner = models.ForeignKey(User, null=True, blank=True,
on_delete=models.SET_NULL, related_name="+")
assignee = models.ForeignKey(User, null=True, blank=True,
on_delete=models.SET_NULL, related_name="+")
on_delete=models.SET_NULL, related_name="+")
assignee = models.ForeignKey(User, null=True, blank=True,
on_delete=models.SET_NULL, related_name="+")
bug_tracker = models.CharField(max_length=2000, blank=True, default="")
created_date = models.DateTimeField(auto_now_add=True)
updated_date = models.DateTimeField(auto_now_add=True)
status = models.CharField(max_length=32, choices=StatusChoice.choices(),
default=StatusChoice.ANNOTATION)
default=StatusChoice.ANNOTATION)
training_project = models.ForeignKey(TrainingProject, null=True, blank=True, on_delete=models.SET_NULL)
def get_project_dirname(self):
return os.path.join(settings.PROJECTS_ROOT, str(self.id))
@ -210,7 +228,7 @@ class Task(models.Model):
# Zero means that there are no limits (default)
segment_size = models.PositiveIntegerField(default=0)
status = models.CharField(max_length=32, choices=StatusChoice.choices(),
default=StatusChoice.ANNOTATION)
default=StatusChoice.ANNOTATION)
data = models.ForeignKey(Data, on_delete=models.CASCADE, null=True, related_name="tasks")
dimension = models.CharField(max_length=2, choices=DimensionType.choices(), default=DimensionType.DIM_2D)
subset = models.CharField(max_length=64, blank=True, default="")
@ -237,6 +255,13 @@ class Task(models.Model):
def __str__(self):
return self.name
class TrainingProjectImage(models.Model):
task = models.ForeignKey(Task, on_delete=models.CASCADE)
idx = models.PositiveIntegerField()
training_image_id = models.CharField(max_length=64)
# Redefined a couple of operation for FileSystemStorage to avoid renaming
# or other side effects.
class MyFileSystemStorage(FileSystemStorage):
@ -319,6 +344,12 @@ class Label(models.Model):
default_permissions = ()
unique_together = ('task', 'name')
class TrainingProjectLabel(models.Model):
cvat_label = models.ForeignKey(Label, on_delete=models.CASCADE, related_name='training_project_label')
training_label_id = models.CharField(max_length=64)
class AttributeType(str, Enum):
CHECKBOX = 'checkbox'
RADIO = 'radio'

@ -9,9 +9,11 @@ import shutil
from rest_framework import serializers, exceptions
from django.contrib.auth.models import User, Group
from cvat.apps.dataset_manager.formats.utils import get_label_color
from cvat.apps.engine import models
from cvat.apps.engine.log import slogger
from cvat.apps.dataset_manager.formats.utils import get_label_color
class BasicUserSerializer(serializers.ModelSerializer):
def validate(self, data):
@ -415,6 +417,7 @@ class TaskSerializer(WriteOnceMixin, serializers.ModelSerializer):
raise serializers.ValidationError('All label names must be unique for the task')
return value
class ProjectSearchSerializer(serializers.ModelSerializer):
class Meta:
model = models.Project
@ -423,17 +426,25 @@ class ProjectSearchSerializer(serializers.ModelSerializer):
ordering = ['-id']
class TrainingProjectSerializer(serializers.ModelSerializer):
class Meta:
model = models.TrainingProject
fields = ('host', 'username', 'password', 'enabled', 'project_class')
write_once_fields = ('host', 'username', 'password', 'project_class')
class ProjectWithoutTaskSerializer(serializers.ModelSerializer):
labels = LabelSerializer(many=True, source='label_set', partial=True, default=[])
owner = BasicUserSerializer(required=False)
owner_id = serializers.IntegerField(write_only=True, allow_null=True, required=False)
assignee = BasicUserSerializer(allow_null=True, required=False)
assignee_id = serializers.IntegerField(write_only=True, allow_null=True, required=False)
training_project = TrainingProjectSerializer(required=False, allow_null=True)
class Meta:
model = models.Project
fields = ('url', 'id', 'name', 'labels', 'owner', 'assignee', 'owner_id', 'assignee_id',
'bug_tracker', 'created_date', 'updated_date', 'status')
fields = ('url', 'id', 'name', 'labels', 'tasks', 'owner', 'assignee', 'owner_id', 'assignee_id',
'bug_tracker', 'created_date', 'updated_date', 'status', 'training_project')
read_only_fields = ('created_date', 'updated_date', 'status', 'owner', 'asignee')
ordering = ['-id']
@ -456,7 +467,17 @@ class ProjectSerializer(ProjectWithoutTaskSerializer):
# pylint: disable=no-self-use
def create(self, validated_data):
labels = validated_data.pop('label_set')
db_project = models.Project.objects.create(**validated_data)
training_data = validated_data.pop('training_project', {})
if training_data.get('enabled'):
host = training_data.pop('host').strip('/')
username = training_data.pop('username').strip()
password = training_data.pop('password').strip()
tr_p = models.TrainingProject.objects.create(**training_data,
host=host, username=username, password=password)
db_project = models.Project.objects.create(**validated_data,
training_project=tr_p)
else:
db_project = models.Project.objects.create(**validated_data)
label_names = list()
for label in labels:
attributes = label.pop('attributespec_set')
@ -472,7 +493,6 @@ class ProjectSerializer(ProjectWithoutTaskSerializer):
shutil.rmtree(project_path)
os.makedirs(db_project.get_project_logs_dirname())
db_project.save()
return db_project
# pylint: disable=no-self-use
@ -530,6 +550,7 @@ class PluginsSerializer(serializers.Serializer):
GIT_INTEGRATION = serializers.BooleanField()
ANALYTICS = serializers.BooleanField()
MODELS = serializers.BooleanField()
PREDICT = serializers.BooleanField()
class DataMetaSerializer(serializers.ModelSerializer):
frames = FrameMetaSerializer(many=True, allow_null=True)

@ -13,6 +13,7 @@ from django.views.generic import RedirectView
from django.conf import settings
from cvat.apps.restrictions.views import RestrictionsViewSet
from cvat.apps.authentication.decorators import login_required
from cvat.apps.training.views import PredictView
schema_view = get_schema_view(
openapi.Info(
@ -53,6 +54,7 @@ router.register('reviews', views.ReviewViewSet)
router.register('issues', views.IssueViewSet)
router.register('comments', views.CommentViewSet)
router.register('restrictions', RestrictionsViewSet, basename='restrictions')
router.register('predict', PredictView, basename='predict')
urlpatterns = [
# Entry point for a client

@ -2,23 +2,23 @@
#
# SPDX-License-Identifier: MIT
import io
import os
import os.path as osp
import io
import shutil
import traceback
from datetime import datetime
from distutils.util import strtobool
from tempfile import mkstemp
import cv2
import cv2
import django_rq
from django.shortcuts import get_object_or_404
from django.apps import apps
from django.conf import settings
from django.contrib.auth.models import User
from django.db import IntegrityError
from django.http import HttpResponse
from django.shortcuts import get_object_or_404
from django.utils import timezone
from django.utils.decorators import method_decorator
from django_filters import rest_framework as filters
@ -35,7 +35,7 @@ from rest_framework.response import Response
from sendfile import sendfile
import cvat.apps.dataset_manager as dm
import cvat.apps.dataset_manager.views # pylint: disable=unused-import
import cvat.apps.dataset_manager.views # pylint: disable=unused-import
from cvat.apps.authentication import auth
from cvat.apps.dataset_manager.bindings import CvatImportError
from cvat.apps.dataset_manager.serializers import DatasetFormatsSerializer
@ -53,7 +53,6 @@ from cvat.apps.engine.serializers import (
CombinedReviewSerializer, IssueSerializer, CombinedIssueSerializer, CommentSerializer
)
from cvat.apps.engine.utils import av_scan_paths
from . import models, task
from .log import clogger, slogger
@ -188,6 +187,7 @@ class ServerViewSet(viewsets.ViewSet):
'GIT_INTEGRATION': apps.is_installed('cvat.apps.dataset_repo'),
'ANALYTICS': False,
'MODELS': False,
'PREDICT': apps.is_installed('cvat.apps.training')
}
if strtobool(os.environ.get("CVAT_ANALYTICS", '0')):
response['ANALYTICS'] = True
@ -290,6 +290,7 @@ class ProjectViewSet(auth.ProjectGetQuerySetMixin, viewsets.ModelViewSet):
context={"request": request})
return Response(serializer.data)
class TaskFilter(filters.FilterSet):
project = filters.CharFilter(field_name="project__name", lookup_expr="icontains")
name = filters.CharFilter(field_name="name", lookup_expr="icontains")
@ -1109,3 +1110,5 @@ def _export_annotations(db_task, rq_id, request, format_name, action, callback,
meta={ 'request_time': timezone.localtime() },
result_ttl=ttl, failure_ttl=ttl)
return Response(status=status.HTTP_202_ACCEPTED)

@ -0,0 +1 @@
default_app_config = 'cvat.apps.training.apps.TrainingConfig'

@ -0,0 +1,362 @@
import uuid
from abc import ABC, abstractmethod
from collections import OrderedDict
from functools import wraps
from typing import Callable, List, Union
import requests
from cacheops import cache, CacheMiss
from cvat.apps.engine.models import TrainingProject, ShapeType
class TrainingServerAPIAbs(ABC):
def __init__(self, host, username, password):
self.host = host
self.username = username
self.password = password
@abstractmethod
def get_server_status(self):
pass
@abstractmethod
def create_project(self, name: str, description: str = '', project_class: TrainingProject.ProjectClass = None,
labels: List[dict] = None):
pass
@abstractmethod
def upload_annotations(self, project_id: str, frames_data: List[dict]):
pass
@abstractmethod
def get_project_status(self, project_id: str) -> dict:
pass
@abstractmethod
def get_annotation(self, project_id: str, image_id: str, width: int, height: int, frame: int,
labels_mapping: dict) -> dict:
pass
def retry(amount: int = 2) -> Callable:
def dec(func: Callable) -> Callable:
@wraps(func)
def wrapper(*args, **kwargs):
__amount = amount
while __amount > 0:
__amount -= 1
try:
result = func(*args, **kwargs)
return result
except Exception:
pass
return wrapper
return dec
class TrainingServerAPI(TrainingServerAPIAbs):
TRAINING_CLASS = {
TrainingProject.ProjectClass.DETECTION: "DETECTION"
}
@staticmethod
def __convert_annotation_from_cvat(shapes):
data = []
for shape in shapes:
x0, y0, x1, y1 = shape['points']
x = x0 / shape['width']
y = y0 / shape['height']
width = (x1 - x0) / shape['width']
height = (y1 - y0) / shape['height']
data.append({
"id": str(uuid.uuid4()),
"shapes": [
{
"type": "rect",
"geometry": {
"x": x,
"y": y,
"width": width,
"height": height,
"points": None,
}
}
],
"editor": None,
"labels": [
{
"id": shape['third_party_label_id'],
"probability": 1.0,
},
],
})
return data
@staticmethod
def __convert_annotation_to_cvat(annotation: dict, image_width: int, image_height: int, frame: int,
labels_mapping: dict) -> List[OrderedDict]:
shapes = []
for i, annotation in enumerate(annotation.get('data', [])):
label_id = annotation['labels'][0]['id']
if not labels_mapping.get(label_id):
continue
shape = annotation['shapes'][0]
if shape['type'] != 'rect':
continue
x = shape['geometry']['x']
y = shape['geometry']['y']
w = shape['geometry']['width']
h = shape['geometry']['height']
x0 = x * image_width
y0 = y * image_height
x1 = image_width * w + x0
y1 = image_height * h + y0
shapes.append(OrderedDict([
('type', ShapeType.RECTANGLE),
('occluded', False),
('z_order', 0),
('points', [x0, y0, x1, y1]),
('id', i),
('frame', int(frame)),
('label', labels_mapping.get(label_id)),
('group', 0),
('source', 'auto'),
('attributes', {})
]))
return shapes
@retry()
def __create_project(self, name: str, description: str = None,
labels: List[dict] = None, tasks: List[dict] = None) -> dict:
url = f'{self.host}/v2/projects'
headers = {
'Context-Type': 'application/json',
'Authorization': f'bearer_token {self.token}',
}
tasks[1]['properties'] = [
{
"id": "labels",
"user_value": labels
}
]
data = {
'name': name,
'description': description,
"dimensions": [],
"group_type": "normal",
'pipeline': {
'connections': [{
'from': {
**tasks[0]['output_ports'][0],
'task_id': tasks[0]['temp_id'],
},
'to': {
**tasks[1]['input_ports'][0],
'task_id': tasks[1]['temp_id'],
}
}],
'tasks': tasks,
},
"pipeline_representation": 'Detection',
"type": "project",
}
response = self.request(method='POST', url=url, json=data, headers=headers)
return response
@retry()
def __get_annotation(self, project_id: str, image_id: str) -> dict:
url = f'{self.host}/v2/projects/{project_id}/media/images/{image_id}/results/online'
headers = {
'Authorization': f'bearer_token {self.token}',
}
response = self.request(method='GET', url=url, headers=headers)
return response
@retry()
def __get_job_status(self, project_id: str) -> dict:
url = f'{self.host}/v2/projects/{project_id}/jobs'
headers = {
'Authorization': f'bearer_token {self.token}',
}
response = self.request(method='GET', url=url, headers=headers)
return response
@retry()
def __get_project_summary(self, project_id: str) -> dict:
url = f'{self.host}/v2/projects/{project_id}/statistics/summary'
headers = {
'Authorization': f'bearer_token {self.token}',
}
response = self.request(method='GET', url=url, headers=headers)
return response
@retry()
def __get_project(self, project_id: str) -> dict:
url = f'{self.host}/v2/projects/{project_id}'
headers = {
'Authorization': f'bearer_token {self.token}',
}
response = self.request(method='GET', url=url, headers=headers)
return response
@retry()
def __get_server_status(self) -> dict:
url = f'{self.host}/v2/status'
headers = {
'Authorization': f'bearer_token {self.token}',
}
response = self.request(method='GET', url=url, headers=headers)
return response
@retry()
def __get_tasks(self) -> List[dict]:
url = f'{self.host}/v2/tasks'
headers = {
'Authorization': f'bearer_token {self.token}',
}
response = self.request(method='GET', url=url, headers=headers)
return response
def __delete_token(self):
cache.delete(self.token_key)
@retry()
def __upload_annotation(self, project_id: str, image_id: str, annotation: List[dict]):
url = f'{self.host}/v2/projects/{project_id}/media/images/{image_id}/annotations'
headers = {
'Authorization': f'bearer_token {self.token}',
'Content-Type': 'application/json'
}
data = {
'image_id': image_id,
'data': annotation
}
response = self.request(method='POST', url=url, headers=headers, json=data)
return response
@retry()
def __upload_image(self, project_id: str, buffer) -> dict:
url = f'{self.host}/v2/projects/{project_id}/media/images'
files = {'file': buffer}
headers = {
'Authorization': f'bearer_token {self.token}',
}
response = self.request(method='POST', url=url, headers=headers, files=files)
return response
@property
def project_id_key(self):
return f'{self.host}_{self.username}_project_id'
@property
def token(self) -> str:
def get_token(host: str, username: str, password: str) -> dict:
url = f'{host}/v2/authentication'
data = {
'username': (None, username),
'password': (None, password),
}
r = requests.post(url=url, files=data, verify=False) # nosec
return r.json()
try:
token = cache.get(self.token_key)
except CacheMiss:
response = get_token(self.host, self.username, self.password)
token = response.get('secure_token', '')
expires_in = response.get('expires_in', 3600)
cache.set(cache_key=self.token_key, data=token, timeout=expires_in)
return token
@property
def token_key(self):
return f'{self.host}_{self.username}_token'
def request(self, method: str, url: str, **kwargs) -> Union[list, dict, str]:
response = requests.request(method=method, url=url, verify=False, **kwargs)
if response.status_code == 401:
self.__delete_token()
raise Exception("401")
result = response.json()
return result
def create_project(self, name: str, description: str = '', project_class: TrainingProject.ProjectClass = None,
labels: List[dict] = None) -> dict:
all_tasks = self.__get_tasks()
task_type = self.TRAINING_CLASS.get(project_class)
task_algo = 'Retinanet - TF2'
tasks = [
next(({'temp_id': '_1_', **task}
for task in all_tasks
if task['task_type'] == 'DATASET'), {}),
next(({'temp_id': '_2_', **task}
for task in all_tasks
if task['task_type'] == task_type and
task['algorithm_name'] == task_algo), {}),
]
labels = [{
'name': label['name'],
'temp_id': label['name']
} for label in labels]
r = self.__create_project(name=name, description=description, tasks=tasks, labels=labels)
return r
def get_server_status(self) -> dict:
return self.__get_server_status()
def upload_annotations(self, project_id: str, frames_data: List[dict]):
for frame in frames_data:
annotation = self.__convert_annotation_from_cvat(frame['shapes'])
self.__upload_annotation(project_id=project_id, image_id=frame['third_party_id'], annotation=annotation)
def upload_image(self, training_id: str, buffer):
response = self.__upload_image(project_id=training_id, buffer=buffer)
return response.get('id')
def get_project_status(self, project_id) -> dict:
summary = self.__get_project_summary(project_id=project_id)
if not summary or not isinstance(summary, list):
return {'message': 'Not available'}
jobs = self.__get_job_status(project_id=project_id)
media_amount = next(item.get('value', 0) for item in summary if item.get('key') == 'Media')
annotation_amount = next(item.get('value', 0) for item in summary if item.get('key') == 'Annotation')
score = next(item.get('value', 0) for item in summary if item.get('key') == 'Score')
job_items = jobs.get('items', 0)
if len(job_items) == 0 and score == 0:
message = 'Not started'
elif len(job_items) == 0 and score > 0:
message = ''
else:
message = 'In progress'
progress = 0 if len(job_items) == 0 else job_items[0]["status"]["progress"]
time_remaining = 0 if len(job_items) == 0 else job_items[0]["status"]['time_remaining']
result = {
'media_amount': media_amount if media_amount else 0,
'annotation_amount': annotation_amount,
'score': score,
'message': message,
'progress': progress,
'time_remaining': time_remaining,
}
return result
def get_annotation(self, project_id: str, image_id: str, width: int, height: int, frame: int,
labels_mapping: dict) -> List[OrderedDict]:
annotation = self.__get_annotation(project_id=project_id, image_id=image_id)
cvat_annotation = self.__convert_annotation_to_cvat(annotation=annotation, image_width=width,
image_height=height, frame=frame,
labels_mapping=labels_mapping)
return cvat_annotation
def get_labels(self, project_id: str) -> List[dict]:
project = self.__get_project(project_id=project_id)
labels = [{
'id': label['id'],
'name': label['name']
} for label in project.get('labels')]
return labels

@ -0,0 +1,11 @@
from django.apps import AppConfig
class TrainingConfig(AppConfig):
name = 'cvat.apps.training'
def ready(self):
# Required to define signals in application
import cvat.apps.training.signals
# Required in order to silent "unused-import" in pyflake
assert cvat.apps.training.signals

@ -0,0 +1,186 @@
from collections import OrderedDict
from typing import List
from cacheops import cache
from django_rq import job
from cvat.apps import dataset_manager as dm
from cvat.apps.engine.frame_provider import FrameProvider
from cvat.apps.engine.models import (
Project,
Task,
TrainingProjectImage,
Label,
Image,
TrainingProjectLabel,
Data,
Job,
ShapeType,
)
from cvat.apps.training.apis import TrainingServerAPI
@job
def save_prediction_server_status_to_cache_job(cache_key,
cvat_project_id,
timeout=60):
cvat_project = Project.objects.get(pk=cvat_project_id)
api = TrainingServerAPI(host=cvat_project.training_project.host, username=cvat_project.training_project.username,
password=cvat_project.training_project.password)
status = api.get_project_status(project_id=cvat_project.training_project.training_id)
resp = {
**status,
'status': 'done'
}
cache.set(cache_key=cache_key, data=resp, timeout=timeout)
@job
def save_frame_prediction_to_cache_job(cache_key: str,
task_id: int,
frame: int,
timeout: int = 60):
task = Task.objects.get(pk=task_id)
training_project_image = TrainingProjectImage.objects.filter(idx=frame, task=task).first()
if not training_project_image:
cache.set(cache_key=cache_key, data={
'annotation': [],
'status': 'done'
}, timeout=timeout)
return
cvat_labels = Label.objects.filter(project__id=task.project_id).all()
training_project = Project.objects.get(pk=task.project_id).training_project
api = TrainingServerAPI(host=training_project.host,
username=training_project.username,
password=training_project.password)
image = Image.objects.get(frame=frame, data=task.data)
labels_mapping = {
TrainingProjectLabel.objects.get(cvat_label=cvat_label).training_label_id: cvat_label.id
for cvat_label in cvat_labels
}
annotation = api.get_annotation(project_id=training_project.training_id,
image_id=training_project_image.training_image_id,
width=image.width,
height=image.height,
labels_mapping=labels_mapping,
frame=frame)
resp = {
'annotation': annotation,
'status': 'done'
}
cache.set(cache_key=cache_key, data=resp, timeout=timeout)
@job
def upload_images_job(task_id: int):
if TrainingProjectImage.objects.filter(task_id=task_id).count() is 0:
task = Task.objects.get(pk=task_id)
frame_provider = FrameProvider(task.data)
frames = frame_provider.get_frames()
api = TrainingServerAPI(
host=task.project.training_project.host,
username=task.project.training_project.username,
password=task.project.training_project.password,
)
for i, (buffer, _) in enumerate(frames):
training_image_id = api.upload_image(training_id=task.project.training_project.training_id, buffer=buffer)
if training_image_id:
TrainingProjectImage.objects.create(task=task, idx=i,
training_image_id=training_image_id)
def __add_fields_to_shape(shape: dict, frame: int, data: Data, labels_mapping: dict) -> dict:
image = Image.objects.get(frame=frame, data=data)
return {
**shape,
'height': image.height,
'width': image.width,
'third_party_label_id': labels_mapping[shape['label_id']],
}
@job
def upload_annotation_to_training_project_job(job_id: int):
cvat_job = Job.objects.get(pk=job_id)
cvat_project = cvat_job.segment.task.project
training_project = cvat_project.training_project
start = cvat_job.segment.start_frame
stop = cvat_job.segment.stop_frame
data = dm.task.get_job_data(job_id)
shapes: List[OrderedDict] = data.get('shapes', [])
frames_data = []
api = TrainingServerAPI(
host=cvat_project.training_project.host,
username=cvat_project.training_project.username,
password=cvat_project.training_project.password,
)
cvat_labels = Label.objects.filter(project=cvat_project).all()
labels_mapping = {
cvat_label.id: TrainingProjectLabel.objects.get(cvat_label=cvat_label).training_label_id
for cvat_label in cvat_labels
}
for frame in range(start, stop + 1):
frame_shapes = list(
map(
lambda x: __add_fields_to_shape(x, frame, cvat_job.segment.task.data, labels_mapping),
filter(
lambda x: x['frame'] == frame and x['type'] == ShapeType.RECTANGLE,
shapes,
)
)
)
if frame_shapes:
training_project_image = TrainingProjectImage.objects.get(task=cvat_job.segment.task, idx=frame)
frames_data.append({
'third_party_id': training_project_image.training_image_id,
'shapes': frame_shapes
})
api.upload_annotations(project_id=training_project.training_id, frames_data=frames_data)
@job
def create_training_project_job(project_id: int):
cvat_project = Project.objects.get(pk=project_id)
training_project = cvat_project.training_project
api = TrainingServerAPI(
host=cvat_project.training_project.host,
username=cvat_project.training_project.username,
password=cvat_project.training_project.password,
)
create_training_project(cvat_project=cvat_project, training_project=training_project, api=api)
def create_training_project(cvat_project, training_project, api):
labels = cvat_project.label_set.all()
training_project_resp = api.create_project(
name=f'{cvat_project.name}_cvat',
project_class=training_project.project_class,
labels=[{'name': label.name} for label in labels]
)
if training_project_resp.get('id'):
training_project.training_id = training_project_resp['id']
training_project.save()
for cvat_label in labels:
training_label = list(filter(lambda x: x['name'] == cvat_label.name, training_project_resp.get('labels', [])))
if training_label:
TrainingProjectLabel.objects.create(cvat_label=cvat_label, training_label_id=training_label[0]['id'])
async def upload_images(cvat_project_id, training_id, api):
project = Project.objects.get(pk=cvat_project_id)
tasks: List[Task] = project.tasks.all()
for task in tasks:
frame_provider = FrameProvider(task)
frames = frame_provider.get_frames()
for i, (buffer, _) in enumerate(frames):
training_image_id = api.upload_image(training_id=training_id, buffer=buffer)
if training_image_id:
TrainingProjectImage.objects.create(project=project, task=task, idx=i,
training_image_id=training_image_id)

@ -0,0 +1,30 @@
from django.db.models.signals import post_save
from django.dispatch import receiver
from cvat.apps.engine.models import Job, StatusChoice, Project, Task
from cvat.apps.training.jobs import (
create_training_project_job,
upload_images_job,
upload_annotation_to_training_project_job,
)
@receiver(post_save, sender=Project, dispatch_uid="create_training_project")
def create_training_project(instance: Project, **kwargs):
if instance.training_project:
create_training_project_job.delay(instance.id)
@receiver(post_save, sender=Task, dispatch_uid='upload_images_to_training_project')
def upload_images_to_training_project(instance: Task, **kwargs):
if (instance.status == StatusChoice.ANNOTATION and
instance.data and instance.data.size != 0 and \
instance.project_id and instance.project.training_project):
upload_images_job.delay(instance.id)
@receiver(post_save, sender=Job, dispatch_uid="upload_annotation_to_training_project")
def upload_annotation_to_training_project(instance: Job, **kwargs):
if instance.status == StatusChoice.COMPLETED:
upload_annotation_to_training_project_job.delay(instance.id)

@ -0,0 +1,11 @@
from django.urls import path, include
from rest_framework import routers
from cvat.apps.training.views import PredictView
router = routers.DefaultRouter(trailing_slash=False)
router.register('', PredictView, basename='predict')
urlpatterns = [
path('', include((router.urls, 'predict'), namespace='predict'))
]

@ -0,0 +1,68 @@
from cacheops import cache, CacheMiss
from drf_yasg.utils import swagger_auto_schema
from rest_framework import viewsets, status
from rest_framework.decorators import action
from rest_framework.permissions import IsAuthenticated, SAFE_METHODS
from rest_framework.response import Response
from cvat.apps.authentication import auth
from cvat.apps.engine.models import Project
from cvat.apps.training.jobs import save_frame_prediction_to_cache_job, save_prediction_server_status_to_cache_job
class PredictView(viewsets.ViewSet):
def get_permissions(self):
http_method = self.request.method
permissions = [IsAuthenticated]
if http_method in SAFE_METHODS:
permissions.append(auth.ProjectAccessPermission)
else:
permissions.append(auth.AdminRolePermission)
return [perm() for perm in permissions]
@swagger_auto_schema(method='get', operation_summary='Returns prediction for image')
@action(detail=False, methods=['GET'], url_path='frame')
def predict_image(self, request):
frame = self.request.query_params.get('frame')
task_id = self.request.query_params.get('task')
if not task_id:
return Response(data='query param "task" empty or not provided', status=status.HTTP_400_BAD_REQUEST)
if not frame:
return Response(data='query param "frame" empty or not provided', status=status.HTTP_400_BAD_REQUEST)
cache_key = f'predict_image_{task_id}_{frame}'
try:
resp = cache.get(cache_key)
except CacheMiss:
save_frame_prediction_to_cache_job.delay(cache_key, task_id=task_id,
frame=frame)
resp = {
'status': 'queued',
}
cache.set(cache_key=cache_key, data=resp, timeout=60)
return Response(resp)
@swagger_auto_schema(method='get',
operation_summary='Returns information of the tasks of the project with the selected id')
@action(detail=False, methods=['GET'], url_path='status')
def predict_status(self, request):
project_id = self.request.query_params.get('project')
if not project_id:
return Response(data='query param "project" empty or not provided', status=status.HTTP_400_BAD_REQUEST)
project = Project.objects.get(pk=project_id)
if not project.training_project:
Response({'status': 'done'})
cache_key = f'predict_status_{project_id}'
try:
resp = cache.get(cache_key)
except CacheMiss:
save_prediction_server_status_to_cache_job.delay(cache_key, cvat_project_id=project_id)
resp = {
'status': 'queued',
}
cache.set(cache_key=cache_key, data=resp, timeout=60)
return Response(resp)

@ -20,6 +20,8 @@ import fcntl
import shutil
import subprocess
import mimetypes
from distutils.util import strtobool
mimetypes.add_type("application/wasm", ".wasm", True)
from pathlib import Path
@ -129,6 +131,9 @@ INSTALLED_APPS = [
'rest_auth.registration'
]
if strtobool(os.environ.get("ADAPTIVE_AUTO_ANNOTATION", 'false')):
INSTALLED_APPS.append('cvat.apps.training')
SITE_ID = 1
REST_FRAMEWORK = {

@ -64,4 +64,4 @@ class PatchedDiscoverRunner(DiscoverRunner):
for config in RQ_QUEUES.values():
config["ASYNC"] = False
super().__init__(*args, **kwargs)
super().__init__(*args, **kwargs)

@ -43,3 +43,6 @@ if apps.is_installed('cvat.apps.opencv'):
if apps.is_installed('silk'):
urlpatterns.append(path('profiler/', include('silk.urls')))
if apps.is_installed('cvat.apps.training'):
urlpatterns.append(path('api/v1/predict/', include('cvat.apps.training.urls')))

@ -42,6 +42,7 @@ services:
ALLOWED_HOSTS: '*'
CVAT_REDIS_HOST: 'cvat_redis'
CVAT_POSTGRES_HOST: 'cvat_db'
ADAPTIVE_AUTO_ANNOTATION: 'false'
volumes:
- cvat_data:/home/django/data
- cvat_keys:/home/django/keys

Loading…
Cancel
Save