diff --git a/cvat-core/src/core-types.ts b/cvat-core/src/core-types.ts new file mode 100644 index 00000000..cb30268b --- /dev/null +++ b/cvat-core/src/core-types.ts @@ -0,0 +1,49 @@ +// Copyright (C) 2023 CVAT.ai Corporation +// +// SPDX-License-Identifier: MIT + +import { ModelKind, ModelReturnType } from './enums'; + +export interface ModelAttribute { + name: string; + values: string[]; + input_type: 'select' | 'number' | 'checkbox' | 'radio' | 'text'; +} + +export interface ModelParams { + canvas: { + minPosVertices?: number; + minNegVertices?: number; + startWithBox?: boolean; + onChangeToolsBlockerState?: (event: string) => void; + }; +} + +export interface ModelTip { + message: string; + gif: string; +} + +export interface SerializedModel { + id?: string | number; + name?: string; + labels?: string[]; + version?: number; + attributes?: Record; + framework?: string; + description?: string; + kind?: ModelKind; + type?: string; + return_type?: ModelReturnType; + owner?: any; + provider?: string; + api_key?: string; + url?: string; + help_message?: string; + animated_gif?: string; + min_pos_points?: number; + min_neg_points?: number; + startswith_box?: boolean; + created_date?: string; + updated_date?: string; +} diff --git a/cvat-core/src/lambda-manager.ts b/cvat-core/src/lambda-manager.ts index c186190b..a6565b8d 100644 --- a/cvat-core/src/lambda-manager.ts +++ b/cvat-core/src/lambda-manager.ts @@ -30,9 +30,11 @@ class LambdaManager { this.cachedList = null; } - async list(): Promise { + async list(): Promise<{ models: MLModel[], count: number }> { const lambdaFunctions = await serverProxy.lambda.list(); - const functions = await serverProxy.functions.list(); + + const functionsResult = await serverProxy.functions.list(); + const { results: functions, count: functionsCount } = functionsResult; const result = [...lambdaFunctions, ...functions]; const models = []; @@ -46,7 +48,7 @@ class LambdaManager { } this.cachedList = models; - return models; + return { models, count: lambdaFunctions.length + functionsCount }; } async run(taskID: number, model: MLModel, args: any) { diff --git a/cvat-core/src/ml-model.ts b/cvat-core/src/ml-model.ts index fb605235..2d70f6aa 100644 --- a/cvat-core/src/ml-model.ts +++ b/cvat-core/src/ml-model.ts @@ -7,50 +7,9 @@ import { isBrowser, isNode } from 'browser-or-node'; import serverProxy from './server-proxy'; import PluginRegistry from './plugins'; import { ModelProviders, ModelKind, ModelReturnType } from './enums'; - -interface ModelAttribute { - name: string; - values: string[]; - input_type: 'select' | 'number' | 'checkbox' | 'radio' | 'text'; -} - -interface ModelParams { - canvas: { - minPosVertices?: number; - minNegVertices?: number; - startWithBox?: boolean; - onChangeToolsBlockerState?: (event: string) => void; - }; -} - -interface ModelTip { - message: string; - gif: string; -} - -interface SerializedModel { - id?: string | number; - name?: string; - labels?: string[]; - version?: number; - attributes?: Record; - framework?: string; - description?: string; - kind?: ModelKind; - type?: string; - return_type?: ModelReturnType; - owner?: any; - provider?: string; - api_key?: string; - url?: string; - help_message?: string; - animated_gif?: string; - min_pos_points?: number; - min_neg_points?: number; - startswith_box?: boolean; - created_date?: string; - updated_date?: string; -} +import { + SerializedModel, ModelAttribute, ModelParams, ModelTip, +} from './core-types'; export default class MLModel { private serialized: SerializedModel; diff --git a/cvat-core/src/server-proxy.ts b/cvat-core/src/server-proxy.ts index 60135647..df18a0f9 100644 --- a/cvat-core/src/server-proxy.ts +++ b/cvat-core/src/server-proxy.ts @@ -13,6 +13,7 @@ import { isEmail } from './common'; import config from './config'; import DownloadWorker from './download.worker'; import { ServerError } from './exceptions'; +import { FunctionsResponseBody } from './server-response-types'; type Params = { org: number | string, @@ -1604,17 +1605,20 @@ async function getAnnotations(session, id) { return response.data; } -async function getFunctions() { +async function getFunctions(): Promise { const { backendAPI } = config; try { const response = await Axios.get(`${backendAPI}/functions`, { proxy: config.proxy, }); - return response.data.results; + return response.data; } catch (errorData) { if (errorData.response.status === 404) { - return []; + return { + results: [], + count: 0, + }; } throw generateError(errorData); } diff --git a/cvat-core/src/server-response-types.ts b/cvat-core/src/server-response-types.ts new file mode 100644 index 00000000..42e158a7 --- /dev/null +++ b/cvat-core/src/server-response-types.ts @@ -0,0 +1,10 @@ +// Copyright (C) 2023 CVAT.ai Corporation +// +// SPDX-License-Identifier: MIT + +import { SerializedModel } from 'core-types'; + +export interface FunctionsResponseBody { + results: SerializedModel[]; + count: number; +} diff --git a/cvat-ui/src/actions/models-actions.ts b/cvat-ui/src/actions/models-actions.ts index 3cdf9faa..4e42a830 100644 --- a/cvat-ui/src/actions/models-actions.ts +++ b/cvat-ui/src/actions/models-actions.ts @@ -40,8 +40,8 @@ export enum ModelsActionTypes { export const modelsActions = { getModels: (query?: ModelsQuery) => createAction(ModelsActionTypes.GET_MODELS, { query }), - getModelsSuccess: (models: MLModel[]) => createAction(ModelsActionTypes.GET_MODELS_SUCCESS, { - models, + getModelsSuccess: (models: MLModel[], count: number) => createAction(ModelsActionTypes.GET_MODELS_SUCCESS, { + models, count, }), getModelsFailed: (error: any) => createAction(ModelsActionTypes.GET_MODELS_FAILED, { error, @@ -113,14 +113,15 @@ export type ModelsActions = ActionUnion; const core = getCore(); -export function getModelsAsync(query: ModelsQuery): ThunkAction { - return async (dispatch): Promise => { +export function getModelsAsync(query?: ModelsQuery): ThunkAction { + return async (dispatch, getState): Promise => { dispatch(modelsActions.getModels(query)); - const filteredQuery = filterNull(query); + const filteredQuery = filterNull(query || getState().models.query); try { - const models = await core.lambda.list(filteredQuery); - dispatch(modelsActions.getModelsSuccess(models)); + const result = await core.lambda.list(filteredQuery); + const { models, count } = result; + dispatch(modelsActions.getModelsSuccess(models, count)); } catch (error) { dispatch(modelsActions.getModelsFailed(error)); } diff --git a/cvat-ui/src/components/models-page/deployed-models-list.tsx b/cvat-ui/src/components/models-page/deployed-models-list.tsx index 5bc279d0..c4c903c8 100644 --- a/cvat-ui/src/components/models-page/deployed-models-list.tsx +++ b/cvat-ui/src/components/models-page/deployed-models-list.tsx @@ -3,17 +3,22 @@ // // SPDX-License-Identifier: MIT -import React, { useState } from 'react'; +import React from 'react'; import moment from 'moment'; -import { useSelector } from 'react-redux'; +import { useSelector, useDispatch } from 'react-redux'; import { Row, Col } from 'antd/lib/grid'; import Pagination from 'antd/lib/pagination'; -import { CombinedState } from 'reducers'; +import { CombinedState, ModelsQuery } from 'reducers'; import { MLModel } from 'cvat-core-wrapper'; import { ModelProviders } from 'cvat-core/src/enums'; +import { getModelsAsync } from 'actions/models-actions'; import DeployedModelItem from './deployed-model-item'; -const PAGE_SIZE = 12; +export const PAGE_SIZE = 12; + +interface Props { + query: ModelsQuery; +} function setUpModelsList(models: MLModel[], newPage: number): MLModel[] { const builtInModels = models.filter((model: MLModel) => model.provider === ModelProviders.CVAT); @@ -23,14 +28,17 @@ function setUpModelsList(models: MLModel[], newPage: number): MLModel[] { return renderModels.slice((newPage - 1) * PAGE_SIZE, newPage * PAGE_SIZE); } -export default function DeployedModelsListComponent(): JSX.Element { +export default function DeployedModelsListComponent(props: Props): JSX.Element { const interactors = useSelector((state: CombinedState) => state.models.interactors); const detectors = useSelector((state: CombinedState) => state.models.detectors); const trackers = useSelector((state: CombinedState) => state.models.trackers); const reid = useSelector((state: CombinedState) => state.models.reid); const classifiers = useSelector((state: CombinedState) => state.models.classifiers); const totalCount = useSelector((state: CombinedState) => state.models.totalCount); - const [page, setPage] = useState(1); + + const dispatch = useDispatch(); + const { query } = props; + const { page } = query; const models = [...interactors, ...detectors, ...trackers, ...reid, ...classifiers]; const items = setUpModelsList(models, page) .map((model): JSX.Element => ); @@ -46,7 +54,10 @@ export default function DeployedModelsListComponent(): JSX.Element { { - setPage(newPage); + dispatch(getModelsAsync({ + ...query, + page: newPage, + })); }} showSizeChanger={false} total={totalCount} diff --git a/cvat-ui/src/components/models-page/models-page.tsx b/cvat-ui/src/components/models-page/models-page.tsx index 3a567dd6..0b08b06a 100644 --- a/cvat-ui/src/components/models-page/models-page.tsx +++ b/cvat-ui/src/components/models-page/models-page.tsx @@ -10,11 +10,12 @@ import { useDispatch, useSelector } from 'react-redux'; import { getModelProvidersAsync, getModelsAsync } from 'actions/models-actions'; import { updateHistoryFromQuery } from 'components/resource-sorting-filtering'; import Spin from 'antd/lib/spin'; +import notification from 'antd/lib/notification'; -import DeployedModelsList from './deployed-models-list'; +import { CombinedState, Indexable } from 'reducers'; +import DeployedModelsList, { PAGE_SIZE } from './deployed-models-list'; import EmptyListComponent from './empty-list'; import FeedbackComponent from '../feedback/feedback'; -import { CombinedState } from '../../reducers'; import TopBar from './top-bar'; function ModelsPageComponent(): JSX.Element { @@ -29,19 +30,33 @@ function ModelsPageComponent(): JSX.Element { }, []); const updatedQuery = { ...query }; + const queryParams = new URLSearchParams(history.location.search); + for (const key of Object.keys(updatedQuery)) { + (updatedQuery as Indexable)[key] = queryParams.get(key) || null; + if (key === 'page') { + updatedQuery.page = updatedQuery.page ? +updatedQuery.page : 1; + } + } useEffect(() => { history.replace({ search: updateHistoryFromQuery(query), }); }, [query]); + const pageOutOfBounds = updatedQuery.page > Math.ceil(totalCount / PAGE_SIZE); useEffect(() => { dispatch(getModelProvidersAsync()); dispatch(getModelsAsync(updatedQuery)); + if (pageOutOfBounds) { + notification.error({ + message: 'Could not fetch models', + description: 'Invalid page', + }); + } }, []); - const content = totalCount ? ( - + const content = (totalCount && !pageOutOfBounds) ? ( + ) : ; return ( diff --git a/cvat-ui/src/reducers/models-reducer.ts b/cvat-ui/src/reducers/models-reducer.ts index 3bf78de9..09aff700 100644 --- a/cvat-ui/src/reducers/models-reducer.ts +++ b/cvat-ui/src/reducers/models-reducer.ts @@ -42,6 +42,10 @@ export default function (state = defaultState, action: ModelsActions | AuthActio return { ...state, fetching: true, + query: { + ...state.query, + ...action.payload.query, + }, }; } case ModelsActionTypes.GET_MODELS_SUCCESS: { @@ -62,7 +66,7 @@ export default function (state = defaultState, action: ModelsActions | AuthActio classifiers: action.payload.models.filter((model: MLModel) => ( model.kind === ModelKind.CLASSIFIER )), - totalCount: action.payload.models.length, + totalCount: action.payload.count, initialized: true, fetching: false, };