From 2d522c8781f8de9d06c3a6969482ac7d8c5261e3 Mon Sep 17 00:00:00 2001 From: Boris Sekachev Date: Mon, 30 May 2022 10:24:19 +0300 Subject: [PATCH] Prepare UI for attributes configuration (#4) * Prepare UI for attributes configuration * Add padding for label attributes * Update attributes inference logic Check the attributes returned by nuclio function call and reject those that have either incompatible types or values. * Update cvat-ui version, CHANGELOG.md * Enhance automatic annotation BE logic The code in lambda_manager didn't account for attributes mappings that had different names thus returning an empty set of attributes because it couldn't find the correct match. Fix this by getting proper mapping from `attrMapping` property of the input data. * Updated CHANGELOG * Updated changelog * Adjusted code & feature * A bit adjusted layout * Minor refactoring * Fixed bug when run auto annotation without 'attributes' key * Fixed a couple of minor issues * Increased access key id length * Fixed unit tests * Merged develop * Rejected unnecessary change Co-authored-by: Artem Zhivoderov --- CHANGELOG.md | 2 +- cvat-core/src/ml-model.js | 49 ++- cvat-core/src/object-state.js | 5 +- cvat-ui/package-lock.json | 4 +- cvat-ui/package.json | 2 +- .../controls-side-bar/tools-control.tsx | 116 +++++-- .../cloud-storage-form.tsx | 4 +- .../model-runner-modal/detector-runner.tsx | 291 ++++++++++++++---- .../components/model-runner-modal/styles.scss | 6 +- .../models-page/deployed-model-item.tsx | 17 +- .../models-page/deployed-models-list.tsx | 8 +- cvat-ui/src/reducers/interfaces.ts | 7 + cvat/apps/engine/media_extractors.py | 3 +- cvat/apps/lambda_manager/tests/test_lambda.py | 58 ++-- cvat/apps/lambda_manager/views.py | 87 ++++-- 15 files changed, 473 insertions(+), 186 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bdeee72b..4525b110 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,7 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## \[2.2.0] - Unreleased ### Added -- TDB +- Support of attributes returned by serverless functions () based on () ### Changed - TDB diff --git a/cvat-core/src/ml-model.js b/cvat-core/src/ml-model.js index b2089a13..68e29bc4 100644 --- a/cvat-core/src/ml-model.js +++ b/cvat-core/src/ml-model.js @@ -1,9 +1,9 @@ -// Copyright (C) 2019-2021 Intel Corporation +// Copyright (C) 2019-2022 Intel Corporation // // SPDX-License-Identifier: MIT /** - * Class representing a machine learning model + * Class representing a serverless function * @memberof module:API.cvat.classes */ class MLModel { @@ -11,6 +11,7 @@ class MLModel { this._id = data.id; this._name = data.name; this._labels = data.labels; + this._attributes = data.attributes || []; this._framework = data.framework; this._description = data.description; this._type = data.type; @@ -28,7 +29,7 @@ class MLModel { } /** - * @returns {string} + * @type {string} * @readonly */ get id() { @@ -36,7 +37,7 @@ class MLModel { } /** - * @returns {string} + * @type {string} * @readonly */ get name() { @@ -44,7 +45,8 @@ class MLModel { } /** - * @returns {string[]} + * @description labels supported by the model + * @type {string[]} * @readonly */ get labels() { @@ -56,7 +58,21 @@ class MLModel { } /** - * @returns {string} + * @typedef ModelAttribute + * @property {string} name + * @property {string[]} values + * @property {'select'|'number'|'checkbox'|'radio'|'text'} input_type + */ + /** + * @type {Object} + * @readonly + */ + get attributes() { + return { ...this._attributes }; + } + + /** + * @type {string} * @readonly */ get framework() { @@ -64,7 +80,7 @@ class MLModel { } /** - * @returns {string} + * @type {string} * @readonly */ get description() { @@ -72,7 +88,7 @@ class MLModel { } /** - * @returns {module:API.cvat.enums.ModelType} + * @type {module:API.cvat.enums.ModelType} * @readonly */ get type() { @@ -80,7 +96,7 @@ class MLModel { } /** - * @returns {object} + * @type {object} * @readonly */ get params() { @@ -90,10 +106,9 @@ class MLModel { } /** - * @typedef {Object} MlModelTip + * @type {MlModelTip} * @property {string} message A short message for a user about the model - * @property {string} gif A gif URL to be shawn to a user as an example - * @returns {MlModelTip} + * @property {string} gif A gif URL to be shown to a user as an example * @readonly */ get tip() { @@ -101,14 +116,16 @@ class MLModel { } /** - * @callback onRequestStatusChange + * @typedef onRequestStatusChange * @param {string} event * @global - */ + */ /** - * @param {onRequestStatusChange} onRequestStatusChange Set canvas onChangeToolsBlockerState callback + * @param {onRequestStatusChange} onRequestStatusChange + * @instance + * @description Used to set a callback when the tool is blocked in UI * @returns {void} - */ + */ set onChangeToolsBlockerState(onChangeToolsBlockerState) { this._params.canvas.onChangeToolsBlockerState = onChangeToolsBlockerState; } diff --git a/cvat-core/src/object-state.js b/cvat-core/src/object-state.js index d1fc8784..5de80ded 100644 --- a/cvat-core/src/object-state.js +++ b/cvat-core/src/object-state.js @@ -1,4 +1,4 @@ -// Copyright (C) 2019-2021 Intel Corporation +// Copyright (C) 2019-2022 Intel Corporation // // SPDX-License-Identifier: MIT @@ -208,7 +208,8 @@ const { Source } = require('./enums'); rotation: { /** * @name rotation - * @type {number} angle measured by degrees + * @description angle measured by degrees + * @type {number} * @memberof module:API.cvat.classes.ObjectState * @throws {module:API.cvat.exceptions.ArgumentError} * @instance diff --git a/cvat-ui/package-lock.json b/cvat-ui/package-lock.json index 841a7f66..66beba8c 100644 --- a/cvat-ui/package-lock.json +++ b/cvat-ui/package-lock.json @@ -1,12 +1,12 @@ { "name": "cvat-ui", - "version": "1.37.1", + "version": "1.38.0", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "cvat-ui", - "version": "1.37.1", + "version": "1.38.0", "license": "MIT", "dependencies": { "@ant-design/icons": "^4.6.3", diff --git a/cvat-ui/package.json b/cvat-ui/package.json index 02caaace..12c76943 100644 --- a/cvat-ui/package.json +++ b/cvat-ui/package.json @@ -1,6 +1,6 @@ { "name": "cvat-ui", - "version": "1.37.1", + "version": "1.38.0", "description": "CVAT single-page application", "main": "src/index.tsx", "scripts": { diff --git a/cvat-ui/src/components/annotation-page/standard-workspace/controls-side-bar/tools-control.tsx b/cvat-ui/src/components/annotation-page/standard-workspace/controls-side-bar/tools-control.tsx index e2f85638..21541bd8 100644 --- a/cvat-ui/src/components/annotation-page/standard-workspace/controls-side-bar/tools-control.tsx +++ b/cvat-ui/src/components/annotation-page/standard-workspace/controls-side-bar/tools-control.tsx @@ -28,7 +28,7 @@ import { Canvas, convertShapesForInteractor } from 'cvat-canvas-wrapper'; import getCore from 'cvat-core-wrapper'; import openCVWrapper from 'utils/opencv-wrapper/opencv-wrapper'; import { - CombinedState, ActiveControl, Model, ObjectType, ShapeType, ToolsBlockerState, + CombinedState, ActiveControl, Model, ObjectType, ShapeType, ToolsBlockerState, ModelAttribute, } from 'reducers/interfaces'; import { interactWithCanvas, @@ -37,9 +37,10 @@ import { updateAnnotationsAsync, createAnnotationsAsync, } from 'actions/annotation-actions'; -import DetectorRunner from 'components/model-runner-modal/detector-runner'; +import DetectorRunner, { DetectorRequestBody } from 'components/model-runner-modal/detector-runner'; import LabelSelector from 'components/label-selector/label-selector'; import CVATTooltip from 'components/common/cvat-tooltip'; +import { Attribute, Label } from 'components/labels-editor/common'; import ApproximationAccuracy, { thresholdFromAccuracy, @@ -374,7 +375,7 @@ export class ToolsControlComponent extends React.PureComponent { } setTimeout(() => this.runInteractionRequest(interactionId)); - } catch (err) { + } catch (err: any) { notification.error({ description: err.toString(), message: 'Interaction error occured', @@ -466,7 +467,7 @@ export class ToolsControlComponent extends React.PureComponent { // update annotations on a canvas fetchAnnotations(); - } catch (err) { + } catch (err: any) { notification.error({ description: err.toString(), message: 'Tracking error occured', @@ -706,7 +707,7 @@ export class ToolsControlComponent extends React.PureComponent { Array.prototype.push.apply(statefullContainer.states, serverlessStates); trackingData.statefull[trackerID] = statefullContainer; delete trackingData.stateless[trackerID]; - } catch (error) { + } catch (error: any) { notification.error({ message: 'Tracker initialization error', description: error.toString(), @@ -757,7 +758,7 @@ export class ToolsControlComponent extends React.PureComponent { trackedShape.shapePoints = shape; }); } - } catch (error) { + } catch (error: any) { notification.error({ message: 'Tracking error', description: error.toString(), @@ -1022,41 +1023,106 @@ export class ToolsControlComponent extends React.PureComponent { }); }); + function checkAttributesCompatibility( + functionAttribute: ModelAttribute | undefined, + dbAttribute: Attribute | undefined, + value: string, + ): boolean { + if (!dbAttribute || !functionAttribute) { + return false; + } + + const { inputType } = (dbAttribute as any as { inputType: string }); + if (functionAttribute.input_type === inputType) { + if (functionAttribute.input_type === 'number') { + const [min, max, step] = dbAttribute.values; + return !Number.isNaN(+value) && +value >= +min && +value <= +max && !(+value % +step); + } + + if (functionAttribute.input_type === 'checkbox') { + return ['true', 'false'].includes(value.toLowerCase()); + } + + if (['select', 'radio'].includes(functionAttribute.input_type)) { + return dbAttribute.values.includes(value); + } + + return true; + } + + switch (functionAttribute.input_type) { + case 'number': + return dbAttribute.values.includes(value) || inputType === 'text'; + case 'text': + return ['select', 'radio'].includes(dbAttribute.input_type) && dbAttribute.values.includes(value); + case 'select': + return (inputType === 'radio' && dbAttribute.values.includes(value)) || inputType === 'text'; + case 'radio': + return (inputType === 'select' && dbAttribute.values.includes(value)) || inputType === 'text'; + case 'checkbox': + return dbAttribute.values.includes(value) || inputType === 'text'; + default: + return false; + } + } + return ( { + runInference={async (model: Model, body: DetectorRequestBody) => { try { this.setState({ mode: 'detection', fetching: true }); const result = await core.lambda.call(jobInstance.taskId, model, { ...body, frame }); const states = result.map( - (data: any): any => new core.classes.ObjectState({ - shapeType: data.type, - label: jobInstance.labels.filter((label: any): boolean => label.name === data.label)[0], - points: data.points, - objectType: ObjectType.SHAPE, - frame, - occluded: false, - source: 'auto', - attributes: (data.attributes as { name: string, value: string }[]) - .reduce((mapping, attr) => { - mapping[attrsMap[data.label][attr.name]] = attr.value; - return mapping; - }, {} as Record), - zOrder: curZOrder, - }), - ); + (data: any): any => { + const jobLabel = (jobInstance.labels as Label[]) + .find((jLabel: Label): boolean => jLabel.name === data.label); + const [modelLabel] = Object.entries(body.mapping) + .find(([, { name }]) => name === data.label) || []; + + if (!jobLabel || !modelLabel) return null; + + return new core.classes.ObjectState({ + shapeType: data.type, + label: jobLabel, + points: data.points, + objectType: ObjectType.SHAPE, + frame, + occluded: false, + source: 'auto', + attributes: (data.attributes as { name: string, value: string }[]) + .reduce((acc, attr) => { + const [modelAttr] = Object.entries(body.mapping[modelLabel].attributes) + .find((value: string[]) => value[1] === attr.name) || []; + const areCompatible = checkAttributesCompatibility( + model.attributes[modelLabel].find((mAttr) => mAttr.name === modelAttr), + jobLabel.attributes.find((jobAttr: Attribute) => ( + jobAttr.name === attr.name + )), + attr.value, + ); + + if (areCompatible) { + acc[attrsMap[data.label][attr.name]] = attr.value; + } + + return acc; + }, {} as Record), + zOrder: curZOrder, + }); + }, + ).filter((state: any) => state); createAnnotations(jobInstance, frame, states); const { onSwitchToolsBlockerState } = this.props; onSwitchToolsBlockerState({ buttonVisible: false }); - } catch (error) { + } catch (error: any) { notification.error({ description: error.toString(), - message: 'Detection error occured', + message: 'Detection error occurred', }); } finally { this.setState({ fetching: false }); diff --git a/cvat-ui/src/components/create-cloud-storage-page/cloud-storage-form.tsx b/cvat-ui/src/components/create-cloud-storage-page/cloud-storage-form.tsx index d4724a23..06d2719b 100644 --- a/cvat-ui/src/components/create-cloud-storage-page/cloud-storage-form.tsx +++ b/cvat-ui/src/components/create-cloud-storage-page/cloud-storage-form.tsx @@ -74,7 +74,7 @@ export default function CreateCloudStorageForm(props: Props): JSX.Element { const fakeCredentialsData = { accountName: 'X'.repeat(24), sessionToken: 'X'.repeat(300), - key: 'X'.repeat(20), + key: 'X'.repeat(128), secretKey: 'X'.repeat(40), keyFile: new File([], 'fakeKey.json'), }; @@ -332,7 +332,7 @@ export default function CreateCloudStorageForm(props: Props): JSX.Element { {...internalCommonProps} > setKeyVisibility(true)} onFocus={() => onFocusCredentialsItem('key', 'key')} diff --git a/cvat-ui/src/components/model-runner-modal/detector-runner.tsx b/cvat-ui/src/components/model-runner-modal/detector-runner.tsx index 82932810..5f24b9fc 100644 --- a/cvat-ui/src/components/model-runner-modal/detector-runner.tsx +++ b/cvat-ui/src/components/model-runner-modal/detector-runner.tsx @@ -14,8 +14,10 @@ import InputNumber from 'antd/lib/input-number'; import Button from 'antd/lib/button'; import notification from 'antd/lib/notification'; -import { Model, StringObject } from 'reducers/interfaces'; +import { Model, ModelAttribute, StringObject } from 'reducers/interfaces'; + import CVATTooltip from 'components/common/cvat-tooltip'; +import { Label as LabelInterface } from 'components/labels-editor/common'; import { clamp } from 'utils/math'; import consts from 'consts'; import { DimensionType } from '../../reducers/interfaces'; @@ -23,28 +25,40 @@ import { DimensionType } from '../../reducers/interfaces'; interface Props { withCleanup: boolean; models: Model[]; - labels: any[]; + labels: LabelInterface[]; dimension: DimensionType; runInference(model: Model, body: object): void; } +interface MappedLabel { + name: string; + attributes: StringObject; +} + +type MappedLabelsList = Record; + +export interface DetectorRequestBody { + mapping: MappedLabelsList; + cleanup: boolean; +} + +interface Match { + model: string | null; + task: string | null; +} + function DetectorRunner(props: Props): JSX.Element { const { models, withCleanup, labels, dimension, runInference, } = props; const [modelID, setModelID] = useState(null); - const [mapping, setMapping] = useState({}); + const [mapping, setMapping] = useState({}); const [threshold, setThreshold] = useState(0.5); const [distance, setDistance] = useState(50); const [cleanup, setCleanup] = useState(false); - const [match, setMatch] = useState<{ - model: string | null; - task: string | null; - }>({ - model: null, - task: null, - }); + const [match, setMatch] = useState({ model: null, task: null }); + const [attrMatches, setAttrMatch] = useState>({}); const model = models.filter((_model): boolean => _model.id === modelID)[0]; const isDetector = model && model.type === 'detector'; @@ -57,24 +71,47 @@ function DetectorRunner(props: Props): JSX.Element { if (model && model.type !== 'reid' && !model.labels.length) { notification.warning({ - message: 'The selected model does not include any lables', + message: 'The selected model does not include any labels', }); } + function matchAttributes( + labelAttributes: LabelInterface['attributes'], + modelAttributes: ModelAttribute[], + ): StringObject { + if (Array.isArray(labelAttributes) && Array.isArray(modelAttributes)) { + return labelAttributes + .reduce((attrAcc: StringObject, attr: any): StringObject => { + if (modelAttributes.some((mAttr) => mAttr.name === attr.name)) { + attrAcc[attr.name] = attr.name; + } + + return attrAcc; + }, {}); + } + + return {}; + } + function updateMatch(modelLabel: string | null, taskLabel: string | null): void { - if (match.model && taskLabel) { - const newmatch: { [index: string]: string } = {}; - newmatch[match.model] = taskLabel; - setMapping({ ...mapping, ...newmatch }); + function addMatch(modelLbl: string, taskLbl: string): void { + const newMatch: MappedLabelsList = {}; + const label = labels.find((l) => l.name === taskLbl) as LabelInterface; + const currentModel = models.filter((_model): boolean => _model.id === modelID)[0]; + const attributes = matchAttributes(label.attributes, currentModel.attributes[modelLbl]); + + newMatch[modelLbl] = { name: taskLbl, attributes }; + setMapping({ ...mapping, ...newMatch }); setMatch({ model: null, task: null }); + } + + if (match.model && taskLabel) { + addMatch(match.model, taskLabel); return; } if (match.task && modelLabel) { - const newmatch: { [index: string]: string } = {}; - newmatch[modelLabel] = match.task; - setMapping({ ...mapping, ...newmatch }); - setMatch({ model: null, task: null }); + addMatch(modelLabel, match.task); return; } @@ -84,14 +121,72 @@ function DetectorRunner(props: Props): JSX.Element { }); } + function updateAttrMatch(modelLabel: string, modelAttrLabel: string | null, taskAttrLabel: string | null): void { + function addAttributeMatch(modelAttr: string, attrLabel: string): void { + const newMatch: StringObject = {}; + newMatch[modelAttr] = attrLabel; + mapping[modelLabel].attributes = { ...mapping[modelLabel].attributes, ...newMatch }; + + delete attrMatches[modelLabel]; + setAttrMatch({ ...attrMatches }); + } + + const modelAttr = attrMatches[modelLabel]?.model; + if (modelAttr && taskAttrLabel) { + addAttributeMatch(modelAttr, taskAttrLabel); + return; + } + + const taskAttrModel = attrMatches[modelLabel]?.task; + if (taskAttrModel && modelAttrLabel) { + addAttributeMatch(modelAttrLabel, taskAttrModel); + return; + } + + attrMatches[modelLabel] = { + model: modelAttrLabel, + task: taskAttrLabel, + }; + setAttrMatch({ ...attrMatches }); + } + + function renderMappingRow( + color: string, + leftLabel: string, + rightLabel: string, + removalTitle: string, + onClick: () => void, + className = '', + ): JSX.Element { + return ( + + + {leftLabel} + + + {rightLabel} + + + + + + + + ); + } + function renderSelector( value: string, tooltip: string, labelsToRender: string[], onChange: (label: string) => void, + className = '', ): JSX.Element { return ( - + {model.labels.map( (label): JSX.Element => ( diff --git a/cvat-ui/src/components/models-page/deployed-models-list.tsx b/cvat-ui/src/components/models-page/deployed-models-list.tsx index 6db6b881..8b49cd66 100644 --- a/cvat-ui/src/components/models-page/deployed-models-list.tsx +++ b/cvat-ui/src/components/models-page/deployed-models-list.tsx @@ -1,4 +1,4 @@ -// Copyright (C) 2020 Intel Corporation +// Copyright (C) 2020-2022 Intel Corporation // // SPDX-License-Identifier: MIT @@ -29,13 +29,13 @@ export default function DeployedModelsListComponent(props: Props): JSX.Element { Name - + Type - + Description - + Labels diff --git a/cvat-ui/src/reducers/interfaces.ts b/cvat-ui/src/reducers/interfaces.ts index 64d2d466..c760c78a 100644 --- a/cvat-ui/src/reducers/interfaces.ts +++ b/cvat-ui/src/reducers/interfaces.ts @@ -255,10 +255,17 @@ export interface ShareState { root: ShareItem; } +export interface ModelAttribute { + name: string; + values: string[]; + input_type: 'select' | 'number' | 'checkbox' | 'radio' | 'text'; +} + export interface Model { id: string; name: string; labels: string[]; + attributes: Record; framework: string; description: string; type: string; diff --git a/cvat/apps/engine/media_extractors.py b/cvat/apps/engine/media_extractors.py index 716b9b66..7cf11a32 100644 --- a/cvat/apps/engine/media_extractors.py +++ b/cvat/apps/engine/media_extractors.py @@ -95,6 +95,7 @@ def rotate_within_exif(img: Image): ORIENTATION.MIRROR_HORIZONTAL_270_ROTATED ,ORIENTATION.MIRROR_HORIZONTAL_90_ROTATED, ]: img = img.transpose(Image.FLIP_LEFT_RIGHT) + return img class IMediaReader(ABC): @@ -125,8 +126,8 @@ class IMediaReader(ABC): preview = Image.open(obj) else: preview = obj - preview.thumbnail(PREVIEW_SIZE) preview = rotate_within_exif(preview) + preview.thumbnail(PREVIEW_SIZE) return preview.convert('RGB') diff --git a/cvat/apps/lambda_manager/tests/test_lambda.py b/cvat/apps/lambda_manager/tests/test_lambda.py index 4a8699ea..831ff682 100644 --- a/cvat/apps/lambda_manager/tests/test_lambda.py +++ b/cvat/apps/lambda_manager/tests/test_lambda.py @@ -324,7 +324,7 @@ class LambdaTestCase(APITestCase): "threshold": 55, "quality": "original", "mapping": { - "car": "car", + "car": { "name": "car" }, }, } response = self._post_request(LAMBDA_REQUESTS_PATH, self.admin, data_main_task) @@ -364,7 +364,7 @@ class LambdaTestCase(APITestCase): "task": self.main_task["id"], "cleanup": True, "mapping": { - "car": "car", + "car": { "name": "car" }, }, } response = self._post_request(f'{LAMBDA_REQUESTS_PATH}', self.admin, data) @@ -404,7 +404,7 @@ class LambdaTestCase(APITestCase): "threshold": 55, "quality": "original", "mapping": { - "car": "car", + "car": { "name": "car" }, }, } data_assigneed_to_user_task = { @@ -414,7 +414,7 @@ class LambdaTestCase(APITestCase): "quality": "compressed", "max_distance": 70, "mapping": { - "car": "car", + "car": { "name": "car" }, }, } @@ -442,7 +442,7 @@ class LambdaTestCase(APITestCase): "task": self.main_task["id"], "cleanup": True, "mapping": { - "car": "car", + "car": { "name": "car" }, }, } @@ -461,7 +461,7 @@ class LambdaTestCase(APITestCase): "task": self.main_task["id"], "cleanup": True, "mapping": { - "car": "car", + "car": { "name": "car" }, }, } response = self._post_request(LAMBDA_REQUESTS_PATH, self.admin, data) @@ -474,7 +474,7 @@ class LambdaTestCase(APITestCase): "task": self.main_task["id"], "cleanup": True, "mapping": { - "car": "car", + "car": { "name": "car" }, }, } response = self._post_request(LAMBDA_REQUESTS_PATH, self.admin, data) @@ -488,7 +488,7 @@ class LambdaTestCase(APITestCase): "task": self.main_task["id"], "cleanup": True, "mapping": { - "car": "car", + "car": { "name": "car" }, }, } self._post_request(LAMBDA_REQUESTS_PATH, self.admin, data) @@ -514,7 +514,7 @@ class LambdaTestCase(APITestCase): "function": id_function_detector, "task": self.main_task["id"], "mapping": { - "car": "car", + "car": { "name": "car" }, }, } response = self._post_request(LAMBDA_REQUESTS_PATH, self.admin, data) @@ -540,7 +540,7 @@ class LambdaTestCase(APITestCase): "function": id_function_detector, "cleanup": True, "mapping": { - "car": "car", + "car": { "name": "car" }, }, } response = self._post_request(LAMBDA_REQUESTS_PATH, self.admin, data) @@ -553,7 +553,7 @@ class LambdaTestCase(APITestCase): "task": 12345, "cleanup": True, "mapping": { - "car": "car", + "car": { "name": "car" }, }, } response = self._post_request(LAMBDA_REQUESTS_PATH, self.admin, data) @@ -569,7 +569,7 @@ class LambdaTestCase(APITestCase): "task": self.main_task["id"], "cleanup": True, "mapping": { - "car": "car", + "car": { "name": "car" }, }, } @@ -584,7 +584,7 @@ class LambdaTestCase(APITestCase): "cleanup": True, "threshold": 0.55, "mapping": { - "car": "car", + "car": { "name": "car" }, }, } data_assigneed_to_user_task = { @@ -592,7 +592,7 @@ class LambdaTestCase(APITestCase): "frame": 0, "cleanup": True, "mapping": { - "car": "car", + "car": { "name": "car" }, }, } @@ -612,7 +612,7 @@ class LambdaTestCase(APITestCase): "frame": 0, "cleanup": True, "mapping": { - "car": "car", + "car": { "name": "car" }, }, } response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}", self.user, data) @@ -753,7 +753,7 @@ class LambdaTestCase(APITestCase): "frame": 0, "cleanup": True, "mapping": { - "car": "car", + "car": { "name": "car" }, }, } @@ -767,7 +767,7 @@ class LambdaTestCase(APITestCase): "frame": 0, "cleanup": True, "mapping": { - "car": "car", + "car": { "name": "car" }, }, } @@ -781,7 +781,7 @@ class LambdaTestCase(APITestCase): "frame": 0, "cleanup": True, "mapping": { - "car": "car", + "car": { "name": "car" }, }, } @@ -796,7 +796,7 @@ class LambdaTestCase(APITestCase): "frame": 0, "cleanup": True, "mapping": { - "car": "car", + "car": { "name": "car" }, }, } @@ -814,7 +814,7 @@ class LambdaTestCase(APITestCase): "cleanup": True, "quality": quality, "mapping": { - "car": "car", + "car": { "name": "car" }, }, } @@ -827,7 +827,7 @@ class LambdaTestCase(APITestCase): "cleanup": True, "quality": "test-error-quality", "mapping": { - "car": "car", + "car": { "name": "car" }, }, } @@ -857,7 +857,7 @@ class LambdaTestCase(APITestCase): "task": self.main_task["id"], "frame": 0, "mapping": { - "car": "car", + "car": { "name": "car" }, }, } response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}", self.admin, data) @@ -879,7 +879,7 @@ class LambdaTestCase(APITestCase): "frame": 0, "cleanup": True, "mapping": { - "car": "car", + "car": { "name": "car" }, }, } response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}", self.admin, data) @@ -891,7 +891,7 @@ class LambdaTestCase(APITestCase): "task": self.main_task["id"], "cleanup": True, "mapping": { - "car": "car", + "car": { "name": "car" }, }, } response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}", self.admin, data) @@ -904,7 +904,7 @@ class LambdaTestCase(APITestCase): "frame": 0, "cleanup": True, "mapping": { - "car": "car", + "car": { "name": "car" }, }, } response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/test-functions-wrong-id", self.admin, data) @@ -917,7 +917,7 @@ class LambdaTestCase(APITestCase): "frame": 0, "cleanup": True, "mapping": { - "car": "car", + "car": { "name": "car" }, }, } response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}", self.admin, data) @@ -931,7 +931,7 @@ class LambdaTestCase(APITestCase): "frame": 12345, "cleanup": True, "mapping": { - "car": "car", + "car": { "name": "car" }, }, } response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}", self.admin, data) @@ -945,7 +945,7 @@ class LambdaTestCase(APITestCase): "frame": 0, "cleanup": True, "mapping": { - "car": "car", + "car": { "name": "car" }, }, } self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}", self.admin, data) @@ -959,7 +959,7 @@ class LambdaTestCase(APITestCase): "frame": 0, "cleanup": True, "mapping": { - "car": "car", + "car": { "name": "car" }, }, } response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_state_building}", self.admin, data) diff --git a/cvat/apps/lambda_manager/views.py b/cvat/apps/lambda_manager/views.py index acd402a1..47b80e1f 100644 --- a/cvat/apps/lambda_manager/views.py +++ b/cvat/apps/lambda_manager/views.py @@ -163,21 +163,24 @@ class LambdaFunction: def invoke(self, db_task, data): try: payload = {} + data = {k: v for k,v in data.items() if v is not None} threshold = data.get("threshold") if threshold: - payload.update({ - "threshold": threshold, - }) + payload.update({ "threshold": threshold }) quality = data.get("quality") mapping = data.get("mapping", {}) - mapping_by_default = {} + task_attributes = {} + mapping_by_default = {} for db_label in (db_task.project.label_set if db_task.project_id else db_task.label_set).prefetch_related("attributespec_set").all(): - mapping_by_default[db_label.name] = db_label.name + mapping_by_default[db_label.name] = { + 'name': db_label.name, + 'attributes': {} + } task_attributes[db_label.name] = {} for attribute in db_label.attributespec_set.all(): task_attributes[db_label.name][attribute.name] = { - 'input_rype': attribute.input_type, + 'input_type': attribute.input_type, 'values': attribute.values.split('\n') } if not mapping: @@ -186,15 +189,27 @@ class LambdaFunction: mapping = mapping_by_default else: # filter labels in mapping which don't exist in the task - mapping = {k:v for k,v in mapping.items() if v in mapping_by_default} + mapping = {k:v for k,v in mapping.items() if v['name'] in mapping_by_default} + + attr_mapping = { label: mapping[label]['attributes'] if 'attributes' in mapping[label] else {} for label in mapping } + mapping = { modelLabel: mapping[modelLabel]['name'] for modelLabel in mapping } + supported_attrs = {} for func_label, func_attrs in self.func_attributes.items(): - if func_label in mapping: - supported_attrs[func_label] = {} - task_attr_names = [task_attr for task_attr in task_attributes[mapping[func_label]]] + if func_label not in mapping: + continue + + mapped_label = mapping[func_label] + mapped_attributes = attr_mapping.get(func_label, {}) + supported_attrs[func_label] = {} + + if mapped_attributes: + task_attr_names = [task_attr for task_attr in task_attributes[mapped_label]] for attr in func_attrs: - if attr['name'] in task_attr_names: - supported_attrs[func_label].update({attr["name"] : attr}) + mapped_attr = mapped_attributes.get(attr["name"]) + if mapped_attr in task_attr_names: + supported_attrs[func_label].update({ attr["name"]: task_attributes[mapped_label][mapped_attr] }) + if self.kind == LambdaType.DETECTOR: payload.update({ "image": self._get_image(db_task, data["frame"], quality) @@ -259,29 +274,43 @@ class LambdaFunction: return db_attr_type == "text" or \ (db_attr_type in ["select", "radio"] and len(value.split(" ")) == 1) elif func_attr_type == "select": - return db_attr["input_type"] in ["radio", "text"] + return db_attr_type in ["radio", "text"] elif func_attr_type == "radio": - return db_attr["input_type"] in ["select", "text"] + return db_attr_type in ["select", "text"] elif func_attr_type == "checkbox": return value in ["true", "false"] else: return False if self.kind == LambdaType.DETECTOR: for item in response: - if item['label'] in mapping: - attributes = deepcopy(item.get("attributes", [])) - item["attributes"] = [] - for attr in attributes: - db_attr = supported_attrs.get(item['label'], {}).get(attr["name"]) - func_attr = [func_attr for func_attr in self.func_attributes.get(item['label'], []) if func_attr['name'] == attr["name"]] - # Skip current attribute if it was not declared as supportd in function config - if not func_attr: - continue - if attr["name"] in supported_attrs.get(item['label'], {}) and check_attr_value(attr["value"], func_attr[0], db_attr): - item["attributes"].append(attr) - item['label'] = mapping[item['label']] - response_filtered.append(item) - response = response_filtered + item_label = item['label'] + + if item_label not in mapping: + continue + + attributes = deepcopy(item.get("attributes", [])) + item["attributes"] = [] + mapped_attributes = attr_mapping[item_label] + + for attr in attributes: + if attr['name'] not in mapped_attributes: + continue + + func_attr = [func_attr for func_attr in self.func_attributes.get(item_label, []) if func_attr['name'] == attr["name"]] + # Skip current attribute if it was not declared as supported in function config + if not func_attr: + continue + + db_attr = supported_attrs.get(item_label, {}).get(attr["name"]) + + if check_attr_value(attr["value"], func_attr[0], db_attr): + attr["name"] = mapped_attributes[attr['name']] + item["attributes"].append(attr) + + item['label'] = mapping[item['label']] + response_filtered.append(item) + response = response_filtered + return response def _get_image(self, db_task, frame, quality): @@ -444,7 +473,7 @@ class LambdaJob: for frame in range(db_task.data.size): annotations = function.invoke(db_task, data={ "frame": frame, "quality": quality, "mapping": mapping, - "threshold": threshold}) + "threshold": threshold }) progress = (frame + 1) / db_task.data.size if not LambdaJob._update_progress(progress): break