You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
548 lines
20 KiB
TypeScript
548 lines
20 KiB
TypeScript
// Copyright (C) 2020 Intel Corporation
|
|
//
|
|
// SPDX-License-Identifier: MIT
|
|
|
|
import { ActionUnion, createAction, ThunkAction } from 'utils/redux';
|
|
import {
|
|
Model,
|
|
ModelType,
|
|
ModelFiles,
|
|
ActiveInference,
|
|
CombinedState,
|
|
} from 'reducers/interfaces';
|
|
import getCore from 'cvat-core';
|
|
|
|
export enum PreinstalledModels {
|
|
RCNN = 'RCNN Object Detector',
|
|
MaskRCNN = 'Mask RCNN Object Detector',
|
|
}
|
|
|
|
export enum ModelsActionTypes {
|
|
GET_MODELS = 'GET_MODELS',
|
|
GET_MODELS_SUCCESS = 'GET_MODELS_SUCCESS',
|
|
GET_MODELS_FAILED = 'GET_MODELS_FAILED',
|
|
DELETE_MODEL = 'DELETE_MODEL',
|
|
DELETE_MODEL_SUCCESS = 'DELETE_MODEL_SUCCESS',
|
|
DELETE_MODEL_FAILED = 'DELETE_MODEL_FAILED',
|
|
CREATE_MODEL = 'CREATE_MODEL',
|
|
CREATE_MODEL_SUCCESS = 'CREATE_MODEL_SUCCESS',
|
|
CREATE_MODEL_FAILED = 'CREATE_MODEL_FAILED',
|
|
CREATE_MODEL_STATUS_UPDATED = 'CREATE_MODEL_STATUS_UPDATED',
|
|
START_INFERENCE_FAILED = 'START_INFERENCE_FAILED',
|
|
GET_INFERENCE_STATUS_SUCCESS = 'GET_INFERENCE_STATUS_SUCCESS',
|
|
GET_INFERENCE_STATUS_FAILED = 'GET_INFERENCE_STATUS_FAILED',
|
|
FETCH_META_FAILED = 'FETCH_META_FAILED',
|
|
SHOW_RUN_MODEL_DIALOG = 'SHOW_RUN_MODEL_DIALOG',
|
|
CLOSE_RUN_MODEL_DIALOG = 'CLOSE_RUN_MODEL_DIALOG',
|
|
CANCEL_INFERENCE_SUCCESS = 'CANCEL_INFERENCE_SUCCESS',
|
|
CANCEL_INFERENCE_FAILED = 'CANCEL_INFERENCE_FAILED',
|
|
}
|
|
|
|
export const modelsActions = {
|
|
getModels: () => createAction(ModelsActionTypes.GET_MODELS),
|
|
getModelsSuccess: (models: Model[]) => createAction(
|
|
ModelsActionTypes.GET_MODELS_SUCCESS, {
|
|
models,
|
|
},
|
|
),
|
|
getModelsFailed: (error: any) => createAction(
|
|
ModelsActionTypes.GET_MODELS_FAILED, {
|
|
error,
|
|
},
|
|
),
|
|
deleteModelSuccess: (id: number) => createAction(
|
|
ModelsActionTypes.DELETE_MODEL_SUCCESS, {
|
|
id,
|
|
},
|
|
),
|
|
deleteModelFailed: (id: number, error: any) => createAction(
|
|
ModelsActionTypes.DELETE_MODEL_FAILED, {
|
|
error, id,
|
|
},
|
|
),
|
|
createModel: () => createAction(ModelsActionTypes.CREATE_MODEL),
|
|
createModelSuccess: () => createAction(ModelsActionTypes.CREATE_MODEL_SUCCESS),
|
|
createModelFailed: (error: any) => createAction(
|
|
ModelsActionTypes.CREATE_MODEL_FAILED, {
|
|
error,
|
|
},
|
|
),
|
|
createModelUpdateStatus: (status: string) => createAction(
|
|
ModelsActionTypes.CREATE_MODEL_STATUS_UPDATED, {
|
|
status,
|
|
},
|
|
),
|
|
fetchMetaFailed: (error: any) => createAction(ModelsActionTypes.FETCH_META_FAILED, { error }),
|
|
getInferenceStatusSuccess: (taskID: number, activeInference: ActiveInference) => createAction(
|
|
ModelsActionTypes.GET_INFERENCE_STATUS_SUCCESS, {
|
|
taskID,
|
|
activeInference,
|
|
},
|
|
),
|
|
getInferenceStatusFailed: (taskID: number, error: any) => createAction(
|
|
ModelsActionTypes.GET_INFERENCE_STATUS_FAILED, {
|
|
taskID,
|
|
error,
|
|
},
|
|
),
|
|
startInferenceFailed: (taskID: number, error: any) => createAction(
|
|
ModelsActionTypes.START_INFERENCE_FAILED, {
|
|
taskID,
|
|
error,
|
|
},
|
|
),
|
|
cancelInferenceSuccess: (taskID: number) => createAction(
|
|
ModelsActionTypes.CANCEL_INFERENCE_SUCCESS, {
|
|
taskID,
|
|
},
|
|
),
|
|
cancelInferenceFaild: (taskID: number, error: any) => createAction(
|
|
ModelsActionTypes.CANCEL_INFERENCE_FAILED, {
|
|
taskID,
|
|
error,
|
|
},
|
|
),
|
|
closeRunModelDialog: () => createAction(ModelsActionTypes.CLOSE_RUN_MODEL_DIALOG),
|
|
showRunModelDialog: (taskInstance: any) => createAction(
|
|
ModelsActionTypes.SHOW_RUN_MODEL_DIALOG, {
|
|
taskInstance,
|
|
},
|
|
),
|
|
};
|
|
|
|
export type ModelsActions = ActionUnion<typeof modelsActions>;
|
|
|
|
const core = getCore();
|
|
const baseURL = core.config.backendAPI.slice(0, -7);
|
|
|
|
export function getModelsAsync(): ThunkAction {
|
|
return async (dispatch, getState): Promise<void> => {
|
|
const state: CombinedState = getState();
|
|
const OpenVINO = state.plugins.list.AUTO_ANNOTATION;
|
|
const RCNN = state.plugins.list.TF_ANNOTATION;
|
|
const MaskRCNN = state.plugins.list.TF_SEGMENTATION;
|
|
|
|
dispatch(modelsActions.getModels());
|
|
const models: Model[] = [];
|
|
|
|
try {
|
|
if (OpenVINO) {
|
|
const response = await core.server.request(
|
|
`${baseURL}/auto_annotation/meta/get`, {
|
|
method: 'POST',
|
|
headers: {
|
|
'Content-Type': 'application/json',
|
|
},
|
|
data: JSON.stringify([]),
|
|
},
|
|
);
|
|
|
|
|
|
for (const model of response.models) {
|
|
models.push({
|
|
id: model.id,
|
|
ownerID: model.owner,
|
|
primary: model.primary,
|
|
name: model.name,
|
|
uploadDate: model.uploadDate,
|
|
updateDate: model.updateDate,
|
|
labels: [...model.labels],
|
|
});
|
|
}
|
|
}
|
|
|
|
if (RCNN) {
|
|
models.push({
|
|
id: null,
|
|
ownerID: null,
|
|
primary: true,
|
|
name: PreinstalledModels.RCNN,
|
|
uploadDate: '',
|
|
updateDate: '',
|
|
labels: ['surfboard', 'car', 'skateboard', 'boat', 'clock',
|
|
'cat', 'cow', 'knife', 'apple', 'cup', 'tv',
|
|
'baseball_bat', 'book', 'suitcase', 'tennis_racket',
|
|
'stop_sign', 'couch', 'cell_phone', 'keyboard',
|
|
'cake', 'tie', 'frisbee', 'truck', 'fire_hydrant',
|
|
'snowboard', 'bed', 'vase', 'teddy_bear',
|
|
'toaster', 'wine_glass', 'traffic_light',
|
|
'broccoli', 'backpack', 'carrot', 'potted_plant',
|
|
'donut', 'umbrella', 'parking_meter', 'bottle',
|
|
'sandwich', 'motorcycle', 'bear', 'banana',
|
|
'person', 'scissors', 'elephant', 'dining_table',
|
|
'toothbrush', 'toilet', 'skis', 'bowl', 'sheep',
|
|
'refrigerator', 'oven', 'microwave', 'train',
|
|
'orange', 'mouse', 'laptop', 'bench', 'bicycle',
|
|
'fork', 'kite', 'zebra', 'baseball_glove', 'bus',
|
|
'spoon', 'horse', 'handbag', 'pizza', 'sports_ball',
|
|
'airplane', 'hair_drier', 'hot_dog', 'remote',
|
|
'sink', 'dog', 'bird', 'giraffe', 'chair',
|
|
],
|
|
});
|
|
}
|
|
|
|
if (MaskRCNN) {
|
|
models.push({
|
|
id: null,
|
|
ownerID: null,
|
|
primary: true,
|
|
name: PreinstalledModels.MaskRCNN,
|
|
uploadDate: '',
|
|
updateDate: '',
|
|
labels: ['BG', 'person', 'bicycle', 'car', 'motorcycle', 'airplane',
|
|
'bus', 'train', 'truck', 'boat', 'traffic light',
|
|
'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird',
|
|
'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear',
|
|
'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
|
|
'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
|
|
'kite', 'baseball bat', 'baseball glove', 'skateboard',
|
|
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
|
|
'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
|
|
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
|
|
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
|
|
'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
|
|
'keyboard', 'cell phone', 'microwave', 'oven', 'toaster',
|
|
'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors',
|
|
'teddy bear', 'hair drier', 'toothbrush',
|
|
],
|
|
});
|
|
}
|
|
} catch (error) {
|
|
dispatch(modelsActions.getModelsFailed(error));
|
|
return;
|
|
}
|
|
|
|
dispatch(modelsActions.getModelsSuccess(models));
|
|
};
|
|
}
|
|
|
|
export function deleteModelAsync(id: number): ThunkAction {
|
|
return async (dispatch): Promise<void> => {
|
|
try {
|
|
await core.server.request(`${baseURL}/auto_annotation/delete/${id}`, {
|
|
method: 'DELETE',
|
|
});
|
|
} catch (error) {
|
|
dispatch(modelsActions.deleteModelFailed(id, error));
|
|
return;
|
|
}
|
|
|
|
dispatch(modelsActions.deleteModelSuccess(id));
|
|
};
|
|
}
|
|
|
|
export function createModelAsync(name: string, files: ModelFiles, global: boolean): ThunkAction {
|
|
return async (dispatch): Promise<void> => {
|
|
async function checkCallback(id: string): Promise<void> {
|
|
try {
|
|
const data = await core.server.request(
|
|
`${baseURL}/auto_annotation/check/${id}`, {
|
|
method: 'GET',
|
|
},
|
|
);
|
|
|
|
switch (data.status) {
|
|
case 'failed':
|
|
dispatch(modelsActions.createModelFailed(
|
|
`Checking request has returned the "${data.status}" status. Message: ${data.error}`,
|
|
));
|
|
break;
|
|
case 'unknown':
|
|
dispatch(modelsActions.createModelFailed(
|
|
`Checking request has returned the "${data.status}" status.`,
|
|
));
|
|
break;
|
|
case 'finished':
|
|
dispatch(modelsActions.createModelSuccess());
|
|
break;
|
|
default:
|
|
if ('progress' in data) {
|
|
modelsActions.createModelUpdateStatus(data.progress);
|
|
}
|
|
setTimeout(checkCallback.bind(null, id), 1000);
|
|
}
|
|
} catch (error) {
|
|
dispatch(modelsActions.createModelFailed(error));
|
|
}
|
|
}
|
|
|
|
dispatch(modelsActions.createModel());
|
|
const data = new FormData();
|
|
data.append('name', name);
|
|
data.append('storage', typeof files.bin === 'string' ? 'shared' : 'local');
|
|
data.append('shared', global.toString());
|
|
Object.keys(files).reduce((acc, key: string): FormData => {
|
|
acc.append(key, files[key]);
|
|
return acc;
|
|
}, data);
|
|
|
|
try {
|
|
dispatch(modelsActions.createModelUpdateStatus('Request is beign sent..'));
|
|
const response = await core.server.request(
|
|
`${baseURL}/auto_annotation/create`, {
|
|
method: 'POST',
|
|
data,
|
|
contentType: false,
|
|
processData: false,
|
|
},
|
|
);
|
|
|
|
dispatch(modelsActions.createModelUpdateStatus('Request is being processed..'));
|
|
setTimeout(checkCallback.bind(null, response.id), 1000);
|
|
} catch (error) {
|
|
dispatch(modelsActions.createModelFailed(error));
|
|
}
|
|
};
|
|
}
|
|
|
|
interface InferenceMeta {
|
|
active: boolean;
|
|
taskID: number;
|
|
requestID: string;
|
|
modelType: ModelType;
|
|
}
|
|
|
|
const timers: any = {};
|
|
|
|
async function timeoutCallback(
|
|
url: string,
|
|
taskID: number,
|
|
modelType: ModelType,
|
|
dispatch: (action: ModelsActions) => void,
|
|
): Promise<void> {
|
|
try {
|
|
delete timers[taskID];
|
|
|
|
const response = await core.server.request(url, {
|
|
method: 'GET',
|
|
});
|
|
|
|
const activeInference: ActiveInference = {
|
|
status: response.status,
|
|
progress: +response.progress || 0,
|
|
error: response.error || response.stderr || '',
|
|
modelType,
|
|
};
|
|
|
|
|
|
if (activeInference.status === 'unknown') {
|
|
dispatch(modelsActions.getInferenceStatusFailed(
|
|
taskID,
|
|
new Error(
|
|
`Inference status for the task ${taskID} is unknown.`,
|
|
),
|
|
));
|
|
|
|
return;
|
|
}
|
|
|
|
if (activeInference.status === 'failed') {
|
|
dispatch(modelsActions.getInferenceStatusFailed(
|
|
taskID,
|
|
new Error(
|
|
`Inference status for the task ${taskID} is failed. ${activeInference.error}`,
|
|
),
|
|
));
|
|
|
|
return;
|
|
}
|
|
|
|
if (activeInference.status !== 'finished') {
|
|
timers[taskID] = setTimeout(
|
|
timeoutCallback.bind(
|
|
null,
|
|
url,
|
|
taskID,
|
|
modelType,
|
|
dispatch,
|
|
), 3000,
|
|
);
|
|
}
|
|
|
|
dispatch(modelsActions.getInferenceStatusSuccess(taskID, activeInference));
|
|
} catch (error) {
|
|
dispatch(modelsActions.getInferenceStatusFailed(taskID, new Error(
|
|
`Server request for the task ${taskID} was failed`,
|
|
)));
|
|
}
|
|
}
|
|
|
|
function subscribe(
|
|
inferenceMeta: InferenceMeta,
|
|
dispatch: (action: ModelsActions) => void,
|
|
): void {
|
|
if (!(inferenceMeta.taskID in timers)) {
|
|
let requestURL = `${baseURL}`;
|
|
if (inferenceMeta.modelType === ModelType.OPENVINO) {
|
|
requestURL = `${requestURL}/auto_annotation/check`;
|
|
} else if (inferenceMeta.modelType === ModelType.RCNN) {
|
|
requestURL = `${requestURL}/tensorflow/annotation/check/task`;
|
|
} else if (inferenceMeta.modelType === ModelType.MASK_RCNN) {
|
|
requestURL = `${requestURL}/tensorflow/segmentation/check/task`;
|
|
}
|
|
requestURL = `${requestURL}/${inferenceMeta.requestID}`;
|
|
timers[inferenceMeta.taskID] = setTimeout(
|
|
timeoutCallback.bind(
|
|
null,
|
|
requestURL,
|
|
inferenceMeta.taskID,
|
|
inferenceMeta.modelType,
|
|
dispatch,
|
|
),
|
|
);
|
|
}
|
|
}
|
|
|
|
export function getInferenceStatusAsync(tasks: number[]): ThunkAction {
|
|
return async (dispatch, getState): Promise<void> => {
|
|
function parse(response: any, modelType: ModelType): InferenceMeta[] {
|
|
return Object.keys(response).map((key: string): InferenceMeta => ({
|
|
taskID: +key,
|
|
requestID: response[key].rq_id || key,
|
|
active: typeof (response[key].active) === 'undefined' ? ['queued', 'started']
|
|
.includes(response[key].status.toLowerCase()) : response[key].active,
|
|
modelType,
|
|
}));
|
|
}
|
|
|
|
const state: CombinedState = getState();
|
|
const OpenVINO = state.plugins.list.AUTO_ANNOTATION;
|
|
const RCNN = state.plugins.list.TF_ANNOTATION;
|
|
const MaskRCNN = state.plugins.list.TF_SEGMENTATION;
|
|
|
|
const dispatchCallback = (action: ModelsActions): void => {
|
|
dispatch(action);
|
|
};
|
|
|
|
try {
|
|
if (OpenVINO) {
|
|
const response = await core.server.request(
|
|
`${baseURL}/auto_annotation/meta/get`, {
|
|
method: 'POST',
|
|
data: JSON.stringify(tasks),
|
|
headers: {
|
|
'Content-Type': 'application/json',
|
|
},
|
|
},
|
|
);
|
|
|
|
parse(response.run, ModelType.OPENVINO)
|
|
.filter((inferenceMeta: InferenceMeta): boolean => inferenceMeta.active)
|
|
.forEach((inferenceMeta: InferenceMeta): void => {
|
|
subscribe(inferenceMeta, dispatchCallback);
|
|
});
|
|
}
|
|
|
|
if (RCNN) {
|
|
const response = await core.server.request(
|
|
`${baseURL}/tensorflow/annotation/meta/get`, {
|
|
method: 'POST',
|
|
data: JSON.stringify(tasks),
|
|
headers: {
|
|
'Content-Type': 'application/json',
|
|
},
|
|
},
|
|
);
|
|
|
|
parse(response, ModelType.RCNN)
|
|
.filter((inferenceMeta: InferenceMeta): boolean => inferenceMeta.active)
|
|
.forEach((inferenceMeta: InferenceMeta): void => {
|
|
subscribe(inferenceMeta, dispatchCallback);
|
|
});
|
|
}
|
|
|
|
if (MaskRCNN) {
|
|
const response = await core.server.request(
|
|
`${baseURL}/tensorflow/segmentation/meta/get`, {
|
|
method: 'POST',
|
|
data: JSON.stringify(tasks),
|
|
headers: {
|
|
'Content-Type': 'application/json',
|
|
},
|
|
},
|
|
);
|
|
|
|
parse(response, ModelType.MASK_RCNN)
|
|
.filter((inferenceMeta: InferenceMeta): boolean => inferenceMeta.active)
|
|
.forEach((inferenceMeta: InferenceMeta): void => {
|
|
subscribe(inferenceMeta, dispatchCallback);
|
|
});
|
|
}
|
|
} catch (error) {
|
|
dispatch(modelsActions.fetchMetaFailed(error));
|
|
}
|
|
};
|
|
}
|
|
|
|
export function startInferenceAsync(
|
|
taskInstance: any,
|
|
model: Model,
|
|
mapping: {
|
|
[index: string]: string;
|
|
},
|
|
cleanOut: boolean,
|
|
): ThunkAction {
|
|
return async (dispatch): Promise<void> => {
|
|
try {
|
|
if (model.name === PreinstalledModels.RCNN) {
|
|
await core.server.request(
|
|
`${baseURL}/tensorflow/annotation/create/task/${taskInstance.id}`,
|
|
);
|
|
} else if (model.name === PreinstalledModels.MaskRCNN) {
|
|
await core.server.request(
|
|
`${baseURL}/tensorflow/segmentation/create/task/${taskInstance.id}`,
|
|
);
|
|
} else {
|
|
await core.server.request(
|
|
`${baseURL}/auto_annotation/start/${model.id}/${taskInstance.id}`, {
|
|
method: 'POST',
|
|
data: JSON.stringify({
|
|
reset: cleanOut,
|
|
labels: mapping,
|
|
}),
|
|
headers: {
|
|
'Content-Type': 'application/json',
|
|
},
|
|
},
|
|
);
|
|
}
|
|
|
|
dispatch(getInferenceStatusAsync([taskInstance.id]));
|
|
} catch (error) {
|
|
dispatch(modelsActions.startInferenceFailed(taskInstance.id, error));
|
|
}
|
|
};
|
|
}
|
|
|
|
export function cancelInferenceAsync(taskID: number): ThunkAction {
|
|
return async (dispatch, getState): Promise<void> => {
|
|
try {
|
|
const inference = getState().models.inferences[taskID];
|
|
if (inference) {
|
|
if (inference.modelType === ModelType.OPENVINO) {
|
|
await core.server.request(
|
|
`${baseURL}/auto_annotation/cancel/${taskID}`,
|
|
);
|
|
} else if (inference.modelType === ModelType.RCNN) {
|
|
await core.server.request(
|
|
`${baseURL}/tensorflow/annotation/cancel/task/${taskID}`,
|
|
);
|
|
} else if (inference.modelType === ModelType.MASK_RCNN) {
|
|
await core.server.request(
|
|
`${baseURL}/tensorflow/segmentation/cancel/task/${taskID}`,
|
|
);
|
|
}
|
|
|
|
if (timers[taskID]) {
|
|
clearTimeout(timers[taskID]);
|
|
delete timers[taskID];
|
|
}
|
|
}
|
|
|
|
dispatch(modelsActions.cancelInferenceSuccess(taskID));
|
|
} catch (error) {
|
|
dispatch(modelsActions.cancelInferenceFaild(taskID, error));
|
|
}
|
|
};
|
|
}
|