From 579bfb38c356dbca6694a4f156735040934986a6 Mon Sep 17 00:00:00 2001 From: Dmitry Kalinin Date: Fri, 17 Dec 2021 15:39:45 +0300 Subject: [PATCH] Project import simple implementation (#3790) --- CHANGELOG.md | 1 + cvat-core/package-lock.json | 4 +- cvat-core/package.json | 2 +- cvat-core/src/annotations.js | 17 + cvat-core/src/project-implementation.js | 13 +- cvat-core/src/project.js | 11 + cvat-core/src/server-proxy.js | 45 +- cvat-ui/package-lock.json | 4 +- cvat-ui/package.json | 2 +- cvat-ui/src/actions/import-actions.ts | 59 +++ cvat-ui/src/actions/tasks-actions.ts | 4 +- .../export-dataset/export-dataset-modal.tsx | 2 +- .../import-dataset-modal.tsx | 153 +++++++ .../import-dataset-status-modal.tsx | 34 ++ .../import-dataset-modal/styles.scss | 32 ++ .../components/project-page/project-page.tsx | 2 + .../components/projects-page/actions-menu.tsx | 26 +- .../projects-page/projects-page.tsx | 2 + cvat-ui/src/reducers/import-reducer.ts | 61 +++ cvat-ui/src/reducers/interfaces.ts | 17 + cvat-ui/src/reducers/notifications-reducer.ts | 34 +- cvat-ui/src/reducers/root-reducer.ts | 2 + cvat/apps/dataset_manager/bindings.py | 407 ++++++++++++++--- cvat/apps/dataset_manager/formats/camvid.py | 4 +- .../dataset_manager/formats/cityscapes.py | 4 +- cvat/apps/dataset_manager/formats/coco.py | 8 +- cvat/apps/dataset_manager/formats/cvat.py | 431 +++++++++++++++++- cvat/apps/dataset_manager/formats/icdar.py | 12 +- cvat/apps/dataset_manager/formats/imagenet.py | 4 +- cvat/apps/dataset_manager/formats/labelme.py | 4 +- cvat/apps/dataset_manager/formats/lfw.py | 5 +- .../dataset_manager/formats/market1501.py | 4 +- cvat/apps/dataset_manager/formats/mask.py | 4 +- cvat/apps/dataset_manager/formats/mot.py | 151 +++--- cvat/apps/dataset_manager/formats/mots.py | 149 +++--- .../dataset_manager/formats/openimages.py | 12 +- .../dataset_manager/formats/pascal_voc.py | 4 +- .../dataset_manager/formats/pointcloud.py | 18 +- cvat/apps/dataset_manager/formats/registry.py | 4 +- cvat/apps/dataset_manager/formats/tfrecord.py | 4 +- .../dataset_manager/formats/velodynepoint.py | 23 +- cvat/apps/dataset_manager/formats/vggface2.py | 4 +- .../apps/dataset_manager/formats/widerface.py | 4 +- cvat/apps/dataset_manager/formats/yolo.py | 12 +- cvat/apps/dataset_manager/project.py | 127 +++++- cvat/apps/dataset_manager/task.py | 17 +- cvat/apps/dataset_manager/util.py | 17 + cvat/apps/engine/serializers.py | 10 + cvat/apps/engine/task.py | 40 +- cvat/apps/engine/tests/test_rest_api.py | 159 +++++++ cvat/apps/engine/views.py | 130 +++++- cvat/requirements/base.txt | 1 + tests/cypress/support/commands_projects.js | 2 +- 53 files changed, 1941 insertions(+), 361 deletions(-) create mode 100644 cvat-ui/src/actions/import-actions.ts create mode 100644 cvat-ui/src/components/import-dataset-modal/import-dataset-modal.tsx create mode 100644 cvat-ui/src/components/import-dataset-modal/import-dataset-status-modal.tsx create mode 100644 cvat-ui/src/components/import-dataset-modal/styles.scss create mode 100644 cvat-ui/src/reducers/import-reducer.ts diff --git a/CHANGELOG.md b/CHANGELOG.md index 1846a53b..38f9a6a9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Data sorting option () - Options to change font size & position of text labels on the canvas () - Add "tag" return type for automatic annotation in Nuclio () +- Dataset importing to a project () - User is able to customize information that text labels show () ### Changed diff --git a/cvat-core/package-lock.json b/cvat-core/package-lock.json index a98ed010..b4436451 100644 --- a/cvat-core/package-lock.json +++ b/cvat-core/package-lock.json @@ -1,12 +1,12 @@ { "name": "cvat-core", - "version": "3.21.1", + "version": "3.22.0", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "cvat-core", - "version": "3.21.1", + "version": "3.22.0", "license": "MIT", "dependencies": { "axios": "^0.21.4", diff --git a/cvat-core/package.json b/cvat-core/package.json index f0c70ad9..a420a715 100644 --- a/cvat-core/package.json +++ b/cvat-core/package.json @@ -1,6 +1,6 @@ { "name": "cvat-core", - "version": "3.21.1", + "version": "3.22.0", "description": "Part of Computer Vision Tool which presents an interface for client-side integration", "main": "babel.config.js", "scripts": { diff --git a/cvat-core/src/annotations.js b/cvat-core/src/annotations.js index 9b9cd695..e2868caf 100644 --- a/cvat-core/src/annotations.js +++ b/cvat-core/src/annotations.js @@ -284,6 +284,22 @@ return result; } + function importDataset(instance, format, file, updateStatusCallback = () => {}) { + if (!(typeof format === 'string')) { + throw new ArgumentError('Format must be a string'); + } + if (!(instance instanceof Project)) { + throw new ArgumentError('Instance should be a Project instance'); + } + if (!(typeof updateStatusCallback === 'function')) { + throw new ArgumentError('Callback should be a function'); + } + if (!(['application/zip', 'application/x-zip-compressed'].includes(file.type))) { + throw new ArgumentError('File should be file instance with ZIP extension'); + } + return serverProxy.projects.importDataset(instance.id, format, file, updateStatusCallback); + } + function undoActions(session, count) { const sessionType = session instanceof Task ? 'task' : 'job'; const cache = getCache(sessionType); @@ -366,6 +382,7 @@ importAnnotations, exportAnnotations, exportDataset, + importDataset, undoActions, redoActions, freezeHistory, diff --git a/cvat-core/src/project-implementation.js b/cvat-core/src/project-implementation.js index c5bb2387..d1d59475 100644 --- a/cvat-core/src/project-implementation.js +++ b/cvat-core/src/project-implementation.js @@ -7,7 +7,7 @@ const { getPreview } = require('./frames'); const { Project } = require('./project'); - const { exportDataset } = require('./annotations'); + const { exportDataset, importDataset } = require('./annotations'); function implementProject(projectClass) { projectClass.prototype.save.implementation = async function () { @@ -61,11 +61,20 @@ }; projectClass.prototype.annotations.exportDataset.implementation = async function ( - format, saveImages, customName, + format, + saveImages, + customName, ) { const result = exportDataset(this, format, customName, saveImages); return result; }; + projectClass.prototype.annotations.importDataset.implementation = async function ( + format, + file, + updateStatusCallback, + ) { + return importDataset(this, format, file, updateStatusCallback); + }; return projectClass; } diff --git a/cvat-core/src/project.js b/cvat-core/src/project.js index acfb21ee..5f34df5c 100644 --- a/cvat-core/src/project.js +++ b/cvat-core/src/project.js @@ -244,6 +244,7 @@ // So, we need return it this.annotations = { exportDataset: Object.getPrototypeOf(this).annotations.exportDataset.bind(this), + importDataset: Object.getPrototypeOf(this).annotations.importDataset.bind(this), }; } @@ -310,6 +311,16 @@ ); return result; }, + async importDataset(format, file, updateStatusCallback = null) { + const result = await PluginRegistry.apiWrapper.call( + this, + Project.prototype.annotations.importDataset, + format, + file, + updateStatusCallback, + ); + return result; + }, }, writable: true, }), diff --git a/cvat-core/src/server-proxy.js b/cvat-core/src/server-proxy.js index d8e96e13..c0f59c9c 100644 --- a/cvat-core/src/server-proxy.js +++ b/cvat-core/src/server-proxy.js @@ -514,6 +514,44 @@ }; } + async function importDataset(id, format, file, onUpdate) { + const { backendAPI } = config; + const url = `${backendAPI}/projects/${id}/dataset`; + + const formData = new FormData(); + formData.append('dataset_file', file); + + return new Promise((resolve, reject) => { + async function requestStatus() { + try { + const response = await Axios.get(`${url}?action=import_status`, { + proxy: config.proxy, + }); + if (response.status === 202) { + if (onUpdate && response.data.message !== '') { + onUpdate(response.data.message, response.data.progress || 0); + } + setTimeout(requestStatus, 3000); + } else if (response.status === 201) { + resolve(); + } else { + reject(generateError(response)); + } + } catch (error) { + reject(generateError(error)); + } + } + + Axios.post(`${url}?format=${format}`, formData, { + proxy: config.proxy, + }).then(() => { + setTimeout(requestStatus, 2000); + }).catch((error) => { + reject(generateError(error)); + }); + }); + } + async function exportTask(id) { const { backendAPI } = config; const url = `${backendAPI}/tasks/${id}`; @@ -577,7 +615,7 @@ const response = await Axios.get(`${backendAPI}/tasks/${id}/status`); if (['Queued', 'Started'].includes(response.data.state)) { if (response.data.message !== '') { - onUpdate(response.data.message); + onUpdate(response.data.message, response.data.progress || 0); } setTimeout(checkStatus, 1000); } else if (response.data.state === 'Finished') { @@ -637,7 +675,7 @@ let response = null; - onUpdate('The task is being created on the server..'); + onUpdate('The task is being created on the server..', null); try { response = await Axios.post(`${backendAPI}/tasks`, JSON.stringify(taskSpec), { proxy: config.proxy, @@ -649,7 +687,7 @@ throw generateError(errorData); } - onUpdate('The data are being uploaded to the server 0%'); + onUpdate('The data are being uploaded to the server..', null); async function chunkUpload(taskId, file) { return new Promise((resolve, reject) => { @@ -1438,6 +1476,7 @@ create: createProject, delete: deleteProject, exportDataset: exportDataset('projects'), + importDataset, }), writable: false, }, diff --git a/cvat-ui/package-lock.json b/cvat-ui/package-lock.json index ebe48e3d..34f85b08 100644 --- a/cvat-ui/package-lock.json +++ b/cvat-ui/package-lock.json @@ -1,12 +1,12 @@ { "name": "cvat-ui", - "version": "1.29.0", + "version": "1.30.0", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "cvat-ui", - "version": "1.29.0", + "version": "1.30.0", "license": "MIT", "dependencies": { "@ant-design/icons": "^4.6.3", diff --git a/cvat-ui/package.json b/cvat-ui/package.json index 5030d6b1..00e91ff2 100644 --- a/cvat-ui/package.json +++ b/cvat-ui/package.json @@ -1,6 +1,6 @@ { "name": "cvat-ui", - "version": "1.29.0", + "version": "1.30.0", "description": "CVAT single-page application", "main": "src/index.tsx", "scripts": { diff --git a/cvat-ui/src/actions/import-actions.ts b/cvat-ui/src/actions/import-actions.ts new file mode 100644 index 00000000..71c9a8eb --- /dev/null +++ b/cvat-ui/src/actions/import-actions.ts @@ -0,0 +1,59 @@ +// Copyright (C) 2021 Intel Corporation +// +// SPDX-License-Identifier: MIT + +import { createAction, ActionUnion, ThunkAction } from 'utils/redux'; +import { CombinedState } from 'reducers/interfaces'; +import { getProjectsAsync } from './projects-actions'; + +export enum ImportActionTypes { + OPEN_IMPORT_MODAL = 'OPEN_IMPORT_MODAL', + CLOSE_IMPORT_MODAL = 'CLOSE_IMPORT_MODAL', + IMPORT_DATASET = 'IMPORT_DATASET', + IMPORT_DATASET_SUCCESS = 'IMPORT_DATASET_SUCCESS', + IMPORT_DATASET_FAILED = 'IMPORT_DATASET_FAILED', + IMPORT_DATASET_UPDATE_STATUS = 'IMPORT_DATASET_UPDATE_STATUS', +} + +export const importActions = { + openImportModal: (instance: any) => createAction(ImportActionTypes.OPEN_IMPORT_MODAL, { instance }), + closeImportModal: () => createAction(ImportActionTypes.CLOSE_IMPORT_MODAL), + importDataset: (projectId: number) => ( + createAction(ImportActionTypes.IMPORT_DATASET, { id: projectId }) + ), + importDatasetSuccess: () => ( + createAction(ImportActionTypes.IMPORT_DATASET_SUCCESS) + ), + importDatasetFailed: (instance: any, error: any) => ( + createAction(ImportActionTypes.IMPORT_DATASET_FAILED, { + instance, + error, + }) + ), + importDatasetUpdateStatus: (progress: number, status: string) => ( + createAction(ImportActionTypes.IMPORT_DATASET_UPDATE_STATUS, { progress, status }) + ), +}; + +export const importDatasetAsync = (instance: any, format: string, file: File): ThunkAction => ( + async (dispatch, getState) => { + try { + const state: CombinedState = getState(); + if (state.import.importingId !== null) { + throw Error('Only one importing of dataset allowed at the same time'); + } + dispatch(importActions.importDataset(instance.id)); + await instance.annotations.importDataset(format, file, (message: string, progress: number) => ( + dispatch(importActions.importDatasetUpdateStatus(progress * 100, message)) + )); + } catch (error) { + dispatch(importActions.importDatasetFailed(instance, error)); + return; + } + + dispatch(importActions.importDatasetSuccess()); + dispatch(getProjectsAsync({ id: instance.id })); + } +); + +export type ImportActions = ActionUnion; diff --git a/cvat-ui/src/actions/tasks-actions.ts b/cvat-ui/src/actions/tasks-actions.ts index 9ef4fc91..25df033d 100644 --- a/cvat-ui/src/actions/tasks-actions.ts +++ b/cvat-ui/src/actions/tasks-actions.ts @@ -414,8 +414,8 @@ export function createTaskAsync(data: any): ThunkAction, {}, {}, A dispatch(createTask()); try { - const savedTask = await taskInstance.save((status: string): void => { - dispatch(createTaskUpdateStatus(status)); + const savedTask = await taskInstance.save((status: string, progress: number): void => { + dispatch(createTaskUpdateStatus(status + (progress !== null ? ` ${Math.floor(progress * 100)}%` : ''))); }); dispatch(createTaskSuccess(savedTask.id)); } catch (error) { diff --git a/cvat-ui/src/components/export-dataset/export-dataset-modal.tsx b/cvat-ui/src/components/export-dataset/export-dataset-modal.tsx index 4eda40a4..32e5654c 100644 --- a/cvat-ui/src/components/export-dataset/export-dataset-modal.tsx +++ b/cvat-ui/src/components/export-dataset/export-dataset-modal.tsx @@ -55,7 +55,7 @@ function ExportDatasetModal(): JSX.Element { useEffect(() => { initActivities(); - }, [instance?.id, instance instanceof core.classes.Project]); + }, [instance?.id, instance instanceof core.classes.Project, taskExportActivities, projectExportActivities]); const closeModal = (): void => { form.resetFields(); diff --git a/cvat-ui/src/components/import-dataset-modal/import-dataset-modal.tsx b/cvat-ui/src/components/import-dataset-modal/import-dataset-modal.tsx new file mode 100644 index 00000000..48ee2e42 --- /dev/null +++ b/cvat-ui/src/components/import-dataset-modal/import-dataset-modal.tsx @@ -0,0 +1,153 @@ +// Copyright (C) 2021 Intel Corporation +// +// SPDX-License-Identifier: MIT + +import './styles.scss'; +import React, { useCallback, useState } from 'react'; +import { useDispatch, useSelector } from 'react-redux'; +import Modal from 'antd/lib/modal'; +import Form from 'antd/lib/form'; +import Text from 'antd/lib/typography/Text'; +import Select from 'antd/lib/select'; +import Notification from 'antd/lib/notification'; +import message from 'antd/lib/message'; +import Upload, { RcFile } from 'antd/lib/upload'; + +import { + DownloadOutlined, InboxOutlined, LoadingOutlined, QuestionCircleFilled, +} from '@ant-design/icons'; + +import CVATTooltip from 'components/common/cvat-tooltip'; +import { CombinedState } from 'reducers/interfaces'; +import { importActions, importDatasetAsync } from 'actions/import-actions'; + +import ImportDatasetStatusModal from './import-dataset-status-modal'; + +type FormValues = { + selectedFormat: string | undefined; +}; + +function ImportDatasetModal(): JSX.Element { + const [form] = Form.useForm(); + const [file, setFile] = useState(null); + const modalVisible = useSelector((state: CombinedState) => state.import.modalVisible); + const instance = useSelector((state: CombinedState) => state.import.instance); + const currentImportId = useSelector((state: CombinedState) => state.import.importingId); + const importers = useSelector((state: CombinedState) => state.formats.annotationFormats.loaders); + const dispatch = useDispatch(); + + const closeModal = useCallback((): void => { + form.resetFields(); + setFile(null); + dispatch(importActions.closeImportModal()); + }, [form]); + + const handleImport = useCallback( + (values: FormValues): void => { + if (file === null) { + Notification.error({ + message: 'No dataset file selected', + }); + return; + } + dispatch(importDatasetAsync(instance, values.selectedFormat as string, file)); + closeModal(); + Notification.info({ + message: 'Dataset export started', + description: `Dataset import was started for project #${instance?.id}. `, + className: 'cvat-notification-notice-import-dataset-start', + }); + }, + [instance?.id, file], + ); + + return ( + <> + + Import dataset to project + + + + + )} + visible={modalVisible} + onCancel={closeModal} + onOk={() => form.submit()} + className='cvat-modal-import-dataset' + > +
+ + + + { + if (!['application/zip', 'application/x-zip-compressed'].includes(_file.type)) { + message.error('Only ZIP archive is supported'); + } else { + setFile(_file); + } + return false; + }} + onRemove={() => { + setFile(null); + }} + > +

