diff --git a/CHANGELOG.md b/CHANGELOG.md index 38f9a6a9..6c3eb8f5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -57,6 +57,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Add a tutorial on attaching cloud storage AWS-S3 () and Azure Blob Container () - The feature to remove annotations in a specified range of frames () +- Project backup/restore () ### Changed diff --git a/cvat-core/package-lock.json b/cvat-core/package-lock.json index b4436451..8e2d1dc1 100644 --- a/cvat-core/package-lock.json +++ b/cvat-core/package-lock.json @@ -1,12 +1,12 @@ { "name": "cvat-core", - "version": "3.22.0", + "version": "3.23.0", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "cvat-core", - "version": "3.22.0", + "version": "3.23.0", "license": "MIT", "dependencies": { "axios": "^0.21.4", diff --git a/cvat-core/package.json b/cvat-core/package.json index a420a715..159b1c54 100644 --- a/cvat-core/package.json +++ b/cvat-core/package.json @@ -1,6 +1,6 @@ { "name": "cvat-core", - "version": "3.22.0", + "version": "3.23.0", "description": "Part of Computer Vision Tool which presents an interface for client-side integration", "main": "babel.config.js", "scripts": { diff --git a/cvat-core/src/project-implementation.js b/cvat-core/src/project-implementation.js index d1d59475..4baa0e77 100644 --- a/cvat-core/src/project-implementation.js +++ b/cvat-core/src/project-implementation.js @@ -76,6 +76,16 @@ return importDataset(this, format, file, updateStatusCallback); }; + projectClass.prototype.backup.implementation = async function () { + const result = await serverProxy.projects.backupProject(this.id); + return result; + }; + + projectClass.restore.implementation = async function (file) { + const result = await serverProxy.projects.restoreProject(file); + return result.id; + }; + return projectClass; } diff --git a/cvat-core/src/project.js b/cvat-core/src/project.js index 5f34df5c..32764744 100644 --- a/cvat-core/src/project.js +++ b/cvat-core/src/project.js @@ -294,6 +294,38 @@ const result = await PluginRegistry.apiWrapper.call(this, Project.prototype.delete); return result; } + + /** + * Method makes a backup of a project + * @method export + * @memberof module:API.cvat.classes.Project + * @readonly + * @instance + * @async + * @throws {module:API.cvat.exceptions.ServerError} + * @throws {module:API.cvat.exceptions.PluginError} + * @returns {string} URL to get result archive + */ + async backup() { + const result = await PluginRegistry.apiWrapper.call(this, Project.prototype.backup); + return result; + } + + /** + * Method restores a project from a backup + * @method restore + * @memberof module:API.cvat.classes.Project + * @readonly + * @instance + * @async + * @throws {module:API.cvat.exceptions.ServerError} + * @throws {module:API.cvat.exceptions.PluginError} + * @returns {number} ID of the imported project + */ + static async restore(file) { + const result = await PluginRegistry.apiWrapper.call(this, Project.restore, file); + return result; + } } Object.defineProperties( diff --git a/cvat-core/src/server-proxy.js b/cvat-core/src/server-proxy.js index c0f59c9c..b20e784f 100644 --- a/cvat-core/src/server-proxy.js +++ b/cvat-core/src/server-proxy.js @@ -554,12 +554,12 @@ async function exportTask(id) { const { backendAPI } = config; - const url = `${backendAPI}/tasks/${id}`; + const url = `${backendAPI}/tasks/${id}/backup`; return new Promise((resolve, reject) => { async function request() { try { - const response = await Axios.get(`${url}?action=export`, { + const response = await Axios.get(url, { proxy: config.proxy, }); if (response.status === 202) { @@ -585,7 +585,7 @@ return new Promise((resolve, reject) => { async function request() { try { - const response = await Axios.post(`${backendAPI}/tasks?action=import`, taskData, { + const response = await Axios.post(`${backendAPI}/tasks/backup`, taskData, { proxy: config.proxy, }); if (response.status === 202) { @@ -605,6 +605,59 @@ }); } + async function backupProject(id) { + const { backendAPI } = config; + const url = `${backendAPI}/projects/${id}/backup`; + + return new Promise((resolve, reject) => { + async function request() { + try { + const response = await Axios.get(url, { + proxy: config.proxy, + }); + if (response.status === 202) { + setTimeout(request, 3000); + } else { + resolve(`${url}?action=download`); + } + } catch (errorData) { + reject(generateError(errorData)); + } + } + + setTimeout(request); + }); + } + + async function restoreProject(file) { + const { backendAPI } = config; + + let data = new FormData(); + data.append('project_file', file); + + return new Promise((resolve, reject) => { + async function request() { + try { + const response = await Axios.post(`${backendAPI}/projects/backup`, data, { + proxy: config.proxy, + }); + if (response.status === 202) { + data = new FormData(); + data.append('rq_id', response.data.rq_id); + setTimeout(request, 3000); + } else { + const restoredProject = await getProjects(`?id=${response.data.id}`); + resolve(restoredProject[0]); + } + } catch (errorData) { + reject(generateError(errorData)); + } + } + + setTimeout(request); + }); + } + async function createTask(taskSpec, taskDataSpec, onUpdate) { const { backendAPI, origin } = config; @@ -1476,6 +1529,8 @@ create: createProject, delete: deleteProject, exportDataset: exportDataset('projects'), + backupProject, + restoreProject, importDataset, }), writable: false, diff --git a/cvat-ui/package-lock.json b/cvat-ui/package-lock.json index 34f85b08..f113b9bf 100644 --- a/cvat-ui/package-lock.json +++ b/cvat-ui/package-lock.json @@ -1,12 +1,12 @@ { "name": "cvat-ui", - "version": "1.30.0", + "version": "1.31.0", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "cvat-ui", - "version": "1.30.0", + "version": "1.31.0", "license": "MIT", "dependencies": { "@ant-design/icons": "^4.6.3", diff --git a/cvat-ui/package.json b/cvat-ui/package.json index 00e91ff2..ec5942cf 100644 --- a/cvat-ui/package.json +++ b/cvat-ui/package.json @@ -1,6 +1,6 @@ { "name": "cvat-ui", - "version": "1.30.0", + "version": "1.31.0", "description": "CVAT single-page application", "main": "src/index.tsx", "scripts": { diff --git a/cvat-ui/src/actions/projects-actions.ts b/cvat-ui/src/actions/projects-actions.ts index 2a1abf47..61b9105b 100644 --- a/cvat-ui/src/actions/projects-actions.ts +++ b/cvat-ui/src/actions/projects-actions.ts @@ -26,6 +26,12 @@ export enum ProjectsActionTypes { DELETE_PROJECT = 'DELETE_PROJECT', DELETE_PROJECT_SUCCESS = 'DELETE_PROJECT_SUCCESS', DELETE_PROJECT_FAILED = 'DELETE_PROJECT_FAILED', + BACKUP_PROJECT = 'BACKUP_PROJECT', + BACKUP_PROJECT_SUCCESS = 'BACKUP_PROJECT_SUCCESS', + BACKUP_PROJECT_FAILED = 'BACKUP_PROJECT_FAILED', + RESTORE_PROJECT = 'IMPORT_PROJECT', + RESTORE_PROJECT_SUCCESS = 'IMPORT_PROJECT_SUCCESS', + RESTORE_PROJECT_FAILED = 'IMPORT_PROJECT_FAILED', } // prettier-ignore @@ -55,6 +61,20 @@ const projectActions = { deleteProjectFailed: (projectId: number, error: any) => ( createAction(ProjectsActionTypes.DELETE_PROJECT_FAILED, { projectId, error }) ), + backupProject: (projectId: number) => createAction(ProjectsActionTypes.BACKUP_PROJECT, { projectId }), + backupProjectSuccess: (projectID: number) => ( + createAction(ProjectsActionTypes.BACKUP_PROJECT_SUCCESS, { projectID }) + ), + backupProjectFailed: (projectID: number, error: any) => ( + createAction(ProjectsActionTypes.BACKUP_PROJECT_FAILED, { projectId: projectID, error }) + ), + restoreProject: () => createAction(ProjectsActionTypes.RESTORE_PROJECT), + restoreProjectSuccess: (projectID: number) => ( + createAction(ProjectsActionTypes.RESTORE_PROJECT_SUCCESS, { projectID }) + ), + restoreProjectFailed: (error: any) => ( + createAction(ProjectsActionTypes.RESTORE_PROJECT_FAILED, { error }) + ), }; export type ProjectActions = ActionUnion; @@ -163,3 +183,31 @@ export function deleteProjectAsync(projectInstance: any): ThunkAction { } }; } + +export function restoreProjectAsync(file: File): ThunkAction { + return async (dispatch: ActionCreator): Promise => { + dispatch(projectActions.restoreProject()); + try { + const projectInstance = await cvat.classes.Project.restore(file); + dispatch(projectActions.restoreProjectSuccess(projectInstance)); + } catch (error) { + dispatch(projectActions.restoreProjectFailed(error)); + } + }; +} + +export function backupProjectAsync(projectInstance: any): ThunkAction { + return async (dispatch: ActionCreator): Promise => { + dispatch(projectActions.backupProject(projectInstance.id)); + + try { + const url = await projectInstance.backup(); + const downloadAnchor = window.document.getElementById('downloadAnchor') as HTMLAnchorElement; + downloadAnchor.href = url; + downloadAnchor.click(); + dispatch(projectActions.backupProjectSuccess(projectInstance.id)); + } catch (error) { + dispatch(projectActions.backupProjectFailed(projectInstance.id, error)); + } + }; +} diff --git a/cvat-ui/src/components/actions-menu/actions-menu.tsx b/cvat-ui/src/components/actions-menu/actions-menu.tsx index 747167d8..13763f57 100644 --- a/cvat-ui/src/components/actions-menu/actions-menu.tsx +++ b/cvat-ui/src/components/actions-menu/actions-menu.tsx @@ -3,7 +3,7 @@ // SPDX-License-Identifier: MIT import './styles.scss'; -import React from 'react'; +import React, { useCallback } from 'react'; import Menu from 'antd/lib/menu'; import Modal from 'antd/lib/modal'; import { LoadingOutlined } from '@ant-design/icons'; @@ -50,29 +50,32 @@ function ActionsMenuComponent(props: Props): JSX.Element { exportIsActive, } = props; - function onClickMenuWrapper(params: MenuInfo): void { - if (!params) { - return; - } + const onClickMenuWrapper = useCallback( + (params: MenuInfo) => { + if (!params) { + return; + } - if (params.key === Actions.DELETE_TASK) { - Modal.confirm({ - title: `The task ${taskID} will be deleted`, - content: 'All related data (images, annotations) will be lost. Continue?', - className: 'cvat-modal-confirm-delete-task', - onOk: () => { - onClickMenu(params); - }, - okButtonProps: { - type: 'primary', - danger: true, - }, - okText: 'Delete', - }); - } else { - onClickMenu(params); - } - } + if (params.key === Actions.DELETE_TASK) { + Modal.confirm({ + title: `The task ${taskID} will be deleted`, + content: 'All related data (images, annotations) will be lost. Continue?', + className: 'cvat-modal-confirm-delete-task', + onOk: () => { + onClickMenu(params); + }, + okButtonProps: { + type: 'primary', + danger: true, + }, + okText: 'Delete', + }); + } else { + onClickMenu(params); + } + }, + [taskID], + ); return ( @@ -104,9 +107,12 @@ function ActionsMenuComponent(props: Props): JSX.Element { Automatic annotation - - {exportIsActive && } - Export task + } + > + Backup Task Move to project diff --git a/cvat-ui/src/components/projects-page/actions-menu.tsx b/cvat-ui/src/components/projects-page/actions-menu.tsx index f7f16611..7c956306 100644 --- a/cvat-ui/src/components/projects-page/actions-menu.tsx +++ b/cvat-ui/src/components/projects-page/actions-menu.tsx @@ -3,11 +3,13 @@ // SPDX-License-Identifier: MIT import React, { useCallback } from 'react'; -import { useDispatch } from 'react-redux'; +import { useDispatch, useSelector } from 'react-redux'; import Modal from 'antd/lib/modal'; import Menu from 'antd/lib/menu'; +import { LoadingOutlined } from '@ant-design/icons'; -import { deleteProjectAsync } from 'actions/projects-actions'; +import { CombinedState } from 'reducers/interfaces'; +import { deleteProjectAsync, backupProjectAsync } from 'actions/projects-actions'; import { exportActions } from 'actions/export-actions'; import { importActions } from 'actions/import-actions'; @@ -19,6 +21,8 @@ export default function ProjectActionsMenuComponent(props: Props): JSX.Element { const { projectInstance } = props; const dispatch = useDispatch(); + const activeBackups = useSelector((state: CombinedState) => state.projects.activities.backups); + const exportIsActive = projectInstance.id in activeBackups; const onDeleteProject = useCallback((): void => { Modal.confirm({ @@ -44,6 +48,13 @@ export default function ProjectActionsMenuComponent(props: Props): JSX.Element { dispatch(importActions.openImportModal(projectInstance))}> Import dataset + dispatch(backupProjectAsync(projectInstance))} + icon={exportIsActive && } + > + Backup Project + Delete diff --git a/cvat-ui/src/components/projects-page/project-item.tsx b/cvat-ui/src/components/projects-page/project-item.tsx index d2f6a9e7..c9292909 100644 --- a/cvat-ui/src/components/projects-page/project-item.tsx +++ b/cvat-ui/src/components/projects-page/project-item.tsx @@ -24,7 +24,7 @@ interface Props { const useCardHeight = useCardHeightHOC({ containerClassName: 'cvat-projects-page', - siblingClassNames: ['cvat-projects-pagination', 'cvat-projects-top-bar'], + siblingClassNames: ['cvat-projects-pagination', 'cvat-projects-page-top-bar'], paddings: 40, numberOfRows: 3, }); diff --git a/cvat-ui/src/components/projects-page/projects-page.tsx b/cvat-ui/src/components/projects-page/projects-page.tsx index 941079a5..74691b96 100644 --- a/cvat-ui/src/components/projects-page/projects-page.tsx +++ b/cvat-ui/src/components/projects-page/projects-page.tsx @@ -23,16 +23,18 @@ export default function ProjectsPageComponent(): JSX.Element { const projectFetching = useSelector((state: CombinedState) => state.projects.fetching); const projectsCount = useSelector((state: CombinedState) => state.projects.current.length); const gettingQuery = useSelector((state: CombinedState) => state.projects.gettingQuery); + const isImporting = useSelector((state: CombinedState) => state.projects.restoring); const anySearchQuery = !!Array.from(new URLSearchParams(search).keys()).filter((value) => value !== 'page').length; - useEffect(() => { + const getSearchParams = (): Partial => { const searchParams: Partial = {}; for (const [param, value] of new URLSearchParams(search)) { searchParams[param] = ['page', 'id'].includes(param) ? Number.parseInt(value, 10) : value; } - dispatch(getProjectsAsync(searchParams)); - }, []); + + return searchParams; + }; useEffect(() => { const searchParams = new URLSearchParams(); @@ -47,6 +49,12 @@ export default function ProjectsPageComponent(): JSX.Element { }); }, [gettingQuery]); + useEffect(() => { + if (isImporting === false) { + dispatch(getProjectsAsync(getSearchParams())); + } + }, [isImporting]); + if (projectFetching) { return ; } diff --git a/cvat-ui/src/components/projects-page/styles.scss b/cvat-ui/src/components/projects-page/styles.scss index f46a86ba..6c060128 100644 --- a/cvat-ui/src/components/projects-page/styles.scss +++ b/cvat-ui/src/components/projects-page/styles.scss @@ -10,6 +10,22 @@ height: 100%; width: 100%; + .cvat-projects-page-top-bar { + > div:nth-child(1) { + > div:nth-child(1) { + width: 100%; + + > div:nth-child(1) { + display: flex; + + span { + margin-right: $grid-unit-size; + } + } + } + } + } + > div:nth-child(1) { padding-bottom: $grid-unit-size; @@ -45,18 +61,11 @@ } } -.cvat-projects-top-bar { - > div:first-child { - .cvat-title { - margin-right: $grid-unit-size; +.cvat-projects-page-top-bar { + > div:nth-child(1) { + > div:nth-child(1) { + width: 100%; } - - display: flex; - } - - > div:nth-child(2) { - display: flex; - justify-content: flex-end; } } @@ -142,3 +151,15 @@ display: flex; flex-wrap: wrap; } + +#cvat-export-project-loading { + margin-left: 10; +} + +#cvat-import-project-button { + padding: 0 30px; +} + +#cvat-import-project-button-loading { + margin-left: 10; +} diff --git a/cvat-ui/src/components/projects-page/top-bar.tsx b/cvat-ui/src/components/projects-page/top-bar.tsx index 46d2a03b..57c9c02b 100644 --- a/cvat-ui/src/components/projects-page/top-bar.tsx +++ b/cvat-ui/src/components/projects-page/top-bar.tsx @@ -8,44 +8,70 @@ import { useHistory } from 'react-router'; import { Row, Col } from 'antd/lib/grid'; import Button from 'antd/lib/button'; import Text from 'antd/lib/typography/Text'; -import { PlusOutlined } from '@ant-design/icons'; +import { PlusOutlined, UploadOutlined, LoadingOutlined } from '@ant-design/icons'; +import Upload from 'antd/lib/upload'; import SearchField from 'components/search-field/search-field'; import { CombinedState, ProjectsQuery } from 'reducers/interfaces'; -import { getProjectsAsync } from 'actions/projects-actions'; +import { getProjectsAsync, restoreProjectAsync } from 'actions/projects-actions'; export default function TopBarComponent(): JSX.Element { const history = useHistory(); const dispatch = useDispatch(); const query = useSelector((state: CombinedState) => state.projects.gettingQuery); - const dimensions = { - md: 11, - lg: 9, - xl: 8, - xxl: 8, - }; + const isImporting = useSelector((state: CombinedState) => state.projects.restoring); return ( - - - Projects - dispatch(getProjectsAsync(_query))} - /> - - - + + + + + Projects + dispatch(getProjectsAsync(_query))} + /> + + + + + { + dispatch(restoreProjectAsync(file)); + return false; + }} + > + + + + + + + + + ); diff --git a/cvat-ui/src/components/tasks-page/top-bar.tsx b/cvat-ui/src/components/tasks-page/top-bar.tsx index f1e2f685..e52383a2 100644 --- a/cvat-ui/src/components/tasks-page/top-bar.tsx +++ b/cvat-ui/src/components/tasks-page/top-bar.tsx @@ -55,7 +55,7 @@ export default function TopBarComponent(props: VisibleTopBarProps): JSX.Element disabled={taskImporting} icon={} > - Import Task + Create from backup {taskImporting && } diff --git a/cvat-ui/src/reducers/interfaces.ts b/cvat-ui/src/reducers/interfaces.ts index 895d0698..c5d61bda 100644 --- a/cvat-ui/src/reducers/interfaces.ts +++ b/cvat-ui/src/reducers/interfaces.ts @@ -53,7 +53,11 @@ export interface ProjectsState { deletes: { [projectId: number]: boolean; // deleted (deleting if in dictionary) }; + backups: { + [projectId: number]: boolean; + } }; + restoring: boolean; } export interface TasksQuery { @@ -330,6 +334,8 @@ export interface NotificationsState { updating: null | ErrorState; deleting: null | ErrorState; creating: null | ErrorState; + restoring: null | ErrorState; + backuping: null | ErrorState; }; tasks: { fetching: null | ErrorState; @@ -434,6 +440,9 @@ export interface NotificationsState { requestPasswordResetDone: string; resetPasswordDone: string; }; + projects: { + restoringDone: string; + } }; } diff --git a/cvat-ui/src/reducers/notifications-reducer.ts b/cvat-ui/src/reducers/notifications-reducer.ts index 2f6b2f8d..d35b3046 100644 --- a/cvat-ui/src/reducers/notifications-reducer.ts +++ b/cvat-ui/src/reducers/notifications-reducer.ts @@ -42,6 +42,8 @@ const defaultState: NotificationsState = { updating: null, deleting: null, creating: null, + restoring: null, + backuping: null, }, tasks: { fetching: null, @@ -146,6 +148,9 @@ const defaultState: NotificationsState = { requestPasswordResetDone: '', resetPasswordDone: '', }, + projects: { + restoringDone: '', + }, }, }; @@ -581,6 +586,51 @@ export default function (state = defaultState, action: AnyAction): Notifications }, }; } + case ProjectsActionTypes.BACKUP_PROJECT_FAILED: { + return { + ...state, + errors: { + ...state.errors, + projects: { + ...state.errors.projects, + backuping: { + message: `Could not backup the project #${action.payload.projectId}`, + reason: action.payload.error.toString(), + }, + }, + }, + }; + } + case ProjectsActionTypes.RESTORE_PROJECT_FAILED: { + return { + ...state, + errors: { + ...state.errors, + projects: { + ...state.errors.projects, + restoring: { + message: 'Could not restore the project', + reason: action.payload.error.toString(), + }, + }, + }, + }; + } + case ProjectsActionTypes.RESTORE_PROJECT_SUCCESS: { + const { projectID } = action.payload; + return { + ...state, + messages: { + ...state.messages, + projects: { + ...state.messages.projects, + restoringDone: + `Project has been created succesfully. + Click here to open`, + }, + }, + }; + } case FormatsActionTypes.GET_FORMATS_FAILED: { return { ...state, diff --git a/cvat-ui/src/reducers/projects-reducer.ts b/cvat-ui/src/reducers/projects-reducer.ts index d8867d2a..f9c1458a 100644 --- a/cvat-ui/src/reducers/projects-reducer.ts +++ b/cvat-ui/src/reducers/projects-reducer.ts @@ -3,6 +3,7 @@ // SPDX-License-Identifier: MIT import { AnyAction } from 'redux'; +import { omit } from 'lodash'; import { ProjectsActionTypes } from 'actions/projects-actions'; import { BoundariesActionTypes } from 'actions/boundaries-actions'; import { AuthActionTypes } from 'actions/auth-actions'; @@ -41,7 +42,9 @@ const defaultState: ProjectsState = { id: null, error: '', }, + backups: {}, }, + restoring: false, }; export default (state: ProjectsState = defaultState, action: AnyAction): ProjectsState => { @@ -206,6 +209,48 @@ export default (state: ProjectsState = defaultState, action: AnyAction): Project }, }; } + case ProjectsActionTypes.BACKUP_PROJECT: { + const { projectId } = action.payload; + const { backups } = state.activities; + + return { + ...state, + activities: { + ...state.activities, + backups: { + ...backups, + ...Object.fromEntries([[projectId, true]]), + }, + }, + }; + } + case ProjectsActionTypes.BACKUP_PROJECT_FAILED: + case ProjectsActionTypes.BACKUP_PROJECT_SUCCESS: { + const { projectID } = action.payload; + const { backups } = state.activities; + + return { + ...state, + activities: { + ...state.activities, + backups: omit(backups, [projectID]), + }, + }; + } + case ProjectsActionTypes.RESTORE_PROJECT: { + return { + ...state, + restoring: true, + }; + } + case ProjectsActionTypes.RESTORE_PROJECT_FAILED: + case ProjectsActionTypes.RESTORE_PROJECT_SUCCESS: { + return { + ...state, + restoring: false, + }; + } + case BoundariesActionTypes.RESET_AFTER_ERROR: case AuthActionTypes.LOGOUT_SUCCESS: { return { ...defaultState }; diff --git a/cvat/apps/authentication/auth.py b/cvat/apps/authentication/auth.py index 5c1f8ea3..707ac7f0 100644 --- a/cvat/apps/authentication/auth.py +++ b/cvat/apps/authentication/auth.py @@ -21,7 +21,7 @@ class TokenAuthentication(_TokenAuthentication): def authenticate(self, request): auth = super().authenticate(request) session = getattr(request, 'session') - if auth is not None and session.session_key is None: + if auth is not None and (session.session_key is None or (not session.modified and not session.load())): login(request, auth[0], 'django.contrib.auth.backends.ModelBackend') return auth diff --git a/cvat/apps/dataset_manager/views.py b/cvat/apps/dataset_manager/views.py index 12708242..a0127f3b 100644 --- a/cvat/apps/dataset_manager/views.py +++ b/cvat/apps/dataset_manager/views.py @@ -16,7 +16,6 @@ import cvat.apps.dataset_manager.task as task import cvat.apps.dataset_manager.project as project from cvat.apps.engine.log import slogger from cvat.apps.engine.models import Project, Task -from cvat.apps.engine.backup import TaskExporter from .formats.registry import EXPORT_FORMATS, IMPORT_FORMATS from .util import current_function_name @@ -80,8 +79,9 @@ def export(dst_format, task_id=None, project_id=None, server_url=None, save_imag scheduler = django_rq.get_scheduler() cleaning_job = scheduler.enqueue_in(time_delta=cache_ttl, func=clear_export_cache, - task_id=task_id, - file_path=output_path, file_ctime=archive_ctime) + file_path=output_path, + file_ctime=archive_ctime, + logger=logger) logger.info( "The {} '{}' is exported as '{}' at '{}' " "and available for downloading for the next {}. " @@ -109,50 +109,16 @@ def export_project_as_dataset(project_id, dst_format=None, server_url=None): def export_project_annotations(project_id, dst_format=None, server_url=None): return export(dst_format, project_id=project_id, server_url=server_url, save_images=False) -def clear_export_cache(task_id, file_path, file_ctime): +def clear_export_cache(file_path, file_ctime, logger): try: if osp.exists(file_path) and osp.getctime(file_path) == file_ctime: os.remove(file_path) - slogger.task[task_id].info( + + logger.info( "Export cache file '{}' successfully removed" \ .format(file_path)) except Exception: - log_exception(slogger.task[task_id]) - raise - -def backup_task(task_id, output_path): - try: - db_task = Task.objects.get(pk=task_id) - - cache_dir = get_export_cache_dir(db_task) - output_path = osp.join(cache_dir, output_path) - - task_time = timezone.localtime(db_task.updated_date).timestamp() - if not (osp.exists(output_path) and \ - task_time <= osp.getmtime(output_path)): - os.makedirs(cache_dir, exist_ok=True) - with tempfile.TemporaryDirectory(dir=cache_dir) as temp_dir: - temp_file = osp.join(temp_dir, 'dump') - task_exporter = TaskExporter(task_id) - task_exporter.export_to(temp_file) - os.replace(temp_file, output_path) - - archive_ctime = osp.getctime(output_path) - scheduler = django_rq.get_scheduler() - cleaning_job = scheduler.enqueue_in(time_delta=TASK_CACHE_TTL, - func=clear_export_cache, - task_id=task_id, - file_path=output_path, file_ctime=archive_ctime) - slogger.task[task_id].info( - "The task '{}' is backuped at '{}' " - "and available for downloading for the next {}. " - "Export cache cleaning job is enqueued, id '{}'".format( - db_task.name, output_path, TASK_CACHE_TTL, - cleaning_job.id)) - - return output_path - except Exception: - log_exception(slogger.task[task_id]) + log_exception(logger) raise def get_export_formats(): diff --git a/cvat/apps/engine/backup.py b/cvat/apps/engine/backup.py index 1803cf85..55e6c083 100644 --- a/cvat/apps/engine/backup.py +++ b/cvat/apps/engine/backup.py @@ -5,45 +5,93 @@ import io import os from enum import Enum +import re import shutil +import tempfile +import uuid from zipfile import ZipFile +from datetime import datetime +from tempfile import mkstemp +import django_rq from django.conf import settings from django.db import transaction +from django.utils import timezone +from rest_framework import serializers, status from rest_framework.parsers import JSONParser from rest_framework.renderers import JSONRenderer +from rest_framework.response import Response +from sendfile import sendfile import cvat.apps.dataset_manager as dm from cvat.apps.engine import models from cvat.apps.engine.log import slogger from cvat.apps.engine.serializers import (AttributeSerializer, DataSerializer, LabeledDataSerializer, SegmentSerializer, SimpleJobSerializer, TaskSerializer, - ReviewSerializer, IssueSerializer, CommentSerializer) + ReviewSerializer, IssueSerializer, CommentSerializer, ProjectSerializer, + ProjectFileSerializer, TaskFileSerializer) from cvat.apps.engine.utils import av_scan_paths -from cvat.apps.engine.models import StorageChoice, StorageMethodChoice, DataChoice +from cvat.apps.engine.models import StorageChoice, StorageMethodChoice, DataChoice, Task, Project from cvat.apps.engine.task import _create_thread +from cvat.apps.dataset_manager.views import TASK_CACHE_TTL, PROJECT_CACHE_TTL, get_export_cache_dir, clear_export_cache, log_exception +from cvat.apps.dataset_manager.bindings import CvatImportError class Version(Enum): V1 = '1.0' -class _TaskBackupBase(): - MANIFEST_FILENAME = 'task.json' - ANNOTATIONS_FILENAME = 'annotations.json' - DATA_DIRNAME = 'data' - TASK_DIRNAME = 'task' + +def _get_label_mapping(db_labels): + label_mapping = {db_label.id: db_label.name for db_label in db_labels} + for db_label in db_labels: + label_mapping[db_label.id] = { + 'value': db_label.name, + 'attributes': {}, + } + for db_attribute in db_label.attributespec_set.all(): + label_mapping[db_label.id]['attributes'][db_attribute.id] = db_attribute.name + + return label_mapping + +class _BackupBase(): + def __init__(self, *args, logger=None, **kwargs): + super().__init__(*args, **kwargs) + self._logger = logger def _prepare_meta(self, allowed_keys, meta): keys_to_drop = set(meta.keys()) - allowed_keys if keys_to_drop: - logger = slogger.task[self._db_task.id] if hasattr(self, '_db_task') else slogger.glob - - logger.warning('the following keys are dropped {}'.format(keys_to_drop)) + if self._logger: + self._logger.warning('the following keys are dropped {}'.format(keys_to_drop)) for key in keys_to_drop: del meta[key] return meta + def _prepare_label_meta(self, labels): + allowed_fields = { + 'name', + 'color', + 'attributes', + } + return self._prepare_meta(allowed_fields, labels) + + def _prepare_attribute_meta(self, attribute): + allowed_fields = { + 'name', + 'mutable', + 'input_type', + 'default_value', + 'values', + } + return self._prepare_meta(allowed_fields, attribute) + +class _TaskBackupBase(_BackupBase): + MANIFEST_FILENAME = 'task.json' + ANNOTATIONS_FILENAME = 'annotations.json' + DATA_DIRNAME = 'data' + TASK_DIRNAME = 'task' + def _prepare_task_meta(self, task): allowed_fields = { 'name', @@ -80,24 +128,6 @@ class _TaskBackupBase(): } return self._prepare_meta(allowed_fields, job) - def _prepare_attribute_meta(self, attribute): - allowed_fields = { - 'name', - 'mutable', - 'input_type', - 'default_value', - 'values', - } - return self._prepare_meta(allowed_fields, attribute) - - def _prepare_label_meta(self, labels): - allowed_fields = { - 'name', - 'color', - 'attributes', - } - return self._prepare_meta(allowed_fields, labels) - def _prepare_annotations(self, annotations, label_mapping): allowed_fields = { 'label', @@ -190,27 +220,12 @@ class _TaskBackupBase(): return db_jobs return () -class TaskExporter(_TaskBackupBase): - def __init__(self, pk, version=Version.V1): - self._db_task = models.Task.objects.prefetch_related('data__images').select_related('data__video').get(pk=pk) - self._db_data = self._db_task.data - self._version = version - - db_labels = (self._db_task.project if self._db_task.project_id else self._db_task).label_set.all().prefetch_related( - 'attributespec_set') +class _ExporterBase(): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) - self._label_mapping = {} - self._label_mapping = {db_label.id: db_label.name for db_label in db_labels} - self._attribute_mapping = {} - for db_label in db_labels: - self._label_mapping[db_label.id] = { - 'value': db_label.name, - 'attributes': {}, - } - for db_attribute in db_label.attributespec_set.all(): - self._label_mapping[db_label.id]['attributes'][db_attribute.id] = db_attribute.name - - def _write_files(self, source_dir, zip_object, files, target_dir): + @staticmethod + def _write_files(source_dir, zip_object, files, target_dir): for filename in files: arcname = os.path.normpath( os.path.join( @@ -233,12 +248,24 @@ class TaskExporter(_TaskBackupBase): target_dir=target_dir, ) - def _write_data(self, zip_object): +class TaskExporter(_ExporterBase, _TaskBackupBase): + def __init__(self, pk, version=Version.V1): + super().__init__(logger=slogger.task[pk]) + self._db_task = models.Task.objects.prefetch_related('data__images').select_related('data__video').get(pk=pk) + self._db_data = self._db_task.data + self._version = version + + db_labels = (self._db_task.project if self._db_task.project_id else self._db_task).label_set.all().prefetch_related( + 'attributespec_set') + self._label_mapping = _get_label_mapping(db_labels) + + def _write_data(self, zip_object, target_dir=None): + target_data_dir = os.path.join(target_dir, self.DATA_DIRNAME) if target_dir else self.DATA_DIRNAME if self._db_data.storage == StorageChoice.LOCAL: self._write_directory( source_dir=self._db_data.get_upload_dirname(), zip_object=zip_object, - target_dir=self.DATA_DIRNAME, + target_dir=target_data_dir, ) elif self._db_data.storage == StorageChoice.SHARE: data_dir = settings.SHARE_ROOT @@ -251,7 +278,7 @@ class TaskExporter(_TaskBackupBase): source_dir=data_dir, zip_object=zip_object, files=media_files, - target_dir=self.DATA_DIRNAME + target_dir=target_data_dir, ) upload_dir = self._db_data.get_upload_dirname() @@ -259,27 +286,26 @@ class TaskExporter(_TaskBackupBase): source_dir=upload_dir, zip_object=zip_object, files=(os.path.join(upload_dir, f) for f in ('manifest.jsonl',)), - target_dir=self.DATA_DIRNAME + target_dir=target_data_dir, ) else: raise NotImplementedError() - def _write_task(self, zip_object): + def _write_task(self, zip_object, target_dir=None): task_dir = self._db_task.get_task_dirname() + target_task_dir = os.path.join(target_dir, self.TASK_DIRNAME) if target_dir else self.TASK_DIRNAME self._write_directory( source_dir=task_dir, zip_object=zip_object, - target_dir=self.TASK_DIRNAME, + target_dir=target_task_dir, recursive=False, ) - def _write_manifest(self, zip_object): + def _write_manifest(self, zip_object, target_dir=None): def serialize_task(): task_serializer = TaskSerializer(self._db_task) - task_serializer.fields.pop('url') - task_serializer.fields.pop('owner') - task_serializer.fields.pop('assignee') - task_serializer.fields.pop('segments') + for field in ('url', 'owner', 'assignee', 'segments'): + task_serializer.fields.pop(field) task = self._prepare_task_meta(task_serializer.data) task['labels'] = [self._prepare_label_meta(l) for l in task['labels']] @@ -317,9 +343,8 @@ class TaskExporter(_TaskBackupBase): def serialize_segment(db_segment): db_job = db_segment.job_set.first() job_serializer = SimpleJobSerializer(db_job) - job_serializer.fields.pop('url') - job_serializer.fields.pop('assignee') - job_serializer.fields.pop('reviewer') + for field in ('url', 'assignee', 'reviewer'): + job_serializer.fields.pop(field) job_data = self._prepare_job_meta(job_serializer.data) segment_serailizer = SegmentSerializer(db_segment) @@ -348,9 +373,10 @@ class TaskExporter(_TaskBackupBase): task['data'] = serialize_data() task['jobs'] = serialize_jobs() - zip_object.writestr(self.MANIFEST_FILENAME, data=JSONRenderer().render(task)) + target_manifest_file = os.path.join(target_dir, self.MANIFEST_FILENAME) if target_dir else self.MANIFEST_FILENAME + zip_object.writestr(target_manifest_file, data=JSONRenderer().render(task)) - def _write_annotations(self, zip_object): + def _write_annotations(self, zip_object, target_dir=None): def serialize_annotations(): job_annotations = [] db_jobs = self._get_db_jobs() @@ -364,36 +390,35 @@ class TaskExporter(_TaskBackupBase): return job_annotations annotations = serialize_annotations() - zip_object.writestr(self.ANNOTATIONS_FILENAME, data=JSONRenderer().render(annotations)) + target_annotations_file = os.path.join(target_dir, self.ANNOTATIONS_FILENAME) if target_dir else self.ANNOTATIONS_FILENAME + zip_object.writestr(target_annotations_file, data=JSONRenderer().render(annotations)) - def export_to(self, filename): + def _export_task(self, zip_obj, target_dir=None): + self._write_data(zip_obj, target_dir) + self._write_task(zip_obj, target_dir) + self._write_manifest(zip_obj, target_dir) + self._write_annotations(zip_obj, target_dir) + + def export_to(self, file, target_dir=None): if self._db_task.data.storage_method == StorageMethodChoice.FILE_SYSTEM and \ self._db_task.data.storage == StorageChoice.SHARE: raise Exception('The task cannot be exported because it does not contain any raw data') - with ZipFile(filename, 'w') as output_file: - self._write_data(output_file) - self._write_task(output_file) - self._write_manifest(output_file) - self._write_annotations(output_file) - -class TaskImporter(_TaskBackupBase): - def __init__(self, filename, user_id): - self._filename = filename - self._user_id = user_id - self._manifest, self._annotations = self._read_meta() - self._version = self._read_version() - self._labels_mapping = {} - self._db_task = None - def _read_meta(self): - with ZipFile(self._filename, 'r') as input_file: - manifest = JSONParser().parse(io.BytesIO(input_file.read(self.MANIFEST_FILENAME))) - annotations = JSONParser().parse(io.BytesIO(input_file.read(self.ANNOTATIONS_FILENAME))) + if isinstance(file, str): + with ZipFile(file, 'w') as zf: + self._export_task(zip_obj=zf, target_dir=target_dir) + elif isinstance(file, ZipFile): + self._export_task(zip_obj=file, target_dir=target_dir) + else: + raise ValueError('Unsuported type of file argument') - return manifest, annotations +class _ImporterBase(): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) - def _read_version(self): - version = self._manifest.pop('version') + @staticmethod + def _read_version(manifest): + version = manifest.pop('version') try: return Version(version) except ValueError: @@ -405,13 +430,23 @@ class TaskImporter(_TaskBackupBase): if not os.path.exists(target_dir): os.makedirs(target_dir) - def _create_labels(self, db_task, labels): + @staticmethod + def _create_labels(labels, db_task=None, db_project=None): label_mapping = {} + if db_task: + label_relation = { + 'task': db_task + } + else: + label_relation = { + 'project': db_project + } + for label in labels: label_name = label['name'] attributes = label.pop('attributes', []) - db_label = models.Label.objects.create(task=db_task, **label) + db_label = models.Label.objects.create(**label_relation, **label) label_mapping[label_name] = { 'value': db_label.id, 'attributes': {}, @@ -426,6 +461,34 @@ class TaskImporter(_TaskBackupBase): return label_mapping +class TaskImporter(_ImporterBase, _TaskBackupBase): + def __init__(self, file, user_id, project_id=None, subdir=None, label_mapping=None): + super().__init__(logger=slogger.glob) + self._file = file + self._subdir = subdir + self._user_id = user_id + self._manifest, self._annotations = self._read_meta() + self._version = self._read_version(self._manifest) + self._labels_mapping = label_mapping + self._db_task = None + self._project_id=project_id + + def _read_meta(self): + def read(zip_object): + manifest_filename = os.path.join(self._subdir, self.MANIFEST_FILENAME) if self._subdir else self.MANIFEST_FILENAME + annotations_filename = os.path.join(self._subdir, self.ANNOTATIONS_FILENAME) if self._subdir else self.ANNOTATIONS_FILENAME + manifest = JSONParser().parse(io.BytesIO(zip_object.read(manifest_filename))) + annotations = JSONParser().parse(io.BytesIO(zip_object.read(annotations_filename))) + return manifest, annotations + + if isinstance(self._file, str): + with ZipFile(self._file, 'r') as input_file: + return read(input_file) + elif isinstance(self._file, ZipFile): + return read(self._file) + + raise ValueError('Unsuported type of file argument') + def _create_annotations(self, db_job, annotations): self._prepare_annotations(annotations, self._labels_mapping) @@ -441,7 +504,6 @@ class TaskImporter(_TaskBackupBase): return segment_size, overlap def _import_task(self): - def _create_comment(comment, db_issue): comment['issue'] = db_issue.id comment_serializer = CommentSerializer(data=comment) @@ -476,13 +538,36 @@ class TaskImporter(_TaskBackupBase): return db_review + def _write_data(zip_object): + data_path = self._db_task.data.get_upload_dirname() + task_dirname = os.path.join(self._subdir, self.TASK_DIRNAME) if self._subdir else self.TASK_DIRNAME + data_dirname = os.path.join(self._subdir, self.DATA_DIRNAME) if self._subdir else self.DATA_DIRNAME + uploaded_files = [] + for f in zip_object.namelist(): + if f.endswith(os.path.sep): + continue + if f.startswith(data_dirname + os.path.sep): + target_file = os.path.join(data_path, os.path.relpath(f, data_dirname)) + self._prepare_dirs(target_file) + with open(target_file, "wb") as out: + out.write(zip_object.read(f)) + uploaded_files.append(os.path.relpath(f, data_dirname)) + elif f.startswith(task_dirname + os.path.sep): + target_file = os.path.join(task_path, os.path.relpath(f, task_dirname)) + self._prepare_dirs(target_file) + with open(target_file, "wb") as out: + out.write(zip_object.read(f)) + + return uploaded_files + data = self._manifest.pop('data') labels = self._manifest.pop('labels') jobs = self._manifest.pop('jobs') self._prepare_task_meta(self._manifest) self._manifest['segment_size'], self._manifest['overlap'] = self._calculate_segment_size(jobs) - self._manifest["owner_id"] = self._user_id + self._manifest['owner_id'] = self._user_id + self._manifest['project_id'] = self._project_id self._db_task = models.Task.objects.create(**self._manifest) task_path = self._db_task.get_task_dirname() @@ -492,7 +577,8 @@ class TaskImporter(_TaskBackupBase): os.makedirs(self._db_task.get_task_logs_dirname()) os.makedirs(self._db_task.get_task_artifacts_dirname()) - self._labels_mapping = self._create_labels(self._db_task, labels) + if not self._labels_mapping: + self._labels_mapping = self._create_labels(db_task=self._db_task, labels=labels) self._prepare_data_meta(data) data_serializer = DataSerializer(data=data) @@ -501,21 +587,11 @@ class TaskImporter(_TaskBackupBase): self._db_task.data = db_data self._db_task.save() - data_path = self._db_task.data.get_upload_dirname() - uploaded_files = [] - with ZipFile(self._filename, 'r') as input_file: - for f in input_file.namelist(): - if f.startswith(self.DATA_DIRNAME + os.path.sep): - target_file = os.path.join(data_path, os.path.relpath(f, self.DATA_DIRNAME)) - self._prepare_dirs(target_file) - with open(target_file, "wb") as out: - out.write(input_file.read(f)) - uploaded_files.append(os.path.relpath(f, self.DATA_DIRNAME)) - elif f.startswith(self.TASK_DIRNAME + os.path.sep): - target_file = os.path.join(task_path, os.path.relpath(f, self.TASK_DIRNAME)) - self._prepare_dirs(target_file) - with open(target_file, "wb") as out: - out.write(input_file.read(f)) + if isinstance(self._file, str): + with ZipFile(self._file, 'r') as zf: + uploaded_files = _write_data(zf) + else: + uploaded_files = _write_data(self._file) data['use_zip_chunks'] = data.pop('chunk_type') == DataChoice.IMAGESET data = data_serializer.data @@ -545,8 +621,301 @@ class TaskImporter(_TaskBackupBase): return self._db_task @transaction.atomic -def import_task(filename, user): +def _import_task(filename, user): av_scan_paths(filename) task_importer = TaskImporter(filename, user) db_task = task_importer.import_task() return db_task.id + + +class _ProjectBackupBase(_BackupBase): + MANIFEST_FILENAME = 'project.json' + TASKNAME_TEMPLATE = 'task_{}' + + def _prepare_project_meta(self, project): + allowed_fields = { + 'bug_tracker', + 'deimension', + 'labels', + 'name', + 'status', + } + + return self._prepare_meta(allowed_fields, project) + +class ProjectExporter(_ExporterBase, _ProjectBackupBase): + def __init__(self, pk, version=Version.V1): + super().__init__(logger=slogger.project[pk]) + self._db_project = models.Project.objects.prefetch_related('tasks').get(pk=pk) + self._version = version + + db_labels = self._db_project.label_set.all().prefetch_related('attributespec_set') + self._label_mapping = _get_label_mapping(db_labels) + + def _write_tasks(self, zip_object): + for idx, db_task in enumerate(self._db_project.tasks.all().order_by('id')): + TaskExporter(db_task.id, self._version).export_to(zip_object, self.TASKNAME_TEMPLATE.format(idx)) + + def _write_manifest(self, zip_object): + def serialize_project(): + project_serializer = ProjectSerializer(self._db_project) + for field in ('assignee', 'owner', 'tasks', 'training_project', 'url'): + project_serializer.fields.pop(field) + + project = self._prepare_project_meta(project_serializer.data) + project['labels'] = [self._prepare_label_meta(l) for l in project['labels']] + for label in project['labels']: + label['attributes'] = [self._prepare_attribute_meta(a) for a in label['attributes']] + + return project + + project = serialize_project() + project['version'] = self._version.value + + zip_object.writestr(self.MANIFEST_FILENAME, data=JSONRenderer().render(project)) + + def export_to(self, filename): + with ZipFile(filename, 'w') as output_file: + self._write_tasks(output_file) + self._write_manifest(output_file) + +class ProjectImporter(_ImporterBase, _ProjectBackupBase): + TASKNAME_RE = 'task_(\d+)/' + + def __init__(self, filename, user_id): + super().__init__(logger=slogger.glob) + self._filename = filename + self._user_id = user_id + self._manifest = self._read_meta() + self._version = self._read_version(self._manifest) + self._db_project = None + self._labels_mapping = {} + + def _read_meta(self): + with ZipFile(self._filename, 'r') as input_file: + manifest = JSONParser().parse(io.BytesIO(input_file.read(self.MANIFEST_FILENAME))) + + return manifest + + def _import_project(self): + labels = self._manifest.pop('labels') + + self._prepare_project_meta(self._manifest) + self._manifest["owner_id"] = self._user_id + + self._db_project = models.Project.objects.create(**self._manifest) + project_path = self._db_project.get_project_dirname() + if os.path.isdir(project_path): + shutil.rmtree(project_path) + os.makedirs(self._db_project.get_project_logs_dirname()) + + self._labels_mapping = self._create_labels(db_project=self._db_project, labels=labels) + + def _import_tasks(self): + def get_tasks(zip_object): + tasks = {} + for fname in zip_object.namelist(): + m = re.match(self.TASKNAME_RE, fname) + if m: + tasks[int(m.group(1))] = m.group(0) + return [v for _, v in sorted(tasks.items())] + + with ZipFile(self._filename, 'r') as zf: + task_dirs = get_tasks(zf) + for task_dir in task_dirs: + TaskImporter( + file=zf, + user_id=self._user_id, + project_id=self._db_project.id, + subdir=task_dir, + label_mapping=self._labels_mapping).import_task() + + def import_project(self): + self._import_project() + self._import_tasks() + + return self._db_project + +@transaction.atomic +def _import_project(filename, user): + av_scan_paths(filename) + project_importer = ProjectImporter(filename, user) + db_project = project_importer.import_project() + return db_project.id + + +def _create_backup(db_instance, Exporter, output_path, logger, cache_ttl): + try: + cache_dir = get_export_cache_dir(db_instance) + output_path = os.path.join(cache_dir, output_path) + + instance_time = timezone.localtime(db_instance.updated_date).timestamp() + if not (os.path.exists(output_path) and \ + instance_time <= os.path.getmtime(output_path)): + os.makedirs(cache_dir, exist_ok=True) + with tempfile.TemporaryDirectory(dir=cache_dir) as temp_dir: + temp_file = os.path.join(temp_dir, 'dump') + exporter = Exporter(db_instance.id) + exporter.export_to(temp_file) + os.replace(temp_file, output_path) + + archive_ctime = os.path.getctime(output_path) + scheduler = django_rq.get_scheduler() + cleaning_job = scheduler.enqueue_in(time_delta=cache_ttl, + func=clear_export_cache, + file_path=output_path, + file_ctime=archive_ctime, + logger=logger) + logger.info( + "The {} '{}' is backuped at '{}' " + "and available for downloading for the next {}. " + "Export cache cleaning job is enqueued, id '{}'".format( + "project" if isinstance(db_instance, Project) else 'task', + db_instance.name, output_path, cache_ttl, + cleaning_job.id)) + + return output_path + except Exception: + log_exception(logger) + raise + +def export(db_instance, request): + action = request.query_params.get('action', None) + if action not in (None, 'download'): + raise serializers.ValidationError( + "Unexpected action specified for the request") + + if isinstance(db_instance, Task): + filename_prefix = 'task' + logger = slogger.task[db_instance.pk] + Exporter = TaskExporter + cache_ttl = TASK_CACHE_TTL + elif isinstance(db_instance, Project): + filename_prefix = 'project' + logger = slogger.project[db_instance.pk] + Exporter = ProjectExporter + cache_ttl = PROJECT_CACHE_TTL + else: + raise Exception( + "Unexpected type of db_isntance: {}".format(type(db_instance))) + + queue = django_rq.get_queue("default") + rq_id = "/api/v1/{}s/{}/backup".format(filename_prefix, db_instance.pk) + rq_job = queue.fetch_job(rq_id) + if rq_job: + last_project_update_time = timezone.localtime(db_instance.updated_date) + request_time = rq_job.meta.get('request_time', None) + if request_time is None or request_time < last_project_update_time: + rq_job.cancel() + rq_job.delete() + else: + if rq_job.is_finished: + file_path = rq_job.return_value + if action == "download" and os.path.exists(file_path): + rq_job.delete() + + timestamp = datetime.strftime(last_project_update_time, + "%Y_%m_%d_%H_%M_%S") + filename = "{}_{}_backup_{}{}".format( + filename_prefix, db_instance.name, timestamp, + os.path.splitext(file_path)[1]) + return sendfile(request, file_path, attachment=True, + attachment_filename=filename.lower()) + else: + if os.path.exists(file_path): + return Response(status=status.HTTP_201_CREATED) + elif rq_job.is_failed: + exc_info = str(rq_job.exc_info) + rq_job.delete() + return Response(exc_info, + status=status.HTTP_500_INTERNAL_SERVER_ERROR) + else: + return Response(status=status.HTTP_202_ACCEPTED) + + ttl = dm.views.PROJECT_CACHE_TTL.total_seconds() + queue.enqueue_call( + func=_create_backup, + args=(db_instance, Exporter, '{}_backup.zip'.format(filename_prefix), logger, cache_ttl), + job_id=rq_id, + meta={ 'request_time': timezone.localtime() }, + result_ttl=ttl, failure_ttl=ttl) + return Response(status=status.HTTP_202_ACCEPTED) + +def _import(importer, request, rq_id, Serializer, file_field_name): + queue = django_rq.get_queue("default") + rq_job = queue.fetch_job(rq_id) + + if not rq_job: + serializer = Serializer(data=request.data) + serializer.is_valid(raise_exception=True) + payload_file = serializer.validated_data[file_field_name] + fd, filename = mkstemp(prefix='cvat_') + with open(filename, 'wb+') as f: + for chunk in payload_file.chunks(): + f.write(chunk) + rq_job = queue.enqueue_call( + func=importer, + args=(filename, request.user.id), + job_id=rq_id, + meta={ + 'tmp_file': filename, + 'tmp_file_descriptor': fd, + }, + ) + else: + if rq_job.is_finished: + project_id = rq_job.return_value + os.close(rq_job.meta['tmp_file_descriptor']) + os.remove(rq_job.meta['tmp_file']) + rq_job.delete() + return Response({'id': project_id}, status=status.HTTP_201_CREATED) + elif rq_job.is_failed: + os.close(rq_job.meta['tmp_file_descriptor']) + os.remove(rq_job.meta['tmp_file']) + exc_info = str(rq_job.exc_info) + rq_job.delete() + + # RQ adds a prefix with exception class name + import_error_prefix = '{}.{}'.format( + CvatImportError.__module__, CvatImportError.__name__) + if exc_info.startswith(import_error_prefix): + exc_info = exc_info.replace(import_error_prefix + ': ', '') + return Response(data=exc_info, + status=status.HTTP_400_BAD_REQUEST) + else: + return Response(data=exc_info, + status=status.HTTP_500_INTERNAL_SERVER_ERROR) + + return Response({'rq_id': rq_id}, status=status.HTTP_202_ACCEPTED) + +def import_project(request): + if 'rq_id' in request.data: + rq_id = request.data['rq_id'] + else: + rq_id = "{}@/api/v1/projects/{}/import".format(request.user, uuid.uuid4()) + Serializer = ProjectFileSerializer + file_field_name = 'project_file' + + return _import( + importer=_import_project, + request=request, + rq_id=rq_id, + Serializer=Serializer, + file_field_name=file_field_name, + ) + +def import_task(request): + if 'rq_id' in request.data: + rq_id = request.data['rq_id'] + else: + rq_id = "{}@/api/v1/tasks/{}/import".format(request.user, uuid.uuid4()) + Serializer = TaskFileSerializer + file_field_name = 'task_file' + + return _import( + importer=_import_task, + request=request, + rq_id=rq_id, + Serializer=Serializer, + file_field_name=file_field_name, + ) diff --git a/cvat/apps/engine/serializers.py b/cvat/apps/engine/serializers.py index d6eeb9be..cbe05629 100644 --- a/cvat/apps/engine/serializers.py +++ b/cvat/apps/engine/serializers.py @@ -739,6 +739,9 @@ class DatasetFileSerializer(serializers.Serializer): class TaskFileSerializer(serializers.Serializer): task_file = serializers.FileField() +class ProjectFileSerializer(serializers.Serializer): + project_file = serializers.FileField() + class ReviewSerializer(serializers.ModelSerializer): assignee = BasicUserSerializer(allow_null=True, required=False) assignee_id = serializers.IntegerField(write_only=True, allow_null=True, required=False) diff --git a/cvat/apps/engine/tests/test_rest_api.py b/cvat/apps/engine/tests/test_rest_api.py index 7de2c199..3f7efae8 100644 --- a/cvat/apps/engine/tests/test_rest_api.py +++ b/cvat/apps/engine/tests/test_rest_api.py @@ -16,6 +16,7 @@ from enum import Enum from glob import glob from io import BytesIO from unittest import mock +import logging import av import numpy as np @@ -35,6 +36,9 @@ from cvat.apps.engine.models import (AttributeSpec, AttributeType, Data, Job, Pr from cvat.apps.engine.media_extractors import ValidateDimension, sort from utils.dataset_manifest import ImageManifestManager, VideoManifestManager +#supress av warnings +logging.getLogger('libav').setLevel(logging.ERROR) + def create_db_users(cls): (group_admin, _) = Group.objects.get_or_create(name="admin") (group_user, _) = Group.objects.get_or_create(name="user") @@ -1396,6 +1400,391 @@ class ProjectListOfTasksAPITestCase(APITestCase): response = self._run_api_v1_projects_id_tasks(None, project.id) self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) +class ProjectBackupAPITestCase(APITestCase): + @classmethod + def setUpTestData(cls): + create_db_users(cls) + cls._create_media() + cls.client = APIClient() + cls._create_projects() + + @classmethod + def tearDownClass(cls): + super().tearDownClass() + for task in cls.tasks: + shutil.rmtree(os.path.join(settings.TASKS_ROOT, str(task["id"]))) + shutil.rmtree(os.path.join(settings.MEDIA_DATA_ROOT, str(task["data_id"]))) + + for f in cls.media['files']: + os.remove(f) + for d in cls.media['dirs']: + shutil.rmtree(d) + + @classmethod + def _create_media(cls): + cls.media_data = [] + cls.media = { + 'files': [], + 'dirs': [], + } + image_count = 10 + imagename_pattern = "test_{}.jpg" + for i in range(image_count): + filename = imagename_pattern.format(i) + path = os.path.join(settings.SHARE_ROOT, filename) + cls.media['files'].append(path) + _, data = generate_image_file(filename) + with open(path, "wb") as image: + image.write(data.read()) + + cls.media_data.append( + { + **{"image_quality": 75, + "copy_data": True, + "start_frame": 2, + "stop_frame": 9, + "frame_filter": "step=2", + }, + **{"server_files[{}]".format(i): imagename_pattern.format(i) for i in range(image_count)}, + } + ) + + filename = "test_video_1.mp4" + path = os.path.join(settings.SHARE_ROOT, filename) + cls.media['files'].append(path) + _, data = generate_video_file(filename, width=1280, height=720) + with open(path, "wb") as video: + video.write(data.read()) + cls.media_data.append( + { + "image_quality": 75, + "copy_data": True, + "start_frame": 2, + "stop_frame": 24, + "frame_filter": "step=2", + "server_files[0]": filename, + } + ) + + filename = os.path.join("test_archive_1.zip") + path = os.path.join(settings.SHARE_ROOT, filename) + cls.media['files'].append(path) + _, data = generate_zip_archive_file(filename, count=5) + with open(path, "wb") as zip_archive: + zip_archive.write(data.read()) + cls.media_data.append( + { + "image_quality": 75, + "server_files[0]": filename, + } + ) + + filename = os.path.join("videos", "test_video_1.mp4") + path = os.path.join(settings.SHARE_ROOT, filename) + cls.media['dirs'].append(os.path.dirname(path)) + os.makedirs(os.path.dirname(path)) + _, data = generate_video_file(filename, width=1280, height=720) + with open(path, "wb") as video: + video.write(data.read()) + + manifest_path = os.path.join(settings.SHARE_ROOT, 'videos', 'manifest.jsonl') + generate_manifest_file(data_type='video', manifest_path=manifest_path, sources=[path]) + + cls.media_data.append( + { + "image_quality": 70, + "copy_data": True, + "server_files[0]": filename, + "server_files[1]": os.path.join("videos", "manifest.jsonl"), + "use_cache": True, + } + ) + + manifest_path = manifest_path=os.path.join(settings.SHARE_ROOT, 'manifest.jsonl') + generate_manifest_file(data_type='images', manifest_path=manifest_path, + sources=[os.path.join(settings.SHARE_ROOT, imagename_pattern.format(i)) for i in range(1, 8)]) + cls.media['files'].append(manifest_path) + cls.media_data.append( + { + **{"image_quality": 70, + "copy_data": True, + "use_cache": True, + "frame_filter": "step=2", + "server_files[0]": "manifest.jsonl", + }, + **{ + **{"server_files[{}]".format(i): imagename_pattern.format(i) for i in range(1, 8)}, + } + } + ) + + cls.media_data.extend([ + # image list local + { + "client_files[0]": generate_image_file("test_1.jpg")[1], + "client_files[1]": generate_image_file("test_2.jpg")[1], + "client_files[2]": generate_image_file("test_3.jpg")[1], + "image_quality": 75, + }, + # video local + { + "client_files[0]": generate_video_file("test_video.mp4")[1], + "image_quality": 75, + }, + # zip archive local + { + "client_files[0]": generate_zip_archive_file("test_archive_1.zip", 10)[1], + "image_quality": 50, + }, + # pdf local + { + "client_files[0]": generate_pdf_file("test_pdf_1.pdf", 7)[1], + "image_quality": 54, + }, + ]) + + @classmethod + def _create_tasks(cls, project): + def _create_task(task_data, media_data): + response = cls.client.post('/api/v1/tasks', data=task_data, format="json") + assert response.status_code == status.HTTP_201_CREATED + tid = response.data["id"] + + for media in media_data.values(): + if isinstance(media, io.BytesIO): + media.seek(0) + response = cls.client.post("/api/v1/tasks/{}/data".format(tid), data=media_data) + assert response.status_code == status.HTTP_202_ACCEPTED + response = cls.client.get("/api/v1/tasks/{}".format(tid)) + data_id = response.data["data"] + cls.tasks.append({ + "id": tid, + "data_id": data_id, + }) + + task_data = [ + { + "name": "my task #1", + "owner_id": cls.owner.id, + "assignee_id": cls.assignee.id, + "overlap": 0, + "segment_size": 100, + "project_id": project.id, + }, + { + "name": "my task #2", + "owner_id": cls.owner.id, + "assignee_id": cls.assignee.id, + "overlap": 1, + "segment_size": 3, + "project_id": project.id, + }, + ] + + with ForceLogin(cls.owner, cls.client): + for data in task_data: + for media in cls.media_data: + _create_task(data, media) + + @classmethod + def _create_projects(cls): + cls.projects = [] + cls.tasks = [] + data = { + "name": "my empty project", + "owner": cls.owner, + "assignee": cls.assignee, + "labels": [{ + "name": "car", + "color": "#ff00ff", + "attributes": [{ + "name": "bool_attribute", + "mutable": True, + "input_type": AttributeType.CHECKBOX, + "default_value": "true" + }], + }, { + "name": "person", + }, + ], + } + db_project = create_db_project(data) + cls.projects.append(db_project) + + data = { + "name": "my project without assignee", + "owner": cls.user, + "labels": [{ + "name": "car", + "color": "#ff00ff", + "attributes": [{ + "name": "bool_attribute", + "mutable": True, + "input_type": AttributeType.CHECKBOX, + "default_value": "true" + }], + }, { + "name": "person", + }, + ], + } + db_project = create_db_project(data) + cls._create_tasks(db_project) + cls.projects.append(db_project) + + data = { + "name": "my big project", + "owner": cls.owner, + "assignee": cls.assignee, + "labels": [{ + "name": "car", + "color": "#ff00ff", + "attributes": [{ + "name": "bool_attribute", + "mutable": True, + "input_type": AttributeType.CHECKBOX, + "default_value": "true" + }], + }, { + "name": "person", + }, + ], + } + db_project = create_db_project(data) + cls._create_tasks(db_project) + cls.projects.append(db_project) + + data = { + "name": "public project", + "labels": [{ + "name": "car", + "color": "#ff00ff", + "attributes": [{ + "name": "bool_attribute", + "mutable": True, + "input_type": AttributeType.CHECKBOX, + "default_value": "true" + }], + }, { + "name": "person", + }, + ], + } + db_project = create_db_project(data) + cls._create_tasks(db_project) + cls.projects.append(db_project) + + data = { + "name": "super project", + "owner": cls.admin, + "assignee": cls.assignee, + "labels": [{ + "name": "car", + "color": "#ff00ff", + "attributes": [{ + "name": "bool_attribute", + "mutable": True, + "input_type": AttributeType.CHECKBOX, + "default_value": "true" + }], + }, { + "name": "person", + }, + ], + } + db_project = create_db_project(data) + cls._create_tasks(db_project) + cls.projects.append(db_project) + + def _run_api_v1_projects_id_export(self, pid, user, query_params=""): + with ForceLogin(user, self.client): + response = self.client.get('/api/v1/projects/{}/backup?{}'.format(pid, query_params), format="json") + + return response + + def _run_api_v1_projects_import(self, user, data): + with ForceLogin(user, self.client): + response = self.client.post('/api/v1/projects/backup', data=data, format="multipart") + + return response + + def _run_api_v1_projects_id(self, pid, user): + with ForceLogin(user, self.client): + response = self.client.get('/api/v1/projects/{}'.format(pid), format="json") + + return response.data + + def _run_api_v1_projects_id_export_import(self, user): + for project in self.projects: + if user: + if user is self.user and (project.assignee or not project.owner): + HTTP_200_OK = status.HTTP_403_FORBIDDEN + HTTP_202_ACCEPTED = status.HTTP_403_FORBIDDEN + HTTP_201_CREATED = status.HTTP_403_FORBIDDEN + else: + HTTP_200_OK = status.HTTP_200_OK + HTTP_202_ACCEPTED = status.HTTP_202_ACCEPTED + HTTP_201_CREATED = status.HTTP_201_CREATED + else: + HTTP_200_OK = status.HTTP_401_UNAUTHORIZED + HTTP_202_ACCEPTED = status.HTTP_401_UNAUTHORIZED + HTTP_201_CREATED = status.HTTP_401_UNAUTHORIZED + + pid = project.id + response = self._run_api_v1_projects_id_export(pid, user) + self.assertEqual(response.status_code, HTTP_202_ACCEPTED) + + response = self._run_api_v1_projects_id_export(pid, user) + self.assertEqual(response.status_code, HTTP_201_CREATED) + + response = self._run_api_v1_projects_id_export(pid, user, "action=download") + self.assertEqual(response.status_code, HTTP_200_OK) + + if user and user is not self.observer and user is not self.user and user is not self.annotator: + self.assertTrue(response.streaming) + content = io.BytesIO(b"".join(response.streaming_content)) + content.seek(0) + + uploaded_data = { + "project_file": content, + } + response = self._run_api_v1_projects_import(user, uploaded_data) + self.assertEqual(response.status_code, HTTP_202_ACCEPTED) + if user is not self.observer and user is not self.user and user is not self.annotator: + rq_id = response.data["rq_id"] + response = self._run_api_v1_projects_import(user, {"rq_id": rq_id}) + self.assertEqual(response.status_code, HTTP_201_CREATED) + original_project = self._run_api_v1_projects_id(pid, user) + imported_project = self._run_api_v1_projects_id(response.data["id"], user) + compare_objects( + self=self, + obj1=original_project, + obj2=imported_project, + ignore_keys=( + "data", + "id", + "url", + "owner", + "assignee", + "created_date", + "updated_date", + "training_project", + "project_id", + "tasks", + ), + ) + + def test_api_v1_projects_id_export_admin(self): + self._run_api_v1_projects_id_export_import(self.admin) + + def test_api_v1_projects_id_export_user(self): + self._run_api_v1_projects_id_export_import(self.user) + + def test_api_v1_projects_id_export_observer(self): + self._run_api_v1_projects_id_export_import(self.observer) + + def test_api_v1_projects_id_export_no_auth(self): + self._run_api_v1_projects_id_export_import(None) class ProjectExportAPITestCase(APITestCase): def setUp(self): self.client = APIClient() @@ -1773,7 +2162,6 @@ class TaskDeleteAPITestCase(APITestCase): task_dir = task.get_task_dirname() self.assertFalse(os.path.exists(task_dir)) - class TaskUpdateAPITestCase(APITestCase): def setUp(self): @@ -2305,8 +2693,6 @@ class TaskCreateAPITestCase(APITestCase): } self._check_api_v1_tasks(None, data) - - class TaskImportExportAPITestCase(APITestCase): def setUp(self): @@ -2603,13 +2989,13 @@ class TaskImportExportAPITestCase(APITestCase): def _run_api_v1_tasks_id_export(self, tid, user, query_params=""): with ForceLogin(user, self.client): - response = self.client.get('/api/v1/tasks/{}?{}'.format(tid, query_params), format="json") + response = self.client.get('/api/v1/tasks/{}/backup?{}'.format(tid, query_params), format="json") return response def _run_api_v1_tasks_id_import(self, user, data): with ForceLogin(user, self.client): - response = self.client.post('/api/v1/tasks?action=import', data=data, format="multipart") + response = self.client.post('/api/v1/tasks/backup', data=data, format="multipart") return response @@ -2637,10 +3023,10 @@ class TaskImportExportAPITestCase(APITestCase): self._create_tasks() for task in self.tasks: tid = task["id"] - response = self._run_api_v1_tasks_id_export(tid, user, "action=export") + response = self._run_api_v1_tasks_id_export(tid, user) self.assertEqual(response.status_code, HTTP_202_ACCEPTED) - response = self._run_api_v1_tasks_id_export(tid, user, "action=export") + response = self._run_api_v1_tasks_id_export(tid, user) self.assertEqual(response.status_code, HTTP_201_CREATED) response = self._run_api_v1_tasks_id_export(tid, user, "action=download") diff --git a/cvat/apps/engine/views.py b/cvat/apps/engine/views.py index 61075df0..424a0051 100644 --- a/cvat/apps/engine/views.py +++ b/cvat/apps/engine/views.py @@ -9,7 +9,6 @@ import os.path as osp import pytz import shutil import traceback -import uuid from datetime import datetime from distutils.util import strtobool from tempfile import mkstemp, NamedTemporaryFile @@ -54,17 +53,36 @@ from cvat.apps.engine.models import ( ) from cvat.apps.engine.models import CloudStorage as CloudStorageModel from cvat.apps.engine.serializers import ( - AboutSerializer, AnnotationFileSerializer, BasicUserSerializer, - DataMetaSerializer, DataSerializer, ExceptionSerializer, - FileInfoSerializer, JobSerializer, LabeledDataSerializer, - LogEventSerializer, ProjectSerializer, ProjectSearchSerializer, - RqStatusSerializer, TaskSerializer, UserSerializer, PluginsSerializer, ReviewSerializer, - CombinedReviewSerializer, IssueSerializer, CombinedIssueSerializer, CommentSerializer, - CloudStorageSerializer, BaseCloudStorageSerializer, TaskFileSerializer, DatasetFileSerializer) + AboutSerializer, + AnnotationFileSerializer, + BaseCloudStorageSerializer, + BasicUserSerializer, + CloudStorageSerializer, + CombinedIssueSerializer, + CombinedReviewSerializer, + CommentSerializer, + DataMetaSerializer, + DataSerializer, + DatasetFileSerializer, + ExceptionSerializer, + FileInfoSerializer, + IssueSerializer, + JobSerializer, + LabeledDataSerializer, + LogEventSerializer, + PluginsSerializer, + ProjectSearchSerializer, + ProjectSerializer, + ReviewSerializer, + RqStatusSerializer, + TaskSerializer, + UserSerializer, + ) from utils.dataset_manifest import ImageManifestManager from cvat.apps.engine.utils import av_scan_paths -from cvat.apps.engine.backup import import_task +from cvat.apps.engine import backup from cvat.apps.engine.mixins import UploadMixin + from . import models, task from .log import clogger, slogger @@ -267,20 +285,20 @@ class ProjectViewSet(auth.ProjectGetQuerySetMixin, viewsets.ModelViewSet): return [perm() for perm in permissions] - def perform_create(self, serializer): - def validate_project_limit(owner): - admin_perm = auth.AdminRolePermission() - is_admin = admin_perm.has_permission(self.request, self) - if not is_admin and settings.RESTRICTIONS['project_limit'] is not None and \ - Project.objects.filter(owner=owner).count() >= settings.RESTRICTIONS['project_limit']: - raise serializers.ValidationError('The user has the maximum number of projects') + def _validate_project_limit(self, owner): + admin_perm = auth.AdminRolePermission() + is_admin = admin_perm.has_permission(self.request, self) + if not is_admin and settings.RESTRICTIONS['project_limit'] is not None and \ + Project.objects.filter(owner=owner).count() >= settings.RESTRICTIONS['project_limit']: + raise serializers.ValidationError('The user has the maximum number of projects') + def perform_create(self, serializer): owner = self.request.data.get('owner', None) if owner: - validate_project_limit(owner) + self._validate_project_limit(owner) serializer.save() else: - validate_project_limit(self.request.user) + self._validate_project_limit(self.request.user) serializer.save(owner=self.request.user) @swagger_auto_schema(method='get', operation_summary='Returns information of the tasks of the project with the selected id', @@ -420,6 +438,15 @@ class ProjectViewSet(auth.ProjectGetQuerySetMixin, viewsets.ModelViewSet): else: return Response("Format is not specified",status=status.HTTP_400_BAD_REQUEST) + @action(methods=['GET'], detail=True, url_path='backup') + def export(self, request, pk=None): + db_project = self.get_object() # force to call check_object_permissions + return backup.export(db_project, request) + + @action(detail=False, methods=['POST']) + def backup(self, request, pk=None): + self._validate_project_limit(owner=self.request.user) + return backup.import_project(request) @staticmethod def _get_rq_response(queue, job_id): queue = django_rq.get_queue(queue) @@ -515,130 +542,23 @@ class TaskViewSet(UploadMixin, auth.TaskGetQuerySetMixin, viewsets.ModelViewSet) Task.objects.filter(owner=owner).count() >= settings.RESTRICTIONS['task_limit']: raise serializers.ValidationError('The user has the maximum number of tasks') - def create(self, request): - action = self.request.query_params.get('action', None) - if action is None: - return super().create(request) - elif action == 'import': - self._validate_task_limit(owner=self.request.user) - if 'rq_id' in request.data: - rq_id = request.data['rq_id'] - else: - rq_id = "{}@/api/v1/tasks/{}/import".format(request.user, uuid.uuid4()) - - queue = django_rq.get_queue("default") - rq_job = queue.fetch_job(rq_id) - - if not rq_job: - serializer = TaskFileSerializer(data=request.data) - serializer.is_valid(raise_exception=True) - task_file = serializer.validated_data['task_file'] - fd, filename = mkstemp(prefix='cvat_') - with open(filename, 'wb+') as f: - for chunk in task_file.chunks(): - f.write(chunk) - rq_job = queue.enqueue_call( - func=import_task, - args=(filename, request.user.id), - job_id=rq_id, - meta={ - 'tmp_file': filename, - 'tmp_file_descriptor': fd, - }, - ) - - else: - if rq_job.is_finished: - task_id = rq_job.return_value - os.close(rq_job.meta['tmp_file_descriptor']) - os.remove(rq_job.meta['tmp_file']) - rq_job.delete() - return Response({'id': task_id}, status=status.HTTP_201_CREATED) - elif rq_job.is_failed: - os.close(rq_job.meta['tmp_file_descriptor']) - os.remove(rq_job.meta['tmp_file']) - exc_info = str(rq_job.exc_info) - rq_job.delete() - - # RQ adds a prefix with exception class name - import_error_prefix = '{}.{}'.format( - CvatImportError.__module__, CvatImportError.__name__) - if exc_info.startswith(import_error_prefix): - exc_info = exc_info.replace(import_error_prefix + ': ', '') - return Response(data=exc_info, - status=status.HTTP_400_BAD_REQUEST) - else: - return Response(data=exc_info, - status=status.HTTP_500_INTERNAL_SERVER_ERROR) - - return Response({'rq_id': rq_id}, status=status.HTTP_202_ACCEPTED) - else: - raise serializers.ValidationError( - "Unexpected action specified for the request") + @action(detail=False, methods=['POST']) + def backup(self, request, pk=None): + self._validate_task_limit(owner=self.request.user) + return backup.import_task(request) - def retrieve(self, request, pk=None): + @action(methods=['GET'], detail=True, url_path='backup') + def export(self, request, pk=None): db_task = self.get_object() # force to call check_object_permissions - action = self.request.query_params.get('action', None) - if action is None: - return super().retrieve(request, pk) - elif action in ('export', 'download'): - queue = django_rq.get_queue("default") - rq_id = "/api/v1/tasks/{}/export".format(pk) - - rq_job = queue.fetch_job(rq_id) - if rq_job: - last_task_update_time = timezone.localtime(db_task.updated_date) - request_time = rq_job.meta.get('request_time', None) - if request_time is None or request_time < last_task_update_time: - rq_job.cancel() - rq_job.delete() - else: - if rq_job.is_finished: - file_path = rq_job.return_value - if action == "download" and osp.exists(file_path): - rq_job.delete() - - timestamp = datetime.strftime(last_task_update_time, - "%Y_%m_%d_%H_%M_%S") - filename = "task_{}_backup_{}{}".format( - db_task.name, timestamp, - osp.splitext(file_path)[1]) - return sendfile(request, file_path, attachment=True, - attachment_filename=filename.lower()) - else: - if osp.exists(file_path): - return Response(status=status.HTTP_201_CREATED) - elif rq_job.is_failed: - exc_info = str(rq_job.exc_info) - rq_job.delete() - return Response(exc_info, - status=status.HTTP_500_INTERNAL_SERVER_ERROR) - else: - return Response(status=status.HTTP_202_ACCEPTED) - - ttl = dm.views.TASK_CACHE_TTL.total_seconds() - queue.enqueue_call( - func=dm.views.backup_task, - args=(pk, 'task_dump.zip'), - job_id=rq_id, - meta={ 'request_time': timezone.localtime() }, - result_ttl=ttl, failure_ttl=ttl) - return Response(status=status.HTTP_202_ACCEPTED) - - else: - raise serializers.ValidationError( - "Unexpected action specified for the request") + return backup.export(db_task, request) def perform_update(self, serializer): instance = serializer.instance - project_id = instance.project_id updated_instance = serializer.save() - if project_id != updated_instance.project_id: - if project_id is not None: - Project.objects.get(id=project_id).save() - if updated_instance.project_id is not None: - Project.objects.get(id=updated_instance.project_id).save() - + if instance.project: + instance.project.save() + if updated_instance.project: + updated_instance.project.save() def perform_create(self, serializer): owner = self.request.data.get('owner', None) diff --git a/tests/cypress/integration/actions_tasks2/case_97_export_import_task.js b/tests/cypress/integration/actions_tasks2/case_97_export_import_task.js index 49db8828..382d5cf9 100644 --- a/tests/cypress/integration/actions_tasks2/case_97_export_import_task.js +++ b/tests/cypress/integration/actions_tasks2/case_97_export_import_task.js @@ -83,7 +83,7 @@ context('Export, import an annotation task.', { browser: '!firefox' }, () => { cy.get('.ant-dropdown') .not('.ant-dropdown-hidden') .within(() => { - cy.contains('[role="menuitem"]', new RegExp('^Export task$')).click().trigger('mouseout'); + cy.contains('[role="menuitem"]', new RegExp('^Backup Task$')).click().trigger('mouseout'); }); cy.getDownloadFileName().then((file) => { taskBackupArchiveFullName = file; @@ -93,7 +93,7 @@ context('Export, import an annotation task.', { browser: '!firefox' }, () => { }); it('Import the task. Check id, labels, shape.', () => { - cy.intercept('POST', '/api/v1/tasks?action=import').as('importTask'); + cy.intercept('POST', '/api/v1/tasks/backup').as('importTask'); cy.get('.cvat-import-task').click().find('input[type=file]').attachFile(taskBackupArchiveFullName); cy.wait('@importTask', { timeout: 5000 }).its('response.statusCode').should('equal', 202); cy.wait('@importTask').its('response.statusCode').should('equal', 201); diff --git a/utils/cli/core/core.py b/utils/cli/core/core.py index 70dc0589..ebfc36e3 100644 --- a/utils/cli/core/core.py +++ b/utils/cli/core/core.py @@ -221,8 +221,7 @@ class CLI(): def tasks_export(self, task_id, filename, export_verification_period=3, **kwargs): """ Export and download a whole task """ - url = self.api.tasks_id(task_id) - export_url = url + '?action=export' + export_url = self.api.tasks_id(task_id) + '/backup' while True: response = self.session.get(export_url) @@ -232,7 +231,7 @@ class CLI(): break sleep(export_verification_period) - response = self.session.get(url + '?action=download') + response = self.session.get(export_url + '?action=download') response.raise_for_status() with open(filename, 'wb') as fp: @@ -243,7 +242,7 @@ class CLI(): def tasks_import(self, filename, import_verification_period=3, **kwargs): """ Import a task""" - url = self.api.tasks + '?action=import' + url = self.api.tasks + '/backup' with open(filename, 'rb') as input_file: response = self.session.post( url,