diff --git a/CHANGELOG.md b/CHANGELOG.md index bdeee72b..4525b110 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,7 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## \[2.2.0] - Unreleased ### Added -- TDB +- Support of attributes returned by serverless functions () based on () ### Changed - TDB diff --git a/cvat-core/src/ml-model.js b/cvat-core/src/ml-model.js index b2089a13..68e29bc4 100644 --- a/cvat-core/src/ml-model.js +++ b/cvat-core/src/ml-model.js @@ -1,9 +1,9 @@ -// Copyright (C) 2019-2021 Intel Corporation +// Copyright (C) 2019-2022 Intel Corporation // // SPDX-License-Identifier: MIT /** - * Class representing a machine learning model + * Class representing a serverless function * @memberof module:API.cvat.classes */ class MLModel { @@ -11,6 +11,7 @@ class MLModel { this._id = data.id; this._name = data.name; this._labels = data.labels; + this._attributes = data.attributes || []; this._framework = data.framework; this._description = data.description; this._type = data.type; @@ -28,7 +29,7 @@ class MLModel { } /** - * @returns {string} + * @type {string} * @readonly */ get id() { @@ -36,7 +37,7 @@ class MLModel { } /** - * @returns {string} + * @type {string} * @readonly */ get name() { @@ -44,7 +45,8 @@ class MLModel { } /** - * @returns {string[]} + * @description labels supported by the model + * @type {string[]} * @readonly */ get labels() { @@ -56,7 +58,21 @@ class MLModel { } /** - * @returns {string} + * @typedef ModelAttribute + * @property {string} name + * @property {string[]} values + * @property {'select'|'number'|'checkbox'|'radio'|'text'} input_type + */ + /** + * @type {Object} + * @readonly + */ + get attributes() { + return { ...this._attributes }; + } + + /** + * @type {string} * @readonly */ get framework() { @@ -64,7 +80,7 @@ class MLModel { } /** - * @returns {string} + * @type {string} * @readonly */ get description() { @@ -72,7 +88,7 @@ class MLModel { } /** - * @returns {module:API.cvat.enums.ModelType} + * @type {module:API.cvat.enums.ModelType} * @readonly */ get type() { @@ -80,7 +96,7 @@ class MLModel { } /** - * @returns {object} + * @type {object} * @readonly */ get params() { @@ -90,10 +106,9 @@ class MLModel { } /** - * @typedef {Object} MlModelTip + * @type {MlModelTip} * @property {string} message A short message for a user about the model - * @property {string} gif A gif URL to be shawn to a user as an example - * @returns {MlModelTip} + * @property {string} gif A gif URL to be shown to a user as an example * @readonly */ get tip() { @@ -101,14 +116,16 @@ class MLModel { } /** - * @callback onRequestStatusChange + * @typedef onRequestStatusChange * @param {string} event * @global - */ + */ /** - * @param {onRequestStatusChange} onRequestStatusChange Set canvas onChangeToolsBlockerState callback + * @param {onRequestStatusChange} onRequestStatusChange + * @instance + * @description Used to set a callback when the tool is blocked in UI * @returns {void} - */ + */ set onChangeToolsBlockerState(onChangeToolsBlockerState) { this._params.canvas.onChangeToolsBlockerState = onChangeToolsBlockerState; } diff --git a/cvat-core/src/object-state.js b/cvat-core/src/object-state.js index d1fc8784..5de80ded 100644 --- a/cvat-core/src/object-state.js +++ b/cvat-core/src/object-state.js @@ -1,4 +1,4 @@ -// Copyright (C) 2019-2021 Intel Corporation +// Copyright (C) 2019-2022 Intel Corporation // // SPDX-License-Identifier: MIT @@ -208,7 +208,8 @@ const { Source } = require('./enums'); rotation: { /** * @name rotation - * @type {number} angle measured by degrees + * @description angle measured by degrees + * @type {number} * @memberof module:API.cvat.classes.ObjectState * @throws {module:API.cvat.exceptions.ArgumentError} * @instance diff --git a/cvat-ui/package-lock.json b/cvat-ui/package-lock.json index 841a7f66..66beba8c 100644 --- a/cvat-ui/package-lock.json +++ b/cvat-ui/package-lock.json @@ -1,12 +1,12 @@ { "name": "cvat-ui", - "version": "1.37.1", + "version": "1.38.0", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "cvat-ui", - "version": "1.37.1", + "version": "1.38.0", "license": "MIT", "dependencies": { "@ant-design/icons": "^4.6.3", diff --git a/cvat-ui/package.json b/cvat-ui/package.json index 02caaace..12c76943 100644 --- a/cvat-ui/package.json +++ b/cvat-ui/package.json @@ -1,6 +1,6 @@ { "name": "cvat-ui", - "version": "1.37.1", + "version": "1.38.0", "description": "CVAT single-page application", "main": "src/index.tsx", "scripts": { diff --git a/cvat-ui/src/components/annotation-page/standard-workspace/controls-side-bar/tools-control.tsx b/cvat-ui/src/components/annotation-page/standard-workspace/controls-side-bar/tools-control.tsx index e2f85638..21541bd8 100644 --- a/cvat-ui/src/components/annotation-page/standard-workspace/controls-side-bar/tools-control.tsx +++ b/cvat-ui/src/components/annotation-page/standard-workspace/controls-side-bar/tools-control.tsx @@ -28,7 +28,7 @@ import { Canvas, convertShapesForInteractor } from 'cvat-canvas-wrapper'; import getCore from 'cvat-core-wrapper'; import openCVWrapper from 'utils/opencv-wrapper/opencv-wrapper'; import { - CombinedState, ActiveControl, Model, ObjectType, ShapeType, ToolsBlockerState, + CombinedState, ActiveControl, Model, ObjectType, ShapeType, ToolsBlockerState, ModelAttribute, } from 'reducers/interfaces'; import { interactWithCanvas, @@ -37,9 +37,10 @@ import { updateAnnotationsAsync, createAnnotationsAsync, } from 'actions/annotation-actions'; -import DetectorRunner from 'components/model-runner-modal/detector-runner'; +import DetectorRunner, { DetectorRequestBody } from 'components/model-runner-modal/detector-runner'; import LabelSelector from 'components/label-selector/label-selector'; import CVATTooltip from 'components/common/cvat-tooltip'; +import { Attribute, Label } from 'components/labels-editor/common'; import ApproximationAccuracy, { thresholdFromAccuracy, @@ -374,7 +375,7 @@ export class ToolsControlComponent extends React.PureComponent { } setTimeout(() => this.runInteractionRequest(interactionId)); - } catch (err) { + } catch (err: any) { notification.error({ description: err.toString(), message: 'Interaction error occured', @@ -466,7 +467,7 @@ export class ToolsControlComponent extends React.PureComponent { // update annotations on a canvas fetchAnnotations(); - } catch (err) { + } catch (err: any) { notification.error({ description: err.toString(), message: 'Tracking error occured', @@ -706,7 +707,7 @@ export class ToolsControlComponent extends React.PureComponent { Array.prototype.push.apply(statefullContainer.states, serverlessStates); trackingData.statefull[trackerID] = statefullContainer; delete trackingData.stateless[trackerID]; - } catch (error) { + } catch (error: any) { notification.error({ message: 'Tracker initialization error', description: error.toString(), @@ -757,7 +758,7 @@ export class ToolsControlComponent extends React.PureComponent { trackedShape.shapePoints = shape; }); } - } catch (error) { + } catch (error: any) { notification.error({ message: 'Tracking error', description: error.toString(), @@ -1022,41 +1023,106 @@ export class ToolsControlComponent extends React.PureComponent { }); }); + function checkAttributesCompatibility( + functionAttribute: ModelAttribute | undefined, + dbAttribute: Attribute | undefined, + value: string, + ): boolean { + if (!dbAttribute || !functionAttribute) { + return false; + } + + const { inputType } = (dbAttribute as any as { inputType: string }); + if (functionAttribute.input_type === inputType) { + if (functionAttribute.input_type === 'number') { + const [min, max, step] = dbAttribute.values; + return !Number.isNaN(+value) && +value >= +min && +value <= +max && !(+value % +step); + } + + if (functionAttribute.input_type === 'checkbox') { + return ['true', 'false'].includes(value.toLowerCase()); + } + + if (['select', 'radio'].includes(functionAttribute.input_type)) { + return dbAttribute.values.includes(value); + } + + return true; + } + + switch (functionAttribute.input_type) { + case 'number': + return dbAttribute.values.includes(value) || inputType === 'text'; + case 'text': + return ['select', 'radio'].includes(dbAttribute.input_type) && dbAttribute.values.includes(value); + case 'select': + return (inputType === 'radio' && dbAttribute.values.includes(value)) || inputType === 'text'; + case 'radio': + return (inputType === 'select' && dbAttribute.values.includes(value)) || inputType === 'text'; + case 'checkbox': + return dbAttribute.values.includes(value) || inputType === 'text'; + default: + return false; + } + } + return ( { + runInference={async (model: Model, body: DetectorRequestBody) => { try { this.setState({ mode: 'detection', fetching: true }); const result = await core.lambda.call(jobInstance.taskId, model, { ...body, frame }); const states = result.map( - (data: any): any => new core.classes.ObjectState({ - shapeType: data.type, - label: jobInstance.labels.filter((label: any): boolean => label.name === data.label)[0], - points: data.points, - objectType: ObjectType.SHAPE, - frame, - occluded: false, - source: 'auto', - attributes: (data.attributes as { name: string, value: string }[]) - .reduce((mapping, attr) => { - mapping[attrsMap[data.label][attr.name]] = attr.value; - return mapping; - }, {} as Record), - zOrder: curZOrder, - }), - ); + (data: any): any => { + const jobLabel = (jobInstance.labels as Label[]) + .find((jLabel: Label): boolean => jLabel.name === data.label); + const [modelLabel] = Object.entries(body.mapping) + .find(([, { name }]) => name === data.label) || []; + + if (!jobLabel || !modelLabel) return null; + + return new core.classes.ObjectState({ + shapeType: data.type, + label: jobLabel, + points: data.points, + objectType: ObjectType.SHAPE, + frame, + occluded: false, + source: 'auto', + attributes: (data.attributes as { name: string, value: string }[]) + .reduce((acc, attr) => { + const [modelAttr] = Object.entries(body.mapping[modelLabel].attributes) + .find((value: string[]) => value[1] === attr.name) || []; + const areCompatible = checkAttributesCompatibility( + model.attributes[modelLabel].find((mAttr) => mAttr.name === modelAttr), + jobLabel.attributes.find((jobAttr: Attribute) => ( + jobAttr.name === attr.name + )), + attr.value, + ); + + if (areCompatible) { + acc[attrsMap[data.label][attr.name]] = attr.value; + } + + return acc; + }, {} as Record), + zOrder: curZOrder, + }); + }, + ).filter((state: any) => state); createAnnotations(jobInstance, frame, states); const { onSwitchToolsBlockerState } = this.props; onSwitchToolsBlockerState({ buttonVisible: false }); - } catch (error) { + } catch (error: any) { notification.error({ description: error.toString(), - message: 'Detection error occured', + message: 'Detection error occurred', }); } finally { this.setState({ fetching: false }); diff --git a/cvat-ui/src/components/create-cloud-storage-page/cloud-storage-form.tsx b/cvat-ui/src/components/create-cloud-storage-page/cloud-storage-form.tsx index d4724a23..06d2719b 100644 --- a/cvat-ui/src/components/create-cloud-storage-page/cloud-storage-form.tsx +++ b/cvat-ui/src/components/create-cloud-storage-page/cloud-storage-form.tsx @@ -74,7 +74,7 @@ export default function CreateCloudStorageForm(props: Props): JSX.Element { const fakeCredentialsData = { accountName: 'X'.repeat(24), sessionToken: 'X'.repeat(300), - key: 'X'.repeat(20), + key: 'X'.repeat(128), secretKey: 'X'.repeat(40), keyFile: new File([], 'fakeKey.json'), }; @@ -332,7 +332,7 @@ export default function CreateCloudStorageForm(props: Props): JSX.Element { {...internalCommonProps} > setKeyVisibility(true)} onFocus={() => onFocusCredentialsItem('key', 'key')} diff --git a/cvat-ui/src/components/model-runner-modal/detector-runner.tsx b/cvat-ui/src/components/model-runner-modal/detector-runner.tsx index 82932810..5f24b9fc 100644 --- a/cvat-ui/src/components/model-runner-modal/detector-runner.tsx +++ b/cvat-ui/src/components/model-runner-modal/detector-runner.tsx @@ -14,8 +14,10 @@ import InputNumber from 'antd/lib/input-number'; import Button from 'antd/lib/button'; import notification from 'antd/lib/notification'; -import { Model, StringObject } from 'reducers/interfaces'; +import { Model, ModelAttribute, StringObject } from 'reducers/interfaces'; + import CVATTooltip from 'components/common/cvat-tooltip'; +import { Label as LabelInterface } from 'components/labels-editor/common'; import { clamp } from 'utils/math'; import consts from 'consts'; import { DimensionType } from '../../reducers/interfaces'; @@ -23,28 +25,40 @@ import { DimensionType } from '../../reducers/interfaces'; interface Props { withCleanup: boolean; models: Model[]; - labels: any[]; + labels: LabelInterface[]; dimension: DimensionType; runInference(model: Model, body: object): void; } +interface MappedLabel { + name: string; + attributes: StringObject; +} + +type MappedLabelsList = Record; + +export interface DetectorRequestBody { + mapping: MappedLabelsList; + cleanup: boolean; +} + +interface Match { + model: string | null; + task: string | null; +} + function DetectorRunner(props: Props): JSX.Element { const { models, withCleanup, labels, dimension, runInference, } = props; const [modelID, setModelID] = useState(null); - const [mapping, setMapping] = useState({}); + const [mapping, setMapping] = useState({}); const [threshold, setThreshold] = useState(0.5); const [distance, setDistance] = useState(50); const [cleanup, setCleanup] = useState(false); - const [match, setMatch] = useState<{ - model: string | null; - task: string | null; - }>({ - model: null, - task: null, - }); + const [match, setMatch] = useState({ model: null, task: null }); + const [attrMatches, setAttrMatch] = useState>({}); const model = models.filter((_model): boolean => _model.id === modelID)[0]; const isDetector = model && model.type === 'detector'; @@ -57,24 +71,47 @@ function DetectorRunner(props: Props): JSX.Element { if (model && model.type !== 'reid' && !model.labels.length) { notification.warning({ - message: 'The selected model does not include any lables', + message: 'The selected model does not include any labels', }); } + function matchAttributes( + labelAttributes: LabelInterface['attributes'], + modelAttributes: ModelAttribute[], + ): StringObject { + if (Array.isArray(labelAttributes) && Array.isArray(modelAttributes)) { + return labelAttributes + .reduce((attrAcc: StringObject, attr: any): StringObject => { + if (modelAttributes.some((mAttr) => mAttr.name === attr.name)) { + attrAcc[attr.name] = attr.name; + } + + return attrAcc; + }, {}); + } + + return {}; + } + function updateMatch(modelLabel: string | null, taskLabel: string | null): void { - if (match.model && taskLabel) { - const newmatch: { [index: string]: string } = {}; - newmatch[match.model] = taskLabel; - setMapping({ ...mapping, ...newmatch }); + function addMatch(modelLbl: string, taskLbl: string): void { + const newMatch: MappedLabelsList = {}; + const label = labels.find((l) => l.name === taskLbl) as LabelInterface; + const currentModel = models.filter((_model): boolean => _model.id === modelID)[0]; + const attributes = matchAttributes(label.attributes, currentModel.attributes[modelLbl]); + + newMatch[modelLbl] = { name: taskLbl, attributes }; + setMapping({ ...mapping, ...newMatch }); setMatch({ model: null, task: null }); + } + + if (match.model && taskLabel) { + addMatch(match.model, taskLabel); return; } if (match.task && modelLabel) { - const newmatch: { [index: string]: string } = {}; - newmatch[modelLabel] = match.task; - setMapping({ ...mapping, ...newmatch }); - setMatch({ model: null, task: null }); + addMatch(modelLabel, match.task); return; } @@ -84,14 +121,72 @@ function DetectorRunner(props: Props): JSX.Element { }); } + function updateAttrMatch(modelLabel: string, modelAttrLabel: string | null, taskAttrLabel: string | null): void { + function addAttributeMatch(modelAttr: string, attrLabel: string): void { + const newMatch: StringObject = {}; + newMatch[modelAttr] = attrLabel; + mapping[modelLabel].attributes = { ...mapping[modelLabel].attributes, ...newMatch }; + + delete attrMatches[modelLabel]; + setAttrMatch({ ...attrMatches }); + } + + const modelAttr = attrMatches[modelLabel]?.model; + if (modelAttr && taskAttrLabel) { + addAttributeMatch(modelAttr, taskAttrLabel); + return; + } + + const taskAttrModel = attrMatches[modelLabel]?.task; + if (taskAttrModel && modelAttrLabel) { + addAttributeMatch(modelAttrLabel, taskAttrModel); + return; + } + + attrMatches[modelLabel] = { + model: modelAttrLabel, + task: taskAttrLabel, + }; + setAttrMatch({ ...attrMatches }); + } + + function renderMappingRow( + color: string, + leftLabel: string, + rightLabel: string, + removalTitle: string, + onClick: () => void, + className = '', + ): JSX.Element { + return ( + + + {leftLabel} + + + {rightLabel} + + + + + + + + ); + } + function renderSelector( value: string, tooltip: string, labelsToRender: string[], onChange: (label: string) => void, + className = '', ): JSX.Element { return ( - + {model.labels.map( (label): JSX.Element => ( diff --git a/cvat-ui/src/components/models-page/deployed-models-list.tsx b/cvat-ui/src/components/models-page/deployed-models-list.tsx index 6db6b881..8b49cd66 100644 --- a/cvat-ui/src/components/models-page/deployed-models-list.tsx +++ b/cvat-ui/src/components/models-page/deployed-models-list.tsx @@ -1,4 +1,4 @@ -// Copyright (C) 2020 Intel Corporation +// Copyright (C) 2020-2022 Intel Corporation // // SPDX-License-Identifier: MIT @@ -29,13 +29,13 @@ export default function DeployedModelsListComponent(props: Props): JSX.Element { Name - + Type - + Description - + Labels diff --git a/cvat-ui/src/reducers/interfaces.ts b/cvat-ui/src/reducers/interfaces.ts index 64d2d466..c760c78a 100644 --- a/cvat-ui/src/reducers/interfaces.ts +++ b/cvat-ui/src/reducers/interfaces.ts @@ -255,10 +255,17 @@ export interface ShareState { root: ShareItem; } +export interface ModelAttribute { + name: string; + values: string[]; + input_type: 'select' | 'number' | 'checkbox' | 'radio' | 'text'; +} + export interface Model { id: string; name: string; labels: string[]; + attributes: Record; framework: string; description: string; type: string; diff --git a/cvat/apps/engine/media_extractors.py b/cvat/apps/engine/media_extractors.py index 716b9b66..7cf11a32 100644 --- a/cvat/apps/engine/media_extractors.py +++ b/cvat/apps/engine/media_extractors.py @@ -95,6 +95,7 @@ def rotate_within_exif(img: Image): ORIENTATION.MIRROR_HORIZONTAL_270_ROTATED ,ORIENTATION.MIRROR_HORIZONTAL_90_ROTATED, ]: img = img.transpose(Image.FLIP_LEFT_RIGHT) + return img class IMediaReader(ABC): @@ -125,8 +126,8 @@ class IMediaReader(ABC): preview = Image.open(obj) else: preview = obj - preview.thumbnail(PREVIEW_SIZE) preview = rotate_within_exif(preview) + preview.thumbnail(PREVIEW_SIZE) return preview.convert('RGB') diff --git a/cvat/apps/lambda_manager/tests/test_lambda.py b/cvat/apps/lambda_manager/tests/test_lambda.py index 4a8699ea..831ff682 100644 --- a/cvat/apps/lambda_manager/tests/test_lambda.py +++ b/cvat/apps/lambda_manager/tests/test_lambda.py @@ -324,7 +324,7 @@ class LambdaTestCase(APITestCase): "threshold": 55, "quality": "original", "mapping": { - "car": "car", + "car": { "name": "car" }, }, } response = self._post_request(LAMBDA_REQUESTS_PATH, self.admin, data_main_task) @@ -364,7 +364,7 @@ class LambdaTestCase(APITestCase): "task": self.main_task["id"], "cleanup": True, "mapping": { - "car": "car", + "car": { "name": "car" }, }, } response = self._post_request(f'{LAMBDA_REQUESTS_PATH}', self.admin, data) @@ -404,7 +404,7 @@ class LambdaTestCase(APITestCase): "threshold": 55, "quality": "original", "mapping": { - "car": "car", + "car": { "name": "car" }, }, } data_assigneed_to_user_task = { @@ -414,7 +414,7 @@ class LambdaTestCase(APITestCase): "quality": "compressed", "max_distance": 70, "mapping": { - "car": "car", + "car": { "name": "car" }, }, } @@ -442,7 +442,7 @@ class LambdaTestCase(APITestCase): "task": self.main_task["id"], "cleanup": True, "mapping": { - "car": "car", + "car": { "name": "car" }, }, } @@ -461,7 +461,7 @@ class LambdaTestCase(APITestCase): "task": self.main_task["id"], "cleanup": True, "mapping": { - "car": "car", + "car": { "name": "car" }, }, } response = self._post_request(LAMBDA_REQUESTS_PATH, self.admin, data) @@ -474,7 +474,7 @@ class LambdaTestCase(APITestCase): "task": self.main_task["id"], "cleanup": True, "mapping": { - "car": "car", + "car": { "name": "car" }, }, } response = self._post_request(LAMBDA_REQUESTS_PATH, self.admin, data) @@ -488,7 +488,7 @@ class LambdaTestCase(APITestCase): "task": self.main_task["id"], "cleanup": True, "mapping": { - "car": "car", + "car": { "name": "car" }, }, } self._post_request(LAMBDA_REQUESTS_PATH, self.admin, data) @@ -514,7 +514,7 @@ class LambdaTestCase(APITestCase): "function": id_function_detector, "task": self.main_task["id"], "mapping": { - "car": "car", + "car": { "name": "car" }, }, } response = self._post_request(LAMBDA_REQUESTS_PATH, self.admin, data) @@ -540,7 +540,7 @@ class LambdaTestCase(APITestCase): "function": id_function_detector, "cleanup": True, "mapping": { - "car": "car", + "car": { "name": "car" }, }, } response = self._post_request(LAMBDA_REQUESTS_PATH, self.admin, data) @@ -553,7 +553,7 @@ class LambdaTestCase(APITestCase): "task": 12345, "cleanup": True, "mapping": { - "car": "car", + "car": { "name": "car" }, }, } response = self._post_request(LAMBDA_REQUESTS_PATH, self.admin, data) @@ -569,7 +569,7 @@ class LambdaTestCase(APITestCase): "task": self.main_task["id"], "cleanup": True, "mapping": { - "car": "car", + "car": { "name": "car" }, }, } @@ -584,7 +584,7 @@ class LambdaTestCase(APITestCase): "cleanup": True, "threshold": 0.55, "mapping": { - "car": "car", + "car": { "name": "car" }, }, } data_assigneed_to_user_task = { @@ -592,7 +592,7 @@ class LambdaTestCase(APITestCase): "frame": 0, "cleanup": True, "mapping": { - "car": "car", + "car": { "name": "car" }, }, } @@ -612,7 +612,7 @@ class LambdaTestCase(APITestCase): "frame": 0, "cleanup": True, "mapping": { - "car": "car", + "car": { "name": "car" }, }, } response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}", self.user, data) @@ -753,7 +753,7 @@ class LambdaTestCase(APITestCase): "frame": 0, "cleanup": True, "mapping": { - "car": "car", + "car": { "name": "car" }, }, } @@ -767,7 +767,7 @@ class LambdaTestCase(APITestCase): "frame": 0, "cleanup": True, "mapping": { - "car": "car", + "car": { "name": "car" }, }, } @@ -781,7 +781,7 @@ class LambdaTestCase(APITestCase): "frame": 0, "cleanup": True, "mapping": { - "car": "car", + "car": { "name": "car" }, }, } @@ -796,7 +796,7 @@ class LambdaTestCase(APITestCase): "frame": 0, "cleanup": True, "mapping": { - "car": "car", + "car": { "name": "car" }, }, } @@ -814,7 +814,7 @@ class LambdaTestCase(APITestCase): "cleanup": True, "quality": quality, "mapping": { - "car": "car", + "car": { "name": "car" }, }, } @@ -827,7 +827,7 @@ class LambdaTestCase(APITestCase): "cleanup": True, "quality": "test-error-quality", "mapping": { - "car": "car", + "car": { "name": "car" }, }, } @@ -857,7 +857,7 @@ class LambdaTestCase(APITestCase): "task": self.main_task["id"], "frame": 0, "mapping": { - "car": "car", + "car": { "name": "car" }, }, } response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}", self.admin, data) @@ -879,7 +879,7 @@ class LambdaTestCase(APITestCase): "frame": 0, "cleanup": True, "mapping": { - "car": "car", + "car": { "name": "car" }, }, } response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}", self.admin, data) @@ -891,7 +891,7 @@ class LambdaTestCase(APITestCase): "task": self.main_task["id"], "cleanup": True, "mapping": { - "car": "car", + "car": { "name": "car" }, }, } response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}", self.admin, data) @@ -904,7 +904,7 @@ class LambdaTestCase(APITestCase): "frame": 0, "cleanup": True, "mapping": { - "car": "car", + "car": { "name": "car" }, }, } response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/test-functions-wrong-id", self.admin, data) @@ -917,7 +917,7 @@ class LambdaTestCase(APITestCase): "frame": 0, "cleanup": True, "mapping": { - "car": "car", + "car": { "name": "car" }, }, } response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}", self.admin, data) @@ -931,7 +931,7 @@ class LambdaTestCase(APITestCase): "frame": 12345, "cleanup": True, "mapping": { - "car": "car", + "car": { "name": "car" }, }, } response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}", self.admin, data) @@ -945,7 +945,7 @@ class LambdaTestCase(APITestCase): "frame": 0, "cleanup": True, "mapping": { - "car": "car", + "car": { "name": "car" }, }, } self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}", self.admin, data) @@ -959,7 +959,7 @@ class LambdaTestCase(APITestCase): "frame": 0, "cleanup": True, "mapping": { - "car": "car", + "car": { "name": "car" }, }, } response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_state_building}", self.admin, data) diff --git a/cvat/apps/lambda_manager/views.py b/cvat/apps/lambda_manager/views.py index acd402a1..47b80e1f 100644 --- a/cvat/apps/lambda_manager/views.py +++ b/cvat/apps/lambda_manager/views.py @@ -163,21 +163,24 @@ class LambdaFunction: def invoke(self, db_task, data): try: payload = {} + data = {k: v for k,v in data.items() if v is not None} threshold = data.get("threshold") if threshold: - payload.update({ - "threshold": threshold, - }) + payload.update({ "threshold": threshold }) quality = data.get("quality") mapping = data.get("mapping", {}) - mapping_by_default = {} + task_attributes = {} + mapping_by_default = {} for db_label in (db_task.project.label_set if db_task.project_id else db_task.label_set).prefetch_related("attributespec_set").all(): - mapping_by_default[db_label.name] = db_label.name + mapping_by_default[db_label.name] = { + 'name': db_label.name, + 'attributes': {} + } task_attributes[db_label.name] = {} for attribute in db_label.attributespec_set.all(): task_attributes[db_label.name][attribute.name] = { - 'input_rype': attribute.input_type, + 'input_type': attribute.input_type, 'values': attribute.values.split('\n') } if not mapping: @@ -186,15 +189,27 @@ class LambdaFunction: mapping = mapping_by_default else: # filter labels in mapping which don't exist in the task - mapping = {k:v for k,v in mapping.items() if v in mapping_by_default} + mapping = {k:v for k,v in mapping.items() if v['name'] in mapping_by_default} + + attr_mapping = { label: mapping[label]['attributes'] if 'attributes' in mapping[label] else {} for label in mapping } + mapping = { modelLabel: mapping[modelLabel]['name'] for modelLabel in mapping } + supported_attrs = {} for func_label, func_attrs in self.func_attributes.items(): - if func_label in mapping: - supported_attrs[func_label] = {} - task_attr_names = [task_attr for task_attr in task_attributes[mapping[func_label]]] + if func_label not in mapping: + continue + + mapped_label = mapping[func_label] + mapped_attributes = attr_mapping.get(func_label, {}) + supported_attrs[func_label] = {} + + if mapped_attributes: + task_attr_names = [task_attr for task_attr in task_attributes[mapped_label]] for attr in func_attrs: - if attr['name'] in task_attr_names: - supported_attrs[func_label].update({attr["name"] : attr}) + mapped_attr = mapped_attributes.get(attr["name"]) + if mapped_attr in task_attr_names: + supported_attrs[func_label].update({ attr["name"]: task_attributes[mapped_label][mapped_attr] }) + if self.kind == LambdaType.DETECTOR: payload.update({ "image": self._get_image(db_task, data["frame"], quality) @@ -259,29 +274,43 @@ class LambdaFunction: return db_attr_type == "text" or \ (db_attr_type in ["select", "radio"] and len(value.split(" ")) == 1) elif func_attr_type == "select": - return db_attr["input_type"] in ["radio", "text"] + return db_attr_type in ["radio", "text"] elif func_attr_type == "radio": - return db_attr["input_type"] in ["select", "text"] + return db_attr_type in ["select", "text"] elif func_attr_type == "checkbox": return value in ["true", "false"] else: return False if self.kind == LambdaType.DETECTOR: for item in response: - if item['label'] in mapping: - attributes = deepcopy(item.get("attributes", [])) - item["attributes"] = [] - for attr in attributes: - db_attr = supported_attrs.get(item['label'], {}).get(attr["name"]) - func_attr = [func_attr for func_attr in self.func_attributes.get(item['label'], []) if func_attr['name'] == attr["name"]] - # Skip current attribute if it was not declared as supportd in function config - if not func_attr: - continue - if attr["name"] in supported_attrs.get(item['label'], {}) and check_attr_value(attr["value"], func_attr[0], db_attr): - item["attributes"].append(attr) - item['label'] = mapping[item['label']] - response_filtered.append(item) - response = response_filtered + item_label = item['label'] + + if item_label not in mapping: + continue + + attributes = deepcopy(item.get("attributes", [])) + item["attributes"] = [] + mapped_attributes = attr_mapping[item_label] + + for attr in attributes: + if attr['name'] not in mapped_attributes: + continue + + func_attr = [func_attr for func_attr in self.func_attributes.get(item_label, []) if func_attr['name'] == attr["name"]] + # Skip current attribute if it was not declared as supported in function config + if not func_attr: + continue + + db_attr = supported_attrs.get(item_label, {}).get(attr["name"]) + + if check_attr_value(attr["value"], func_attr[0], db_attr): + attr["name"] = mapped_attributes[attr['name']] + item["attributes"].append(attr) + + item['label'] = mapping[item['label']] + response_filtered.append(item) + response = response_filtered + return response def _get_image(self, db_task, frame, quality): @@ -444,7 +473,7 @@ class LambdaJob: for frame in range(db_task.data.size): annotations = function.invoke(db_task, data={ "frame": frame, "quality": quality, "mapping": mapping, - "threshold": threshold}) + "threshold": threshold }) progress = (frame + 1) / db_task.data.size if not LambdaJob._update_progress(progress): break