+ +

+

Click or drag file to this area

+
+
+
+ + + ); +} + +export default React.memo(ImportDatasetModal); diff --git a/cvat-ui/src/components/import-dataset-modal/import-dataset-status-modal.tsx b/cvat-ui/src/components/import-dataset-modal/import-dataset-status-modal.tsx new file mode 100644 index 00000000..01d29ee9 --- /dev/null +++ b/cvat-ui/src/components/import-dataset-modal/import-dataset-status-modal.tsx @@ -0,0 +1,34 @@ +// Copyright (C) 2021 Intel Corporation +// +// SPDX-License-Identifier: MIT + +import './styles.scss'; +import React from 'react'; +import { useSelector } from 'react-redux'; +import Modal from 'antd/lib/modal'; +import Alert from 'antd/lib/alert'; +import Progress from 'antd/lib/progress'; + +import { CombinedState } from 'reducers/interfaces'; + +function ImportDatasetStatusModal(): JSX.Element { + const currentImportId = useSelector((state: CombinedState) => state.import.importingId); + const progress = useSelector((state: CombinedState) => state.import.progress); + const status = useSelector((state: CombinedState) => state.import.status); + + return ( + + + + + ); +} + +export default React.memo(ImportDatasetStatusModal); diff --git a/cvat-ui/src/components/import-dataset-modal/styles.scss b/cvat-ui/src/components/import-dataset-modal/styles.scss new file mode 100644 index 00000000..3ad34643 --- /dev/null +++ b/cvat-ui/src/components/import-dataset-modal/styles.scss @@ -0,0 +1,32 @@ +// Copyright (C) 2021 Intel Corporation +// +// SPDX-License-Identifier: MIT + +@import '../../base.scss'; + +.cvat-modal-import-dataset-option-item > .ant-select-item-option-content, +.cvat-modal-import-select .ant-select-selection-item { + > span[role='img'] { + color: $info-icon-color; + margin-right: $grid-unit-size; + } +} + +.cvat-modal-import-header-question-icon { + margin-left: $grid-unit-size; + color: $text-color-secondary; +} + +.cvat-modal-import-dataset-status .ant-modal-body { + display: flex; + align-items: center; + flex-flow: column; + + .ant-progress { + margin-bottom: $grid-unit-size * 2; + } + + .ant-alert { + width: 100%; + } +} diff --git a/cvat-ui/src/components/project-page/project-page.tsx b/cvat-ui/src/components/project-page/project-page.tsx index 56847ea4..80b5f5af 100644 --- a/cvat-ui/src/components/project-page/project-page.tsx +++ b/cvat-ui/src/components/project-page/project-page.tsx @@ -21,6 +21,7 @@ import TaskItem from 'components/tasks-page/task-item'; import SearchField from 'components/search-field/search-field'; import MoveTaskModal from 'components/move-task-modal/move-task-modal'; import ModelRunnerDialog from 'components/model-runner-modal/model-runner-dialog'; +import ImportDatasetModal from 'components/import-dataset-modal/import-dataset-modal'; import { useDidUpdateEffect } from 'utils/hooks'; import DetailsComponent from './details'; import ProjectTopBar from './top-bar'; @@ -171,6 +172,7 @@ export default function ProjectPageComponent(): JSX.Element { + ); } diff --git a/cvat-ui/src/components/projects-page/actions-menu.tsx b/cvat-ui/src/components/projects-page/actions-menu.tsx index 4b6657e7..f7f16611 100644 --- a/cvat-ui/src/components/projects-page/actions-menu.tsx +++ b/cvat-ui/src/components/projects-page/actions-menu.tsx @@ -2,13 +2,14 @@ // // SPDX-License-Identifier: MIT -import React from 'react'; +import React, { useCallback } from 'react'; import { useDispatch } from 'react-redux'; import Modal from 'antd/lib/modal'; import Menu from 'antd/lib/menu'; import { deleteProjectAsync } from 'actions/projects-actions'; import { exportActions } from 'actions/export-actions'; +import { importActions } from 'actions/import-actions'; interface Props { projectInstance: any; @@ -19,7 +20,7 @@ export default function ProjectActionsMenuComponent(props: Props): JSX.Element { const dispatch = useDispatch(); - const onDeleteProject = (): void => { + const onDeleteProject = useCallback((): void => { Modal.confirm({ title: `The project ${projectInstance.id} will be deleted`, content: 'All related data (images, annotations) will be lost. Continue?', @@ -33,21 +34,18 @@ export default function ProjectActionsMenuComponent(props: Props): JSX.Element { }, okText: 'Delete', }); - }; + }, []); return ( - - dispatch(exportActions.openExportModal(projectInstance))} - > - Export project dataset + + dispatch(exportActions.openExportModal(projectInstance))}> + Export dataset -
- + dispatch(importActions.openImportModal(projectInstance))}> + Import dataset + + + Delete
diff --git a/cvat-ui/src/components/projects-page/projects-page.tsx b/cvat-ui/src/components/projects-page/projects-page.tsx index f200b4ba..941079a5 100644 --- a/cvat-ui/src/components/projects-page/projects-page.tsx +++ b/cvat-ui/src/components/projects-page/projects-page.tsx @@ -11,6 +11,7 @@ import Spin from 'antd/lib/spin'; import { CombinedState, ProjectsQuery } from 'reducers/interfaces'; import { getProjectsAsync } from 'actions/projects-actions'; import FeedbackComponent from 'components/feedback/feedback'; +import ImportDatasetModal from 'components/import-dataset-modal/import-dataset-modal'; import EmptyListComponent from './empty-list'; import TopBarComponent from './top-bar'; import ProjectListComponent from './project-list'; @@ -55,6 +56,7 @@ export default function ProjectsPageComponent(): JSX.Element { {projectsCount ? : } + ); } diff --git a/cvat-ui/src/reducers/import-reducer.ts b/cvat-ui/src/reducers/import-reducer.ts new file mode 100644 index 00000000..db851e83 --- /dev/null +++ b/cvat-ui/src/reducers/import-reducer.ts @@ -0,0 +1,61 @@ +// Copyright (C) 2021 Intel Corporation +// +// SPDX-License-Identifier: MIT + +import { ImportActions, ImportActionTypes } from 'actions/import-actions'; + +import { ImportState } from './interfaces'; + +const defaultState: ImportState = { + progress: 0.0, + status: '', + instance: null, + importingId: null, + modalVisible: false, +}; + +export default (state: ImportState = defaultState, action: ImportActions): ImportState => { + switch (action.type) { + case ImportActionTypes.OPEN_IMPORT_MODAL: + return { + ...state, + modalVisible: true, + instance: action.payload.instance, + }; + case ImportActionTypes.CLOSE_IMPORT_MODAL: { + return { + ...state, + modalVisible: false, + instance: null, + }; + } + case ImportActionTypes.IMPORT_DATASET: { + const { id } = action.payload; + + return { + ...state, + importingId: id, + status: 'The file is being uploaded to the server', + }; + } + case ImportActionTypes.IMPORT_DATASET_UPDATE_STATUS: { + const { progress, status } = action.payload; + return { + ...state, + progress, + status, + }; + } + case ImportActionTypes.IMPORT_DATASET_FAILED: + case ImportActionTypes.IMPORT_DATASET_SUCCESS: { + return { + ...state, + progress: defaultState.progress, + status: defaultState.status, + importingId: null, + }; + } + default: + return state; + } +}; diff --git a/cvat-ui/src/reducers/interfaces.ts b/cvat-ui/src/reducers/interfaces.ts index ff83618b..895d0698 100644 --- a/cvat-ui/src/reducers/interfaces.ts +++ b/cvat-ui/src/reducers/interfaces.ts @@ -117,6 +117,14 @@ export interface ExportState { modalVisible: boolean; } +export interface ImportState { + importingId: number | null; + progress: number; + status: string; + instance: any; + modalVisible: boolean; +} + export interface FormatsState { annotationFormats: any; fetching: boolean; @@ -396,6 +404,14 @@ export interface NotificationsState { predictor: { prediction: null | ErrorState; }; + exporting: { + dataset: null | ErrorState; + annotation: null | ErrorState; + }; + importing: { + dataset: null | ErrorState; + annotation: null | ErrorState; + }; cloudStorages: { creating: null | ErrorState; fetching: null | ErrorState; @@ -705,6 +721,7 @@ export interface CombinedState { shortcuts: ShortcutsState; review: ReviewState; export: ExportState; + import: ImportState; cloudStorages: CloudStoragesState; } diff --git a/cvat-ui/src/reducers/notifications-reducer.ts b/cvat-ui/src/reducers/notifications-reducer.ts index 398a0a47..2f6b2f8d 100644 --- a/cvat-ui/src/reducers/notifications-reducer.ts +++ b/cvat-ui/src/reducers/notifications-reducer.ts @@ -17,6 +17,7 @@ import { BoundariesActionTypes } from 'actions/boundaries-actions'; import { UserAgreementsActionTypes } from 'actions/useragreements-actions'; import { ReviewActionTypes } from 'actions/review-actions'; import { ExportActionTypes } from 'actions/export-actions'; +import { ImportActionTypes } from 'actions/import-actions'; import { CloudStorageActionTypes } from 'actions/cloud-storage-actions'; import getCore from 'cvat-core-wrapper'; @@ -115,6 +116,14 @@ const defaultState: NotificationsState = { predictor: { prediction: null, }, + exporting: { + dataset: null, + annotation: null, + }, + importing: { + dataset: null, + annotation: null, + }, cloudStorages: { creating: null, fetching: null, @@ -327,9 +336,9 @@ export default function (state = defaultState, action: AnyAction): Notifications ...state, errors: { ...state.errors, - tasks: { - ...state.errors.tasks, - exportingAsDataset: { + exporting: { + ...state.errors.exporting, + dataset: { message: 'Could not export dataset for the ' + `` + @@ -340,6 +349,25 @@ export default function (state = defaultState, action: AnyAction): Notifications }, }; } + case ImportActionTypes.IMPORT_DATASET_FAILED: { + const instanceID = action.payload.instance.id; + return { + ...state, + errors: { + ...state.errors, + exporting: { + ...state.errors.exporting, + dataset: { + message: + 'Could not import dataset to the ' + + `` + + `project ${instanceID}`, + reason: action.payload.error.toString(), + }, + }, + }, + }; + } case TasksActionTypes.GET_TASKS_FAILED: { return { ...state, diff --git a/cvat-ui/src/reducers/root-reducer.ts b/cvat-ui/src/reducers/root-reducer.ts index ae2e400f..df534b22 100644 --- a/cvat-ui/src/reducers/root-reducer.ts +++ b/cvat-ui/src/reducers/root-reducer.ts @@ -18,6 +18,7 @@ import shortcutsReducer from './shortcuts-reducer'; import userAgreementsReducer from './useragreements-reducer'; import reviewReducer from './review-reducer'; import exportReducer from './export-reducer'; +import importReducer from './import-reducer'; import cloudStoragesReducer from './cloud-storages-reducer'; export default function createRootReducer(): Reducer { @@ -37,6 +38,7 @@ export default function createRootReducer(): Reducer { userAgreements: userAgreementsReducer, review: reviewReducer, export: exportReducer, + import: importReducer, cloudStorages: cloudStoragesReducer, }); } diff --git a/cvat/apps/dataset_manager/bindings.py b/cvat/apps/dataset_manager/bindings.py index 76873ef1..64dc1bb5 100644 --- a/cvat/apps/dataset_manager/bindings.py +++ b/cvat/apps/dataset_manager/bindings.py @@ -3,27 +3,33 @@ # # SPDX-License-Identifier: MIT -import os.path as osp import sys +import rq +import os.path as osp +from attr import attrib, attrs from collections import namedtuple from pathlib import Path from typing import (Any, Callable, DefaultDict, Dict, List, Literal, Mapping, - NamedTuple, OrderedDict, Tuple, Union) + NamedTuple, OrderedDict, Tuple, Union, Set) import datumaro.components.annotation as datum_annotation import datumaro.components.extractor as datum_extractor +from datumaro.components.dataset import Dataset from datumaro.util import cast from datumaro.util.image import ByteImage, Image from django.utils import timezone from cvat.apps.engine.frame_provider import FrameProvider -from cvat.apps.engine.models import AttributeType, DimensionType +from cvat.apps.engine.models import AttributeType, DimensionType, AttributeSpec from cvat.apps.engine.models import Image as Img from cvat.apps.engine.models import Label, Project, ShapeType, Task +from cvat.apps.dataset_manager.formats.utils import get_label_color from .annotation import AnnotationIR, AnnotationManager, TrackManager +CVAT_INTERNAL_ATTRIBUTES = {'occluded', 'outside', 'keyframe', 'track_id', 'rotation'} + class InstanceLabelData: Attribute = NamedTuple('Attribute', [('name', str), ('value', Any)]) @@ -32,6 +38,8 @@ class InstanceLabelData: db_labels = instance.label_set.all().prefetch_related('attributespec_set').order_by('pk') + # If this flag is set to true, create attribute within anntations import + self._soft_attribute_import = False self._label_mapping = OrderedDict[int, Label]( ((db_label.id, db_label) for db_label in db_labels), ) @@ -86,7 +94,7 @@ class InstanceLabelData: def _get_immutable_attribute_id(self, label_id, attribute_name): return self._get_attribute_id(label_id, attribute_name, 'immutable') - def _import_attribute(self, label_id, attribute): + def _import_attribute(self, label_id, attribute, mutable=False): spec_id = self._get_attribute_id(label_id, attribute.name) value = attribute.value @@ -108,6 +116,39 @@ class InstanceLabelData: raise Exception("Failed to convert attribute '%s'='%s': %s" % (self._get_label_name(label_id), value, e)) + elif self._soft_attribute_import: + if isinstance(value, (int, float)): + attr_type = AttributeType.NUMBER + elif isinstance(value, bool): + attr_type = AttributeType.CHECKBOX + else: + value = str(value) + if value.lower() in {'true', 'false'}: + value = value.lower() == 'true' + attr_type = AttributeType.CHECKBOX + else: + attr_type = AttributeType.TEXT + + attr_spec = AttributeSpec( + label_id=label_id, + name=attribute.name, + input_type=attr_type, + mutable=mutable, + ) + attr_spec.save() + spec_id = attr_spec.id + if label_id not in self._label_mapping: + self._label_mapping[label_id] = Label.objects.get(id=label_id) + if label_id not in self._attribute_mapping: + self._attribute_mapping[label_id] = {'mutable': {}, 'immutable': {}, 'spec': {}} + self._attribute_mapping[label_id]['immutable'][spec_id] = attribute.name + self._attribute_mapping[label_id]['spec'][spec_id] = attr_spec + self._attribute_mapping_merged[label_id] = { + **self._attribute_mapping[label_id]['mutable'], + **self._attribute_mapping[label_id]['immutable'], + } + + return { 'spec_id': spec_id, 'value': value } def _export_attributes(self, attributes): @@ -397,6 +438,14 @@ class TaskData(InstanceLabelData): def meta(self): return self._meta + @property + def soft_attribute_import(self): + return self._soft_attribute_import + + @soft_attribute_import.setter + def soft_attribute_import(self, value: bool): + self._soft_attribute_import = value + def _import_tag(self, tag): _tag = tag._asdict() label_id = self._get_label_id(_tag.pop('label')) @@ -404,7 +453,10 @@ class TaskData(InstanceLabelData): _tag['label_id'] = label_id _tag['attributes'] = [self._import_attribute(label_id, attrib) for attrib in _tag['attributes'] - if self._get_attribute_id(label_id, attrib.name)] + if self._get_attribute_id(label_id, attrib.name) or ( + self.soft_attribute_import and attrib.name not in CVAT_INTERNAL_ATTRIBUTES + ) + ] return _tag def _import_shape(self, shape): @@ -414,7 +466,10 @@ class TaskData(InstanceLabelData): _shape['label_id'] = label_id _shape['attributes'] = [self._import_attribute(label_id, attrib) for attrib in _shape['attributes'] - if self._get_attribute_id(label_id, attrib.name)] + if self._get_attribute_id(label_id, attrib.name) or ( + self.soft_attribute_import and attrib.name not in CVAT_INTERNAL_ATTRIBUTES + ) + ] _shape['points'] = list(map(float, _shape['points'])) return _shape @@ -430,10 +485,16 @@ class TaskData(InstanceLabelData): shape['frame'] = self.rel_frame_id(int(shape['frame'])) _track['attributes'] = [self._import_attribute(label_id, attrib) for attrib in shape['attributes'] - if self._get_immutable_attribute_id(label_id, attrib.name)] - shape['attributes'] = [self._import_attribute(label_id, attrib) + if self._get_immutable_attribute_id(label_id, attrib.name) or ( + self.soft_attribute_import and attrib.name not in CVAT_INTERNAL_ATTRIBUTES + ) + ] + shape['attributes'] = [self._import_attribute(label_id, attrib, mutable=True) for attrib in shape['attributes'] - if self._get_mutable_attribute_id(label_id, attrib.name)] + if self._get_mutable_attribute_id(label_id, attrib.name) or ( + self.soft_attribute_import and attrib.name not in CVAT_INTERNAL_ATTRIBUTES + ) + ] shape['points'] = list(map(float, shape['points'])) return _track @@ -510,40 +571,86 @@ class TaskData(InstanceLabelData): return None class ProjectData(InstanceLabelData): - LabeledShape = NamedTuple('LabledShape', [('type', str), ('frame', int), ('label', str), ('points', List[float]), ('occluded', bool), ('attributes', List[InstanceLabelData.Attribute]), ('source', str), ('group', int), ('rotation', float), ('z_order', int), ('task_id', int)]) - LabeledShape.__new__.__defaults__ = (0, 0, 0) - TrackedShape = NamedTuple('TrackedShape', - [('type', str), ('frame', int), ('points', List[float]), ('occluded', bool), ('outside', bool), ('keyframe', bool), ('attributes', List[InstanceLabelData.Attribute]), ('rotation', float), ('source', str), ('group', int), ('z_order', int), ('label', str), ('track_id', int)], - ) - TrackedShape.__new__.__defaults__ = (0, 'manual', 0, 0, None, 0) - Track = NamedTuple('Track', [('label', str), ('group', int), ('source', str), ('shapes', List[TrackedShape]), ('task_id', int)]) - Tag = NamedTuple('Tag', [('frame', int), ('label', str), ('attributes', List[InstanceLabelData.Attribute]), ('source', str), ('group', int), ('task_id', int)]) - Tag.__new__.__defaults__ = (0, ) - Frame = NamedTuple('Frame', [('task_id', int), ('subset', str), ('idx', int), ('id', int), ('frame', int), ('name', str), ('width', int), ('height', int), ('labeled_shapes', List[Union[LabeledShape, TrackedShape]]), ('tags', List[Tag])]) - - def __init__(self, annotation_irs: Mapping[str, AnnotationIR], db_project: Project, host: str, create_callback: Callable = None): + @attrs + class LabeledShape: + type: str = attrib() + frame: int = attrib() + label: str = attrib() + points: List[float] = attrib() + occluded: bool = attrib() + attributes: List[InstanceLabelData.Attribute] = attrib() + source: str = attrib(default='manual') + group: int = attrib(default=0) + rotation: int = attrib(default=0) + z_order: int = attrib(default=0) + task_id: int = attrib(default=None) + subset: str = attrib(default=None) + + @attrs + class TrackedShape: + type: str = attrib() + frame: int = attrib() + points: List[float] = attrib() + occluded: bool = attrib() + outside: bool = attrib() + keyframe: bool = attrib() + attributes: List[InstanceLabelData.Attribute] = attrib() + rotation: int = attrib(default=0) + source: str = attrib(default='manual') + group: int = attrib(default=0) + z_order: int = attrib(default=0) + label: str = attrib(default=None) + track_id: int = attrib(default=0) + + @attrs + class Track: + label: str = attrib() + shapes: List['ProjectData.TrackedShape'] = attrib() + source: str = attrib(default='manual') + group: int = attrib(default=0) + task_id: int = attrib(default=None) + subset: str = attrib(default=None) + + @attrs + class Tag: + frame: int = attrib() + label: str = attrib() + attributes: List[InstanceLabelData.Attribute] = attrib() + source: str = attrib(default='manual') + group: int = attrib(default=0) + task_id: int = attrib(default=None) + subset: str = attrib(default=None) + + @attrs + class Frame: + idx: int = attrib() + id: int = attrib() + frame: int = attrib() + name: str = attrib() + width: int = attrib() + height: int = attrib() + labeled_shapes: List[Union['ProjectData.LabeledShape', 'ProjectData.TrackedShape']] = attrib() + tags: List['ProjectData.Tag'] = attrib() + task_id: int = attrib(default=None) + subset: str = attrib(default=None) + + def __init__(self, annotation_irs: Mapping[str, AnnotationIR], db_project: Project, host: str = '', task_annotations: Mapping[int, Any] = None, project_annotation=None): self._annotation_irs = annotation_irs self._db_project = db_project - self._db_tasks: OrderedDict[int, Task] = OrderedDict( - ((db_task.id, db_task) for db_task in db_project.tasks.order_by("subset","id").all()) - ) - self._subsets = set() + self._task_annotations = task_annotations self._host = host - self._create_callback = create_callback - self._MAX_ANNO_SIZE = 30000 + self._soft_attribute_import = False + self._project_annotation = project_annotation + self._tasks_data: Dict[int, TaskData] = {} self._frame_info: Dict[Tuple[int, int], Literal["path", "width", "height", "subset"]] = dict() - self._frame_mapping: Dict[Tuple[str, str], Tuple[str, str]] = dict() - self._frame_steps: Dict[int, int] = {task.id: task.data.get_frame_step() for task in self._db_tasks.values()} - - for task in self._db_tasks.values(): - self._subsets.add(task.subset) - self._subsets: List[str] = list(self._subsets) + # (subset, path): (task id, frame number) + self._frame_mapping: Dict[Tuple[str, str], Tuple[int, int]] = dict() + self._frame_steps: Dict[int, int] = {} + self.new_tasks: Set[int] = set() InstanceLabelData.__init__(self, db_project) + self.init() - self._init_task_frame_offsets() - self._init_frame_info() - self._init_meta() def abs_frame_id(self, task_id: int, relative_id: int) -> int: task = self._db_tasks[task_id] @@ -559,6 +666,24 @@ class ProjectData(InstanceLabelData): raise ValueError(f"Unknown frame {absolute_id}") return d + def init(self): + self._init_tasks() + self._init_task_frame_offsets() + self._init_frame_info() + self._init_meta() + + def _init_tasks(self): + self._db_tasks: OrderedDict[int, Task] = OrderedDict( + ((db_task.id, db_task) for db_task in self._db_project.tasks.order_by("subset","id").all()) + ) + + subsets = set() + for task in self._db_tasks.values(): + subsets.add(task.subset) + self._subsets: List[str] = list(subsets) + + self._frame_steps: Dict[int, int] = {task.id: task.data.get_frame_step() for task in self._db_tasks.values()} + def _init_task_frame_offsets(self): self._task_frame_offsets: Dict[int, int] = dict() s = 0 @@ -627,6 +752,8 @@ class ProjectData(InstanceLabelData): ])) for db_label in self._label_mapping.values() ]), + ("subsets", '\n'.join([s if s else datum_extractor.DEFAULT_SUBSET_NAME for s in self._subsets])), + ("owner", OrderedDict([ ("username", self._db_project.owner.username), ("email", self._db_project.owner.email), @@ -789,15 +916,72 @@ class ProjectData(InstanceLabelData): def tasks(self): return list(self._db_tasks.values()) + @property + def soft_attribute_import(self): + return self._soft_attribute_import + + @soft_attribute_import.setter + def soft_attribute_import(self, value: bool): + self._soft_attribute_import = value + for task_data in self._tasks_data.values(): + task_data.soft_attribute_import = value + @property def task_data(self): for task_id, task in self._db_tasks.items(): - yield TaskData(self._annotation_irs[task_id], task, self._host) + if task_id in self._tasks_data: + yield self._tasks_data[task_id] + else: + task_data = TaskData( + annotation_ir=self._annotation_irs[task_id], + db_task=task, + host=self._host, + create_callback=self._task_annotations[task_id].create \ + if self._task_annotations is not None else None, + ) + task_data._MAX_ANNO_SIZE //= len(self._db_tasks) + task_data.soft_attribute_import = self.soft_attribute_import + self._tasks_data[task_id] = task_data + yield task_data @staticmethod def _get_filename(path): return osp.splitext(path)[0] + def match_frame(self, path: str, subset: str=datum_extractor.DEFAULT_SUBSET_NAME, root_hint: str=None, path_has_ext: bool=True): + if path_has_ext: + path = self._get_filename(path) + match_task, match_frame = self._frame_mapping.get((subset, path), (None, None)) + if not match_frame and root_hint and not path.startswith(root_hint): + path = osp.join(root_hint, path) + match_task, match_frame = self._frame_mapping.get((subset, path), (None, None)) + return match_task, match_frame + + def match_frame_fuzzy(self, path): + path = Path(self._get_filename(path)).parts + for (_subset, _path), (_tid, frame_number) in self._frame_mapping.items(): + if Path(_path).parts[-len(path):] == path : + return frame_number + return None + + def split_dataset(self, dataset: Dataset): + for task_data in self.task_data: + if task_data._db_task.id not in self.new_tasks: + continue + subset_dataset: Dataset = dataset.subsets()[task_data.db_task.subset].as_dataset() + yield subset_dataset, task_data + + def add_labels(self, labels: List[dict]): + attributes = [] + _labels = [] + for label in labels: + _attributes = label.pop('attributes') + _labels.append(Label(**label)) + attributes += [(label['name'], AttributeSpec(**at)) for at in _attributes] + self._project_annotation.add_labels(_labels, attributes) + + def add_task(self, task, files): + self._project_annotation.add_task(task, files, self) class CVATDataExtractorMixin: def __init__(self): @@ -1192,23 +1376,33 @@ def match_dm_item(item, task_data, root_hint=None): "'%s' with any task frame" % item.id) return frame_number -def find_dataset_root(dm_dataset, task_data): +def find_dataset_root(dm_dataset, instance_data: Union[TaskData, ProjectData]): longest_path = max(dm_dataset, key=lambda x: len(Path(x.id).parts), default=None) if longest_path is None: return None longest_path = longest_path.id - longest_match = task_data.match_frame_fuzzy(longest_path) + longest_match = instance_data.match_frame_fuzzy(longest_path) if longest_match is None: return None - longest_match = osp.dirname(task_data.frame_info[longest_match]['path']) + longest_match = osp.dirname(instance_data.frame_info[longest_match]['path']) prefix = longest_match[:-len(osp.dirname(longest_path)) or None] if prefix.endswith('/'): prefix = prefix[:-1] return prefix -def import_dm_annotations(dm_dataset, task_data): +def import_dm_annotations(dm_dataset: Dataset, instance_data: Union[TaskData, ProjectData]): + if len(dm_dataset) == 0: + return + + if isinstance(instance_data, ProjectData): + for sub_dataset, task_data in instance_data.split_dataset(dm_dataset): + # FIXME: temporary workaround for cvat format, will be removed after migration importer to datumaro + sub_dataset._format = dm_dataset.format + import_dm_annotations(sub_dataset, task_data) + return + shapes = { datum_annotation.AnnotationType.bbox: ShapeType.RECTANGLE, datum_annotation.AnnotationType.polygon: ShapeType.POLYGON, @@ -1217,16 +1411,15 @@ def import_dm_annotations(dm_dataset, task_data): datum_annotation.AnnotationType.cuboid_3d: ShapeType.CUBOID } - if len(dm_dataset) == 0: - return - label_cat = dm_dataset.categories()[datum_annotation.AnnotationType.label] - root_hint = find_dataset_root(dm_dataset, task_data) + root_hint = find_dataset_root(dm_dataset, instance_data) + + tracks = {} for item in dm_dataset: - frame_number = task_data.abs_frame_id( - match_dm_item(item, task_data, root_hint=root_hint)) + frame_number = instance_data.abs_frame_id( + match_dm_item(item, instance_data, root_hint=root_hint)) # do not store one-item groups group_map = {0: 0} @@ -1255,27 +1448,117 @@ def import_dm_annotations(dm_dataset, task_data): except Exception as e: ann.points = ann.points ann.z_order = 0 - task_data.add_shape(task_data.LabeledShape( - type=shapes[ann.type], - frame=frame_number, - points = ann.points, - label=label_cat.items[ann.label].name, - occluded=ann.attributes.get('occluded') == True, - z_order=ann.z_order, - group=group_map.get(ann.group, 0), - source='manual', - attributes=[task_data.Attribute(name=n, value=str(v)) - for n, v in ann.attributes.items()], - )) + + track_id = ann.attributes.pop('track_id', None) + if track_id is None or dm_dataset.format != 'cvat' : + instance_data.add_shape(instance_data.LabeledShape( + type=shapes[ann.type], + frame=frame_number, + points=ann.points, + label=label_cat.items[ann.label].name, + occluded=ann.attributes.pop('occluded', None) == True, + z_order=ann.z_order, + group=group_map.get(ann.group, 0), + source=str(ann.attributes.pop('source')).lower() \ + if str(ann.attributes.get('source', None)).lower() in {'auto', 'manual'} else 'manual', + attributes=[instance_data.Attribute(name=n, value=str(v)) + for n, v in ann.attributes.items()], + )) + continue + + if ann.attributes.get('keyframe', None) == True or ann.attributes.get('outside', None) == True: + track = instance_data.TrackedShape( + type=shapes[ann.type], + frame=frame_number, + occluded=ann.attributes.pop('occluded', None) == True, + outside=ann.attributes.pop('outside', None) == True, + keyframe=ann.attributes.get('keyframe', None) == True, + points=ann.points, + z_order=ann.z_order, + source=str(ann.attributes.pop('source')).lower() \ + if str(ann.attributes.get('source', None)).lower() in {'auto', 'manual'} else 'manual', + attributes=[instance_data.Attribute(name=n, value=str(v)) + for n, v in ann.attributes.items()], + ) + + if track_id not in tracks: + tracks[track_id] = instance_data.Track( + label=label_cat.items[ann.label].name, + group=group_map.get(ann.group, 0), + source=str(ann.attributes.pop('source')).lower() \ + if str(ann.attributes.get('source', None)).lower() in {'auto', 'manual'} else 'manual', + shapes=[], + ) + + tracks[track_id].shapes.append(track) + elif ann.type == datum_annotation.AnnotationType.label: - task_data.add_tag(task_data.Tag( + instance_data.add_tag(instance_data.Tag( frame=frame_number, label=label_cat.items[ann.label].name, group=group_map.get(ann.group, 0), source='manual', - attributes=[task_data.Attribute(name=n, value=str(v)) + attributes=[instance_data.Attribute(name=n, value=str(v)) for n, v in ann.attributes.items()], )) except Exception as e: raise CvatImportError("Image {}: can't import annotation " - "#{} ({}): {}".format(item.id, idx, ann.type.name, e)) + "#{} ({}): {}".format(item.id, idx, ann.type.name, e)) from e + + for track in tracks.values(): + instance_data.add_track(track) + + +def import_labels_to_project(project_annotation, dataset: Dataset): + labels = [] + label_colors = [] + for label in dataset.categories()[datum_annotation.AnnotationType.label].items: + db_label = Label( + name=label.name, + color=get_label_color(label.name, label_colors) + ) + labels.append(db_label) + label_colors.append(db_label.color) + project_annotation.add_labels(labels) + +def load_dataset_data(project_annotation, dataset: Dataset, project_data): + if not project_annotation.db_project.label_set.count(): + import_labels_to_project(project_annotation, dataset) + else: + for label in dataset.categories()[datum_annotation.AnnotationType.label].items: + if not project_annotation.db_project.label_set.filter(name=label.name).exists(): + raise CvatImportError(f'Target project does not have label with name "{label.name}"') + for subset_id, subset in enumerate(dataset.subsets().values()): + job = rq.get_current_job() + job.meta['status'] = 'Task from dataset is being created...' + job.meta['progress'] = (subset_id + job.meta.get('task_progress', 0.)) / len(dataset.subsets().keys()) + job.save_meta() + + task_fields = { + 'project': project_annotation.db_project, + 'name': subset.name, + 'owner': project_annotation.db_project.owner, + 'subset': subset.name, + } + + subset_dataset = subset.as_dataset() + + dataset_files = { + 'media': [], + 'data_root': dataset.data_path + osp.sep, + } + + for dataset_item in subset_dataset: + if dataset_item.image and dataset_item.image.has_data: + dataset_files['media'].append(dataset_item.image.path) + elif dataset_item.point_cloud: + dataset_files['media'].append(dataset_item.point_cloud) + if isinstance(dataset_item.related_images, list): + dataset_files['media'] += \ + list(map(lambda ri: ri.path, dataset_item.related_images)) + + shortes_path = min(dataset_files['media'], key=lambda x: len(Path(x).parts), default=None) + if shortes_path is not None: + dataset_files['data_root'] = str(Path(shortes_path).parent.absolute()) + osp.sep + + project_annotation.add_task(task_fields, dataset_files, project_data) diff --git a/cvat/apps/dataset_manager/formats/camvid.py b/cvat/apps/dataset_manager/formats/camvid.py index 983fcec8..bbb31217 100644 --- a/cvat/apps/dataset_manager/formats/camvid.py +++ b/cvat/apps/dataset_manager/formats/camvid.py @@ -33,10 +33,12 @@ def _export(dst_file, instance_data, save_images=False): make_zip_archive(temp_dir, dst_file) @importer(name='CamVid', ext='ZIP', version='1.0') -def _import(src_file, instance_data): +def _import(src_file, instance_data, load_data_callback=None): with TemporaryDirectory() as tmp_dir: Archive(src_file.name).extractall(tmp_dir) dataset = Dataset.import_from(tmp_dir, 'camvid', env=dm_env) dataset.transform('masks_to_polygons') + if load_data_callback is not None: + load_data_callback(dataset, instance_data) import_dm_annotations(dataset, instance_data) diff --git a/cvat/apps/dataset_manager/formats/cityscapes.py b/cvat/apps/dataset_manager/formats/cityscapes.py index fd5655a1..df82f4d6 100644 --- a/cvat/apps/dataset_manager/formats/cityscapes.py +++ b/cvat/apps/dataset_manager/formats/cityscapes.py @@ -34,7 +34,7 @@ def _export(dst_file, instance_data, save_images=False): make_zip_archive(temp_dir, dst_file) @importer(name='Cityscapes', ext='ZIP', version='1.0') -def _import(src_file, instance_data): +def _import(src_file, instance_data, load_data_callback=None): with TemporaryDirectory() as tmp_dir: Archive(src_file.name).extractall(tmp_dir) @@ -46,4 +46,6 @@ def _import(src_file, instance_data): dataset = Dataset.import_from(tmp_dir, 'cityscapes', env=dm_env) dataset.transform('masks_to_polygons') + if load_data_callback is not None: + load_data_callback(dataset, instance_data) import_dm_annotations(dataset, instance_data) diff --git a/cvat/apps/dataset_manager/formats/coco.py b/cvat/apps/dataset_manager/formats/coco.py index 927df2de..be85277e 100644 --- a/cvat/apps/dataset_manager/formats/coco.py +++ b/cvat/apps/dataset_manager/formats/coco.py @@ -13,7 +13,6 @@ from cvat.apps.dataset_manager.util import make_zip_archive from .registry import dm_env, exporter, importer - @exporter(name='COCO', ext='ZIP', version='1.0') def _export(dst_file, instance_data, save_images=False): dataset = Dataset.from_extractors(GetCVATDataExtractor( @@ -25,14 +24,17 @@ def _export(dst_file, instance_data, save_images=False): make_zip_archive(temp_dir, dst_file) @importer(name='COCO', ext='JSON, ZIP', version='1.0') -def _import(src_file, instance_data): +def _import(src_file, instance_data, load_data_callback=None): if zipfile.is_zipfile(src_file): with TemporaryDirectory() as tmp_dir: zipfile.ZipFile(src_file).extractall(tmp_dir) + dataset = Dataset.import_from(tmp_dir, 'coco', env=dm_env) + if load_data_callback is not None: + load_data_callback(dataset, instance_data) import_dm_annotations(dataset, instance_data) else: dataset = Dataset.import_from(src_file.name, 'coco_instances', env=dm_env) - import_dm_annotations(dataset, instance_data) \ No newline at end of file + import_dm_annotations(dataset, instance_data) diff --git a/cvat/apps/dataset_manager/formats/cvat.py b/cvat/apps/dataset_manager/formats/cvat.py index 8dde913e..1bc472f7 100644 --- a/cvat/apps/dataset_manager/formats/cvat.py +++ b/cvat/apps/dataset_manager/formats/cvat.py @@ -5,19 +5,418 @@ from io import BufferedWriter import os import os.path as osp +from glob import glob from typing import Callable import zipfile from collections import OrderedDict -from glob import glob from tempfile import TemporaryDirectory +from defusedxml import ElementTree + +from datumaro.components.dataset import Dataset, DatasetItem +from datumaro.components.extractor import Importer, Extractor, DEFAULT_SUBSET_NAME +from datumaro.components.annotation import ( + AnnotationType, Bbox, Points, Polygon, PolyLine, Label, LabelCategories, +) -from datumaro.components.extractor import DatasetItem +from datumaro.util.image import Image -from cvat.apps.dataset_manager.bindings import TaskData, match_dm_item, ProjectData, get_defaulted_subset +from cvat.apps.dataset_manager.bindings import TaskData, match_dm_item, ProjectData, get_defaulted_subset, import_dm_annotations from cvat.apps.dataset_manager.util import make_zip_archive from cvat.apps.engine.frame_provider import FrameProvider -from .registry import exporter, importer +from .registry import exporter, importer, dm_env + +class CvatPath: + IMAGES_DIR = 'images' + + MEDIA_EXTS = ('.jpg', '.jpeg', '.png') + + BUILTIN_ATTRS = {'occluded', 'outside', 'keyframe', 'track_id'} + +class CvatExtractor(Extractor): + _SUPPORTED_SHAPES = ('box', 'polygon', 'polyline', 'points') + + def __init__(self, path, subsets=None): + assert osp.isfile(path), path + rootpath = osp.dirname(path) + images_dir = '' + if osp.isdir(osp.join(rootpath, CvatPath.IMAGES_DIR)): + images_dir = osp.join(rootpath, CvatPath.IMAGES_DIR) + self._images_dir = images_dir + self._path = path + + if not subsets: + subsets = self._get_subsets_from_anno(path) + self._subsets = subsets + super().__init__(subsets=self._subsets) + + image_items = self._parse_images(images_dir, self._subsets) + items, categories = self._parse(path) + self._items = list(self._load_items(items, image_items).values()) + self._categories = categories + + def categories(self): + return self._categories + + def __iter__(self): + yield from self._items + + def __len__(self): + return len(self._items) + + def get(self, _id, subset=DEFAULT_SUBSET_NAME): + assert subset in self._subsets, '{} not in {}'.format(subset, ', '.join(self._subsets)) + return super().get(_id, subset) + + @staticmethod + def _get_subsets_from_anno(path): + context = ElementTree.iterparse(path, events=("start", "end")) + context = iter(context) + + for ev, el in context: + if ev == 'start': + if el.tag == 'subsets': + if el.text is not None: + subsets = el.text.split('\n') + return subsets + if ev == 'end': + if el.tag == 'meta': + return [DEFAULT_SUBSET_NAME] + el.clear() + return [DEFAULT_SUBSET_NAME] + + @staticmethod + def _parse_images(image_dir, subsets): + items = OrderedDict() + + def parse_image_dir(image_dir, subset): + for file in sorted(glob(image_dir), key=osp.basename): + name, ext = osp.splitext(osp.basename(file)) + if ext.lower() in CvatPath.MEDIA_EXTS: + items[(subset, name)] = DatasetItem(id=name, annotations=[], + image=Image(path=file), subset=subset or DEFAULT_SUBSET_NAME, + ) + + if subsets == [DEFAULT_SUBSET_NAME] and not osp.isdir(osp.join(image_dir, DEFAULT_SUBSET_NAME)): + parse_image_dir(osp.join(image_dir, '*.*'), None) + else: + for subset in subsets: + parse_image_dir(osp.join(image_dir, subset, '*.*'), subset) + return items + + @classmethod + def _parse(cls, path): + context = ElementTree.iterparse(path, events=("start", "end")) + context = iter(context) + + categories, tasks_info, attribute_types = cls._parse_meta(context) + + items = OrderedDict() + + track = None + shape = None + tag = None + attributes = None + image = None + subset = None + for ev, el in context: + if ev == 'start': + if el.tag == 'track': + frame_size = tasks_info[int(el.attrib.get('task_id'))]['frame_size'] if el.attrib.get('task_id') else tuple(tasks_info.values())[0]['frame_size'] + track = { + 'id': el.attrib['id'], + 'label': el.attrib.get('label'), + 'group': int(el.attrib.get('group_id', 0)), + 'height': frame_size[0], + 'width': frame_size[1], + } + subset = el.attrib.get('subset') + elif el.tag == 'image': + image = { + 'name': el.attrib.get('name'), + 'frame': el.attrib['id'], + 'width': el.attrib.get('width'), + 'height': el.attrib.get('height'), + } + subset = el.attrib.get('subset') + elif el.tag in cls._SUPPORTED_SHAPES and (track or image): + attributes = {} + shape = { + 'type': None, + 'attributes': attributes, + } + if track: + shape.update(track) + shape['track_id'] = int(track['id']) + if image: + shape.update(image) + elif el.tag == 'tag' and image: + attributes = {} + tag = { + 'frame': image['frame'], + 'attributes': attributes, + 'group': int(el.attrib.get('group_id', 0)), + 'label': el.attrib['label'], + } + subset = el.attrib.get('subset') + elif ev == 'end': + if el.tag == 'attribute' and attributes is not None: + attr_value = el.text or '' + attr_type = attribute_types.get(el.attrib['name']) + if el.text in ['true', 'false']: + attr_value = attr_value == 'true' + elif attr_type is not None and attr_type != 'text': + try: + attr_value = float(attr_value) + except ValueError: + pass + attributes[el.attrib['name']] = attr_value + elif el.tag in cls._SUPPORTED_SHAPES: + if track is not None: + shape['frame'] = el.attrib['frame'] + shape['outside'] = (el.attrib.get('outside') == '1') + shape['keyframe'] = (el.attrib.get('keyframe') == '1') + if image is not None: + shape['label'] = el.attrib.get('label') + shape['group'] = int(el.attrib.get('group_id', 0)) + + shape['type'] = el.tag + shape['occluded'] = (el.attrib.get('occluded') == '1') + shape['z_order'] = int(el.attrib.get('z_order', 0)) + + if el.tag == 'box': + shape['points'] = list(map(float, [ + el.attrib['xtl'], el.attrib['ytl'], + el.attrib['xbr'], el.attrib['ybr'], + ])) + else: + shape['points'] = [] + for pair in el.attrib['points'].split(';'): + shape['points'].extend(map(float, pair.split(','))) + + frame_desc = items.get((subset, shape['frame']), {'annotations': []}) + frame_desc['annotations'].append( + cls._parse_shape_ann(shape, categories)) + items[(subset, shape['frame'])] = frame_desc + shape = None + + elif el.tag == 'tag': + frame_desc = items.get((subset, tag['frame']), {'annotations': []}) + frame_desc['annotations'].append( + cls._parse_tag_ann(tag, categories)) + items[(subset, tag['frame'])] = frame_desc + tag = None + elif el.tag == 'track': + track = None + elif el.tag == 'image': + frame_desc = items.get((subset, image['frame']), {'annotations': []}) + frame_desc.update({ + 'name': image.get('name'), + 'height': image.get('height'), + 'width': image.get('width'), + 'subset': subset, + }) + items[(subset, image['frame'])] = frame_desc + image = None + el.clear() + + return items, categories + + @staticmethod + def _parse_meta(context): + ev, el = next(context) + if not (ev == 'start' and el.tag == 'annotations'): + raise Exception("Unexpected token ") + + categories = {} + + tasks_info = {} + frame_size = [None, None] + task_id = None + mode = None + labels = OrderedDict() + label = None + + # Recursive descent parser + el = None + states = ['annotations'] + def accepted(expected_state, tag, next_state=None): + state = states[-1] + if state == expected_state and el is not None and el.tag == tag: + if not next_state: + next_state = tag + states.append(next_state) + return True + return False + def consumed(expected_state, tag): + state = states[-1] + if state == expected_state and el is not None and el.tag == tag: + states.pop() + return True + return False + + for ev, el in context: + if ev == 'start': + if accepted('annotations', 'meta'): pass + elif accepted('meta', 'task'): pass + elif accepted('meta', 'project'): pass + elif accepted('project', 'tasks'): pass + elif accepted('tasks', 'task'): pass + elif accepted('task', 'id', next_state='task_id'): pass + elif accepted('task', 'segment'): pass + elif accepted('task', 'mode'): pass + elif accepted('task', 'original_size'): pass + elif accepted('original_size', 'height', next_state='frame_height'): pass + elif accepted('original_size', 'width', next_state='frame_width'): pass + elif accepted('task', 'labels'): pass + elif accepted('project', 'labels'): pass + elif accepted('labels', 'label'): + label = { 'name': None, 'attributes': [] } + elif accepted('label', 'name', next_state='label_name'): pass + elif accepted('label', 'attributes'): pass + elif accepted('attributes', 'attribute'): pass + elif accepted('attribute', 'name', next_state='attr_name'): pass + elif accepted('attribute', 'input_type', next_state='attr_type'): pass + elif accepted('annotations', 'image') or \ + accepted('annotations', 'track') or \ + accepted('annotations', 'tag'): + break + else: + pass + elif ev == 'end': + if consumed('meta', 'meta'): + break + elif consumed('project', 'project'): pass + elif consumed('tasks', 'tasks'): pass + elif consumed('task', 'task'): + tasks_info[task_id] = { + 'frame_size': frame_size, + 'mode': mode, + } + frame_size = [None, None] + mode = None + elif consumed('task_id', 'id'): + task_id = int(el.text) + elif consumed('segment', 'segment'): pass + elif consumed('mode', 'mode'): + mode = el.text + elif consumed('original_size', 'original_size'): pass + elif consumed('frame_height', 'height'): + frame_size[0] = int(el.text) + elif consumed('frame_width', 'width'): + frame_size[1] = int(el.text) + elif consumed('label_name', 'name'): + label['name'] = el.text + elif consumed('attr_name', 'name'): + label['attributes'].append({'name': el.text}) + elif consumed('attr_type', 'input_type'): + label['attributes'][-1]['input_type'] = el.text + elif consumed('attribute', 'attribute'): pass + elif consumed('attributes', 'attributes'): pass + elif consumed('label', 'label'): + labels[label['name']] = label['attributes'] + label = None + elif consumed('labels', 'labels'): pass + else: + pass + + assert len(states) == 1 and states[0] == 'annotations', \ + "Expected 'meta' section in the annotation file, path: %s" % states + + common_attrs = ['occluded'] + if 'interpolation' in map(lambda t: t['mode'], tasks_info.values()): + common_attrs.append('keyframe') + common_attrs.append('outside') + common_attrs.append('track_id') + + label_cat = LabelCategories(attributes=common_attrs) + attribute_types = {} + for label, attrs in labels.items(): + attr_names = {v['name'] for v in attrs} + label_cat.add(label, attributes=attr_names) + for attr in attrs: + attribute_types[attr['name']] = attr['input_type'] + + categories[AnnotationType.label] = label_cat + return categories, tasks_info, attribute_types + + @classmethod + def _parse_shape_ann(cls, ann, categories): + ann_id = ann.get('id', 0) + ann_type = ann['type'] + + attributes = ann.get('attributes') or {} + if 'occluded' in categories[AnnotationType.label].attributes: + attributes['occluded'] = ann.get('occluded', False) + if 'outside' in ann: + attributes['outside'] = ann['outside'] + if 'keyframe' in ann: + attributes['keyframe'] = ann['keyframe'] + if 'track_id' in ann: + attributes['track_id'] = ann['track_id'] + + group = ann.get('group') + + label = ann.get('label') + label_id = categories[AnnotationType.label].find(label)[0] + + z_order = ann.get('z_order', 0) + points = ann.get('points', []) + + if ann_type == 'polyline': + return PolyLine(points, label=label_id, z_order=z_order, + id=ann_id, attributes=attributes, group=group) + + elif ann_type == 'polygon': + return Polygon(points, label=label_id, z_order=z_order, + id=ann_id, attributes=attributes, group=group) + + elif ann_type == 'points': + return Points(points, label=label_id, z_order=z_order, + id=ann_id, attributes=attributes, group=group) + + elif ann_type == 'box': + x, y = points[0], points[1] + w, h = points[2] - x, points[3] - y + return Bbox(x, y, w, h, label=label_id, z_order=z_order, + id=ann_id, attributes=attributes, group=group) + + else: + raise NotImplementedError("Unknown annotation type '%s'" % ann_type) + + @classmethod + def _parse_tag_ann(cls, ann, categories): + label = ann.get('label') + label_id = categories[AnnotationType.label].find(label)[0] + group = ann.get('group') + attributes = ann.get('attributes') + return Label(label_id, attributes=attributes, group=group) + + def _load_items(self, parsed, image_items): + for (subset, frame_id), item_desc in parsed.items(): + name = item_desc.get('name', 'frame_%06d.PNG' % int(frame_id)) + image = osp.join(self._images_dir, subset, name) if subset else osp.join(self._images_dir, name) + image_size = (item_desc.get('height'), item_desc.get('width')) + if all(image_size): + image = Image(path=image, size=tuple(map(int, image_size))) + di = image_items.get((subset, osp.splitext(name)[0]), DatasetItem( + id=name, annotations=[], + )) + di.subset = subset or DEFAULT_SUBSET_NAME + di.annotations = item_desc.get('annotations') + di.attributes = {'frame': int(frame_id)} + di.image = image if isinstance(image, Image) else di.image + image_items[(subset, osp.splitext(name)[0])] = di + return image_items + +dm_env.extractors.register('cvat', CvatExtractor) + +class CvatImporter(Importer): + @classmethod + def find_sources(cls, path): + return cls._find_sources_recursive(path, '.xml', 'cvat') + +dm_env.importers.register('cvat', CvatImporter) def pairwise(iterable): @@ -457,13 +856,11 @@ def dump_as_cvat_interpolation(dumper, annotations): dumper.close_root() -def load(file_object, annotations): - from defusedxml import ElementTree +def load_anno(file_object, annotations): + supported_shapes = ('box', 'polygon', 'polyline', 'points', 'cuboid') context = ElementTree.iterparse(file_object, events=("start", "end")) context = iter(context) - ev, _ = next(context) - - supported_shapes = ('box', 'polygon', 'polyline', 'points', 'cuboid') + next(context) track = None shape = None @@ -641,15 +1038,21 @@ def _export_images(dst_file, instance_data, save_images=False): anno_callback=dump_as_cvat_annotation, save_images=save_images) @importer(name='CVAT', ext='XML, ZIP', version='1.1') -def _import(src_file, task_data): +def _import(src_file, instance_data, load_data_callback=None): is_zip = zipfile.is_zipfile(src_file) src_file.seek(0) if is_zip: with TemporaryDirectory() as tmp_dir: zipfile.ZipFile(src_file).extractall(tmp_dir) - anno_paths = glob(osp.join(tmp_dir, '**', '*.xml'), recursive=True) - for p in anno_paths: - load(p, task_data) + if isinstance(instance_data, ProjectData): + dataset = Dataset.import_from(tmp_dir, 'cvat', env=dm_env) + if load_data_callback is not None: + load_data_callback(dataset, instance_data) + import_dm_annotations(dataset, instance_data) + else: + anno_paths = glob(osp.join(tmp_dir, '**', '*.xml'), recursive=True) + for p in anno_paths: + load_anno(p, instance_data) else: - load(src_file, task_data) + load_anno(src_file, instance_data) diff --git a/cvat/apps/dataset_manager/formats/icdar.py b/cvat/apps/dataset_manager/formats/icdar.py index 524d2283..27f85ed6 100644 --- a/cvat/apps/dataset_manager/formats/icdar.py +++ b/cvat/apps/dataset_manager/formats/icdar.py @@ -86,11 +86,13 @@ def _export_recognition(dst_file, instance_data, save_images=False): make_zip_archive(temp_dir, dst_file) @importer(name='ICDAR Recognition', ext='ZIP', version='1.0') -def _import(src_file, instance_data): +def _import(src_file, instance_data, load_data_callback=None): with TemporaryDirectory() as tmp_dir: zipfile.ZipFile(src_file).extractall(tmp_dir) dataset = Dataset.import_from(tmp_dir, 'icdar_word_recognition', env=dm_env) dataset.transform(CaptionToLabel, 'icdar') + if load_data_callback is not None: + load_data_callback(dataset, instance_data) import_dm_annotations(dataset, instance_data) @@ -103,12 +105,14 @@ def _export_localization(dst_file, instance_data, save_images=False): make_zip_archive(temp_dir, dst_file) @importer(name='ICDAR Localization', ext='ZIP', version='1.0') -def _import(src_file, instance_data): +def _import(src_file, instance_data, load_data_callback=None): with TemporaryDirectory() as tmp_dir: zipfile.ZipFile(src_file).extractall(tmp_dir) dataset = Dataset.import_from(tmp_dir, 'icdar_text_localization', env=dm_env) dataset.transform(AddLabelToAnns, 'icdar') + if load_data_callback is not None: + load_data_callback(dataset, instance_data) import_dm_annotations(dataset, instance_data) @@ -125,10 +129,12 @@ def _export_segmentation(dst_file, instance_data, save_images=False): make_zip_archive(temp_dir, dst_file) @importer(name='ICDAR Segmentation', ext='ZIP', version='1.0') -def _import(src_file, instance_data): +def _import(src_file, instance_data, load_data_callback=None): with TemporaryDirectory() as tmp_dir: zipfile.ZipFile(src_file).extractall(tmp_dir) dataset = Dataset.import_from(tmp_dir, 'icdar_text_segmentation', env=dm_env) dataset.transform(AddLabelToAnns, 'icdar') dataset.transform('masks_to_polygons') + if load_data_callback is not None: + load_data_callback(dataset, instance_data) import_dm_annotations(dataset, instance_data) diff --git a/cvat/apps/dataset_manager/formats/imagenet.py b/cvat/apps/dataset_manager/formats/imagenet.py index 1085ef74..a84f487f 100644 --- a/cvat/apps/dataset_manager/formats/imagenet.py +++ b/cvat/apps/dataset_manager/formats/imagenet.py @@ -29,11 +29,13 @@ def _export(dst_file, instance_data, save_images=False): make_zip_archive(temp_dir, dst_file) @importer(name='ImageNet', ext='ZIP', version='1.0') -def _import(src_file, instance_data): +def _import(src_file, instance_data, load_data_callback=None): with TemporaryDirectory() as tmp_dir: zipfile.ZipFile(src_file).extractall(tmp_dir) if glob(osp.join(tmp_dir, '*.txt')): dataset = Dataset.import_from(tmp_dir, 'imagenet_txt', env=dm_env) else: dataset = Dataset.import_from(tmp_dir, 'imagenet', env=dm_env) + if load_data_callback is not None: + load_data_callback(dataset, instance_data) import_dm_annotations(dataset, instance_data) \ No newline at end of file diff --git a/cvat/apps/dataset_manager/formats/labelme.py b/cvat/apps/dataset_manager/formats/labelme.py index 2fc1f7f7..9918056b 100644 --- a/cvat/apps/dataset_manager/formats/labelme.py +++ b/cvat/apps/dataset_manager/formats/labelme.py @@ -24,10 +24,12 @@ def _export(dst_file, instance_data, save_images=False): make_zip_archive(temp_dir, dst_file) @importer(name='LabelMe', ext='ZIP', version='3.0') -def _import(src_file, instance_data): +def _import(src_file, instance_data, load_data_callback=None): with TemporaryDirectory() as tmp_dir: Archive(src_file.name).extractall(tmp_dir) dataset = Dataset.import_from(tmp_dir, 'label_me', env=dm_env) dataset.transform('masks_to_polygons') + if load_data_callback is not None: + load_data_callback(dataset, instance_data) import_dm_annotations(dataset, instance_data) diff --git a/cvat/apps/dataset_manager/formats/lfw.py b/cvat/apps/dataset_manager/formats/lfw.py index 97cc9dcb..6ec4caba 100644 --- a/cvat/apps/dataset_manager/formats/lfw.py +++ b/cvat/apps/dataset_manager/formats/lfw.py @@ -14,12 +14,13 @@ from .registry import dm_env, exporter, importer @importer(name='LFW', ext='ZIP', version='1.0') -def _import(src_file, instance_data): +def _import(src_file, instance_data, load_data_callback=None): with TemporaryDirectory() as tmp_dir: Archive(src_file.name).extractall(tmp_dir) dataset = Dataset.import_from(tmp_dir, 'lfw') - + if load_data_callback is not None: + load_data_callback(dataset, instance_data) import_dm_annotations(dataset, instance_data) @exporter(name='LFW', ext='ZIP', version='1.0') diff --git a/cvat/apps/dataset_manager/formats/market1501.py b/cvat/apps/dataset_manager/formats/market1501.py index 3ba14f98..272f7f15 100644 --- a/cvat/apps/dataset_manager/formats/market1501.py +++ b/cvat/apps/dataset_manager/formats/market1501.py @@ -70,10 +70,12 @@ def _export(dst_file, instance_data, save_images=False): make_zip_archive(temp_dir, dst_file) @importer(name='Market-1501', ext='ZIP', version='1.0') -def _import(src_file, instance_data): +def _import(src_file, instance_data, load_data_callback=None): with TemporaryDirectory() as tmp_dir: zipfile.ZipFile(src_file).extractall(tmp_dir) dataset = Dataset.import_from(tmp_dir, 'market1501', env=dm_env) dataset.transform(AttrToLabelAttr, 'market-1501') + if load_data_callback is not None: + load_data_callback(dataset, instance_data) import_dm_annotations(dataset, instance_data) diff --git a/cvat/apps/dataset_manager/formats/mask.py b/cvat/apps/dataset_manager/formats/mask.py index 3074d1b9..026e6a56 100644 --- a/cvat/apps/dataset_manager/formats/mask.py +++ b/cvat/apps/dataset_manager/formats/mask.py @@ -30,10 +30,12 @@ def _export(dst_file, instance_data, save_images=False): make_zip_archive(temp_dir, dst_file) @importer(name='Segmentation mask', ext='ZIP', version='1.1') -def _import(src_file, instance_data): +def _import(src_file, instance_data, load_data_callback=None): with TemporaryDirectory() as tmp_dir: Archive(src_file.name).extractall(tmp_dir) dataset = Dataset.import_from(tmp_dir, 'voc', env=dm_env) dataset.transform('masks_to_polygons') + if load_data_callback is not None: + load_data_callback(dataset, instance_data) import_dm_annotations(dataset, instance_data) diff --git a/cvat/apps/dataset_manager/formats/mot.py b/cvat/apps/dataset_manager/formats/mot.py index 26fc7b0d..3f0dfd3f 100644 --- a/cvat/apps/dataset_manager/formats/mot.py +++ b/cvat/apps/dataset_manager/formats/mot.py @@ -13,6 +13,78 @@ from cvat.apps.dataset_manager.util import make_zip_archive from .registry import dm_env, exporter, importer +def _import_task(dataset, task_data): + tracks = {} + label_cat = dataset.categories()[datumaro.AnnotationType.label] + + for item in dataset: + frame_number = int(item.id) - 1 # NOTE: MOT frames start from 1 + frame_number = task_data.abs_frame_id(frame_number) + + for ann in item.annotations: + if ann.type != datumaro.AnnotationType.bbox: + continue + + track_id = ann.attributes.get('track_id') + if track_id is None: + # Extension. Import regular boxes: + task_data.add_shape(task_data.LabeledShape( + type='rectangle', + label=label_cat.items[ann.label].name, + points=ann.points, + occluded=ann.attributes.get('occluded') == True, + z_order=ann.z_order, + group=0, + frame=frame_number, + attributes=[], + source='manual', + )) + continue + + shape = task_data.TrackedShape( + type='rectangle', + points=ann.points, + occluded=ann.attributes.get('occluded') == True, + outside=False, + keyframe=True, + z_order=ann.z_order, + frame=frame_number, + attributes=[], + source='manual', + ) + + # build trajectories as lists of shapes in track dict + if track_id not in tracks: + tracks[track_id] = task_data.Track( + label_cat.items[ann.label].name, 0, 'manual', []) + tracks[track_id].shapes.append(shape) + + for track in tracks.values(): + # MOT annotations do not require frames to be ordered + track.shapes.sort(key=lambda t: t.frame) + + # insert outside=True in skips between the frames track is visible + prev_shape_idx = 0 + prev_shape = track.shapes[0] + for shape in track.shapes[1:]: + has_skip = task_data.frame_step < shape.frame - prev_shape.frame + if has_skip and not prev_shape.outside: + prev_shape = prev_shape._replace(outside=True, + frame=prev_shape.frame + task_data.frame_step) + prev_shape_idx += 1 + track.shapes.insert(prev_shape_idx, prev_shape) + prev_shape = shape + prev_shape_idx += 1 + + # Append a shape with outside=True to finish the track + last_shape = track.shapes[-1] + if last_shape.frame + task_data.frame_step <= \ + int(task_data.meta['task']['stop_frame']): + track.shapes.append(last_shape._replace(outside=True, + frame=last_shape.frame + task_data.frame_step) + ) + task_data.add_track(track) + @exporter(name='MOT', ext='ZIP', version='1.1') def _export(dst_file, instance_data, save_images=False): @@ -24,79 +96,18 @@ def _export(dst_file, instance_data, save_images=False): make_zip_archive(temp_dir, dst_file) @importer(name='MOT', ext='ZIP', version='1.1') -def _import(src_file, task_data): +def _import(src_file, instance_data, load_data_callback=None): with TemporaryDirectory() as tmp_dir: Archive(src_file.name).extractall(tmp_dir) dataset = Dataset.import_from(tmp_dir, 'mot_seq', env=dm_env) + if load_data_callback is not None: + load_data_callback(dataset, instance_data) - tracks = {} - label_cat = dataset.categories()[datumaro.AnnotationType.label] - - for item in dataset: - frame_number = int(item.id) - 1 # NOTE: MOT frames start from 1 - frame_number = task_data.abs_frame_id(frame_number) - - for ann in item.annotations: - if ann.type != datumaro.AnnotationType.bbox: - continue - - track_id = ann.attributes.get('track_id') - if track_id is None: - # Extension. Import regular boxes: - task_data.add_shape(task_data.LabeledShape( - type='rectangle', - label=label_cat.items[ann.label].name, - points=ann.points, - occluded=ann.attributes.get('occluded') == True, - z_order=ann.z_order, - group=0, - frame=frame_number, - attributes=[], - source='manual', - )) - continue - - shape = task_data.TrackedShape( - type='rectangle', - points=ann.points, - occluded=ann.attributes.get('occluded') == True, - outside=False, - keyframe=True, - z_order=ann.z_order, - frame=frame_number, - attributes=[], - source='manual', - ) - - # build trajectories as lists of shapes in track dict - if track_id not in tracks: - tracks[track_id] = task_data.Track( - label_cat.items[ann.label].name, 0, 'manual', []) - tracks[track_id].shapes.append(shape) - - for track in tracks.values(): - # MOT annotations do not require frames to be ordered - track.shapes.sort(key=lambda t: t.frame) - - # insert outside=True in skips between the frames track is visible - prev_shape_idx = 0 - prev_shape = track.shapes[0] - for shape in track.shapes[1:]: - has_skip = task_data.frame_step < shape.frame - prev_shape.frame - if has_skip and not prev_shape.outside: - prev_shape = prev_shape._replace(outside=True, - frame=prev_shape.frame + task_data.frame_step) - prev_shape_idx += 1 - track.shapes.insert(prev_shape_idx, prev_shape) - prev_shape = shape - prev_shape_idx += 1 + # Dirty way to determine instance type to avoid circular dependency + if hasattr(instance_data, '_db_project'): + for sub_dataset, task_data in instance_data.split_dataset(dataset): + _import_task(sub_dataset, task_data) + else: + _import_task(dataset, instance_data) - # Append a shape with outside=True to finish the track - last_shape = track.shapes[-1] - if last_shape.frame + task_data.frame_step <= \ - int(task_data.meta['task']['stop_frame']): - track.shapes.append(last_shape._replace(outside=True, - frame=last_shape.frame + task_data.frame_step) - ) - task_data.add_track(track) diff --git a/cvat/apps/dataset_manager/formats/mots.py b/cvat/apps/dataset_manager/formats/mots.py index 1d3371f2..d26e2237 100644 --- a/cvat/apps/dataset_manager/formats/mots.py +++ b/cvat/apps/dataset_manager/formats/mots.py @@ -22,6 +22,77 @@ class KeepTracks(ItemTransform): return item.wrap(annotations=[a for a in item.annotations if 'track_id' in a.attributes]) +def _import_task(dataset, task_data): + tracks = {} + label_cat = dataset.categories()[AnnotationType.label] + + root_hint = find_dataset_root(dataset, task_data) + + shift = 0 + for item in dataset: + frame_number = task_data.abs_frame_id( + match_dm_item(item, task_data, root_hint=root_hint)) + + track_ids = set() + + for ann in item.annotations: + if ann.type != AnnotationType.polygon: + continue + + track_id = ann.attributes['track_id'] + group_id = track_id + + if track_id in track_ids: + # use negative id for tracks with the same id on the same frame + shift -= 1 + track_id = shift + else: + track_ids.add(track_id) + + shape = task_data.TrackedShape( + type='polygon', + points=ann.points, + occluded=ann.attributes.get('occluded') == True, + outside=False, + keyframe=True, + z_order=ann.z_order, + frame=frame_number, + attributes=[], + source='manual', + group=group_id + ) + + # build trajectories as lists of shapes in track dict + if track_id not in tracks: + tracks[track_id] = task_data.Track( + label_cat.items[ann.label].name, 0, 'manual', []) + tracks[track_id].shapes.append(shape) + + for track in tracks.values(): + track.shapes.sort(key=lambda t: t.frame) + + # insert outside=True in skips between the frames track is visible + prev_shape_idx = 0 + prev_shape = track.shapes[0] + for shape in track.shapes[1:]: + has_skip = task_data.frame_step < shape.frame - prev_shape.frame + if has_skip and not prev_shape.outside: + prev_shape = prev_shape._replace(outside=True, + frame=prev_shape.frame + task_data.frame_step) + prev_shape_idx += 1 + track.shapes.insert(prev_shape_idx, prev_shape) + prev_shape = shape + prev_shape_idx += 1 + + # Append a shape with outside=True to finish the track + last_shape = track.shapes[-1] + if last_shape.frame + task_data.frame_step <= \ + int(task_data.meta['task']['stop_frame']): + track.shapes.append(last_shape._replace(outside=True, + frame=last_shape.frame + task_data.frame_step) + ) + task_data.add_track(track) + @exporter(name='MOTS PNG', ext='ZIP', version='1.0') def _export(dst_file, instance_data, save_images=False): dataset = Dataset.from_extractors(GetCVATDataExtractor( @@ -37,79 +108,19 @@ def _export(dst_file, instance_data, save_images=False): make_zip_archive(temp_dir, dst_file) @importer(name='MOTS PNG', ext='ZIP', version='1.0') -def _import(src_file, task_data): +def _import(src_file, instance_data, load_data_callback=None): with TemporaryDirectory() as tmp_dir: Archive(src_file.name).extractall(tmp_dir) dataset = Dataset.import_from(tmp_dir, 'mots', env=dm_env) dataset.transform('masks_to_polygons') + if load_data_callback is not None: + load_data_callback(dataset, instance_data) - tracks = {} - label_cat = dataset.categories()[AnnotationType.label] - - root_hint = find_dataset_root(dataset, task_data) - - shift = 0 - for item in dataset: - frame_number = task_data.abs_frame_id( - match_dm_item(item, task_data, root_hint=root_hint)) - - track_ids = set() - - for ann in item.annotations: - if ann.type != AnnotationType.polygon: - continue - - track_id = ann.attributes['track_id'] - group_id = track_id - - if track_id in track_ids: - # use negative id for tracks with the same id on the same frame - shift -= 1 - track_id = shift - else: - track_ids.add(track_id) - - shape = task_data.TrackedShape( - type='polygon', - points=ann.points, - occluded=ann.attributes.get('occluded') == True, - outside=False, - keyframe=True, - z_order=ann.z_order, - frame=frame_number, - attributes=[], - source='manual', - group=group_id - ) - - # build trajectories as lists of shapes in track dict - if track_id not in tracks: - tracks[track_id] = task_data.Track( - label_cat.items[ann.label].name, 0, 'manual', []) - tracks[track_id].shapes.append(shape) - - for track in tracks.values(): - track.shapes.sort(key=lambda t: t.frame) - - # insert outside=True in skips between the frames track is visible - prev_shape_idx = 0 - prev_shape = track.shapes[0] - for shape in track.shapes[1:]: - has_skip = task_data.frame_step < shape.frame - prev_shape.frame - if has_skip and not prev_shape.outside: - prev_shape = prev_shape._replace(outside=True, - frame=prev_shape.frame + task_data.frame_step) - prev_shape_idx += 1 - track.shapes.insert(prev_shape_idx, prev_shape) - prev_shape = shape - prev_shape_idx += 1 + # Dirty way to determine instance type to avoid circular dependency + if hasattr(instance_data, '_db_project'): + for sub_dataset, task_data in instance_data.split_dataset(dataset): + _import_task(sub_dataset, task_data) + else: + _import_task(dataset, instance_data) - # Append a shape with outside=True to finish the track - last_shape = track.shapes[-1] - if last_shape.frame + task_data.frame_step <= \ - int(task_data.meta['task']['stop_frame']): - track.shapes.append(last_shape._replace(outside=True, - frame=last_shape.frame + task_data.frame_step) - ) - task_data.add_track(track) diff --git a/cvat/apps/dataset_manager/formats/openimages.py b/cvat/apps/dataset_manager/formats/openimages.py index aa2ceaff..1842693f 100644 --- a/cvat/apps/dataset_manager/formats/openimages.py +++ b/cvat/apps/dataset_manager/formats/openimages.py @@ -51,7 +51,7 @@ def _export(dst_file, task_data, save_images=False): make_zip_archive(temp_dir, dst_file) @importer(name='Open Images V6', ext='ZIP', version='1.0') -def _import(src_file, task_data): +def _import(src_file, instance_data, load_data_callback=None): with TemporaryDirectory() as tmp_dir: Archive(src_file.name).extractall(tmp_dir) @@ -64,14 +64,14 @@ def _import(src_file, task_data): item_ids = list(find_item_ids(tmp_dir)) root_hint = find_dataset_root( - [DatasetItem(id=item_id) for item_id in item_ids], task_data) + [DatasetItem(id=item_id) for item_id in item_ids], instance_data) for item_id in item_ids: frame_info = None try: frame_id = match_dm_item(DatasetItem(id=item_id), - task_data, root_hint) - frame_info = task_data.frame_info[frame_id] + instance_data, root_hint) + frame_info = instance_data.frame_info[frame_id] except Exception: # nosec pass if frame_info is not None: @@ -80,6 +80,8 @@ def _import(src_file, task_data): dataset = Dataset.import_from(tmp_dir, 'open_images', image_meta=image_meta, env=dm_env) dataset.transform('masks_to_polygons') - import_dm_annotations(dataset, task_data) + if load_data_callback is not None: + load_data_callback(dataset, instance_data) + import_dm_annotations(dataset, instance_data) diff --git a/cvat/apps/dataset_manager/formats/pascal_voc.py b/cvat/apps/dataset_manager/formats/pascal_voc.py index 93504628..6958a8a7 100644 --- a/cvat/apps/dataset_manager/formats/pascal_voc.py +++ b/cvat/apps/dataset_manager/formats/pascal_voc.py @@ -29,7 +29,7 @@ def _export(dst_file, instance_data, save_images=False): make_zip_archive(temp_dir, dst_file) @importer(name='PASCAL VOC', ext='ZIP', version='1.1') -def _import(src_file, instance_data): +def _import(src_file, instance_data, load_data_callback=None): with TemporaryDirectory() as tmp_dir: Archive(src_file.name).extractall(tmp_dir) @@ -58,4 +58,6 @@ def _import(src_file, instance_data): dataset = Dataset.import_from(tmp_dir, 'voc', env=dm_env) dataset.transform('masks_to_polygons') + if load_data_callback is not None: + load_data_callback(dataset, instance_data) import_dm_annotations(dataset, instance_data) diff --git a/cvat/apps/dataset_manager/formats/pointcloud.py b/cvat/apps/dataset_manager/formats/pointcloud.py index 0009cd2f..a9485a42 100644 --- a/cvat/apps/dataset_manager/formats/pointcloud.py +++ b/cvat/apps/dataset_manager/formats/pointcloud.py @@ -28,15 +28,17 @@ def _export_images(dst_file, task_data, save_images=False): @importer(name='Sly Point Cloud Format', ext='ZIP', version='1.0', dimension=DimensionType.DIM_3D) -def _import(src_file, task_data): +def _import(src_file, instance_data, load_data_callback=None): - if zipfile.is_zipfile(src_file): - with TemporaryDirectory() as tmp_dir: + with TemporaryDirectory() as tmp_dir: + if zipfile.is_zipfile(src_file): zipfile.ZipFile(src_file).extractall(tmp_dir) dataset = Dataset.import_from(tmp_dir, 'sly_pointcloud', env=dm_env) - import_dm_annotations(dataset, task_data) - else: - dataset = Dataset.import_from(src_file.name, - 'sly_pointcloud', env=dm_env) - import_dm_annotations(dataset, task_data) + else: + dataset = Dataset.import_from(src_file.name, + 'sly_pointcloud', env=dm_env) + if load_data_callback is not None: + load_data_callback(dataset, instance_data) + import_dm_annotations(dataset, instance_data) + diff --git a/cvat/apps/dataset_manager/formats/registry.py b/cvat/apps/dataset_manager/formats/registry.py index 876c40fc..4e86969e 100644 --- a/cvat/apps/dataset_manager/formats/registry.py +++ b/cvat/apps/dataset_manager/formats/registry.py @@ -17,11 +17,11 @@ class _Format: ENABLED = True class Exporter(_Format): - def __call__(self, dst_file, task_data, **options): + def __call__(self, dst_file, instance_data, **options): raise NotImplementedError() class Importer(_Format): - def __call__(self, src_file, task_data, **options): + def __call__(self, src_file, instance_data, load_data_callback=None, **options): raise NotImplementedError() def _wrap_format(f_or_cls, klass, name, version, ext, display_name, enabled, dimension=DimensionType.DIM_2D): diff --git a/cvat/apps/dataset_manager/formats/tfrecord.py b/cvat/apps/dataset_manager/formats/tfrecord.py index d9c705a7..3b2c2af2 100644 --- a/cvat/apps/dataset_manager/formats/tfrecord.py +++ b/cvat/apps/dataset_manager/formats/tfrecord.py @@ -32,9 +32,11 @@ def _export(dst_file, instance_data, save_images=False): make_zip_archive(temp_dir, dst_file) @importer(name='TFRecord', ext='ZIP', version='1.0', enabled=tf_available) -def _import(src_file, instance_data): +def _import(src_file, instance_data, load_data_callback=None): with TemporaryDirectory() as tmp_dir: Archive(src_file.name).extractall(tmp_dir) dataset = Dataset.import_from(tmp_dir, 'tf_detection_api', env=dm_env) + if load_data_callback is not None: + load_data_callback(dataset, instance_data) import_dm_annotations(dataset, instance_data) diff --git a/cvat/apps/dataset_manager/formats/velodynepoint.py b/cvat/apps/dataset_manager/formats/velodynepoint.py index 747c4751..9a7598c5 100644 --- a/cvat/apps/dataset_manager/formats/velodynepoint.py +++ b/cvat/apps/dataset_manager/formats/velodynepoint.py @@ -30,16 +30,17 @@ def _export_images(dst_file, task_data, save_images=False): @importer(name='Kitti Raw Format', ext='ZIP', version='1.0', dimension=DimensionType.DIM_3D) -def _import(src_file, task_data): - if zipfile.is_zipfile(src_file): - with TemporaryDirectory() as tmp_dir: - zipfile.ZipFile(src_file).extractall(tmp_dir) +def _import(src_file, instance_data, load_data_callback=None): + with TemporaryDirectory() as tmp_dir: + if zipfile.is_zipfile(src_file): + zipfile.ZipFile(src_file).extractall(tmp_dir) - dataset = Dataset.import_from( - tmp_dir, 'kitti_raw', env=dm_env) - import_dm_annotations(dataset, task_data) - else: + dataset = Dataset.import_from( + tmp_dir, 'kitti_raw', env=dm_env) + else: - dataset = Dataset.import_from( - src_file.name, 'kitti_raw', env=dm_env) - import_dm_annotations(dataset, task_data) + dataset = Dataset.import_from( + src_file.name, 'kitti_raw', env=dm_env) + if load_data_callback is not None: + load_data_callback(dataset, instance_data) + import_dm_annotations(dataset, instance_data) diff --git a/cvat/apps/dataset_manager/formats/vggface2.py b/cvat/apps/dataset_manager/formats/vggface2.py index d75f960a..603c090e 100644 --- a/cvat/apps/dataset_manager/formats/vggface2.py +++ b/cvat/apps/dataset_manager/formats/vggface2.py @@ -24,10 +24,12 @@ def _export(dst_file, instance_data, save_images=False): make_zip_archive(temp_dir, dst_file) @importer(name='VGGFace2', ext='ZIP', version='1.0') -def _import(src_file, instance_data): +def _import(src_file, instance_data, load_data_callback=None): with TemporaryDirectory() as tmp_dir: zipfile.ZipFile(src_file).extractall(tmp_dir) dataset = Dataset.import_from(tmp_dir, 'vgg_face2', env=dm_env) dataset.transform('rename', r"|([^/]+/)?(.+)|\2|") + if load_data_callback is not None: + load_data_callback(dataset, instance_data) import_dm_annotations(dataset, instance_data) diff --git a/cvat/apps/dataset_manager/formats/widerface.py b/cvat/apps/dataset_manager/formats/widerface.py index b578c14c..85a93478 100644 --- a/cvat/apps/dataset_manager/formats/widerface.py +++ b/cvat/apps/dataset_manager/formats/widerface.py @@ -24,9 +24,11 @@ def _export(dst_file, instance_data, save_images=False): make_zip_archive(temp_dir, dst_file) @importer(name='WiderFace', ext='ZIP', version='1.0') -def _import(src_file, instance_data): +def _import(src_file, instance_data, load_data_callback=None): with TemporaryDirectory() as tmp_dir: zipfile.ZipFile(src_file).extractall(tmp_dir) dataset = Dataset.import_from(tmp_dir, 'wider_face', env=dm_env) + if load_data_callback is not None: + load_data_callback(dataset, instance_data) import_dm_annotations(dataset, instance_data) diff --git a/cvat/apps/dataset_manager/formats/yolo.py b/cvat/apps/dataset_manager/formats/yolo.py index 6327f3c0..6dbd59d4 100644 --- a/cvat/apps/dataset_manager/formats/yolo.py +++ b/cvat/apps/dataset_manager/formats/yolo.py @@ -28,7 +28,7 @@ def _export(dst_file, instance_data, save_images=False): make_zip_archive(temp_dir, dst_file) @importer(name='YOLO', ext='ZIP', version='1.1') -def _import(src_file, task_data): +def _import(src_file, instance_data, load_data_callback=None): with TemporaryDirectory() as tmp_dir: Archive(src_file.name).extractall(tmp_dir) @@ -36,13 +36,13 @@ def _import(src_file, task_data): frames = [YoloExtractor.name_from_path(osp.relpath(p, tmp_dir)) for p in glob(osp.join(tmp_dir, '**', '*.txt'), recursive=True)] root_hint = find_dataset_root( - [DatasetItem(id=frame) for frame in frames], task_data) + [DatasetItem(id=frame) for frame in frames], instance_data) for frame in frames: frame_info = None try: - frame_id = match_dm_item(DatasetItem(id=frame), task_data, + frame_id = match_dm_item(DatasetItem(id=frame), instance_data, root_hint=root_hint) - frame_info = task_data.frame_info[frame_id] + frame_info = instance_data.frame_info[frame_id] except Exception: # nosec pass if frame_info is not None: @@ -50,4 +50,6 @@ def _import(src_file, task_data): dataset = Dataset.import_from(tmp_dir, 'yolo', env=dm_env, image_info=image_info) - import_dm_annotations(dataset, task_data) + if load_data_callback is not None: + load_data_callback(dataset, instance_data) + import_dm_annotations(dataset, instance_data) diff --git a/cvat/apps/dataset_manager/project.py b/cvat/apps/dataset_manager/project.py index 866a75d4..a649ba22 100644 --- a/cvat/apps/dataset_manager/project.py +++ b/cvat/apps/dataset_manager/project.py @@ -2,16 +2,19 @@ # # SPDX-License-Identifier: MIT -from typing import Callable +import rq +from typing import Any, Callable, List, Mapping, Tuple from django.db import transaction from cvat.apps.engine import models +from cvat.apps.engine.serializers import DataSerializer, TaskSerializer +from cvat.apps.engine.task import _create_thread as create_task from cvat.apps.dataset_manager.task import TaskAnnotation from .annotation import AnnotationIR -from .bindings import ProjectData -from .formats.registry import make_exporter +from .bindings import ProjectData, load_dataset_data +from .formats.registry import make_exporter, make_importer def export_project(project_id, dst_file, format_name, server_url=None, save_images=False): @@ -21,35 +24,93 @@ def export_project(project_id, dst_file, format_name, # more dump request received at the same time: # https://github.com/opencv/cvat/issues/217 with transaction.atomic(): - project = ProjectAnnotation(project_id) + project = ProjectAnnotationAndData(project_id) project.init_from_db() exporter = make_exporter(format_name) with open(dst_file, 'wb') as f: project.export(f, exporter, host=server_url, save_images=save_images) -class ProjectAnnotation: +class ProjectAnnotationAndData: def __init__(self, pk: int): self.db_project = models.Project.objects.get(id=pk) self.db_tasks = models.Task.objects.filter(project__id=pk).order_by('id') + self.task_annotations: dict[int, TaskAnnotation] = dict() self.annotation_irs: dict[int, AnnotationIR] = dict() + self.tasks_to_add: list[models.Task] = [] + def reset(self): for annotation_ir in self.annotation_irs.values(): annotation_ir.reset() - def put(self, data): - raise NotImplementedError() - - def create(self, data): - raise NotImplementedError() - - def update(self, data): - raise NotImplementedError() - - def delete(self, data=None): - raise NotImplementedError() + def put(self, tasks_data: Mapping[int,Any]): + for task_id, data in tasks_data.items(): + self.task_annotations[task_id].put(data) + + def create(self, tasks_data: Mapping[int,Any]): + for task_id, data in tasks_data.items(): + self.task_annotations[task_id].create(data) + + def update(self, tasks_data: Mapping[int,Any]): + for task_id, data in tasks_data.items(): + self.task_annotations[task_id].update(data) + + def delete(self, tasks_data: Mapping[int,Any]=None): + if tasks_data is not None: + for task_id, data in tasks_data.items(): + self.task_annotations[task_id].put(data) + else: + for task_annotation in self.task_annotations.values(): + task_annotation.delete() + + def add_task(self, task_fields: dict, files: dict, project_data: ProjectData = None): + def split_name(file): + _, name = file.split(files['data_root']) + return name + + + data_serializer = DataSerializer(data={ + "server_files": files['media'], + #TODO: followed fields whould be replaced with proper input values from request in future + "use_cache": False, + "use_zip_chunks": True, + "image_quality": 70, + }) + data_serializer.is_valid(raise_exception=True) + db_data = data_serializer.save() + db_task = TaskSerializer.create(None, { + **task_fields, + 'data_id': db_data.id, + 'project_id': self.db_project.id + }) + data = {k:v for k, v in data_serializer.data.items()} + data['use_zip_chunks'] = data_serializer.validated_data['use_zip_chunks'] + data['use_cache'] = data_serializer.validated_data['use_cache'] + data['copy_data'] = data_serializer.validated_data['copy_data'] + data['server_files_path'] = files['data_root'] + data['stop_frame'] = None + data['server_files'] = list(map(split_name, data['server_files'])) + + create_task(db_task, data, isDatasetImport=True) + self.db_tasks = models.Task.objects.filter(project__id=self.db_project.id).order_by('id') + self.init_from_db() + if project_data is not None: + project_data.new_tasks.add(db_task.id) + project_data.init() + + def add_labels(self, labels: List[models.Label], attributes: List[Tuple[str, models.AttributeSpec]] = None): + for label in labels: + label.project = self.db_project + # We need label_id here, so we can't use bulk_create here + label.save() + + for label_name, attribute in attributes or []: + label, = filter(lambda l: l.name == label_name, labels) + attribute.label = label + if attributes: + models.AttributeSpec.objects.bulk_create([a[1] for a in attributes]) def init_from_db(self): self.reset() @@ -57,6 +118,7 @@ class ProjectAnnotation: for task in self.db_tasks: annotation = TaskAnnotation(pk=task.id) annotation.init_from_db() + self.task_annotations[task.id] = annotation self.annotation_irs[task.id] = annotation.ir_data def export(self, dst_file: str, exporter: Callable, host: str='', **options): @@ -66,6 +128,37 @@ class ProjectAnnotation: host=host ) exporter(dst_file, project_data, **options) + + def load_dataset_data(self, *args, **kwargs): + load_dataset_data(self, *args, **kwargs) + + def import_dataset(self, dataset_file, importer): + project_data = ProjectData( + annotation_irs=self.annotation_irs, + db_project=self.db_project, + task_annotations=self.task_annotations, + project_annotation=self, + ) + project_data.soft_attribute_import = True + + importer(dataset_file, project_data, self.load_dataset_data) + + self.create({tid: ir.serialize() for tid, ir in self.annotation_irs.items() if tid in project_data.new_tasks}) + @property def data(self) -> dict: - raise NotImplementedError() \ No newline at end of file + raise NotImplementedError() + +@transaction.atomic +def import_dataset_as_project(project_id, dataset_file, format_name): + rq_job = rq.get_current_job() + rq_job.meta['status'] = 'Dataset import has been started...' + rq_job.meta['progress'] = 0. + rq_job.save_meta() + + project = ProjectAnnotationAndData(project_id) + project.init_from_db() + + importer = make_importer(format_name) + with open(dataset_file, 'rb') as f: + project.import_dataset(f, importer) \ No newline at end of file diff --git a/cvat/apps/dataset_manager/task.py b/cvat/apps/dataset_manager/task.py index d278fd9a..da692d4f 100644 --- a/cvat/apps/dataset_manager/task.py +++ b/cvat/apps/dataset_manager/task.py @@ -6,7 +6,6 @@ from collections import OrderedDict from enum import Enum -from django.conf import settings from django.db import transaction from django.utils import timezone @@ -17,6 +16,7 @@ from cvat.apps.profiler import silk_profile from .annotation import AnnotationIR, AnnotationManager from .bindings import TaskData from .formats.registry import make_exporter, make_importer +from .util import bulk_create class dotdict(OrderedDict): @@ -39,21 +39,6 @@ class PatchAction(str, Enum): def __str__(self): return self.value -def bulk_create(db_model, objects, flt_param): - if objects: - if flt_param: - if 'postgresql' in settings.DATABASES["default"]["ENGINE"]: - return db_model.objects.bulk_create(objects) - else: - ids = list(db_model.objects.filter(**flt_param).values_list('id', flat=True)) - db_model.objects.bulk_create(objects) - - return list(db_model.objects.exclude(id__in=ids).filter(**flt_param)) - else: - return db_model.objects.bulk_create(objects) - - return [] - def _merge_table_rows(rows, keys_for_merge, field_id): # It is necessary to keep a stable order of original rows # (e.g. for tracked boxes). Otherwise prev_box.frame can be bigger diff --git a/cvat/apps/dataset_manager/util.py b/cvat/apps/dataset_manager/util.py index c18db840..cd11ec2d 100644 --- a/cvat/apps/dataset_manager/util.py +++ b/cvat/apps/dataset_manager/util.py @@ -6,6 +6,7 @@ import inspect import os, os.path as osp import zipfile +from django.conf import settings def current_function_name(depth=1): @@ -18,3 +19,19 @@ def make_zip_archive(src_path, dst_path): for name in filenames: path = osp.join(dirpath, name) archive.write(path, osp.relpath(path, src_path)) + + +def bulk_create(db_model, objects, flt_param): + if objects: + if flt_param: + if 'postgresql' in settings.DATABASES["default"]["ENGINE"]: + return db_model.objects.bulk_create(objects) + else: + ids = list(db_model.objects.filter(**flt_param).values_list('id', flat=True)) + db_model.objects.bulk_create(objects) + + return list(db_model.objects.exclude(id__in=ids).filter(**flt_param)) + else: + return db_model.objects.bulk_create(objects) + + return [] \ No newline at end of file diff --git a/cvat/apps/engine/serializers.py b/cvat/apps/engine/serializers.py index 3356de1a..d6eeb9be 100644 --- a/cvat/apps/engine/serializers.py +++ b/cvat/apps/engine/serializers.py @@ -220,6 +220,7 @@ class RqStatusSerializer(serializers.Serializer): state = serializers.ChoiceField(choices=[ "Queued", "Started", "Finished", "Failed"]) message = serializers.CharField(allow_blank=True, default="") + progress = serializers.FloatField(max_value=100, default=0) class WriteOnceMixin: @@ -726,6 +727,15 @@ class LogEventSerializer(serializers.Serializer): class AnnotationFileSerializer(serializers.Serializer): annotation_file = serializers.FileField() +class DatasetFileSerializer(serializers.Serializer): + dataset_file = serializers.FileField() + + @staticmethod + def validate_dataset_file(value): + if os.path.splitext(value.name)[1] != '.zip': + raise serializers.ValidationError('Dataset file should be zip archive') + return value + class TaskFileSerializer(serializers.Serializer): task_file = serializers.FileField() diff --git a/cvat/apps/engine/task.py b/cvat/apps/engine/task.py index b364c338..220ee2e1 100644 --- a/cvat/apps/engine/task.py +++ b/cvat/apps/engine/task.py @@ -53,13 +53,16 @@ def rq_handler(job, exc_type, exc_value, traceback): ############################# Internal implementation for server API -def _copy_data_from_share(server_files, upload_dir): +def _copy_data_from_source(server_files, upload_dir, server_dir=None): job = rq.get_current_job() - job.meta['status'] = 'Data are being copied from share..' + job.meta['status'] = 'Data are being copied from source..' job.save_meta() for path in server_files: - source_path = os.path.join(settings.SHARE_ROOT, os.path.normpath(path)) + if server_dir is None: + source_path = os.path.join(settings.SHARE_ROOT, os.path.normpath(path)) + else: + source_path = os.path.join(server_dir, os.path.normpath(path)) target_path = os.path.join(upload_dir, path) if os.path.isdir(source_path): copy_tree(source_path, target_path) @@ -218,14 +221,16 @@ def _get_manifest_frame_indexer(start_frame=0, frame_step=1): @transaction.atomic -def _create_thread(tid, data, isImport=False): - slogger.glob.info("create task #{}".format(tid)) +def _create_thread(db_task, data, isBackupRestore=False, isDatasetImport=False): + if isinstance(db_task, int): + db_task = models.Task.objects.select_for_update().get(pk=db_task) + + slogger.glob.info("create task #{}".format(db_task.id)) - db_task = models.Task.objects.select_for_update().get(pk=tid) db_data = db_task.data upload_dir = db_data.get_upload_dirname() - if data['remote_files']: + if data['remote_files'] and not isDatasetImport: data['remote_files'] = _download_data(data['remote_files'], upload_dir) manifest_file = [] @@ -236,7 +241,7 @@ def _create_thread(tid, data, isImport=False): if data['server_files']: if db_data.storage == models.StorageChoice.LOCAL: - _copy_data_from_share(data['server_files'], upload_dir) + _copy_data_from_source(data['server_files'], upload_dir, data.get('server_files_path')) elif db_data.storage == models.StorageChoice.SHARE: upload_dir = settings.SHARE_ROOT else: # cloud storage @@ -297,12 +302,12 @@ def _create_thread(tid, data, isImport=False): if media_files: if extractor is not None: raise Exception('Combined data types are not supported') - if isImport and media_type == 'image' and db_data.storage == models.StorageChoice.SHARE: + if (isDatasetImport or isBackupRestore) and media_type == 'image' and db_data.storage == models.StorageChoice.SHARE: manifest_index = _get_manifest_frame_indexer(db_data.start_frame, db_data.get_frame_step()) db_data.start_frame = 0 data['stop_frame'] = None db_data.frame_filter = '' - if isImport and media_type != 'video' and db_data.storage_method == models.StorageMethodChoice.CACHE: + if isBackupRestore and media_type != 'video' and db_data.storage_method == models.StorageMethodChoice.CACHE: # we should sort media_files according to the manifest content sequence manifest = ImageManifestManager(db_data.get_manifest_path()) manifest.set_index() @@ -319,9 +324,9 @@ def _create_thread(tid, data, isImport=False): del sorted_media_files data['sorting_method'] = models.SortingMethod.PREDEFINED source_paths=[os.path.join(upload_dir, f) for f in media_files] - if manifest_file and not isImport and data['sorting_method'] in {models.SortingMethod.RANDOM, models.SortingMethod.PREDEFINED}: + if manifest_file and not isBackupRestore and data['sorting_method'] in {models.SortingMethod.RANDOM, models.SortingMethod.PREDEFINED}: raise Exception("It isn't supported to upload manifest file and use random sorting") - if isImport and db_data.storage_method == models.StorageMethodChoice.FILE_SYSTEM and \ + if isBackupRestore and db_data.storage_method == models.StorageMethodChoice.FILE_SYSTEM and \ data['sorting_method'] in {models.SortingMethod.RANDOM, models.SortingMethod.PREDEFINED}: raise Exception("It isn't supported to import the task that was created without cache but with random/predefined sorting") @@ -377,12 +382,11 @@ def _create_thread(tid, data, isImport=False): if not hasattr(update_progress, 'call_counter'): update_progress.call_counter = 0 - status_template = 'Images are being compressed {}' - if progress: - current_progress = '{}%'.format(round(progress * 100)) - else: - current_progress = '{}'.format(progress_animation[update_progress.call_counter]) - job.meta['status'] = status_template.format(current_progress) + status_message = 'Images are being compressed' + if not progress: + status_message = '{} {}'.format(status_message, progress_animation[update_progress.call_counter]) + job.meta['status'] = status_message + job.meta['task_progress'] = progress or 0. job.save_meta() update_progress.call_counter = (update_progress.call_counter + 1) % len(progress_animation) diff --git a/cvat/apps/engine/tests/test_rest_api.py b/cvat/apps/engine/tests/test_rest_api.py index 20b2ee06..7de2c199 100644 --- a/cvat/apps/engine/tests/test_rest_api.py +++ b/cvat/apps/engine/tests/test_rest_api.py @@ -1470,6 +1470,165 @@ class ProjectExportAPITestCase(APITestCase): self._check_xml(pid, user, 3) +class ProjectImportExportAPITestCase(APITestCase): + def setUp(self) -> None: + self.client = APIClient() + self.tasks = [] + self.projects = [] + + @classmethod + def setUpTestData(cls) -> None: + create_db_users(cls) + + cls.media_data = [ + { + **{ + **{"client_files[{}]".format(i): generate_image_file("test_{}.jpg".format(i))[1] for i in range(10)}, + }, + **{ + "image_quality": 75, + }, + }, + { + **{ + **{"client_files[{}]".format(i): generate_image_file("test_{}.jpg".format(i))[1] for i in range(10)}, + }, + "image_quality": 75, + }, + ] + + def _create_tasks(self): + self.tasks = [] + + def _create_task(task_data, media_data): + response = self.client.post('/api/v1/tasks', data=task_data, format="json") + assert response.status_code == status.HTTP_201_CREATED + tid = response.data["id"] + + for media in media_data.values(): + if isinstance(media, io.BytesIO): + media.seek(0) + response = self.client.post("/api/v1/tasks/{}/data".format(tid), data=media_data) + assert response.status_code == status.HTTP_202_ACCEPTED + response = self.client.get("/api/v1/tasks/{}".format(tid)) + data_id = response.data["data"] + self.tasks.append({ + "id": tid, + "data_id": data_id, + }) + + task_data = [ + { + "name": "my task #1", + "owner_id": self.owner.id, + "assignee_id": self.assignee.id, + "overlap": 0, + "segment_size": 100, + "project_id": self.projects[0]["id"], + }, + { + "name": "my task #2", + "owner_id": self.owner.id, + "assignee_id": self.assignee.id, + "overlap": 1, + "segment_size": 3, + "project_id": self.projects[0]["id"], + }, + ] + + with ForceLogin(self.owner, self.client): + for data, media in zip(task_data, self.media_data): + _create_task(data, media) + + def _create_projects(self): + self.projects = [] + + def _create_project(project_data): + response = self.client.post('/api/v1/projects', data=project_data, format="json") + assert response.status_code == status.HTTP_201_CREATED + self.projects.append(response.data) + + project_data = [ + { + "name": "Project for export", + "owner_id": self.owner.id, + "assignee_id": self.assignee.id, + "labels": [ + { + "name": "car", + "color": "#ff00ff", + "attributes": [{ + "name": "bool_attribute", + "mutable": True, + "input_type": AttributeType.CHECKBOX, + "default_value": "true" + }], + }, { + "name": "person", + }, + ] + }, { + "name": "Project for import", + "owner_id": self.owner.id, + "assignee_id": self.assignee.id, + }, + ] + + with ForceLogin(self.owner, self.client): + for data in project_data: + _create_project(data) + + def _run_api_v1_projects_id_dataset_export(self, pid, user, query_params=""): + with ForceLogin(user, self.client): + response = self.client.get("/api/v1/projects/{}/dataset?{}".format(pid, query_params), format="json") + return response + + def _run_api_v1_projects_id_dataset_import(self, pid, user, data, f): + with ForceLogin(user, self.client): + response = self.client.post("/api/v1/projects/{}/dataset?format={}".format(pid, f), data=data, format="multipart") + return response + + def _run_api_v1_projects_id_dataset_import_status(self, pid, user): + with ForceLogin(user, self.client): + response = self.client.get("/api/v1/projects/{}/dataset?action=import_status".format(pid), format="json") + return response + + def test_api_v1_projects_id_export_import(self): + + self._create_projects() + self._create_tasks() + pid_export, pid_import = self.projects[0]["id"], self.projects[1]["id"] + response = self._run_api_v1_projects_id_dataset_export(pid_export, self.owner, "format=CVAT for images 1.1") + self.assertEqual(response.status_code, status.HTTP_202_ACCEPTED) + + response = self._run_api_v1_projects_id_dataset_export(pid_export, self.owner, "format=CVAT for images 1.1") + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + + response = self._run_api_v1_projects_id_dataset_export(pid_export, self.owner, "format=CVAT for images 1.1&action=download") + self.assertEqual(response.status_code, status.HTTP_200_OK) + + self.assertTrue(response.streaming) + tmp_file = tempfile.NamedTemporaryFile(suffix=".zip") + tmp_file.write(b"".join(response.streaming_content)) + tmp_file.seek(0) + + import_data = { + "dataset_file": tmp_file, + } + + response = self._run_api_v1_projects_id_dataset_import(pid_import, self.owner, import_data, "CVAT 1.1") + self.assertEqual(response.status_code, status.HTTP_202_ACCEPTED) + + response = self._run_api_v1_projects_id_dataset_import_status(pid_import, self.owner) + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + + def tearDown(self) -> None: + for task in self.tasks: + shutil.rmtree(os.path.join(settings.TASKS_ROOT, str(task["id"]))) + shutil.rmtree(os.path.join(settings.MEDIA_DATA_ROOT, str(task["data_id"]))) + for project in self.projects: + shutil.rmtree(os.path.join(settings.PROJECTS_ROOT, str(project["id"]))) + class TaskListAPITestCase(APITestCase): def setUp(self): self.client = APIClient() diff --git a/cvat/apps/engine/views.py b/cvat/apps/engine/views.py index 45e23c7b..61075df0 100644 --- a/cvat/apps/engine/views.py +++ b/cvat/apps/engine/views.py @@ -60,7 +60,7 @@ from cvat.apps.engine.serializers import ( LogEventSerializer, ProjectSerializer, ProjectSearchSerializer, RqStatusSerializer, TaskSerializer, UserSerializer, PluginsSerializer, ReviewSerializer, CombinedReviewSerializer, IssueSerializer, CombinedIssueSerializer, CommentSerializer, - CloudStorageSerializer, BaseCloudStorageSerializer, TaskFileSerializer,) + CloudStorageSerializer, BaseCloudStorageSerializer, TaskFileSerializer, DatasetFileSerializer) from utils.dataset_manifest import ImageManifestManager from cvat.apps.engine.utils import av_scan_paths from cvat.apps.engine.backup import import_task @@ -312,7 +312,7 @@ class ProjectViewSet(auth.ProjectGetQuerySetMixin, viewsets.ModelViewSet): type=openapi.TYPE_STRING, required=False), openapi.Parameter('action', in_=openapi.IN_QUERY, description='Used to start downloading process after annotation file had been created', - type=openapi.TYPE_STRING, required=False, enum=['download']) + type=openapi.TYPE_STRING, required=False, enum=['download', 'import_status']) ], responses={'202': openapi.Response(description='Exporting has been started'), '201': openapi.Response(description='Output file is ready for downloading'), @@ -320,20 +320,68 @@ class ProjectViewSet(auth.ProjectGetQuerySetMixin, viewsets.ModelViewSet): '405': openapi.Response(description='Format is not available'), } ) - @action(detail=True, methods=['GET'], serializer_class=None, + @swagger_auto_schema(method='post', operation_summary='Import dataset in specific format as a project', + manual_parameters=[ + openapi.Parameter('format', openapi.IN_QUERY, + description="Desired dataset format name\nYou can get the list of supported formats at:\n/server/annotation/formats", + type=openapi.TYPE_STRING, required=True) + ], + responses={'202': openapi.Response(description='Exporting has been started'), + '400': openapi.Response(description='Failed to import dataset'), + '405': openapi.Response(description='Format is not available'), + } + ) + @action(detail=True, methods=['GET', 'POST'], serializer_class=None, url_path='dataset') - def dataset_export(self, request, pk): + def dataset(self, request, pk): db_project = self.get_object() # force to call check_object_permissions - format_name = request.query_params.get("format", "") - return _export_annotations(db_instance=db_project, - rq_id="/api/v1/project/{}/dataset/{}".format(pk, format_name), - request=request, - action=request.query_params.get("action", "").lower(), - callback=dm.views.export_project_as_dataset, - format_name=format_name, - filename=request.query_params.get("filename", "").lower(), - ) + if request.method == 'POST': + format_name = request.query_params.get("format", "") + + return _import_project_dataset( + request=request, + rq_id=f"/api/v1/project/{pk}/dataset_import", + rq_func=dm.project.import_dataset_as_project, + pk=pk, + format_name=format_name, + ) + else: + action = request.query_params.get("action", "").lower() + if action in ("import_status",): + queue = django_rq.get_queue("default") + rq_job = queue.fetch_job(f"/api/v1/project/{pk}/dataset_import") + if rq_job is None: + return Response(status=status.HTTP_404_NOT_FOUND) + elif rq_job.is_finished: + os.close(rq_job.meta['tmp_file_descriptor']) + os.remove(rq_job.meta['tmp_file']) + rq_job.delete() + return Response(status=status.HTTP_201_CREATED) + elif rq_job.is_failed: + os.close(rq_job.meta['tmp_file_descriptor']) + os.remove(rq_job.meta['tmp_file']) + rq_job.delete() + return Response( + data=str(rq_job.exc_info), + status=status.HTTP_500_INTERNAL_SERVER_ERROR + ) + else: + return Response( + data=self._get_rq_response('default', f'/api/v1/project/{pk}/dataset_import'), + status=status.HTTP_202_ACCEPTED + ) + else: + format_name = request.query_params.get("format", "") + return _export_annotations( + db_instance=db_project, + rq_id="/api/v1/project/{}/dataset/{}".format(pk, format_name), + request=request, + action=action, + callback=dm.views.export_project_as_dataset, + format_name=format_name, + filename=request.query_params.get("filename", "").lower(), + ) @swagger_auto_schema(method='get', operation_summary='Method allows to download project annotations', manual_parameters=[ @@ -372,6 +420,24 @@ class ProjectViewSet(auth.ProjectGetQuerySetMixin, viewsets.ModelViewSet): else: return Response("Format is not specified",status=status.HTTP_400_BAD_REQUEST) + @staticmethod + def _get_rq_response(queue, job_id): + queue = django_rq.get_queue(queue) + job = queue.fetch_job(job_id) + response = {} + if job is None or job.is_finished: + response = { "state": "Finished" } + elif job.is_queued: + response = { "state": "Queued" } + elif job.is_failed: + response = { "state": "Failed", "message": job.exc_info } + else: + response = { "state": "Started" } + response['message'] = job.meta.get('status', '') + response['progress'] = job.meta.get('progress', 0.) + + return response + class TaskFilter(filters.FilterSet): project = filters.CharFilter(field_name="project__name", lookup_expr="icontains") name = filters.CharFilter(field_name="name", lookup_expr="icontains") @@ -859,6 +925,7 @@ class TaskViewSet(UploadMixin, auth.TaskGetQuerySetMixin, viewsets.ModelViewSet) response = { "state": "Started" } if 'status' in job.meta: response['message'] = job.meta['status'] + response['progress'] = job.meta.get('task_progress', 0.) return response @@ -1608,8 +1675,8 @@ def _export_annotations(db_instance, rq_id, request, format_name, action, callba return Response(status=status.HTTP_405_METHOD_NOT_ALLOWED) queue = django_rq.get_queue("default") - rq_job = queue.fetch_job(rq_id) + if rq_job: last_instance_update_time = timezone.localtime(db_instance.updated_date) if isinstance(db_instance, Project): @@ -1659,3 +1726,38 @@ def _export_annotations(db_instance, rq_id, request, format_name, action, callba meta={ 'request_time': timezone.localtime() }, result_ttl=ttl, failure_ttl=ttl) return Response(status=status.HTTP_202_ACCEPTED) + +def _import_project_dataset(request, rq_id, rq_func, pk, format_name): + format_desc = {f.DISPLAY_NAME: f + for f in dm.views.get_import_formats()}.get(format_name) + if format_desc is None: + raise serializers.ValidationError( + "Unknown input format '{}'".format(format_name)) + elif not format_desc.ENABLED: + return Response(status=status.HTTP_405_METHOD_NOT_ALLOWED) + + queue = django_rq.get_queue("default") + rq_job = queue.fetch_job(rq_id) + + if not rq_job: + serializer = DatasetFileSerializer(data=request.data) + if serializer.is_valid(raise_exception=True): + dataset_file = serializer.validated_data['dataset_file'] + fd, filename = mkstemp(prefix='cvat_{}'.format(pk)) + with open(filename, 'wb+') as f: + for chunk in dataset_file.chunks(): + f.write(chunk) + + rq_job = queue.enqueue_call( + func=rq_func, + args=(pk, filename, format_name), + job_id=rq_id, + meta={ + 'tmp_file': filename, + 'tmp_file_descriptor': fd, + }, + ) + else: + return Response(status=status.HTTP_409_CONFLICT, data='Import job already exists') + + return Response(status=status.HTTP_202_ACCEPTED) diff --git a/cvat/requirements/base.txt b/cvat/requirements/base.txt index 41580bac..e83a1271 100644 --- a/cvat/requirements/base.txt +++ b/cvat/requirements/base.txt @@ -1,3 +1,4 @@ +attrs==21.2.0 click==7.1.2 Django==3.1.13 django-appconf==1.0.4 diff --git a/tests/cypress/support/commands_projects.js b/tests/cypress/support/commands_projects.js index 5a0f75b6..f1db0968 100644 --- a/tests/cypress/support/commands_projects.js +++ b/tests/cypress/support/commands_projects.js @@ -71,7 +71,7 @@ Cypress.Commands.add('exportProject', ({ projectName, type, dumpType, archiveCustomeName, }) => { cy.projectActions(projectName); - cy.get('.cvat-project-actions-menu').contains('Export project dataset').click(); + cy.get('.cvat-project-actions-menu').contains('Export dataset').click(); cy.get('.cvat-modal-export-project').should('be.visible').find('.cvat-modal-export-select').click(); cy.contains('.cvat-modal-export-option-item', dumpType).should('be.visible').click(); cy.get('.cvat-modal-export-select').should('contain.text', dumpType);