New models UI (#5635)

main
Kirill Lakhov 3 years ago committed by GitHub
parent 91b36ce393
commit 3775bc2557
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -37,6 +37,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Windows Installation Instructions adjusted to work around <https://github.com/nuclio/nuclio/issues/1821> - Windows Installation Instructions adjusted to work around <https://github.com/nuclio/nuclio/issues/1821>
- The contour detection function for semantic segmentation (<https://github.com/opencv/cvat/pull/4665>) - The contour detection function for semantic segmentation (<https://github.com/opencv/cvat/pull/4665>)
- Delete newline character when generating a webhook signature (<https://github.com/opencv/cvat/pull/5622>) - Delete newline character when generating a webhook signature (<https://github.com/opencv/cvat/pull/5622>)
- DL models UI (<https://github.com/opencv/cvat/pull/5635>)
### Deprecated ### Deprecated
- TDB - TDB

@ -1,6 +1,6 @@
{ {
"name": "cvat-core", "name": "cvat-core",
"version": "8.0.0", "version": "8.1.0",
"description": "Part of Computer Vision Tool which presents an interface for client-side integration", "description": "Part of Computer Vision Tool which presents an interface for client-side integration",
"main": "src/api.ts", "main": "src/api.ts",
"scripts": { "scripts": {

@ -1,5 +1,5 @@
// Copyright (C) 2019-2022 Intel Corporation // Copyright (C) 2019-2022 Intel Corporation
// Copyright (C) 2022 CVAT.ai Corporation // Copyright (C) 2022-2023 CVAT.ai Corporation
// //
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
@ -37,6 +37,7 @@ export default function implementAPI(cvat) {
cvat.lambda.cancel.implementation = lambdaManager.cancel.bind(lambdaManager); cvat.lambda.cancel.implementation = lambdaManager.cancel.bind(lambdaManager);
cvat.lambda.listen.implementation = lambdaManager.listen.bind(lambdaManager); cvat.lambda.listen.implementation = lambdaManager.listen.bind(lambdaManager);
cvat.lambda.requests.implementation = lambdaManager.requests.bind(lambdaManager); cvat.lambda.requests.implementation = lambdaManager.requests.bind(lambdaManager);
cvat.lambda.providers.implementation = lambdaManager.providers.bind(lambdaManager);
cvat.server.about.implementation = async () => { cvat.server.about.implementation = async () => {
const result = await serverProxy.server.about(); const result = await serverProxy.server.about();

@ -190,18 +190,22 @@ function build() {
const result = await PluginRegistry.apiWrapper(cvat.lambda.call, task, model, args); const result = await PluginRegistry.apiWrapper(cvat.lambda.call, task, model, args);
return result; return result;
}, },
async cancel(requestID) { async cancel(requestID, functionID) {
const result = await PluginRegistry.apiWrapper(cvat.lambda.cancel, requestID); const result = await PluginRegistry.apiWrapper(cvat.lambda.cancel, requestID, functionID);
return result; return result;
}, },
async listen(requestID, onChange) { async listen(requestID, functionID, onChange) {
const result = await PluginRegistry.apiWrapper(cvat.lambda.listen, requestID, onChange); const result = await PluginRegistry.apiWrapper(cvat.lambda.listen, requestID, functionID, onChange);
return result; return result;
}, },
async requests() { async requests() {
const result = await PluginRegistry.apiWrapper(cvat.lambda.requests); const result = await PluginRegistry.apiWrapper(cvat.lambda.requests);
return result; return result;
}, },
async providers() {
const result = await PluginRegistry.apiWrapper(cvat.lambda.providers);
return result;
},
}, },
logger: loggerStorage, logger: loggerStorage,
config: { config: {

@ -1,5 +1,5 @@
// Copyright (C) 2019-2022 Intel Corporation // Copyright (C) 2019-2022 Intel Corporation
// Copyright (C) 2022 CVAT.ai Corporation // Copyright (C) 2022-2023 CVAT.ai Corporation
// //
// SPDX-License-Identifier = MIT // SPDX-License-Identifier = MIT
@ -132,10 +132,23 @@ export enum HistoryActions {
RESTORED_FRAME = 'Restored frame', RESTORED_FRAME = 'Restored frame',
} }
export enum ModelType { export enum ModelKind {
DETECTOR = 'detector', DETECTOR = 'detector',
INTERACTOR = 'interactor', INTERACTOR = 'interactor',
TRACKER = 'tracker', TRACKER = 'tracker',
CLASSIFIER = 'classifier',
REID = 'reid',
}
export enum ModelProviders {
CVAT = 'cvat',
}
export enum ModelReturnType {
RECTANGLE = 'rectangle',
TAG = 'tag',
POLYGON = 'polygon',
MASK = 'mask',
} }
export const colors = [ export const colors = [

@ -1,12 +1,25 @@
// Copyright (C) 2019-2022 Intel Corporation // Copyright (C) 2019-2022 Intel Corporation
// Copyright (C) 2022 CVAT.ai Corporation // Copyright (C) 2022-2023 CVAT.ai Corporation
// //
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
import serverProxy from './server-proxy'; import serverProxy from './server-proxy';
import { ArgumentError } from './exceptions'; import { ArgumentError } from './exceptions';
import MLModel from './ml-model'; import MLModel from './ml-model';
import { RQStatus } from './enums'; import { ModelProviders, RQStatus } from './enums';
export interface ModelProvider {
name: string;
icon: string;
attributes: Record<string, string>;
}
interface ModelProxy {
run: (body: any) => Promise<any>;
call: (modelID: string | number, body: any) => Promise<any>;
status: (requestID: string) => Promise<any>;
cancel: (requestID: string) => Promise<any>;
}
class LambdaManager { class LambdaManager {
private listening: any; private listening: any;
@ -18,18 +31,16 @@ class LambdaManager {
} }
async list(): Promise<MLModel[]> { async list(): Promise<MLModel[]> {
if (Array.isArray(this.cachedList)) { const lambdaFunctions = await serverProxy.lambda.list();
return [...this.cachedList]; const functions = await serverProxy.functions.list();
}
const result = await serverProxy.lambda.list(); const result = [...lambdaFunctions, ...functions];
const models = []; const models = [];
for (const model of result) { for (const model of result) {
models.push( models.push(
new MLModel({ new MLModel({
...model, ...model,
type: model.kind,
}), }),
); );
} }
@ -59,7 +70,7 @@ class LambdaManager {
function: model.id, function: model.id,
}; };
const result = await serverProxy.lambda.run(body); const result = await LambdaManager.getModelProxy(model).run(body);
return result.id; return result.id;
} }
@ -73,32 +84,43 @@ class LambdaManager {
task: taskID, task: taskID,
}; };
const result = await serverProxy.lambda.call(model.id, body); const result = await LambdaManager.getModelProxy(model).call(model.id, body);
return result; return result;
} }
async requests() { async requests() {
const result = await serverProxy.lambda.requests(); const lambdaRequests = await serverProxy.lambda.requests();
const functionsRequests = await serverProxy.functions.requests();
const result = [...lambdaRequests, ...functionsRequests];
return result.filter((request) => ['queued', 'started'].includes(request.status)); return result.filter((request) => ['queued', 'started'].includes(request.status));
} }
async cancel(requestID): Promise<void> { async cancel(requestID, functionID): Promise<void> {
if (typeof requestID !== 'string') { if (typeof requestID !== 'string') {
throw new ArgumentError(`Request id argument is required to be a string. But got ${requestID}`); throw new ArgumentError(`Request id argument is required to be a string. But got ${requestID}`);
} }
const model = this.cachedList.find((_model) => _model.id === functionID);
if (!model) {
throw new ArgumentError('Incorrect Function Id provided');
}
if (this.listening[requestID]) { if (this.listening[requestID]) {
clearTimeout(this.listening[requestID].timeout); clearTimeout(this.listening[requestID].timeout);
delete this.listening[requestID]; delete this.listening[requestID];
} }
await serverProxy.lambda.cancel(requestID);
await LambdaManager.getModelProxy(model).cancel(requestID);
} }
async listen(requestID, onUpdate): Promise<void> { async listen(requestID, functionID, onUpdate): Promise<void> {
const model = this.cachedList.find((_model) => _model.id === functionID);
if (!model) {
throw new ArgumentError('Incorrect Function Id provided');
}
const timeoutCallback = async (): Promise<void> => { const timeoutCallback = async (): Promise<void> => {
try { try {
this.listening[requestID].timeout = null; this.listening[requestID].timeout = null;
const response = await serverProxy.lambda.status(requestID); const response = await LambdaManager.getModelProxy(model).status(requestID);
if (response.status === RQStatus.QUEUED || response.status === RQStatus.STARTED) { if (response.status === RQStatus.QUEUED || response.status === RQStatus.STARTED) {
onUpdate(response.status, response.progress || 0); onUpdate(response.status, response.progress || 0);
@ -123,9 +145,28 @@ class LambdaManager {
this.listening[requestID] = { this.listening[requestID] = {
onUpdate, onUpdate,
functionID,
timeout: setTimeout(timeoutCallback, 2000), timeout: setTimeout(timeoutCallback, 2000),
}; };
} }
async providers(): Promise<ModelProvider[]> {
const providersData: Record<string, Record<string, string>> = await serverProxy.functions.providers();
const providers = Object.entries(providersData).map(([provider, attributes]) => {
const { icon } = attributes;
delete attributes.icon;
return {
name: provider,
icon,
attributes,
};
});
return providers;
}
private static getModelProxy(model: MLModel): ModelProxy {
return model.provider === ModelProviders.CVAT ? serverProxy.lambda : serverProxy.functions;
}
} }
export default new LambdaManager(); export default new LambdaManager();

@ -1,9 +1,12 @@
// Copyright (C) 2019-2022 Intel Corporation // Copyright (C) 2019-2022 Intel Corporation
// Copyright (C) 2022 CVAT.ai Corporation // Copyright (C) 2022-2023 CVAT.ai Corporation
// //
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
import { ModelType } from './enums'; import { isBrowser, isNode } from 'browser-or-node';
import serverProxy from './server-proxy';
import PluginRegistry from './plugins';
import { ModelProviders, ModelKind, ModelReturnType } from './enums';
interface ModelAttribute { interface ModelAttribute {
name: string; name: string;
@ -26,19 +29,27 @@ interface ModelTip {
} }
interface SerializedModel { interface SerializedModel {
id: string; id?: string | number;
name: string; name?: string;
labels: string[]; labels?: string[];
version: number; version?: number;
attributes: Record<string, ModelAttribute>; attributes?: Record<string, ModelAttribute>;
framework: string; framework?: string;
description: string; description?: string;
type: ModelType; kind?: ModelKind;
type?: string;
return_type?: ModelReturnType;
owner?: any;
provider?: string;
api_key?: string;
url?: string;
help_message?: string; help_message?: string;
animated_gif?: string; animated_gif?: string;
min_pos_points?: number; min_pos_points?: number;
min_neg_points?: number; min_neg_points?: number;
startswith_box?: boolean; startswith_box?: boolean;
created_date?: string;
updated_date?: string;
} }
export default class MLModel { export default class MLModel {
@ -49,7 +60,7 @@ export default class MLModel {
this.serialized = { ...serialized }; this.serialized = { ...serialized };
} }
public get id(): string { public get id(): string | number {
return this.serialized.id; return this.serialized.id;
} }
@ -77,8 +88,8 @@ export default class MLModel {
return this.serialized.description; return this.serialized.description;
} }
public get type(): ModelType { public get kind(): ModelKind {
return this.serialized.type; return this.serialized.kind;
} }
public get params(): ModelParams { public get params(): ModelParams {
@ -104,8 +115,110 @@ export default class MLModel {
}; };
} }
public get owner(): string {
return this.serialized?.owner?.username || '';
}
public get provider(): string {
return this.serialized?.provider || ModelProviders.CVAT;
}
public get isDeletable(): boolean {
return this.provider !== ModelProviders.CVAT;
}
public get createdDate(): string | undefined {
return this.serialized?.created_date;
}
public get updatedDate(): string | undefined {
return this.serialized?.updated_date;
}
public get url(): string | undefined {
return this.serialized?.url;
}
public get returnType(): ModelReturnType | undefined {
return this.serialized?.return_type;
}
// Used to set a callback when the tool is blocked in UI // Used to set a callback when the tool is blocked in UI
public set onChangeToolsBlockerState(onChangeToolsBlockerState: (event: string) => void) { public set onChangeToolsBlockerState(onChangeToolsBlockerState: (event: string) => void) {
this.changeToolsBlockerStateCallback = onChangeToolsBlockerState; this.changeToolsBlockerStateCallback = onChangeToolsBlockerState;
} }
public async save(): Promise<MLModel> {
const result = await PluginRegistry.apiWrapper.call(this, MLModel.prototype.save);
return result;
}
public async delete(): Promise<MLModel> {
const result = await PluginRegistry.apiWrapper.call(this, MLModel.prototype.delete);
return result;
}
public async getPreview(): Promise<string> {
const result = await PluginRegistry.apiWrapper.call(this, MLModel.prototype.getPreview);
return result;
}
} }
Object.defineProperties(MLModel.prototype.save, {
implementation: {
writable: false,
enumerable: false,
value: async function implementation(): Promise<MLModel> {
const modelData = {
provider: this.provider,
url: this.serialized.url,
api_key: this.serialized.api_key,
};
const model = await serverProxy.functions.create(modelData);
return new MLModel(model);
},
},
});
Object.defineProperties(MLModel.prototype.delete, {
implementation: {
writable: false,
enumerable: false,
value: async function implementation(): Promise<void> {
if (this.isDeletable) {
await serverProxy.functions.delete(this.id);
}
},
},
});
Object.defineProperties(MLModel.prototype.getPreview, {
implementation: {
writable: false,
enumerable: false,
value: async function implementation(): Promise<string | ArrayBuffer> {
if (this.provider === ModelProviders.CVAT) {
return '';
}
return new Promise((resolve, reject) => {
serverProxy.functions
.getPreview(this.id)
.then((result) => {
if (isNode) {
resolve(global.Buffer.from(result, 'binary').toString('base64'));
} else if (isBrowser) {
const reader = new FileReader();
reader.onload = () => {
resolve(reader.result);
};
reader.readAsDataURL(result);
}
})
.catch((error) => {
reject(error);
});
});
},
},
});

@ -1550,10 +1550,74 @@ async function getAnnotations(session, id) {
} catch (errorData) { } catch (errorData) {
throw generateError(errorData); throw generateError(errorData);
} }
return response.data;
}
async function getFunctions() {
const { backendAPI } = config;
try {
const response = await Axios.get(`${backendAPI}/functions`, {
proxy: config.proxy,
});
return response.data.results;
} catch (errorData) {
if (errorData.response.status === 404) {
return [];
}
throw generateError(errorData);
}
}
async function getFunctionPreview(modelID) {
const { backendAPI } = config;
let response = null;
try {
const url = `${backendAPI}/functions/${modelID}/preview`;
response = await Axios.get(url, {
proxy: config.proxy,
responseType: 'blob',
});
} catch (errorData) {
const code = errorData.response ? errorData.response.status : errorData.code;
throw new ServerError(`Could not get preview for the model ${modelID} from the server`, code);
}
return response.data; return response.data;
} }
async function getFunctionProviders() {
const { backendAPI } = config;
try {
const response = await Axios.get(`${backendAPI}/functions/info`, {
proxy: config.proxy,
});
return response.data;
} catch (errorData) {
if (errorData.response.status === 404) {
return [];
}
throw generateError(errorData);
}
}
async function deleteFunction(functionId: number) {
const { backendAPI } = config;
try {
await Axios.delete(`${backendAPI}/functions/${functionId}`, {
proxy: config.proxy,
headers: {
'Content-Type': 'application/json',
},
});
} catch (errorData) {
throw generateError(errorData);
}
}
// Session is 'task' or 'job' // Session is 'task' or 'job'
async function updateAnnotations(session, id, data, action) { async function updateAnnotations(session, id, data, action) {
const { backendAPI } = config; const { backendAPI } = config;
@ -1580,10 +1644,26 @@ async function updateAnnotations(session, id, data, action) {
} catch (errorData) { } catch (errorData) {
throw generateError(errorData); throw generateError(errorData);
} }
return response.data; return response.data;
} }
async function runFunctionRequest(body) {
const { backendAPI } = config;
try {
const response = await Axios.post(`${backendAPI}/functions/requests/`, JSON.stringify(body), {
proxy: config.proxy,
headers: {
'Content-Type': 'application/json',
},
});
return response.data;
} catch (errorData) {
throw generateError(errorData);
}
}
// Session is 'task' or 'job' // Session is 'task' or 'job'
async function uploadAnnotations( async function uploadAnnotations(
session, session,
@ -1604,7 +1684,6 @@ async function uploadAnnotations(
}; };
const url = `${backendAPI}/${session}s/${id}/annotations`; const url = `${backendAPI}/${session}s/${id}/annotations`;
async function wait() { async function wait() {
return new Promise<void>((resolve, reject) => { return new Promise<void>((resolve, reject) => {
async function requestStatus() { async function requestStatus() {
@ -1666,7 +1745,6 @@ async function uploadAnnotations(
throw generateError(errorData); throw generateError(errorData);
} }
} }
try { try {
return await wait(); return await wait();
} catch (errorData) { } catch (errorData) {
@ -1674,6 +1752,19 @@ async function uploadAnnotations(
} }
} }
async function getFunctionRequestStatus(requestID) {
const { backendAPI } = config;
try {
const response = await Axios.get(`${backendAPI}/functions/requests/${requestID}`, {
proxy: config.proxy,
});
return response.data;
} catch (errorData) {
throw generateError(errorData);
}
}
// Session is 'task' or 'job' // Session is 'task' or 'job'
async function dumpAnnotations(id, name, format) { async function dumpAnnotations(id, name, format) {
const { backendAPI } = config; const { backendAPI } = config;
@ -1703,11 +1794,40 @@ async function dumpAnnotations(id, name, format) {
reject(generateError(errorData)); reject(generateError(errorData));
}); });
} }
setTimeout(request); setTimeout(request);
}); });
} }
async function cancelFunctionRequest(requestId) {
const { backendAPI } = config;
try {
await Axios.delete(`${backendAPI}/functions/requests/${requestId}`, {
method: 'DELETE',
});
} catch (errorData) {
throw generateError(errorData);
}
}
async function createFunction(functionData: any) {
const params = enableOrganization();
const { backendAPI } = config;
try {
const response = await Axios.post(`${backendAPI}/functions`, JSON.stringify(functionData), {
proxy: config.proxy,
params,
headers: {
'Content-Type': 'application/json',
},
});
return response.data;
} catch (errorData) {
throw generateError(errorData);
}
}
async function saveLogs(logs) { async function saveLogs(logs) {
const { backendAPI } = config; const { backendAPI } = config;
@ -1723,6 +1843,40 @@ async function saveLogs(logs) {
} }
} }
async function callFunction(funId, body) {
const { backendAPI } = config;
try {
const response = await Axios.post(`${backendAPI}/functions/${funId}/run`, JSON.stringify(body), {
proxy: config.proxy,
headers: {
'Content-Type': 'application/json',
},
});
return response.data;
} catch (errorData) {
throw generateError(errorData);
}
}
async function getFunctionsRequests() {
const { backendAPI } = config;
try {
const response = await Axios.get(`${backendAPI}/functions/requests/`, {
proxy: config.proxy,
});
return response.data;
} catch (errorData) {
if (errorData.response.status === 404) {
return [];
}
throw generateError(errorData);
}
}
async function getLambdaFunctions() { async function getLambdaFunctions() {
const { backendAPI } = config; const { backendAPI } = config;
@ -1732,6 +1886,9 @@ async function getLambdaFunctions() {
}); });
return response.data; return response.data;
} catch (errorData) { } catch (errorData) {
if (errorData.response.status === 503) {
return [];
}
throw generateError(errorData); throw generateError(errorData);
} }
} }
@ -2427,6 +2584,19 @@ export default Object.freeze({
cancel: cancelLambdaRequest, cancel: cancelLambdaRequest,
}), }),
functions: Object.freeze({
list: getFunctions,
status: getFunctionRequestStatus,
requests: getFunctionsRequests,
run: runFunctionRequest,
call: callFunction,
create: createFunction,
providers: getFunctionProviders,
delete: deleteFunction,
cancel: cancelFunctionRequest,
getPreview: getFunctionPreview,
}),
issues: Object.freeze({ issues: Object.freeze({
create: createIssue, create: createIssue,
update: updateIssue, update: updateIssue,

@ -1,6 +1,6 @@
{ {
"name": "cvat-ui", "name": "cvat-ui",
"version": "1.47.1", "version": "1.48.0",
"description": "CVAT single-page application", "description": "CVAT single-page application",
"main": "src/index.tsx", "main": "src/index.tsx",
"scripts": { "scripts": {

@ -12,7 +12,7 @@ import { CanvasMode as Canvas3DMode } from 'cvat-canvas3d-wrapper';
import { import {
RectDrawingMethod, CuboidDrawingMethod, Canvas, CanvasMode as Canvas2DMode, RectDrawingMethod, CuboidDrawingMethod, Canvas, CanvasMode as Canvas2DMode,
} from 'cvat-canvas-wrapper'; } from 'cvat-canvas-wrapper';
import { getCore } from 'cvat-core-wrapper'; import { getCore, MLModel } from 'cvat-core-wrapper';
import logger, { LogType } from 'cvat-logger'; import logger, { LogType } from 'cvat-logger';
import { getCVATStore } from 'cvat-store'; import { getCVATStore } from 'cvat-store';
@ -22,7 +22,6 @@ import {
ContextMenuType, ContextMenuType,
DimensionType, DimensionType,
FrameSpeed, FrameSpeed,
Model,
ObjectType, ObjectType,
OpenCVTool, OpenCVTool,
Rotation, Rotation,
@ -1507,7 +1506,7 @@ export function pasteShapeAsync(): ThunkAction {
}; };
} }
export function interactWithCanvas(activeInteractor: Model | OpenCVTool, activeLabelID: number): AnyAction { export function interactWithCanvas(activeInteractor: MLModel | OpenCVTool, activeLabelID: number): AnyAction {
return { return {
type: AnnotationActionTypes.INTERACT_WITH_CANVAS, type: AnnotationActionTypes.INTERACT_WITH_CANVAS,
payload: { payload: {

@ -1,11 +1,13 @@
// Copyright (C) 2021-2022 Intel Corporation // Copyright (C) 2021-2022 Intel Corporation
// Copyright (C) 2023 CVAT.ai Corporation
// //
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
import { Dispatch, ActionCreator } from 'redux'; import { Dispatch, ActionCreator } from 'redux';
import { ActionUnion, createAction, ThunkAction } from 'utils/redux'; import { ActionUnion, createAction, ThunkAction } from 'utils/redux';
import { getCore } from 'cvat-core-wrapper'; import { getCore } from 'cvat-core-wrapper';
import { CloudStoragesQuery, CloudStorage, Indexable } from 'reducers'; import { CloudStoragesQuery, CloudStorage } from 'reducers';
import { filterNull } from 'utils/filter-null';
const cvat = getCore(); const cvat = getCore();
@ -106,12 +108,7 @@ export function getCloudStoragesAsync(query: Partial<CloudStoragesQuery>): Thunk
dispatch(cloudStoragesActions.getCloudStorages()); dispatch(cloudStoragesActions.getCloudStorages());
dispatch(cloudStoragesActions.updateCloudStoragesGettingQuery(query)); dispatch(cloudStoragesActions.updateCloudStoragesGettingQuery(query));
const filteredQuery = { ...query }; const filteredQuery = filterNull(query);
for (const key in filteredQuery) {
if ((filteredQuery as Indexable)[key] === null) {
delete (filteredQuery as Indexable)[key];
}
}
let result = null; let result = null;
try { try {

@ -1,10 +1,12 @@
// Copyright (C) 2022 Intel Corporation // Copyright (C) 2022 Intel Corporation
// Copyright (C) 2023 CVAT.ai Corporation
// //
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
import { ActionUnion, createAction, ThunkAction } from 'utils/redux'; import { ActionUnion, createAction, ThunkAction } from 'utils/redux';
import { getCore } from 'cvat-core-wrapper'; import { getCore } from 'cvat-core-wrapper';
import { Indexable, JobsQuery, Job } from 'reducers'; import { JobsQuery, Job } from 'reducers';
import { filterNull } from 'utils/filter-null';
const cvat = getCore(); const cvat = getCore();
@ -43,14 +45,9 @@ export type JobsActions = ActionUnion<typeof jobsActions>;
export const getJobsAsync = (query: JobsQuery): ThunkAction => async (dispatch) => { export const getJobsAsync = (query: JobsQuery): ThunkAction => async (dispatch) => {
try { try {
// We remove all keys with null values from the query // We remove all keys with null values from the query
const filteredQuery = { ...query }; const filteredQuery = filterNull(query);
for (const key of Object.keys(query)) {
if ((filteredQuery as Indexable)[key] === null) {
delete (filteredQuery as Indexable)[key];
}
}
dispatch(jobsActions.getJobs(filteredQuery)); dispatch(jobsActions.getJobs(filteredQuery as JobsQuery));
const jobs = await cvat.jobs.get(filteredQuery); const jobs = await cvat.jobs.get(filteredQuery);
dispatch(jobsActions.getJobsSuccess(jobs)); dispatch(jobsActions.getJobsSuccess(jobs));
} catch (error) { } catch (error) {

@ -1,15 +1,27 @@
// Copyright (C) 2020-2022 Intel Corporation // Copyright (C) 2020-2022 Intel Corporation
// Copyright (C) 2022-2023 CVAT.ai Corporation
// //
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
import { ActionUnion, createAction, ThunkAction } from 'utils/redux'; import { ActionUnion, createAction, ThunkAction } from 'utils/redux';
import { Model, ActiveInference, RQStatus } from 'reducers'; import {
import { getCore } from 'cvat-core-wrapper'; ActiveInference, RQStatus, ModelsQuery,
} from 'reducers';
import { getCore, MLModel, ModelProvider } from 'cvat-core-wrapper';
import { filterNull } from 'utils/filter-null';
const cvat = getCore();
export enum ModelsActionTypes { export enum ModelsActionTypes {
GET_MODELS = 'GET_MODELS', GET_MODELS = 'GET_MODELS',
GET_MODELS_SUCCESS = 'GET_MODELS_SUCCESS', GET_MODELS_SUCCESS = 'GET_MODELS_SUCCESS',
GET_MODELS_FAILED = 'GET_MODELS_FAILED', GET_MODELS_FAILED = 'GET_MODELS_FAILED',
CREATE_MODEL = 'CREATE_MODEL',
CREATE_MODEL_SUCCESS = 'CREATE_MODEL_SUCCESS',
CREATE_MODEL_FAILED = 'CREATE_MODEL_FAILED',
DELETE_MODEL = 'DELETE_MODEL',
DELETE_MODEL_SUCCESS = 'DELETE_MODEL_SUCCESS',
DELETE_MODEL_FAILED = 'DELETE_MODEL_FAILED',
START_INFERENCE_FAILED = 'START_INFERENCE_FAILED', START_INFERENCE_FAILED = 'START_INFERENCE_FAILED',
GET_INFERENCE_STATUS_SUCCESS = 'GET_INFERENCE_STATUS_SUCCESS', GET_INFERENCE_STATUS_SUCCESS = 'GET_INFERENCE_STATUS_SUCCESS',
GET_INFERENCE_STATUS_FAILED = 'GET_INFERENCE_STATUS_FAILED', GET_INFERENCE_STATUS_FAILED = 'GET_INFERENCE_STATUS_FAILED',
@ -18,16 +30,32 @@ export enum ModelsActionTypes {
CLOSE_RUN_MODEL_DIALOG = 'CLOSE_RUN_MODEL_DIALOG', CLOSE_RUN_MODEL_DIALOG = 'CLOSE_RUN_MODEL_DIALOG',
CANCEL_INFERENCE_SUCCESS = 'CANCEL_INFERENCE_SUCCESS', CANCEL_INFERENCE_SUCCESS = 'CANCEL_INFERENCE_SUCCESS',
CANCEL_INFERENCE_FAILED = 'CANCEL_INFERENCE_FAILED', CANCEL_INFERENCE_FAILED = 'CANCEL_INFERENCE_FAILED',
GET_MODEL_PROVIDERS = 'GET_MODEL_PROVIDERS',
GET_MODEL_PROVIDERS_SUCCESS = 'GET_MODEL_PROVIDERS_SUCCESS',
GET_MODEL_PROVIDERS_FAILED = 'GET_MODEL_PROVIDERS_FAILED',
GET_MODEL_PREVIEW = 'GET_MODEL_PREVIEW',
GET_MODEL_PREVIEW_SUCCESS = 'GET_MODEL_PREVIEW_SUCCESS',
GET_MODEL_PREVIEW_FAILED = 'GET_MODEL_PREVIEW_FAILED',
} }
export const modelsActions = { export const modelsActions = {
getModels: () => createAction(ModelsActionTypes.GET_MODELS), getModels: (query?: ModelsQuery) => createAction(ModelsActionTypes.GET_MODELS, { query }),
getModelsSuccess: (models: Model[]) => createAction(ModelsActionTypes.GET_MODELS_SUCCESS, { getModelsSuccess: (models: MLModel[]) => createAction(ModelsActionTypes.GET_MODELS_SUCCESS, {
models, models,
}), }),
getModelsFailed: (error: any) => createAction(ModelsActionTypes.GET_MODELS_FAILED, { getModelsFailed: (error: any) => createAction(ModelsActionTypes.GET_MODELS_FAILED, {
error, error,
}), }),
createModel: () => createAction(ModelsActionTypes.CREATE_MODEL),
createModelSuccess: (model: MLModel) => createAction(ModelsActionTypes.CREATE_MODEL_SUCCESS, {
model,
}),
createModelFailed: (error: any) => createAction(ModelsActionTypes.CREATE_MODEL_FAILED, { error }),
deleteModel: (model: MLModel) => createAction(ModelsActionTypes.DELETE_MODEL, { model }),
deleteModelSuccess: (modelID: string | number) => createAction(ModelsActionTypes.DELETE_MODEL_SUCCESS, { modelID }),
deleteModelFailed: (modelName: string, error: any) => (
createAction(ModelsActionTypes.DELETE_MODEL_FAILED, { modelName, error })
),
fetchMetaFailed: (error: any) => createAction(ModelsActionTypes.FETCH_META_FAILED, { error }), fetchMetaFailed: (error: any) => createAction(ModelsActionTypes.FETCH_META_FAILED, { error }),
getInferenceStatusSuccess: (taskID: number, activeInference: ActiveInference) => ( getInferenceStatusSuccess: (taskID: number, activeInference: ActiveInference) => (
createAction(ModelsActionTypes.GET_INFERENCE_STATUS_SUCCESS, { createAction(ModelsActionTypes.GET_INFERENCE_STATUS_SUCCESS, {
@ -64,18 +92,34 @@ export const modelsActions = {
taskInstance, taskInstance,
}) })
), ),
getModelProviders: () => createAction(ModelsActionTypes.GET_MODEL_PROVIDERS),
getModelProvidersSuccess: (providers: ModelProvider[]) => (
createAction(ModelsActionTypes.GET_MODEL_PROVIDERS_SUCCESS, {
providers,
})),
getModelProvidersFailed: (error: any) => createAction(ModelsActionTypes.GET_MODEL_PROVIDERS_FAILED, { error }),
getModelPreview: (modelID: string | number) => (
createAction(ModelsActionTypes.GET_MODEL_PREVIEW, { modelID })
),
getModelPreviewSuccess: (modelID: string | number, preview: string) => (
createAction(ModelsActionTypes.GET_MODEL_PREVIEW_SUCCESS, { modelID, preview })
),
getModelPreviewFailed: (modelID: string | number, error: any) => (
createAction(ModelsActionTypes.GET_MODEL_PREVIEW_FAILED, { modelID, error })
),
}; };
export type ModelsActions = ActionUnion<typeof modelsActions>; export type ModelsActions = ActionUnion<typeof modelsActions>;
const core = getCore(); const core = getCore();
export function getModelsAsync(): ThunkAction { export function getModelsAsync(query: ModelsQuery): ThunkAction {
return async (dispatch): Promise<void> => { return async (dispatch): Promise<void> => {
dispatch(modelsActions.getModels()); dispatch(modelsActions.getModels(query));
const filteredQuery = filterNull(query);
try { try {
const models = await core.lambda.list(); const models = await core.lambda.list(filteredQuery);
dispatch(modelsActions.getModelsSuccess(models)); dispatch(modelsActions.getModelsSuccess(models));
} catch (error) { } catch (error) {
dispatch(modelsActions.getModelsFailed(error)); dispatch(modelsActions.getModelsFailed(error));
@ -83,15 +127,43 @@ export function getModelsAsync(): ThunkAction {
}; };
} }
export function createModelAsync(modelData: Record<string, string>): ThunkAction {
return async function (dispatch) {
const model = new cvat.classes.MLModel(modelData);
dispatch(modelsActions.createModel());
try {
const createdModel = await model.save();
dispatch(modelsActions.createModelSuccess(createdModel));
} catch (error) {
dispatch(modelsActions.createModelFailed(error));
throw error;
}
};
}
export function deleteModelAsync(model: MLModel): ThunkAction {
return async function (dispatch) {
dispatch(modelsActions.deleteModel(model));
try {
await model.delete();
dispatch(modelsActions.deleteModelSuccess(model.id));
} catch (error) {
dispatch(modelsActions.deleteModelFailed(model.name, error));
}
};
}
interface InferenceMeta { interface InferenceMeta {
taskID: number; taskID: number;
requestID: string; requestID: string;
functionID: string | number;
} }
function listen(inferenceMeta: InferenceMeta, dispatch: (action: ModelsActions) => void): void { function listen(inferenceMeta: InferenceMeta, dispatch: (action: ModelsActions) => void): void {
const { taskID, requestID } = inferenceMeta; const { taskID, requestID, functionID } = inferenceMeta;
core.lambda core.lambda
.listen(requestID, (status: RQStatus, progress: number, message: string) => { .listen(requestID, functionID, (status: RQStatus, progress: number, message: string) => {
if (status === RQStatus.failed || status === RQStatus.unknown) { if (status === RQStatus.failed || status === RQStatus.unknown) {
dispatch( dispatch(
modelsActions.getInferenceStatusFailed( modelsActions.getInferenceStatusFailed(
@ -107,6 +179,7 @@ function listen(inferenceMeta: InferenceMeta, dispatch: (action: ModelsActions)
modelsActions.getInferenceStatusSuccess(taskID, { modelsActions.getInferenceStatusSuccess(taskID, {
status, status,
progress, progress,
functionID,
error: message, error: message,
id: requestID, id: requestID,
}), }),
@ -119,6 +192,7 @@ function listen(inferenceMeta: InferenceMeta, dispatch: (action: ModelsActions)
progress: 0, progress: 0,
error: error.toString(), error: error.toString(),
id: requestID, id: requestID,
functionID,
}), }),
); );
}); });
@ -136,6 +210,7 @@ export function getInferenceStatusAsync(): ThunkAction {
.map((request: any): object => ({ .map((request: any): object => ({
taskID: +request.function.task, taskID: +request.function.task,
requestID: request.id, requestID: request.id,
functionID: request.function.id,
})) }))
.forEach((inferenceMeta: InferenceMeta): void => { .forEach((inferenceMeta: InferenceMeta): void => {
listen(inferenceMeta, dispatchCallback); listen(inferenceMeta, dispatchCallback);
@ -146,7 +221,7 @@ export function getInferenceStatusAsync(): ThunkAction {
}; };
} }
export function startInferenceAsync(taskId: number, model: Model, body: object): ThunkAction { export function startInferenceAsync(taskId: number, model: MLModel, body: object): ThunkAction {
return async (dispatch): Promise<void> => { return async (dispatch): Promise<void> => {
try { try {
const requestID: string = await core.lambda.run(taskId, model, body); const requestID: string = await core.lambda.run(taskId, model, body);
@ -157,6 +232,7 @@ export function startInferenceAsync(taskId: number, model: Model, body: object):
listen( listen(
{ {
taskID: taskId, taskID: taskId,
functionID: model.id,
requestID, requestID,
}, },
dispatchCallback, dispatchCallback,
@ -171,10 +247,32 @@ export function cancelInferenceAsync(taskID: number): ThunkAction {
return async (dispatch, getState): Promise<void> => { return async (dispatch, getState): Promise<void> => {
try { try {
const inference = getState().models.inferences[taskID]; const inference = getState().models.inferences[taskID];
await core.lambda.cancel(inference.id); await core.lambda.cancel(inference.id, inference.functionID);
dispatch(modelsActions.cancelInferenceSuccess(taskID)); dispatch(modelsActions.cancelInferenceSuccess(taskID));
} catch (error) { } catch (error) {
dispatch(modelsActions.cancelInferenceFailed(taskID, error)); dispatch(modelsActions.cancelInferenceFailed(taskID, error));
} }
}; };
} }
export function getModelProvidersAsync(): ThunkAction {
return async function (dispatch) {
dispatch(modelsActions.getModelProviders());
try {
const providers = await cvat.lambda.providers();
dispatch(modelsActions.getModelProvidersSuccess(providers));
} catch (error) {
dispatch(modelsActions.getModelProvidersFailed(error));
}
};
}
export const getModelPreviewAsync = (model: MLModel): ThunkAction => async (dispatch) => {
dispatch(modelsActions.getModelPreview(model.id));
try {
const result = await model.getPreview();
dispatch(modelsActions.getModelPreviewSuccess(model.id, result));
} catch (error) {
dispatch(modelsActions.getModelPreviewFailed(model.id, error));
}
};

@ -1,5 +1,5 @@
// Copyright (C) 2019-2022 Intel Corporation // Copyright (C) 2019-2022 Intel Corporation
// Copyright (C) 2022 CVAT.ai Corporation // Copyright (C) 2022-2023 CVAT.ai Corporation
// //
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
@ -7,11 +7,12 @@ import { Dispatch, ActionCreator } from 'redux';
import { ActionUnion, createAction, ThunkAction } from 'utils/redux'; import { ActionUnion, createAction, ThunkAction } from 'utils/redux';
import { import {
ProjectsQuery, TasksQuery, CombinedState, Indexable, ProjectsQuery, TasksQuery, CombinedState,
} from 'reducers'; } from 'reducers';
import { getTasksAsync } from 'actions/tasks-actions'; import { getTasksAsync } from 'actions/tasks-actions';
import { getCVATStore } from 'cvat-store'; import { getCVATStore } from 'cvat-store';
import { getCore } from 'cvat-core-wrapper'; import { getCore } from 'cvat-core-wrapper';
import { filterNull } from 'utils/filter-null';
const cvat = getCore(); const cvat = getCore();
@ -99,17 +100,10 @@ export function getProjectsAsync(
dispatch(projectActions.updateProjectsGettingQuery(query, tasksQuery)); dispatch(projectActions.updateProjectsGettingQuery(query, tasksQuery));
// Clear query object from null fields // Clear query object from null fields
const filteredQuery: Partial<ProjectsQuery> = { const filteredQuery: Partial<ProjectsQuery> = filterNull({
page: 1, page: 1,
...query, ...query,
}; });
for (const key of Object.keys(filteredQuery)) {
const value = (filteredQuery as Indexable)[key];
if (value === null || typeof value === 'undefined') {
delete (filteredQuery as Indexable)[key];
}
}
let result = null; let result = null;
try { try {

@ -1,14 +1,15 @@
// Copyright (C) 2019-2022 Intel Corporation // Copyright (C) 2019-2022 Intel Corporation
// Copyright (C) 2022 CVAT.ai Corporation // Copyright (C) 2022-2023 CVAT.ai Corporation
// //
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
import { AnyAction, Dispatch, ActionCreator } from 'redux'; import { AnyAction, Dispatch, ActionCreator } from 'redux';
import { ThunkAction } from 'redux-thunk'; import { ThunkAction } from 'redux-thunk';
import { import {
TasksQuery, CombinedState, Indexable, StorageLocation, TasksQuery, CombinedState, StorageLocation,
} from 'reducers'; } from 'reducers';
import { getCore, Storage } from 'cvat-core-wrapper'; import { getCore, Storage } from 'cvat-core-wrapper';
import { filterNull } from 'utils/filter-null';
import { getInferenceStatusAsync } from './models-actions'; import { getInferenceStatusAsync } from './models-actions';
const cvat = getCore(); const cvat = getCore();
@ -74,13 +75,7 @@ export function getTasksAsync(
return async (dispatch: ActionCreator<Dispatch>): Promise<void> => { return async (dispatch: ActionCreator<Dispatch>): Promise<void> => {
dispatch(getTasks(query, updateQuery)); dispatch(getTasks(query, updateQuery));
// We remove all keys with null values from the query const filteredQuery = filterNull(query);
const filteredQuery = { ...query };
for (const key of Object.keys(query)) {
if ((filteredQuery as Indexable)[key] === null) {
delete (filteredQuery as Indexable)[key];
}
}
let result = null; let result = null;
try { try {

@ -1,11 +1,12 @@
// Copyright (C) 2022 CVAT.ai Corporation // Copyright (C) 2022-2023 CVAT.ai Corporation
// //
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
import { getCore, Webhook } from 'cvat-core-wrapper'; import { getCore, Webhook } from 'cvat-core-wrapper';
import { Dispatch, ActionCreator, Store } from 'redux'; import { Dispatch, ActionCreator, Store } from 'redux';
import { Indexable, WebhooksQuery } from 'reducers'; import { WebhooksQuery } from 'reducers';
import { ActionUnion, createAction, ThunkAction } from 'utils/redux'; import { ActionUnion, createAction, ThunkAction } from 'utils/redux';
import { filterNull } from 'utils/filter-null';
const cvat = getCore(); const cvat = getCore();
@ -47,13 +48,7 @@ export const getWebhooksAsync = (query: WebhooksQuery): ThunkAction => (
async (dispatch: ActionCreator<Dispatch>): Promise<void> => { async (dispatch: ActionCreator<Dispatch>): Promise<void> => {
dispatch(webhooksActions.getWebhooks(query)); dispatch(webhooksActions.getWebhooks(query));
// We remove all keys with null values from the query const filteredQuery = filterNull(query);
const filteredQuery = { ...query };
for (const key of Object.keys(query)) {
if ((filteredQuery as Indexable)[key] === null) {
delete (filteredQuery as Indexable)[key];
}
}
let result = null; let result = null;
try { try {

@ -1,4 +1,5 @@
// Copyright (C) 2020-2022 Intel Corporation // Copyright (C) 2020-2022 Intel Corporation
// Copyright (C) 2022-2023 CVAT.ai Corporation
// //
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
@ -26,10 +27,12 @@ import lodash from 'lodash';
import { AIToolsIcon } from 'icons'; import { AIToolsIcon } from 'icons';
import { Canvas, convertShapesForInteractor } from 'cvat-canvas-wrapper'; import { Canvas, convertShapesForInteractor } from 'cvat-canvas-wrapper';
import { getCore, Attribute, Label } from 'cvat-core-wrapper'; import {
getCore, Attribute, Label, MLModel,
} from 'cvat-core-wrapper';
import openCVWrapper from 'utils/opencv-wrapper/opencv-wrapper'; import openCVWrapper from 'utils/opencv-wrapper/opencv-wrapper';
import { import {
CombinedState, ActiveControl, Model, ObjectType, ShapeType, ToolsBlockerState, ModelAttribute, CombinedState, ActiveControl, ObjectType, ShapeType, ToolsBlockerState, ModelAttribute,
} from 'reducers'; } from 'reducers';
import { import {
interactWithCanvas, interactWithCanvas,
@ -57,9 +60,9 @@ interface StateToProps {
jobInstance: any; jobInstance: any;
isActivated: boolean; isActivated: boolean;
frame: number; frame: number;
interactors: Model[]; interactors: MLModel[];
detectors: Model[]; detectors: MLModel[];
trackers: Model[]; trackers: MLModel[];
curZOrder: number; curZOrder: number;
defaultApproxPolyAccuracy: number; defaultApproxPolyAccuracy: number;
toolsBlockerState: ToolsBlockerState; toolsBlockerState: ToolsBlockerState;
@ -67,7 +70,7 @@ interface StateToProps {
} }
interface DispatchToProps { interface DispatchToProps {
onInteractionStart(activeInteractor: Model, activeLabelID: number): void; onInteractionStart(activeInteractor: MLModel, activeLabelID: number): void;
updateAnnotations(statesToUpdate: any[]): void; updateAnnotations(statesToUpdate: any[]): void;
createAnnotations(sessionInstance: any, frame: number, statesToCreate: any[]): void; createAnnotations(sessionInstance: any, frame: number, statesToCreate: any[]): void;
fetchAnnotations(): void; fetchAnnotations(): void;
@ -133,13 +136,13 @@ interface TrackedShape {
clientID: number; clientID: number;
serverlessState: any; serverlessState: any;
shapePoints: number[]; shapePoints: number[];
trackerModel: Model; trackerModel: MLModel;
} }
interface State { interface State {
activeInteractor: Model | null; activeInteractor: MLModel | null;
activeLabelID: number; activeLabelID: number;
activeTracker: Model | null; activeTracker: MLModel | null;
convertMasksToPolygons: boolean; convertMasksToPolygons: boolean;
trackedShapes: TrackedShape[]; trackedShapes: TrackedShape[];
fetching: boolean; fetching: boolean;
@ -211,7 +214,7 @@ export class ToolsControlComponent extends React.PureComponent<Props, State> {
}; };
lastestApproximatedPoints: number[][]; lastestApproximatedPoints: number[][];
latestRequest: null | { latestRequest: null | {
interactor: Model; interactor: MLModel;
data: { data: {
frame: number; frame: number;
neg_points: number[][]; neg_points: number[][];
@ -444,7 +447,7 @@ export class ToolsControlComponent extends React.PureComponent<Props, State> {
this.constructFromPoints(this.interaction.lastestApproximatedPoints); this.constructFromPoints(this.interaction.lastestApproximatedPoints);
} }
} else if (shapesUpdated) { } else if (shapesUpdated) {
const interactor = activeInteractor as Model; const interactor = activeInteractor as MLModel;
this.interaction.latestRequest = { this.interaction.latestRequest = {
interactor, interactor,
data: { data: {
@ -498,7 +501,7 @@ export class ToolsControlComponent extends React.PureComponent<Props, State> {
clientID, clientID,
serverlessState: null, serverlessState: null,
shapePoints: points, shapePoints: points,
trackerModel: activeTracker as Model, trackerModel: activeTracker as MLModel,
}, },
], ],
}); });
@ -527,7 +530,7 @@ export class ToolsControlComponent extends React.PureComponent<Props, State> {
private setActiveInteractor = (value: string): void => { private setActiveInteractor = (value: string): void => {
const { interactors } = this.props; const { interactors } = this.props;
const [interactor] = interactors.filter((_interactor: Model) => _interactor.id === value); const [interactor] = interactors.filter((_interactor: MLModel) => _interactor.id === value);
if (interactor.version < MIN_SUPPORTED_INTERACTOR_VERSION) { if (interactor.version < MIN_SUPPORTED_INTERACTOR_VERSION) {
notification.warning({ notification.warning({
@ -544,7 +547,7 @@ export class ToolsControlComponent extends React.PureComponent<Props, State> {
private setActiveTracker = (value: string): void => { private setActiveTracker = (value: string): void => {
const { trackers } = this.props; const { trackers } = this.props;
this.setState({ this.setState({
activeTracker: trackers.filter((tracker: Model) => tracker.id === value)[0], activeTracker: trackers.filter((tracker: MLModel) => tracker.id === value)[0],
}); });
}; };
@ -723,7 +726,7 @@ export class ToolsControlComponent extends React.PureComponent<Props, State> {
for (const trackerID of Object.keys(trackingData.stateless)) { for (const trackerID of Object.keys(trackingData.stateless)) {
let hideMessage = null; let hideMessage = null;
try { try {
const [tracker] = trackers.filter((_tracker: Model) => _tracker.id === trackerID); const [tracker] = trackers.filter((_tracker: MLModel) => _tracker.id === trackerID);
if (!tracker) { if (!tracker) {
throw new Error(`Suitable tracker with ID ${trackerID} not found in tracker list`); throw new Error(`Suitable tracker with ID ${trackerID} not found in tracker list`);
} }
@ -770,7 +773,7 @@ export class ToolsControlComponent extends React.PureComponent<Props, State> {
// 4. run tracking for all the objects // 4. run tracking for all the objects
let hideMessage = null; let hideMessage = null;
try { try {
const [tracker] = trackers.filter((_tracker: Model) => _tracker.id === trackerID); const [tracker] = trackers.filter((_tracker: MLModel) => _tracker.id === trackerID);
if (!tracker) { if (!tracker) {
throw new Error(`Suitable tracker with ID ${trackerID} not found in tracker list`); throw new Error(`Suitable tracker with ID ${trackerID} not found in tracker list`);
} }
@ -955,7 +958,7 @@ export class ToolsControlComponent extends React.PureComponent<Props, State> {
onChange={this.setActiveTracker} onChange={this.setActiveTracker}
> >
{trackers.map( {trackers.map(
(tracker: Model): JSX.Element => ( (tracker: MLModel): JSX.Element => (
<Select.Option value={tracker.id} title={tracker.description} key={tracker.id}> <Select.Option value={tracker.id} title={tracker.description} key={tracker.id}>
{tracker.name} {tracker.name}
</Select.Option> </Select.Option>
@ -1030,7 +1033,7 @@ export class ToolsControlComponent extends React.PureComponent<Props, State> {
onChange={this.setActiveInteractor} onChange={this.setActiveInteractor}
> >
{interactors.map( {interactors.map(
(interactor: Model): JSX.Element => ( (interactor: MLModel): JSX.Element => (
<Select.Option <Select.Option
value={interactor.id} value={interactor.id}
title={interactor.description} title={interactor.description}
@ -1161,7 +1164,7 @@ export class ToolsControlComponent extends React.PureComponent<Props, State> {
models={detectors} models={detectors}
labels={jobInstance.labels} labels={jobInstance.labels}
dimension={jobInstance.dimension} dimension={jobInstance.dimension}
runInference={async (model: Model, body: DetectorRequestBody) => { runInference={async (model: MLModel, body: DetectorRequestBody) => {
try { try {
this.setState({ mode: 'detection', fetching: true }); this.setState({ mode: 'detection', fetching: true });
const result = await core.lambda.call(jobInstance.taskId, model, { const result = await core.lambda.call(jobInstance.taskId, model, {

@ -1,4 +1,4 @@
// Copyright (C) 2022 CVAT.ai Corporation // Copyright (C) 2022-2023 CVAT.ai Corporation
// //
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
@ -14,12 +14,15 @@ import { getCloudStoragePreviewAsync } from 'actions/cloud-storage-actions';
import { import {
CombinedState, Job, Task, Project, CloudStorage, CombinedState, Job, Task, Project, CloudStorage,
} from 'reducers'; } from 'reducers';
import MLModel from 'cvat-core/src/ml-model';
import { getModelPreviewAsync } from 'actions/models-actions';
interface Props { interface Props {
job?: Job | undefined; job?: Job | undefined;
task?: Task | undefined; task?: Task | undefined;
project?: Project | undefined; project?: Project | undefined;
cloudStorage?: CloudStorage | undefined; cloudStorage?: CloudStorage | undefined;
model?: MLModel | undefined;
onClick?: (event: React.MouseEvent) => void; onClick?: (event: React.MouseEvent) => void;
loadingClassName?: string; loadingClassName?: string;
emptyPreviewClassName?: string; emptyPreviewClassName?: string;
@ -35,6 +38,7 @@ export default function Preview(props: Props): JSX.Element {
task, task,
project, project,
cloudStorage, cloudStorage,
model,
onClick, onClick,
loadingClassName, loadingClassName,
emptyPreviewClassName, emptyPreviewClassName,
@ -51,6 +55,8 @@ export default function Preview(props: Props): JSX.Element {
return state.tasks.previews[task.id]; return state.tasks.previews[task.id];
} if (cloudStorage !== undefined) { } if (cloudStorage !== undefined) {
return state.cloudStorages.previews[cloudStorage.id]; return state.cloudStorages.previews[cloudStorage.id];
} if (model !== undefined) {
return state.models.previews[model.id];
} }
return ''; return '';
}); });
@ -65,6 +71,8 @@ export default function Preview(props: Props): JSX.Element {
dispatch(getTaskPreviewAsync(task)); dispatch(getTaskPreviewAsync(task));
} else if (cloudStorage !== undefined) { } else if (cloudStorage !== undefined) {
dispatch(getCloudStoragePreviewAsync(cloudStorage)); dispatch(getCloudStoragePreviewAsync(cloudStorage));
} else if (model !== undefined) {
dispatch(getModelPreviewAsync(model));
} }
} }
}, [preview]); }, [preview]);
@ -79,7 +87,7 @@ export default function Preview(props: Props): JSX.Element {
if (preview.initialized && !preview.preview) { if (preview.initialized && !preview.preview) {
return ( return (
<div className={emptyPreviewClassName || ''} aria-hidden> <div className={emptyPreviewClassName || ''} onClick={onClick} aria-hidden>
<PictureOutlined /> <PictureOutlined />
</div> </div>
); );

@ -0,0 +1,48 @@
// Copyright (C) 2022-2023 CVAT.ai Corporation
//
// SPDX-License-Identifier: MIT
import './styles.scss';
import React, { useEffect } from 'react';
import { Row, Col } from 'antd/lib/grid';
import Text from 'antd/lib/typography/Text';
import Spin from 'antd/lib/spin';
import { CombinedState } from 'reducers';
import { useSelector, useDispatch } from 'react-redux';
import { getModelProvidersAsync } from 'actions/models-actions';
import ModelForm from './model-form';
function CreateModelPage(): JSX.Element {
const dispatch = useDispatch();
const fetching = useSelector((state: CombinedState) => state.models.providers.fetching);
const providers = useSelector((state: CombinedState) => state.models.providers.list);
useEffect(() => {
dispatch(getModelProvidersAsync());
}, []);
return (
<div className='cvat-create-model-page'>
<Row justify='center' align='middle'>
<Col>
<Text className='cvat-title'>Add a model</Text>
</Col>
</Row>
{
fetching ? (
<div className='cvat-empty-webhooks-list'>
<Spin size='large' className='cvat-spinner' />
</div>
) : (
<Row justify='center' align='top'>
<Col md={20} lg={16} xl={14} xxl={9}>
<ModelForm providers={providers} />
</Col>
</Row>
)
}
</div>
);
}
export default React.memo(CreateModelPage);

@ -0,0 +1,162 @@
// Copyright (C) 2023 CVAT.ai Corporation
//
// SPDX-License-Identifier: MIT
import './styles.scss';
import React, { useCallback, useState } from 'react';
import { Store } from 'antd/lib/form/interface';
import { Row, Col } from 'antd/lib/grid';
import Form from 'antd/lib/form';
import Button from 'antd/lib/button';
import Select from 'antd/lib/select';
import notification from 'antd/lib/notification';
import Input from 'antd/lib/input/Input';
import { CombinedState } from 'reducers';
import { useHistory } from 'react-router';
import { useDispatch, useSelector } from 'react-redux';
import { createModelAsync } from 'actions/models-actions';
import { ModelProvider } from 'cvat-core-wrapper';
import ModelProviderIcon from 'components/models-page/model-provider-icon';
interface Props {
providers: ModelProvider[];
}
function createProviderFormItems(providerAttributes: Record<string, string>): JSX.Element {
delete providerAttributes.url;
return (
<>
{
Object.entries(providerAttributes).map(([key, text]) => (
<Form.Item
key={key}
name={key}
label={text}
rules={[{ required: true, message: `Please, specify ${text}` }]}
>
<Input />
</Form.Item>
))
}
</>
);
}
function ModelForm(props: Props): JSX.Element {
const { providers } = props;
const providerList = providers.map((provider) => ({
value: provider.name,
text: provider.name.charAt(0).toUpperCase() + provider.name.slice(1),
}));
const providerMap = Object.fromEntries(providers.map((provider) => [provider.name, provider.attributes]));
const [form] = Form.useForm();
const history = useHistory();
const dispatch = useDispatch();
const fetching = useSelector((state: CombinedState) => state.models.fetching);
const [currentProviderForm, setCurrentProviderForm] = useState<JSX.Element | null>(null);
const onChangeProviderValue = useCallback((provider: string) => {
setCurrentProviderForm(createProviderFormItems(providerMap[provider]));
const emptiedKeys: Record<string, string | null> = { ...providerMap[provider] };
Object.keys(providerMap[provider]).forEach((k) => { emptiedKeys[k] = null; });
form.setFieldsValue(emptiedKeys);
}, []);
const [providerTouched, setProviderTouched] = useState(false);
const [currentUrlEmpty, setCurrentUrlEmpty] = useState(true);
const handleSubmit = useCallback(async (): Promise<void> => {
try {
const values: Store = await form.validateFields();
await dispatch(createModelAsync(values));
form.resetFields();
setCurrentProviderForm(null);
setProviderTouched(false);
setCurrentUrlEmpty(true);
notification.info({
message: 'Model has been successfully created',
className: 'cvat-notification-create-model-success',
});
// eslint-disable-next-line no-empty
} catch (e) {}
}, []);
return (
<Row className='cvat-create-model-form-wrapper'>
<Col span={24}>
<Form
form={form}
layout='vertical'
>
<Col>
<Form.Item
name='url'
label='Model URL'
rules={[{ required: true, message: 'Please, specify Model URL' }]}
>
<Input onChange={(event: React.ChangeEvent<HTMLInputElement>) => {
const { value } = event.target;
const guessedProvider = providers.find((provider) => value.includes(provider.name));
if (guessedProvider && !providerTouched) {
form.setFieldsValue({ provider: guessedProvider.name });
setCurrentProviderForm(createProviderFormItems(providerMap[guessedProvider.name]));
}
setCurrentUrlEmpty(!value);
}}
/>
</Form.Item>
</Col>
{
!currentUrlEmpty && (
<>
<Form.Item
label='Provider'
name='provider'
rules={[{ required: true, message: 'Please, specify model provider' }]}
>
<Select
virtual={false}
onChange={onChangeProviderValue}
className='cvat-select-model-provider'
onSelect={() => { setProviderTouched(true); }}
>
{
providerList.map(({ value, text }) => (
<Select.Option value={value} key={value}>
<div className='cvat-model-provider-icon'>
<ModelProviderIcon providerName={value} />
<span className='cvat-cloud-storage-select-provider'>
{text}
</span>
</div>
</Select.Option>
))
}
</Select>
</Form.Item>
{currentProviderForm}
</>
)
}
</Form>
</Col>
<Col span={24} className='cvat-create-models-actions'>
<Row justify='end'>
<Col>
<Button onClick={() => history.goBack()}>
Cancel
</Button>
</Col>
<Col offset={1}>
<Button type='primary' onClick={handleSubmit} loading={fetching} disabled={currentUrlEmpty}>
Submit
</Button>
</Col>
</Row>
</Col>
</Row>
);
}
export default React.memo(ModelForm);

@ -0,0 +1,46 @@
// Copyright (C) 2022-2023 CVAT.ai Corporation
//
// SPDX-License-Identifier: MIT
@import '../../base.scss';
.cvat-create-model-page {
width: 100%;
height: 100%;
padding-top: $grid-unit-size * 5;
.cvat-title {
font-size: 36px;
}
}
.cvat-create-model-form-wrapper {
margin-top: $grid-unit-size * 3;
height: auto;
border: 1px solid $border-color-1;
border-radius: 3px;
padding: $grid-unit-size * 3;
background: $background-color-1;
text-align: initial;
}
.cvat-create-models-actions {
margin-top: $grid-unit-size * 2;
}
.cvat-model-provider-icon {
display: flex;
img {
margin-top: 3px;
margin-right: $grid-unit-size;
width: $grid-unit-size * 2;
height: $grid-unit-size * 2;
}
}
.cvat-select-model-provider {
img {
margin-top: 7px;
}
}

@ -1,5 +1,5 @@
// Copyright (C) 2020-2022 Intel Corporation // Copyright (C) 2020-2022 Intel Corporation
// Copyright (C) 2022 CVAT.ai Corporation // Copyright (C) 2022-2023 CVAT.ai Corporation
// //
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
@ -33,9 +33,9 @@ import ExportDatasetModal from 'components/export-dataset/export-dataset-modal';
import ExportBackupModal from 'components/export-backup/export-backup-modal'; import ExportBackupModal from 'components/export-backup/export-backup-modal';
import ImportDatasetModal from 'components/import-dataset/import-dataset-modal'; import ImportDatasetModal from 'components/import-dataset/import-dataset-modal';
import ImportBackupModal from 'components/import-backup/import-backup-modal'; import ImportBackupModal from 'components/import-backup/import-backup-modal';
import ModelsPageContainer from 'containers/models-page/models-page';
import JobsPageComponent from 'components/jobs-page/jobs-page'; import JobsPageComponent from 'components/jobs-page/jobs-page';
import ModelsPageComponent from 'components/models-page/models-page';
import TasksPageContainer from 'containers/tasks-page/tasks-page'; import TasksPageContainer from 'containers/tasks-page/tasks-page';
import CreateTaskPageContainer from 'containers/create-task-page/create-task-page'; import CreateTaskPageContainer from 'containers/create-task-page/create-task-page';
@ -72,6 +72,7 @@ import appConfig from 'config';
import EmailConfirmationPage from './email-confirmation-pages/email-confirmed'; import EmailConfirmationPage from './email-confirmation-pages/email-confirmed';
import EmailVerificationSentPage from './email-confirmation-pages/email-verification-sent'; import EmailVerificationSentPage from './email-confirmation-pages/email-verification-sent';
import IncorrectEmailConfirmationPage from './email-confirmation-pages/incorrect-email-confirmation'; import IncorrectEmailConfirmationPage from './email-confirmation-pages/incorrect-email-confirmation';
import CreateModelPage from './create-model-page/create-model-page';
interface CVATAppProps { interface CVATAppProps {
loadFormats: () => void; loadFormats: () => void;
@ -330,7 +331,7 @@ class CVATApplication extends React.PureComponent<CVATAppProps & RouteComponentP
<ReactMarkdown>{title}</ReactMarkdown> <ReactMarkdown>{title}</ReactMarkdown>
), ),
duration: null, duration: null,
description: error.length > 200 ? 'Open the Browser Console to get details' : <ReactMarkdown>{error}</ReactMarkdown>, description: error.length > 300 ? 'Open the Browser Console to get details' : <ReactMarkdown>{error}</ReactMarkdown>,
}); });
// eslint-disable-next-line no-console // eslint-disable-next-line no-console
@ -449,7 +450,14 @@ class CVATApplication extends React.PureComponent<CVATAppProps & RouteComponentP
<Route exact path='/webhooks/update/:id' component={UpdateWebhookPage} /> <Route exact path='/webhooks/update/:id' component={UpdateWebhookPage} />
<Route exact path='/organization' component={OrganizationPage} /> <Route exact path='/organization' component={OrganizationPage} />
{isModelPluginActive && ( {isModelPluginActive && (
<Route exact path='/models' component={ModelsPageContainer} /> <Route
path='/models'
>
<Switch>
<Route exact path='/models' component={ModelsPageComponent} />
<Route exact path='/models/create' component={CreateModelPage} />
</Switch>
</Route>
)} )}
<Redirect <Redirect
push push

@ -1,4 +1,5 @@
// Copyright (C) 2020-2022 Intel Corporation // Copyright (C) 2020-2022 Intel Corporation
// Copyright (C) 2022-2023 CVAT.ai Corporation
// //
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
@ -14,20 +15,21 @@ import Button from 'antd/lib/button';
import Switch from 'antd/lib/switch'; import Switch from 'antd/lib/switch';
import notification from 'antd/lib/notification'; import notification from 'antd/lib/notification';
import { Model, ModelAttribute, StringObject } from 'reducers'; import { ModelAttribute, StringObject } from 'reducers';
import CVATTooltip from 'components/common/cvat-tooltip'; import CVATTooltip from 'components/common/cvat-tooltip';
import { Label as LabelInterface } from 'components/labels-editor/common'; import { Label as LabelInterface } from 'components/labels-editor/common';
import { clamp } from 'utils/math'; import { clamp } from 'utils/math';
import config from 'config'; import config from 'config';
import { MLModel, ModelKind, ModelReturnType } from 'cvat-core-wrapper';
import { DimensionType } from '../../reducers'; import { DimensionType } from '../../reducers';
interface Props { interface Props {
withCleanup: boolean; withCleanup: boolean;
models: Model[]; models: MLModel[];
labels: LabelInterface[]; labels: LabelInterface[];
dimension: DimensionType; dimension: DimensionType;
runInference(model: Model, body: object): void; runInference(model: MLModel, body: object): void;
} }
interface MappedLabel { interface MappedLabel {
@ -63,15 +65,20 @@ function DetectorRunner(props: Props): JSX.Element {
const [attrMatches, setAttrMatch] = useState<Record<string, Match>>({}); const [attrMatches, setAttrMatch] = useState<Record<string, Match>>({});
const model = models.filter((_model): boolean => _model.id === modelID)[0]; const model = models.filter((_model): boolean => _model.id === modelID)[0];
const isDetector = model && model.type === 'detector'; const isDetector = model && model.kind === ModelKind.DETECTOR;
const isReId = model && model.type === 'reid'; const isReId = model && model.kind === ModelKind.REID;
const isClassifier = model && model.kind === ModelKind.CLASSIFIER;
const convertMasksToPolygonsAvailable = isDetector &&
(!model.returnType || model.returnType === ModelReturnType.MASK);
const buttonEnabled = const buttonEnabled =
model && (model.type === 'reid' || (model.type === 'detector' && !!Object.keys(mapping).length)); model && (model.kind === ModelKind.REID ||
(model.kind === ModelKind.DETECTOR && !!Object.keys(mapping).length) ||
const modelLabels = (isDetector ? model.labels : []).filter((_label: string): boolean => !(_label in mapping)); (model.kind === ModelKind.CLASSIFIER && !!Object.keys(mapping).length));
const taskLabels = isDetector ? labels.map((label: any): string => label.name) : []; const canHaveMapping = isDetector || isClassifier;
const modelLabels = (canHaveMapping ? model.labels : []).filter((_label: string): boolean => !(_label in mapping));
if (model && model.type !== 'reid' && !model.labels.length) { const taskLabels = canHaveMapping ? labels.map((label: any): string => label.name) : [];
if (model && model.kind === ModelKind.REID && !model.labels.length) {
notification.warning({ notification.warning({
message: 'The selected model does not include any labels', message: 'The selected model does not include any labels',
}); });
@ -241,7 +248,6 @@ function DetectorRunner(props: Props): JSX.Element {
return acc; return acc;
}, {}, }, {},
); );
setMapping(defaultMapping); setMapping(defaultMapping);
setMatch({ model: null, task: null }); setMatch({ model: null, task: null });
setAttrMatch({}); setAttrMatch({});
@ -249,7 +255,7 @@ function DetectorRunner(props: Props): JSX.Element {
}} }}
> >
{models.map( {models.map(
(_model: Model): JSX.Element => ( (_model: MLModel): JSX.Element => (
<Select.Option value={_model.id} key={_model.id}> <Select.Option value={_model.id} key={_model.id}>
{_model.name} {_model.name}
</Select.Option> </Select.Option>
@ -258,7 +264,7 @@ function DetectorRunner(props: Props): JSX.Element {
</Select> </Select>
</Col> </Col>
</Row> </Row>
{isDetector && {canHaveMapping &&
Object.keys(mapping).length ? Object.keys(mapping).length ?
Object.keys(mapping).map((modelLabel: string) => { Object.keys(mapping).map((modelLabel: string) => {
const label = labels const label = labels
@ -337,7 +343,7 @@ function DetectorRunner(props: Props): JSX.Element {
</React.Fragment> </React.Fragment>
); );
}) : null} }) : null}
{isDetector && !!taskLabels.length && !!modelLabels.length ? ( {canHaveMapping && !!taskLabels.length && !!modelLabels.length ? (
<> <>
<Row justify='start' align='middle'> <Row justify='start' align='middle'>
<Col span={10}> <Col span={10}>
@ -354,7 +360,7 @@ function DetectorRunner(props: Props): JSX.Element {
</Row> </Row>
</> </>
) : null} ) : null}
{isDetector && ( {convertMasksToPolygonsAvailable && (
<div className='detector-runner-convert-masks-to-polygons-wrapper'> <div className='detector-runner-convert-masks-to-polygons-wrapper'>
<Switch <Switch
checked={convertMasksToPolygons} checked={convertMasksToPolygons}
@ -423,19 +429,25 @@ function DetectorRunner(props: Props): JSX.Element {
disabled={!buttonEnabled} disabled={!buttonEnabled}
type='primary' type='primary'
onClick={() => { onClick={() => {
const detectorRequestBody: DetectorRequestBody = { let requestBody: object = {};
mapping, if (model.kind === ModelKind.DETECTOR) {
cleanup, requestBody = {
convMaskToPoly: convertMasksToPolygons, mapping,
}; cleanup,
convMaskToPoly: convertMasksToPolygons,
runInference( };
model, } else if (model.kind === ModelKind.REID) {
model.type === 'detector' ? detectorRequestBody : { requestBody = {
threshold, threshold,
max_distance: distance, max_distance: distance,
}, };
); } else if (model.kind === ModelKind.CLASSIFIER) {
requestBody = {
mapping,
};
}
runInference(model, requestBody);
}} }}
> >
Annotate Annotate

@ -1,4 +1,5 @@
// Copyright (C) 2020-2022 Intel Corporation // Copyright (C) 2020-2022 Intel Corporation
// Copyright (C) 2022-2023 CVAT.ai Corporation
// //
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
@ -9,36 +10,39 @@ import Modal from 'antd/lib/modal';
import { ThunkDispatch } from 'utils/redux'; import { ThunkDispatch } from 'utils/redux';
import { modelsActions, startInferenceAsync } from 'actions/models-actions'; import { modelsActions, startInferenceAsync } from 'actions/models-actions';
import { Model, CombinedState } from 'reducers'; import { CombinedState } from 'reducers';
import MLModel from 'cvat-core/src/ml-model';
import DetectorRunner from './detector-runner'; import DetectorRunner from './detector-runner';
interface StateToProps { interface StateToProps {
visible: boolean; visible: boolean;
task: any; task: any;
detectors: Model[]; detectors: MLModel[];
reid: Model[]; reid: MLModel[];
classifiers: MLModel[];
} }
interface DispatchToProps { interface DispatchToProps {
runInference(task: any, model: Model, body: object): void; runInference(task: any, model: MLModel, body: object): void;
closeDialog(): void; closeDialog(): void;
} }
function mapStateToProps(state: CombinedState): StateToProps { function mapStateToProps(state: CombinedState): StateToProps {
const { models } = state; const { models } = state;
const { detectors, reid } = models; const { detectors, reid, classifiers } = models;
return { return {
visible: models.modelRunnerIsVisible, visible: models.modelRunnerIsVisible,
task: models.modelRunnerTask, task: models.modelRunnerTask,
reid, reid,
detectors, detectors,
classifiers,
}; };
} }
function mapDispatchToProps(dispatch: ThunkDispatch): DispatchToProps { function mapDispatchToProps(dispatch: ThunkDispatch): DispatchToProps {
return { return {
runInference(taskID: number, model: Model, body: object) { runInference(taskID: number, model: MLModel, body: object) {
dispatch(startInferenceAsync(taskID, model, body)); dispatch(startInferenceAsync(taskID, model, body));
}, },
closeDialog() { closeDialog() {
@ -49,10 +53,10 @@ function mapDispatchToProps(dispatch: ThunkDispatch): DispatchToProps {
function ModelRunnerDialog(props: StateToProps & DispatchToProps): JSX.Element { function ModelRunnerDialog(props: StateToProps & DispatchToProps): JSX.Element {
const { const {
reid, detectors, task, visible, runInference, closeDialog, reid, detectors, classifiers, task, visible, runInference, closeDialog,
} = props; } = props;
const models = [...reid, ...detectors]; const models = [...reid, ...detectors, ...classifiers];
return ( return (
<Modal <Modal

@ -1,51 +1,159 @@
// Copyright (C) 2020-2022 Intel Corporation // Copyright (C) 2020-2022 Intel Corporation
// Copyright (C) 2022-2023 CVAT.ai Corporation
// //
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
import React from 'react'; import React, { useCallback, useState } from 'react';
import { Row, Col } from 'antd/lib/grid'; import { Row, Col } from 'antd/lib/grid';
import Tag from 'antd/lib/tag'; import Tag from 'antd/lib/tag';
import Select from 'antd/lib/select';
import Text from 'antd/lib/typography/Text'; import Text from 'antd/lib/typography/Text';
import { Model } from 'reducers'; import { MoreOutlined } from '@ant-design/icons';
import CVATTooltip from 'components/common/cvat-tooltip'; import Modal from 'antd/lib/modal';
import { MLModel, ModelProviders } from 'cvat-core-wrapper';
import Title from 'antd/lib/typography/Title';
import Meta from 'antd/lib/card/Meta';
import Preview from 'components/common/preview';
import moment from 'moment';
import Divider from 'antd/lib/divider';
import Card from 'antd/lib/card';
import Dropdown from 'antd/lib/dropdown';
import Button from 'antd/lib/button';
import ModelActionsMenuComponent from './models-action-menu';
import ModelProviderIcon from './model-provider-icon';
interface Props { interface Props {
model: Model; model: MLModel;
} }
export default function DeployedModelItem(props: Props): JSX.Element { export default function DeployedModelItem(props: Props): JSX.Element {
const { model } = props; const { model } = props;
const { provider } = model;
const [isRemoved, setIsRemoved] = useState(false);
const [isModalShown, setIsModalShown] = useState(false);
const onOpenModel = () => {
setIsModalShown(true);
};
const onCloseModel = () => {
setIsModalShown(false);
};
const onDelete = useCallback(() => {
setIsRemoved(true);
}, []);
const created = moment(model.createdDate).fromNow();
const icon = <ModelProviderIcon providerName={provider} />;
return ( return (
<Row className='cvat-models-list-item'> <>
<Col span={3}> <Modal
<Tag color='purple'>{model.framework}</Tag> className='cvat-model-info-modal'
</Col> title='Model'
<Col span={3}> visible={isModalShown}
<CVATTooltip overlay={model.name}> onCancel={onCloseModel}
<Text className='cvat-text-color'>{model.name}</Text> footer={null}
</CVATTooltip> >
</Col> <Preview
<Col span={3} offset={1}> model={model}
<Tag color='orange'>{model.type}</Tag> loadingClassName='cvat-model-item-loading-preview'
</Col> emptyPreviewClassName='cvat-model-item-empty-preview'
<Col span={8}> previewWrapperClassName='cvat-models-item-card-preview-wrapper'
<CVATTooltip overlay={model.description}> previewClassName='cvat-models-item-card-preview'
<Text style={{ whiteSpace: 'normal', height: 'auto' }}>{model.description}</Text> />
</CVATTooltip> {icon ? <div className='cvat-model-item-provider-inner'>{icon}</div> : null}
</Col> <div className='cvat-model-info-container'>
<Col span={5} offset={1}> <Title level={3}>{model.name}</Title>
<Select showSearch placeholder='Supported labels' style={{ width: '90%' }} value='Supported labels'> <Text type='secondary'>{`Added ${created}`}</Text>
{model.labels.map( </div>
(label): JSX.Element => ( <Divider />
<Select.Option value={label} key={label}> {
{label} model.labels?.length ? (
</Select.Option> <>
), <div className='cvat-model-info-container'>
<Text className='cvat-model-info-modal-labels-title'>Labels:</Text>
</div>
<div className='cvat-model-info-container cvat-model-info-modal-labels-list'>
{model.labels.map((label) => <Tag key={label}>{label}</Tag>)}
</div>
<Divider />
</>
) : null
}
<Row justify='space-between' className='cvat-model-info-container'>
<Col span={15}>
<Row>
<Col span={8}>
<Text strong>Provider</Text>
</Col>
<Col>
<Text strong>Type</Text>
</Col>
</Row>
<Row>
<Col span={8}>
{model.provider}
</Col>
<Col>
{model.kind}
</Col>
</Row>
</Col>
{model.owner && (
<Col>
<Row>
<Col>
<Text strong>Owner</Text>
</Col>
</Row>
<Row>
<Col>
{model.owner}
</Col>
</Row>
</Col>
)}
</Row>
</Modal>
<Card
cover={(
<Preview
model={model}
loadingClassName='cvat-model-item-loading-preview'
emptyPreviewClassName='cvat-model-item-empty-preview'
previewWrapperClassName='cvat-models-item-card-preview-wrapper'
previewClassName='cvat-models-item-card-preview'
onClick={onOpenModel}
/>
)}
size='small'
className={`cvat-models-item-card ${isRemoved ? 'cvat-models-item-card-removed' : ''} `}
>
<Meta
title={(
<span onClick={onOpenModel} className='cvat-models-item-title' aria-hidden>
{model.name}
</span>
)}
description={(
<div className='cvat-models-item-description'>
<Row onClick={onOpenModel} className='cvat-models-item-text-description'>
{model.owner && (<Text strong>{model.owner}</Text>)}
<Text type='secondary'>{` Added ${created}`}</Text>
</Row>
{
model.provider !== ModelProviders.CVAT && (
<Dropdown overlay={<ModelActionsMenuComponent model={model} onDelete={onDelete} />}>
<Button type='link' size='large' icon={<MoreOutlined />} />
</Dropdown>
)
}
</div>
)} )}
</Select> />
</Col> {
</Row> icon ? <div className='cvat-model-item-provider'>{icon}</div> : null
}
</Card>
</>
); );
} }

@ -1,44 +1,35 @@
// Copyright (C) 2020-2022 Intel Corporation // Copyright (C) 2020-2022 Intel Corporation
// Copyright (C) 2022-2023 CVAT.ai Corporation
// //
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
import React from 'react'; import React from 'react';
import moment from 'moment';
import { useSelector } from 'react-redux';
import { Row, Col } from 'antd/lib/grid'; import { Row, Col } from 'antd/lib/grid';
import Text from 'antd/lib/typography/Text'; import { CombinedState } from 'reducers';
import { MLModel } from 'cvat-core-wrapper';
import { Model } from 'reducers'; import { ModelProviders } from 'cvat-core/src/enums';
import DeployedModelItem from './deployed-model-item'; import DeployedModelItem from './deployed-model-item';
interface Props { export default function DeployedModelsListComponent(): JSX.Element {
models: Model[]; const interactors = useSelector((state: CombinedState) => state.models.interactors);
} const detectors = useSelector((state: CombinedState) => state.models.detectors);
const trackers = useSelector((state: CombinedState) => state.models.trackers);
export default function DeployedModelsListComponent(props: Props): JSX.Element { const reid = useSelector((state: CombinedState) => state.models.reid);
const { models } = props; const classifiers = useSelector((state: CombinedState) => state.models.classifiers);
const models = [...interactors, ...detectors, ...trackers, ...reid, ...classifiers];
const builtInModels = models.filter((model: MLModel) => model.provider === ModelProviders.CVAT);
const externalModels = models.filter((model: MLModel) => model.provider !== ModelProviders.CVAT);
externalModels.sort((a, b) => moment(a.createdDate).valueOf() - moment(b.createdDate).valueOf());
const items = models.map((model): JSX.Element => <DeployedModelItem key={model.id} model={model} />); const renderModels = [...builtInModels, ...externalModels];
const items = renderModels.map((model): JSX.Element => <DeployedModelItem key={model.id} model={model} />);
return ( return (
<> <>
<Row justify='center' align='middle'> <Row justify='center' align='middle'>
<Col md={22} lg={18} xl={16} xxl={14} className='cvat-models-list'> <Col md={22} lg={18} xl={16} xxl={16} className='cvat-models-list'>
<Row align='middle' style={{ padding: '10px' }}>
<Col span={3}>
<Text strong>Framework</Text>
</Col>
<Col span={3}>
<Text strong>Name</Text>
</Col>
<Col span={3} offset={1}>
<Text strong>Type</Text>
</Col>
<Col span={8}>
<Text strong>Description</Text>
</Col>
<Col span={5} offset={1}>
<Text strong>Labels</Text>
</Col>
</Row>
{items} {items}
</Col> </Col>
</Row> </Row>

@ -1,4 +1,5 @@
// Copyright (C) 2020-2022 Intel Corporation // Copyright (C) 2020-2022 Intel Corporation
// Copyright (C) 2022-2023 CVAT.ai Corporation
// //
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT

@ -0,0 +1,29 @@
// Copyright (C) 2023 CVAT.ai Corporation
//
// SPDX-License-Identifier: MIT
import React from 'react';
import { ModelProvider } from 'cvat-core-wrapper';
import { CombinedState } from 'reducers';
import { useSelector } from 'react-redux';
interface Props {
providerName: string;
}
export default function ModelProviderIcon(props: Props): JSX.Element | null {
const { providerName } = props;
const providers = useSelector((state: CombinedState) => state.models.providers.list);
let icon: JSX.Element | null = null;
const providerInstance = providers.find((_provider: ModelProvider) => _provider.name === providerName);
if (providerInstance) {
icon = (
<img
src={`data:image/svg+xml;utf8,${encodeURIComponent(providerInstance.icon)}`}
alt={providerName}
/>
);
}
return icon;
}

@ -0,0 +1,66 @@
// Copyright (C) 2023 CVAT.ai Corporation
//
// SPDX-License-Identifier: MIT
import React, { useCallback } from 'react';
import { useDispatch } from 'react-redux';
import Modal from 'antd/lib/modal';
import Menu from 'antd/lib/menu';
import { MLModel, ModelProviders } from 'cvat-core-wrapper';
import { deleteModelAsync } from 'actions/models-actions';
interface Props {
model: MLModel;
onDelete: () => void;
}
export default function ModelActionsMenuComponent(props: Props): JSX.Element {
const { model, onDelete } = props;
const { provider } = model;
const cvatProvider = provider === ModelProviders.CVAT;
const dispatch = useDispatch();
const onDeleteModel = useCallback((): void => {
Modal.confirm({
title: `The model ${model.name} will be deleted`,
content: 'You will not be able to use it anymore. Continue?',
className: 'cvat-modal-confirm-remove-model',
onOk: () => {
dispatch(deleteModelAsync(model));
onDelete();
},
okButtonProps: {
type: 'primary',
danger: true,
},
okText: 'Delete',
});
}, []);
const onOpenUrl = useCallback((): void => {
window.open(model.url, '_blank');
}, []);
return (
<Menu selectable={false} className='cvat-project-actions-menu'>
{
!cvatProvider && (
<Menu.Item key='open' onClick={onOpenUrl}>
Open model URL
</Menu.Item>
)
}
{
!cvatProvider && (
<>
<Menu.Divider />
<Menu.Item key='delete' onClick={onDeleteModel}>
Delete
</Menu.Item>
</>
)
}
</Menu>
);
}

@ -0,0 +1,55 @@
// Copyright (C) 2022-2023 CVAT.ai Corporation
//
// SPDX-License-Identifier: MIT
import { Config } from 'react-awesome-query-builder';
export const config: Partial<Config> = {
fields: {
description: {
label: 'Description',
type: 'text',
valueSources: ['value'],
operators: ['like'],
},
target_url: {
label: 'Target URL',
type: 'text',
valueSources: ['value'],
operators: ['like'],
},
owner: {
label: 'Owner',
type: 'text',
valueSources: ['value'],
operators: ['equal'],
},
updated_date: {
label: 'Last updated',
type: 'datetime',
operators: ['between', 'greater', 'greater_or_equal', 'less', 'less_or_equal'],
},
type: {
label: 'Type',
type: 'select',
valueSources: ['value'],
fieldSettings: {
listValues: [
{ value: 'organization', title: 'Organization' },
{ value: 'project', title: 'Project' },
],
},
},
id: {
label: 'ID',
type: 'number',
operators: ['equal', 'between', 'greater', 'greater_or_equal', 'less', 'less_or_equal'],
fieldSettings: { min: 0 },
valueSources: ['value'],
},
},
};
export const localStorageRecentCapacity = 10;
export const localStorageRecentKeyword = 'recentlyAppliedWebhooksFilters';
export const predefinedFilterValues = {};

@ -1,33 +1,91 @@
// Copyright (C) 2020-2022 Intel Corporation // Copyright (C) 2020-2022 Intel Corporation
// Copyright (C) 2022-2023 CVAT.ai Corporation
// //
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
import './styles.scss'; import './styles.scss';
import React from 'react'; import React, { useCallback, useEffect } from 'react';
import { useHistory } from 'react-router';
import { useDispatch, useSelector } from 'react-redux';
import { getModelProvidersAsync, getModelsAsync } from 'actions/models-actions';
import { updateHistoryFromQuery } from 'components/resource-sorting-filtering';
import Spin from 'antd/lib/spin';
import DeployedModelsList from './deployed-models-list'; import DeployedModelsList from './deployed-models-list';
import EmptyListComponent from './empty-list'; import EmptyListComponent from './empty-list';
import FeedbackComponent from '../feedback/feedback'; import FeedbackComponent from '../feedback/feedback';
import { Model } from '../../reducers'; import { CombinedState } from '../../reducers';
import TopBar from './top-bar';
interface Props { function ModelsPageComponent(): JSX.Element {
interactors: Model[]; const history = useHistory();
detectors: Model[]; const dispatch = useDispatch();
trackers: Model[]; const fetching = useSelector((state: CombinedState) => state.models.fetching);
reid: Model[]; const query = useSelector((state: CombinedState) => state.models.query);
} const totalCount = useSelector((state: CombinedState) => state.models.totalCount);
const onCreateModel = useCallback(() => {
history.push('/models/create');
}, []);
export default function ModelsPageComponent(props: Props): JSX.Element { const updatedQuery = { ...query };
const { useEffect(() => {
interactors, detectors, trackers, reid, history.replace({
} = props; search: updateHistoryFromQuery(query),
});
}, [query]);
const deployedModels = [...detectors, ...interactors, ...trackers, ...reid]; useEffect(() => {
dispatch(getModelProvidersAsync());
dispatch(getModelsAsync());
}, []);
const content = totalCount ? (
<DeployedModelsList />
) : <EmptyListComponent />;
return ( return (
<div className='cvat-models-page'> <div className='cvat-models-page'>
{deployedModels.length ? <DeployedModelsList models={deployedModels} /> : <EmptyListComponent />} <TopBar
disabled
query={updatedQuery}
onCreateModel={onCreateModel}
onApplySearch={(search: string | null) => {
dispatch(
getModelsAsync({
...query,
search,
page: 1,
}),
);
}}
onApplyFilter={(filter: string | null) => {
dispatch(
getModelsAsync({
...query,
filter,
page: 1,
}),
);
}}
onApplySorting={(sorting: string | null) => {
dispatch(
getModelsAsync({
...query,
sort: sorting,
page: 1,
}),
);
}}
/>
{ fetching ? (
<div className='cvat-empty-models-list'>
<Spin size='large' className='cvat-spinner' />
</div>
) : content }
<FeedbackComponent /> <FeedbackComponent />
</div> </div>
); );
} }
export default React.memo(ModelsPageComponent);

@ -1,4 +1,5 @@
// Copyright (C) 2020-2022 Intel Corporation // Copyright (C) 2020-2022 Intel Corporation
// Copyright (C) 2022-2023 CVAT.ai Corporation
// //
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
@ -7,14 +8,10 @@
.cvat-models-page { .cvat-models-page {
padding-top: $grid-unit-size * 2; padding-top: $grid-unit-size * 2;
padding-bottom: $grid-unit-size; padding-bottom: $grid-unit-size;
height: 90%;
overflow: auto; overflow: auto;
position: fixed; position: fixed;
height: 100%;
width: 100%; width: 100%;
> div:nth-child(1) {
margin-bottom: $grid-unit-size;
}
} }
.cvat-empty-models-list { .cvat-empty-models-list {
@ -26,10 +23,12 @@
.cvat-models-list { .cvat-models-list {
height: 100%; height: 100%;
overflow-y: auto; display: flex;
flex-wrap: wrap;
} }
.cvat-models-list-item { .cvat-models-list-item {
position: relative;
width: 100%; width: 100%;
height: auto; height: auto;
border: 1px solid $border-color-1; border: 1px solid $border-color-1;
@ -59,3 +58,177 @@
overflow: hidden; overflow: hidden;
} }
} }
.cvat-models-item-card-removed {
opacity: 0.5;
pointer-events: none;
}
.cvat-models-page-top-bar {
margin: $grid-unit-size * 3 0;
> div {
display: flex;
}
}
.cvat-models-heading {
padding: $grid-unit-size * 2;
}
.cvat-models-page-filters-wrapper {
display: flex;
justify-content: space-between;
align-items: center;
width: 100%;
> div {
display: flex;
margin-right: $grid-unit-size * 4;
> button {
margin-right: $grid-unit-size;
}
}
}
.cvat-models-add-wrapper {
display: inline-block;
}
.cvat-model-delete {
position: absolute;
top: $grid-unit-size;
right: $grid-unit-size;
color: #8c8c8c;
font-size: 10px;
&:hover {
cursor: pointer;
color: #595959;
}
}
.cvat-models-item-card {
width: 25%;
border-width: 4px;
height: $grid-unit-size * 28;
overflow: hidden;
.ant-card-meta-title {
margin-bottom: 0 !important;
}
}
.cvat-model-item-loading-preview,
.cvat-model-item-empty-preview {
.ant-spin {
position: inherit;
}
font-size: 80px;
text-align: center;
height: $grid-unit-size * 18;
&:hover {
cursor: pointer;
}
}
.cvat-models-item-card-preview-wrapper {
display: flex;
justify-content: center;
align-items: center;
height: $grid-unit-size * 18;
overflow: hidden;
&:hover {
cursor: pointer;
}
}
.cvat-model-item-provider {
position: absolute;
top: $grid-unit-size;
right: $grid-unit-size;
}
.cvat-model-item-provider-inner {
@extend .cvat-model-item-provider;
right: $grid-unit-size * 2;
svg {
width: $grid-unit-size * 4;
height: $grid-unit-size * 4;
}
}
.cvat-models-item-description {
font-size: 14px;
display: flex;
justify-content: space-between;
> div > span:nth-child(2) {
margin-left: $grid-unit-size;
}
button {
position: relative;
color: black;
margin-top: -$grid-unit-size * 2;
margin-right: -$grid-unit-size;
}
&:hover {
cursor: pointer;
}
}
.cvat-model-info-modal {
.ant-modal-body {
position: relative;
padding: 0;
>.cvat-model-info-container:not(:last-child) {
padding: 0 $grid-unit-size * 3;
}
>.cvat-model-info-container:last-child {
padding: 0 $grid-unit-size * 3 $grid-unit-size * 3 $grid-unit-size * 3;
}
}
h3 {
margin-top: $grid-unit-size * 2;
}
}
.cvat-model-info-modal-labels-title {
font-size: 16px;
}
.cvat-model-info-modal-labels-list {
margin-top: $grid-unit-size;
max-height: $grid-unit-size * 18;
overflow: auto;
.ant-tag {
margin-top: $grid-unit-size;
}
}
.cvat-models-item-title {
&:hover {
cursor: pointer !important;
}
}
.cvat-models-item-text-description {
max-width: 80%;
text-overflow: ellipsis;
white-space: nowrap;
overflow: hidden;
max-height: $grid-unit-size * 3;
display: inline;
}

@ -0,0 +1,89 @@
// Copyright (C) 2022-2023 CVAT.ai Corporation
//
// SPDX-License-Identifier: MIT
import React, { useState } from 'react';
import { Row, Col } from 'antd/lib/grid';
import { PlusOutlined } from '@ant-design/icons';
import Button from 'antd/lib/button';
import { Input } from 'antd';
import { SortingComponent, ResourceFilterHOC, defaultVisibility } from 'components/resource-sorting-filtering';
import { ModelsQuery } from 'reducers';
import {
localStorageRecentKeyword, localStorageRecentCapacity, config,
} from './models-filter-configuration';
const FilteringComponent = ResourceFilterHOC(
config, localStorageRecentKeyword, localStorageRecentCapacity,
);
interface VisibleTopBarProps {
onApplyFilter(filter: string | null): void;
onApplySorting(sorting: string | null): void;
onApplySearch(search: string | null): void;
query: ModelsQuery;
onCreateModel(): void;
disabled?: boolean;
}
export default function TopBarComponent(props: VisibleTopBarProps): JSX.Element {
const {
query, onApplyFilter, onApplySorting, onApplySearch, onCreateModel, disabled,
} = props;
const [visibility, setVisibility] = useState(defaultVisibility);
return (
<Row className='cvat-models-page-top-bar' justify='center' align='middle'>
<Col md={22} lg={18} xl={16} xxl={16}>
<div className='cvat-models-page-filters-wrapper'>
<Input.Search
disabled={disabled}
enterButton
onSearch={(phrase: string) => {
onApplySearch(phrase);
}}
defaultValue={query.search || ''}
className='cvat-webhooks-page-search-bar'
placeholder='Search ...'
/>
<div>
<SortingComponent
disabled={disabled}
visible={visibility.sorting}
onVisibleChange={(visible: boolean) => (
setVisibility({ ...defaultVisibility, sorting: visible })
)}
defaultFields={query.sort?.split(',') || ['-ID']}
sortingFields={['ID', 'Target URL', 'Owner', 'Description', 'Type', 'Updated date']}
onApplySorting={onApplySorting}
/>
<FilteringComponent
disabled={disabled}
value={query.filter}
predefinedVisible={visibility.predefined}
builderVisible={visibility.builder}
recentVisible={visibility.recent}
onPredefinedVisibleChange={(visible: boolean) => (
setVisibility({ ...defaultVisibility, predefined: visible })
)}
onBuilderVisibleChange={(visible: boolean) => (
setVisibility({ ...defaultVisibility, builder: visible })
)}
onRecentVisibleChange={(visible: boolean) => (
setVisibility({
...defaultVisibility,
builder: visibility.builder,
recent: visible,
})
)}
onApplyFilter={onApplyFilter}
/>
</div>
</div>
<div className='cvat-models-add-wrapper'>
<Button onClick={onCreateModel} type='primary' className='cvat-create-model' icon={<PlusOutlined />} />
</div>
</Col>
</Row>
);
}

@ -1,5 +1,5 @@
// Copyright (C) 2022 Intel Corporation // Copyright (C) 2022 Intel Corporation
// Copyright (C) 2022 CVAT.ai Corporation // Copyright (C) 2022-2023 CVAT.ai Corporation
// //
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
@ -26,6 +26,7 @@ interface ResourceFilterProps {
recentVisible: boolean; recentVisible: boolean;
builderVisible: boolean; builderVisible: boolean;
value: string | null; value: string | null;
disabled?: boolean;
onPredefinedVisibleChange?: (visible: boolean) => void; onPredefinedVisibleChange?: (visible: boolean) => void;
onBuilderVisibleChange(visible: boolean): void; onBuilderVisibleChange(visible: boolean): void;
onRecentVisibleChange(visible: boolean): void; onRecentVisibleChange(visible: boolean): void;
@ -117,6 +118,7 @@ export default function ResourceFilterHOC(
const { const {
predefinedVisible, builderVisible, recentVisible, value, predefinedVisible, builderVisible, recentVisible, value,
onPredefinedVisibleChange, onBuilderVisibleChange, onRecentVisibleChange, onApplyFilter, onPredefinedVisibleChange, onBuilderVisibleChange, onRecentVisibleChange, onApplyFilter,
disabled,
} = props; } = props;
const user = useSelector((state: CombinedState) => state.auth.user); const user = useSelector((state: CombinedState) => state.auth.user);
@ -248,6 +250,7 @@ export default function ResourceFilterHOC(
) : null ) : null
} }
<Dropdown <Dropdown
disabled={disabled}
placement='bottomRight' placement='bottomRight'
visible={builderVisible} visible={builderVisible}
destroyPopupOnHide destroyPopupOnHide
@ -356,7 +359,7 @@ export default function ResourceFilterHOC(
</Button> </Button>
</Dropdown> </Dropdown>
<Button <Button
disabled={!(appliedFilter.built || appliedFilter.predefined || appliedFilter.recent)} disabled={!(appliedFilter.built || appliedFilter.predefined || appliedFilter.recent) || disabled}
size='small' size='small'
type='link' type='link'
onClick={() => { setAppliedFilter({ ...defaultAppliedFilter }); }} onClick={() => { setAppliedFilter({ ...defaultAppliedFilter }); }}

@ -1,4 +1,5 @@
// Copyright (C) 2022 Intel Corporation // Copyright (C) 2022 Intel Corporation
// Copyright (C) 2022-2023 CVAT.ai Corporation
// //
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
@ -17,6 +18,7 @@ interface Props {
sortingFields: string[]; sortingFields: string[];
defaultFields: string[]; defaultFields: string[];
visible: boolean; visible: boolean;
disabled?: boolean;
onVisibleChange(visible: boolean): void; onVisibleChange(visible: boolean): void;
onApplySorting(sorting: string | null): void; onApplySorting(sorting: string | null): void;
} }
@ -97,7 +99,7 @@ const SortableList = SortableContainer(
function SortingModalComponent(props: Props): JSX.Element { function SortingModalComponent(props: Props): JSX.Element {
const { const {
sortingFields: sortingFieldsProp, sortingFields: sortingFieldsProp,
defaultFields, visible, onApplySorting, onVisibleChange, defaultFields, visible, onApplySorting, onVisibleChange, disabled,
} = props; } = props;
const [appliedSorting, setAppliedSorting] = useState<Record<string, string>>( const [appliedSorting, setAppliedSorting] = useState<Record<string, string>>(
defaultFields.reduce((acc: Record<string, string>, field: string) => { defaultFields.reduce((acc: Record<string, string>, field: string) => {
@ -174,6 +176,7 @@ function SortingModalComponent(props: Props): JSX.Element {
return ( return (
<Dropdown <Dropdown
disabled={disabled}
destroyPopupOnHide destroyPopupOnHide
visible={visible} visible={visible}
placement='bottomLeft' placement='bottomLeft'

@ -1,31 +0,0 @@
// Copyright (C) 2020-2022 Intel Corporation
//
// SPDX-License-Identifier: MIT
import { connect } from 'react-redux';
import ModelsPageComponent from 'components/models-page/models-page';
import { Model, CombinedState } from 'reducers';
interface StateToProps {
interactors: Model[];
detectors: Model[];
trackers: Model[];
reid: Model[];
}
function mapStateToProps(state: CombinedState): StateToProps {
const { models } = state;
const {
interactors, detectors, trackers, reid,
} = models;
return {
interactors,
detectors,
trackers,
reid,
};
}
export default connect(mapStateToProps, {})(ModelsPageComponent);

@ -1,14 +1,19 @@
// Copyright (C) 2020-2022 Intel Corporation // Copyright (C) 2020-2022 Intel Corporation
// Copyright (C) 2022-2023 CVAT.ai Corporation
// //
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
import _cvat from 'cvat-core/src/api'; import _cvat from 'cvat-core/src/api';
import ObjectState from 'cvat-core/src/object-state'; import ObjectState from 'cvat-core/src/object-state';
import Webhook from 'cvat-core/src/webhook'; import Webhook from 'cvat-core/src/webhook';
import MLModel from 'cvat-core/src/ml-model';
import { ModelProvider } from 'cvat-core/src/lambda-manager';
import { import {
Label, Attribute, RawAttribute, RawLabel, Label, Attribute, RawAttribute, RawLabel,
} from 'cvat-core/src/labels'; } from 'cvat-core/src/labels';
import { ShapeType, LabelType } from 'cvat-core/src/enums'; import {
ShapeType, LabelType, ModelKind, ModelProviders, ModelReturnType,
} from 'cvat-core/src/enums';
import { Storage, StorageData } from 'cvat-core/src/storage'; import { Storage, StorageData } from 'cvat-core/src/storage';
import { SocialAuthMethods, SocialAuthMethod } from 'cvat-core/src/auth-methods'; import { SocialAuthMethods, SocialAuthMethod } from 'cvat-core/src/auth-methods';
@ -33,6 +38,10 @@ export {
Storage, Storage,
Webhook, Webhook,
SocialAuthMethod, SocialAuthMethod,
MLModel,
ModelKind,
ModelProviders,
ModelReturnType,
}; };
export type { export type {
@ -40,4 +49,5 @@ export type {
RawLabel, RawLabel,
StorageData, StorageData,
SocialAuthMethods, SocialAuthMethods,
ModelProvider,
}; };

@ -5,7 +5,9 @@
import { Canvas3d } from 'cvat-canvas3d/src/typescript/canvas3d'; import { Canvas3d } from 'cvat-canvas3d/src/typescript/canvas3d';
import { Canvas, RectDrawingMethod, CuboidDrawingMethod } from 'cvat-canvas-wrapper'; import { Canvas, RectDrawingMethod, CuboidDrawingMethod } from 'cvat-canvas-wrapper';
import { Webhook, SocialAuthMethods } from 'cvat-core-wrapper'; import {
Webhook, SocialAuthMethods, MLModel, ModelProvider,
} from 'cvat-core-wrapper';
import { IntelligentScissors } from 'utils/opencv-wrapper/intelligent-scissors'; import { IntelligentScissors } from 'utils/opencv-wrapper/intelligent-scissors';
import { KeyMap } from 'utils/mousetrap-react'; import { KeyMap } from 'utils/mousetrap-react';
import { OpenCVTracker } from 'utils/opencv-wrapper/opencv-interfaces'; import { OpenCVTracker } from 'utils/opencv-wrapper/opencv-interfaces';
@ -324,23 +326,12 @@ export interface ModelAttribute {
input_type: 'select' | 'number' | 'checkbox' | 'radio' | 'text'; input_type: 'select' | 'number' | 'checkbox' | 'radio' | 'text';
} }
export interface Model { export interface ModelsQuery {
id: string; page: number;
name: string; id: number | null;
labels: string[]; search: string | null;
version: number; filter: string | null;
attributes: Record<string, ModelAttribute[]>; sort: string | null;
framework: string;
description: string;
type: string;
onChangeToolsBlockerState: (event: string) => void;
tip: {
message: string;
gif: string;
};
params: {
canvas: Record<string, number | boolean>;
};
} }
export type OpenCVTool = IntelligentScissors | OpenCVTracker; export type OpenCVTool = IntelligentScissors | OpenCVTracker;
@ -375,21 +366,32 @@ export interface ActiveInference {
progress: number; progress: number;
error: string; error: string;
id: string; id: string;
functionID: string | number;
} }
export interface ModelsState { export interface ModelsState {
initialized: boolean; initialized: boolean;
fetching: boolean; fetching: boolean;
creatingStatus: string; creatingStatus: string;
interactors: Model[]; interactors: MLModel[];
detectors: Model[]; detectors: MLModel[];
trackers: Model[]; trackers: MLModel[];
reid: Model[]; reid: MLModel[];
classifiers: MLModel[];
totalCount: number;
inferences: { inferences: {
[index: number]: ActiveInference; [index: number]: ActiveInference;
}; };
modelRunnerIsVisible: boolean; modelRunnerIsVisible: boolean;
modelRunnerTask: any; modelRunnerTask: any;
query: ModelsQuery;
providers: {
fetching: boolean;
list: ModelProvider[];
}
previews: {
[index: string]: Preview;
};
} }
export interface ErrorState { export interface ErrorState {
@ -452,6 +454,8 @@ export interface NotificationsState {
canceling: null | ErrorState; canceling: null | ErrorState;
metaFetching: null | ErrorState; metaFetching: null | ErrorState;
inferenceStatusFetching: null | ErrorState; inferenceStatusFetching: null | ErrorState;
creating: null | ErrorState;
deleting: null | ErrorState;
}; };
annotation: { annotation: {
saving: null | ErrorState; saving: null | ErrorState;
@ -679,7 +683,7 @@ export interface AnnotationState {
frameAngles: number[]; frameAngles: number[];
}; };
drawing: { drawing: {
activeInteractor?: Model | OpenCVTool; activeInteractor?: MLModel | OpenCVTool;
activeShapeType: ShapeType; activeShapeType: ShapeType;
activeRectDrawingMethod?: RectDrawingMethod; activeRectDrawingMethod?: RectDrawingMethod;
activeCuboidDrawingMethod?: CuboidDrawingMethod; activeCuboidDrawingMethod?: CuboidDrawingMethod;

@ -1,11 +1,13 @@
// Copyright (C) 2020-2022 Intel Corporation // Copyright (C) 2020-2022 Intel Corporation
// Copyright (C) 2022-2023 CVAT.ai Corporation
// //
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
import { BoundariesActions, BoundariesActionTypes } from 'actions/boundaries-actions'; import { BoundariesActions, BoundariesActionTypes } from 'actions/boundaries-actions';
import { ModelsActionTypes, ModelsActions } from 'actions/models-actions'; import { ModelsActionTypes, ModelsActions } from 'actions/models-actions';
import { AuthActionTypes, AuthActions } from 'actions/auth-actions'; import { AuthActionTypes, AuthActions } from 'actions/auth-actions';
import { ModelsState, Model } from '.'; import { MLModel, ModelKind } from 'cvat-core-wrapper';
import { ModelsState } from '.';
const defaultState: ModelsState = { const defaultState: ModelsState = {
initialized: false, initialized: false,
@ -15,9 +17,23 @@ const defaultState: ModelsState = {
detectors: [], detectors: [],
trackers: [], trackers: [],
reid: [], reid: [],
classifiers: [],
modelRunnerIsVisible: false, modelRunnerIsVisible: false,
modelRunnerTask: null, modelRunnerTask: null,
inferences: {}, inferences: {},
totalCount: 0,
query: {
page: 1,
id: null,
search: null,
filter: null,
sort: null,
},
providers: {
fetching: false,
list: [],
},
previews: {},
}; };
export default function (state = defaultState, action: ModelsActions | AuthActions | BoundariesActions): ModelsState { export default function (state = defaultState, action: ModelsActions | AuthActions | BoundariesActions): ModelsState {
@ -25,17 +41,28 @@ export default function (state = defaultState, action: ModelsActions | AuthActio
case ModelsActionTypes.GET_MODELS: { case ModelsActionTypes.GET_MODELS: {
return { return {
...state, ...state,
initialized: false,
fetching: true, fetching: true,
}; };
} }
case ModelsActionTypes.GET_MODELS_SUCCESS: { case ModelsActionTypes.GET_MODELS_SUCCESS: {
return { return {
...state, ...state,
interactors: action.payload.models.filter((model: Model) => ['interactor'].includes(model.type)), interactors: action.payload.models.filter((model: MLModel) => (
detectors: action.payload.models.filter((model: Model) => ['detector'].includes(model.type)), model.kind === ModelKind.INTERACTOR
trackers: action.payload.models.filter((model: Model) => ['tracker'].includes(model.type)), )),
reid: action.payload.models.filter((model: Model) => ['reid'].includes(model.type)), detectors: action.payload.models.filter((model: MLModel) => (
model.kind === ModelKind.DETECTOR
)),
trackers: action.payload.models.filter((model: MLModel) => (
model.kind === ModelKind.TRACKER
)),
reid: action.payload.models.filter((model: MLModel) => (
model.kind === ModelKind.REID
)),
classifiers: action.payload.models.filter((model: MLModel) => (
model.kind === ModelKind.CLASSIFIER
)),
totalCount: action.payload.models.length,
initialized: true, initialized: true,
fetching: false, fetching: false,
}; };
@ -47,6 +74,34 @@ export default function (state = defaultState, action: ModelsActions | AuthActio
fetching: false, fetching: false,
}; };
} }
case ModelsActionTypes.CREATE_MODEL: {
return {
...state,
fetching: true,
};
}
case ModelsActionTypes.CREATE_MODEL_FAILED: {
return {
...state,
fetching: false,
};
}
case ModelsActionTypes.CREATE_MODEL_SUCCESS: {
const mutual = {
...state,
fetching: false,
};
if (action.payload.model.kind === ModelKind.REID) {
return {
...mutual,
reid: [...state.reid, action.payload.model],
};
}
return {
...mutual,
[`${action.payload.model.kind}s`]: [...`${action.payload.model.kind}s`, action.payload.model],
};
}
case ModelsActionTypes.SHOW_RUN_MODEL_DIALOG: { case ModelsActionTypes.SHOW_RUN_MODEL_DIALOG: {
return { return {
...state, ...state,
@ -102,6 +157,78 @@ export default function (state = defaultState, action: ModelsActions | AuthActio
inferences: { ...inferences }, inferences: { ...inferences },
}; };
} }
case ModelsActionTypes.GET_MODEL_PROVIDERS: {
return {
...state,
providers: {
...state.providers,
fetching: true,
},
};
}
case ModelsActionTypes.GET_MODEL_PROVIDERS_SUCCESS: {
return {
...state,
providers: {
fetching: false,
list: action.payload.providers,
},
};
}
case ModelsActionTypes.GET_MODEL_PROVIDERS_FAILED: {
return {
...state,
providers: {
...state.providers,
fetching: false,
},
};
}
case ModelsActionTypes.GET_MODEL_PREVIEW: {
const { modelID } = action.payload;
const { previews } = state;
return {
...state,
previews: {
...previews,
[modelID]: {
preview: '',
fetching: true,
initialized: false,
},
},
};
}
case ModelsActionTypes.GET_MODEL_PREVIEW_SUCCESS: {
const { modelID, preview } = action.payload;
const { previews } = state;
return {
...state,
previews: {
...previews,
[modelID]: {
preview,
fetching: false,
initialized: true,
},
},
};
}
case ModelsActionTypes.GET_MODEL_PREVIEW_FAILED: {
const { modelID } = action.payload;
const { previews } = state;
return {
...state,
previews: {
...previews,
[modelID]: {
...previews[modelID],
fetching: false,
initialized: true,
},
},
};
}
case BoundariesActionTypes.RESET_AFTER_ERROR: case BoundariesActionTypes.RESET_AFTER_ERROR:
case AuthActionTypes.LOGOUT_SUCCESS: { case AuthActionTypes.LOGOUT_SUCCESS: {
return { ...defaultState }; return { ...defaultState };

@ -80,6 +80,8 @@ const defaultState: NotificationsState = {
canceling: null, canceling: null,
metaFetching: null, metaFetching: null,
inferenceStatusFetching: null, inferenceStatusFetching: null,
creating: null,
deleting: null,
}, },
annotation: { annotation: {
saving: null, saving: null,
@ -803,6 +805,37 @@ export default function (state = defaultState, action: AnyAction): Notifications
}, },
}; };
} }
case ModelsActionTypes.CREATE_MODEL_FAILED: {
return {
...state,
errors: {
...state.errors,
models: {
...state.errors.models,
creating: {
message: 'Could not create model',
reason: action.payload.error.toString(),
},
},
},
};
}
case ModelsActionTypes.DELETE_MODEL_FAILED: {
const { modelName } = action.payload;
return {
...state,
errors: {
...state.errors,
models: {
...state.errors.models,
deleting: {
message: `Could not delete model ${modelName}`,
reason: action.payload.error.toString(),
},
},
},
};
}
case AnnotationActionTypes.GET_JOB_FAILED: { case AnnotationActionTypes.GET_JOB_FAILED: {
return { return {
...state, ...state,

@ -0,0 +1,17 @@
// Copyright (C) 2023 CVAT.ai Corporation
//
// SPDX-License-Identifier: MIT
import { Indexable } from 'reducers';
export function filterNull<Type>(obj: Type): Type {
const filteredObject = { ...obj };
if (filteredObject) {
for (const key of Object.keys(filteredObject)) {
if ((filteredObject as Indexable)[key] === null) {
delete (filteredObject as Indexable)[key];
}
}
}
return filteredObject;
}
Loading…
Cancel
Save