Re-Identification application (#299)
parent
b4e6f22f58
commit
e225c38092
@ -0,0 +1,16 @@
|
|||||||
|
exports.settings = {bullet: '*', paddedTable: false}
|
||||||
|
|
||||||
|
exports.plugins = [
|
||||||
|
require('remark-preset-lint-recommended'),
|
||||||
|
require('remark-preset-lint-consistent'),
|
||||||
|
require('remark-validate-links'),
|
||||||
|
[require("remark-lint-no-dead-urls"), { skipOffline: true }],
|
||||||
|
[require("remark-lint-maximum-line-length"), 120],
|
||||||
|
[require("remark-lint-maximum-heading-length"), 120],
|
||||||
|
[require("remark-lint-list-item-indent"), "tab-size"],
|
||||||
|
[require("remark-lint-list-item-spacing"), false],
|
||||||
|
[require("remark-lint-strong-marker"), "*"],
|
||||||
|
[require("remark-lint-emphasis-marker"), "_"],
|
||||||
|
[require("remark-lint-unordered-list-marker-style"), "-"],
|
||||||
|
[require("remark-lint-ordered-list-marker-style"), "."],
|
||||||
|
]
|
||||||
@ -0,0 +1,22 @@
|
|||||||
|
# Re-Identification Application
|
||||||
|
|
||||||
|
## About the application
|
||||||
|
|
||||||
|
The ReID application uses deep learning model to perform an automatic bbox merging between neighbor frames.
|
||||||
|
You can use "Merge" and "Split" functionality to edit automatically generated annotation.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
This application will be installed automatically with the [OpenVINO](https://github.com/opencv/cvat/blob/develop/components/openvino/README.md) component.
|
||||||
|
|
||||||
|
## Running
|
||||||
|
|
||||||
|
For starting the ReID merge process:
|
||||||
|
|
||||||
|
- Open an annotation job
|
||||||
|
- Open the menu
|
||||||
|
- Click the "Run ReID Merge" button
|
||||||
|
- Click the "Submit" button. Also here you can experiment with values of model threshold or maximum distance.
|
||||||
|
- Model threshold is maximum cosine distance between objects embeddings.
|
||||||
|
- Maximum distance defines a maximum radius that an object can diverge between neightbor frames.
|
||||||
|
- The process will be run. You can cancel it in the menu.
|
||||||
@ -0,0 +1,9 @@
|
|||||||
|
# Copyright (C) 2018 Intel Corporation
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
from cvat.settings.base import JS_3RDPARTY
|
||||||
|
|
||||||
|
default_app_config = 'cvat.apps.reid.apps.ReidConfig'
|
||||||
|
|
||||||
|
JS_3RDPARTY['engine'] = JS_3RDPARTY.get('engine', []) + ['reid/js/enginePlugin.js']
|
||||||
@ -0,0 +1,8 @@
|
|||||||
|
# Copyright (C) 2018 Intel Corporation
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
from django.apps import AppConfig
|
||||||
|
|
||||||
|
class ReidConfig(AppConfig):
|
||||||
|
name = 'cvat.apps.reid'
|
||||||
@ -0,0 +1,226 @@
|
|||||||
|
# Copyright (C) 2018 Intel Corporation
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
import os
|
||||||
|
import rq
|
||||||
|
import cv2
|
||||||
|
import math
|
||||||
|
import numpy
|
||||||
|
import fnmatch
|
||||||
|
|
||||||
|
from openvino.inference_engine import IENetwork, IEPlugin
|
||||||
|
from scipy.optimize import linear_sum_assignment
|
||||||
|
from scipy.spatial.distance import euclidean, cosine
|
||||||
|
|
||||||
|
from cvat.apps.engine.models import Job
|
||||||
|
|
||||||
|
|
||||||
|
class ReID:
|
||||||
|
__threshold = None
|
||||||
|
__max_distance = None
|
||||||
|
__frame_urls = None
|
||||||
|
__frame_boxes = None
|
||||||
|
__stop_frame = None
|
||||||
|
__plugin = None
|
||||||
|
__executable_network = None
|
||||||
|
__input_blob_name = None
|
||||||
|
__output_blob_name = None
|
||||||
|
__input_height = None
|
||||||
|
__input_width = None
|
||||||
|
|
||||||
|
|
||||||
|
def __init__(self, jid, data):
|
||||||
|
self.__threshold = data["threshold"]
|
||||||
|
self.__max_distance = data["maxDistance"]
|
||||||
|
self.__frame_urls = {}
|
||||||
|
self.__frame_boxes = {}
|
||||||
|
|
||||||
|
db_job = Job.objects.select_related('segment__task').get(pk = jid)
|
||||||
|
db_segment = db_job.segment
|
||||||
|
db_task = db_segment.task
|
||||||
|
|
||||||
|
self.__stop_frame = db_segment.stop_frame
|
||||||
|
|
||||||
|
for root, _, filenames in os.walk(db_task.get_data_dirname()):
|
||||||
|
for filename in fnmatch.filter(filenames, '*.jpg'):
|
||||||
|
frame = int(os.path.splitext(filename)[0])
|
||||||
|
if frame >= db_segment.start_frame and frame <= db_segment.stop_frame:
|
||||||
|
self.__frame_urls[frame] = os.path.join(root, filename)
|
||||||
|
|
||||||
|
for frame in self.__frame_urls:
|
||||||
|
self.__frame_boxes[frame] = [box for box in data["boxes"] if box["frame"] == frame]
|
||||||
|
|
||||||
|
IE_PLUGINS_PATH = os.getenv('IE_PLUGINS_PATH', None)
|
||||||
|
REID_MODEL_DIR = os.getenv('REID_MODEL_DIR', None)
|
||||||
|
|
||||||
|
if not IE_PLUGINS_PATH:
|
||||||
|
raise Exception("Environment variable 'IE_PLUGINS_PATH' isn't defined")
|
||||||
|
if not REID_MODEL_DIR:
|
||||||
|
raise Exception("Environment variable 'REID_MODEL_DIR' isn't defined")
|
||||||
|
|
||||||
|
REID_XML = os.path.join(REID_MODEL_DIR, "reid.xml")
|
||||||
|
REID_BIN = os.path.join(REID_MODEL_DIR, "reid.bin")
|
||||||
|
|
||||||
|
self.__plugin = IEPlugin(device="CPU", plugin_dirs=[IE_PLUGINS_PATH])
|
||||||
|
network = IENetwork.from_ir(model=REID_XML, weights=REID_BIN)
|
||||||
|
self.__input_blob_name = next(iter(network.inputs))
|
||||||
|
self.__output_blob_name = next(iter(network.outputs))
|
||||||
|
self.__input_height, self.__input_width = network.inputs[self.__input_blob_name].shape[-2:]
|
||||||
|
self.__executable_network = self.__plugin.load(network=network)
|
||||||
|
del network
|
||||||
|
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
if self.__executable_network:
|
||||||
|
del self.__executable_network
|
||||||
|
self.__executable_network = None
|
||||||
|
|
||||||
|
if self.__plugin:
|
||||||
|
del self.__plugin
|
||||||
|
self.__plugin = None
|
||||||
|
|
||||||
|
|
||||||
|
def __boxes_are_compatible(self, cur_box, next_box):
|
||||||
|
cur_c_x = (cur_box["xtl"] + cur_box["xbr"]) / 2
|
||||||
|
cur_c_y = (cur_box["ytl"] + cur_box["ybr"]) / 2
|
||||||
|
next_c_x = (next_box["xtl"] + next_box["xbr"]) / 2
|
||||||
|
next_c_y = (next_box["ytl"] + next_box["ybr"]) / 2
|
||||||
|
compatible_distance = euclidean([cur_c_x, cur_c_y], [next_c_x, next_c_y]) <= self.__max_distance
|
||||||
|
compatible_label = cur_box["label_id"] == next_box["label_id"]
|
||||||
|
return compatible_distance and compatible_label and "path_id" not in next_box
|
||||||
|
|
||||||
|
|
||||||
|
def __compute_difference(self, image_1, image_2):
|
||||||
|
image_1 = cv2.resize(image_1, (self.__input_width, self.__input_height)).transpose((2,0,1))
|
||||||
|
image_2 = cv2.resize(image_2, (self.__input_width, self.__input_height)).transpose((2,0,1))
|
||||||
|
|
||||||
|
input_1 = {
|
||||||
|
self.__input_blob_name: image_1[numpy.newaxis, ...]
|
||||||
|
}
|
||||||
|
|
||||||
|
input_2 = {
|
||||||
|
self.__input_blob_name: image_2[numpy.newaxis, ...]
|
||||||
|
}
|
||||||
|
|
||||||
|
embedding_1 = self.__executable_network.infer(inputs = input_1)[self.__output_blob_name]
|
||||||
|
embedding_2 = self.__executable_network.infer(inputs = input_2)[self.__output_blob_name]
|
||||||
|
|
||||||
|
embedding_1 = embedding_1.reshape(embedding_1.size)
|
||||||
|
embedding_2 = embedding_2.reshape(embedding_2.size)
|
||||||
|
|
||||||
|
return cosine(embedding_1, embedding_2)
|
||||||
|
|
||||||
|
|
||||||
|
def __compute_difference_matrix(self, cur_boxes, next_boxes, cur_image, next_image):
|
||||||
|
def _int(number, upper):
|
||||||
|
return math.floor(numpy.clip(number, 0, upper - 1))
|
||||||
|
|
||||||
|
default_mat_value = 1000.0
|
||||||
|
|
||||||
|
matrix = numpy.full([len(cur_boxes), len(next_boxes)], default_mat_value, dtype=float)
|
||||||
|
for row, cur_box in enumerate(cur_boxes):
|
||||||
|
cur_width = cur_image.shape[1]
|
||||||
|
cur_height = cur_image.shape[0]
|
||||||
|
cur_xtl, cur_xbr, cur_ytl, cur_ybr = (
|
||||||
|
_int(cur_box["xtl"], cur_width), _int(cur_box["xbr"], cur_width),
|
||||||
|
_int(cur_box["ytl"], cur_height), _int(cur_box["ybr"], cur_height)
|
||||||
|
)
|
||||||
|
|
||||||
|
for col, next_box in enumerate(next_boxes):
|
||||||
|
next_box = next_boxes[col]
|
||||||
|
next_width = next_image.shape[1]
|
||||||
|
next_height = next_image.shape[0]
|
||||||
|
next_xtl, next_xbr, next_ytl, next_ybr = (
|
||||||
|
_int(next_box["xtl"], next_width), _int(next_box["xbr"], next_width),
|
||||||
|
_int(next_box["ytl"], next_height), _int(next_box["ybr"], next_height)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self.__boxes_are_compatible(cur_box, next_box):
|
||||||
|
continue
|
||||||
|
|
||||||
|
crop_1 = cur_image[cur_ytl:cur_ybr, cur_xtl:cur_xbr]
|
||||||
|
crop_2 = next_image[next_ytl:next_ybr, next_xtl:next_xbr]
|
||||||
|
matrix[row][col] = self.__compute_difference(crop_1, crop_2)
|
||||||
|
|
||||||
|
return matrix
|
||||||
|
|
||||||
|
|
||||||
|
def __apply_matching(self):
|
||||||
|
frames = sorted(list(self.__frame_boxes.keys()))
|
||||||
|
job = rq.get_current_job()
|
||||||
|
box_paths = {}
|
||||||
|
|
||||||
|
for idx, (cur_frame, next_frame) in enumerate(list(zip(frames[:-1], frames[1:]))):
|
||||||
|
job.refresh()
|
||||||
|
if "cancel" in job.meta:
|
||||||
|
return None
|
||||||
|
|
||||||
|
job.meta["progress"] = idx * 100.0 / len(frames)
|
||||||
|
job.save_meta()
|
||||||
|
|
||||||
|
cur_boxes = self.__frame_boxes[cur_frame]
|
||||||
|
next_boxes = self.__frame_boxes[next_frame]
|
||||||
|
|
||||||
|
for box in cur_boxes:
|
||||||
|
if "path_id" not in box:
|
||||||
|
path_id = len(box_paths)
|
||||||
|
box_paths[path_id] = [box]
|
||||||
|
box["path_id"] = path_id
|
||||||
|
|
||||||
|
if not (len(cur_boxes) and len(next_boxes)):
|
||||||
|
continue
|
||||||
|
|
||||||
|
cur_image = cv2.imread(self.__frame_urls[cur_frame], cv2.IMREAD_COLOR)
|
||||||
|
next_image = cv2.imread(self.__frame_urls[next_frame], cv2.IMREAD_COLOR)
|
||||||
|
difference_matrix = self.__compute_difference_matrix(cur_boxes, next_boxes, cur_image, next_image)
|
||||||
|
cur_idxs, next_idxs = linear_sum_assignment(difference_matrix)
|
||||||
|
for idx, cur_idx in enumerate(cur_idxs):
|
||||||
|
if (difference_matrix[cur_idx][next_idxs[idx]]) <= self.__threshold:
|
||||||
|
cur_box = cur_boxes[cur_idx]
|
||||||
|
next_box = next_boxes[next_idxs[idx]]
|
||||||
|
next_box["path_id"] = cur_box["path_id"]
|
||||||
|
box_paths[cur_box["path_id"]].append(next_box)
|
||||||
|
|
||||||
|
for box in self.__frame_boxes[frames[-1]]:
|
||||||
|
if "path_id" not in box:
|
||||||
|
path_id = len(box_paths)
|
||||||
|
box["path_id"] = path_id
|
||||||
|
box_paths[path_id] = [box]
|
||||||
|
|
||||||
|
return box_paths
|
||||||
|
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
box_paths = self.__apply_matching()
|
||||||
|
output = []
|
||||||
|
|
||||||
|
# ReID process has been canceled
|
||||||
|
if box_paths is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
for path_id in box_paths:
|
||||||
|
output.append({
|
||||||
|
"label_id": box_paths[path_id][0]["label_id"],
|
||||||
|
"group_id": 0,
|
||||||
|
"attributes": [],
|
||||||
|
"frame": box_paths[path_id][0]["frame"],
|
||||||
|
"shapes": box_paths[path_id]
|
||||||
|
})
|
||||||
|
|
||||||
|
for box in output[-1]["shapes"]:
|
||||||
|
del box["id"]
|
||||||
|
del box["path_id"]
|
||||||
|
del box["group_id"]
|
||||||
|
del box["label_id"]
|
||||||
|
box["outside"] = False
|
||||||
|
box["attributes"] = []
|
||||||
|
|
||||||
|
for path in output:
|
||||||
|
if path["shapes"][-1]["frame"] != self.__stop_frame:
|
||||||
|
copy = path["shapes"][-1].copy()
|
||||||
|
copy["outside"] = True
|
||||||
|
copy["frame"] += 1
|
||||||
|
path["shapes"].append(copy)
|
||||||
|
|
||||||
|
return output
|
||||||
@ -0,0 +1,170 @@
|
|||||||
|
/*
|
||||||
|
* Copyright (C) 2018 Intel Corporation
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: MIT
|
||||||
|
*/
|
||||||
|
|
||||||
|
/* global showMessage userConfirm */
|
||||||
|
|
||||||
|
|
||||||
|
document.addEventListener('DOMContentLoaded', () => {
|
||||||
|
function run(overlay, cancelButton, thresholdInput, distanceInput) {
|
||||||
|
const collection = window.cvat.data.get();
|
||||||
|
const data = {
|
||||||
|
threshold: +thresholdInput.prop('value'),
|
||||||
|
maxDistance: +distanceInput.prop('value'),
|
||||||
|
boxes: collection.boxes,
|
||||||
|
};
|
||||||
|
|
||||||
|
overlay.removeClass('hidden');
|
||||||
|
cancelButton.prop('disabled', true);
|
||||||
|
$.ajax({
|
||||||
|
url: `reid/start/job/${window.cvat.job.id}`,
|
||||||
|
type: 'POST',
|
||||||
|
data: JSON.stringify(data),
|
||||||
|
contentType: 'application/json',
|
||||||
|
success: () => {
|
||||||
|
function checkCallback() {
|
||||||
|
$.ajax({
|
||||||
|
url: `/reid/check/${window.cvat.job.id}`,
|
||||||
|
type: 'GET',
|
||||||
|
success: (jobData) => {
|
||||||
|
if (jobData.progress) {
|
||||||
|
cancelButton.text(`Cancel ReID Merge (${jobData.progress.toString().slice(0, 4)}%)`);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (['queued', 'started'].includes(jobData.status)) {
|
||||||
|
setTimeout(checkCallback, 1000);
|
||||||
|
} else {
|
||||||
|
overlay.addClass('hidden');
|
||||||
|
|
||||||
|
if (jobData.status === 'finished') {
|
||||||
|
if (jobData.result) {
|
||||||
|
collection.boxes = [];
|
||||||
|
collection.box_paths = collection.box_paths
|
||||||
|
.concat(JSON.parse(jobData.result));
|
||||||
|
window.cvat.data.clear();
|
||||||
|
window.cvat.data.set(collection);
|
||||||
|
showMessage('ReID merge has done.');
|
||||||
|
} else {
|
||||||
|
showMessage('ReID merge been canceled.');
|
||||||
|
}
|
||||||
|
} else if (jobData.status === 'failed') {
|
||||||
|
const message = `ReID merge has fallen. Error: '${jobData.stderr}'`;
|
||||||
|
showMessage(message);
|
||||||
|
} else {
|
||||||
|
let message = `Check request returned "${jobData.status}" status.`;
|
||||||
|
if (jobData.stderr) {
|
||||||
|
message += ` Error: ${jobData.stderr}`;
|
||||||
|
}
|
||||||
|
showMessage(message);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
error: (errorData) => {
|
||||||
|
overlay.addClass('hidden');
|
||||||
|
const message = `Can not check ReID merge. Code: ${errorData.status}. Message: ${errorData.responseText || errorData.statusText}`;
|
||||||
|
showMessage(message);
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
setTimeout(checkCallback, 1000);
|
||||||
|
},
|
||||||
|
error: (errorData) => {
|
||||||
|
overlay.addClass('hidden');
|
||||||
|
const message = `Can not start ReID merge. Code: ${errorData.status}. Message: ${errorData.responseText || errorData.statusText}`;
|
||||||
|
showMessage(message);
|
||||||
|
},
|
||||||
|
complete: () => {
|
||||||
|
cancelButton.prop('disabled', false);
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
function cancel(overlay, cancelButton) {
|
||||||
|
cancelButton.prop('disabled', true);
|
||||||
|
$.ajax({
|
||||||
|
url: `/reid/cancel/${window.cvat.job.id}`,
|
||||||
|
type: 'GET',
|
||||||
|
success: () => {
|
||||||
|
overlay.addClass('hidden');
|
||||||
|
cancelButton.text('Cancel ReID Merge (0%)');
|
||||||
|
},
|
||||||
|
error: (errorData) => {
|
||||||
|
const message = `Can not cancel ReID process. Code: ${errorData.status}. Message: ${errorData.responseText || errorData.statusText}`;
|
||||||
|
showMessage(message);
|
||||||
|
},
|
||||||
|
complete: () => {
|
||||||
|
cancelButton.prop('disabled', false);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
const buttonsUI = $('#engineMenuButtons');
|
||||||
|
const reidWindowId = 'reidSubmitWindow';
|
||||||
|
const reidThresholdValueId = 'reidThresholdValue';
|
||||||
|
const reidDistanceValueId = 'reidDistanceValue';
|
||||||
|
const reidCancelMergeId = 'reidCancelMerge';
|
||||||
|
const reidSubmitMergeId = 'reidSubmitMerge';
|
||||||
|
const reidCancelButtonId = 'reidCancelReID';
|
||||||
|
const reidOverlay = 'reidOverlay';
|
||||||
|
|
||||||
|
$('<button> Run ReID Merge </button>').on('click', () => {
|
||||||
|
$('#annotationMenu').addClass('hidden');
|
||||||
|
$(`#${reidWindowId}`).removeClass('hidden');
|
||||||
|
}).addClass('menuButton semiBold h2').prependTo(buttonsUI);
|
||||||
|
|
||||||
|
$(`
|
||||||
|
<div class="modal hidden" id="${reidWindowId}">
|
||||||
|
<div class="modal-content" style="width: 300px; height: 170px;">
|
||||||
|
<table>
|
||||||
|
<tr>
|
||||||
|
<td> <label class="regular h2"> Threshold: </label> </td>
|
||||||
|
<td> <input id="${reidThresholdValueId}" class="regular h1" type="number"` +
|
||||||
|
`title="Maximum cosine distance between embeddings of objects" min="0.05" max="0.95" value="0.5" step="0.05"> </td>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td> <label class="regular h2"> Max Pixel Distance </label> </td>
|
||||||
|
<td> <input id="${reidDistanceValueId}" class="regular h1" type="number"` +
|
||||||
|
`title="Maximum radius that an object can diverge between neighbor frames" min="10" max="1000" value="50" step="10"> </td>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td colspan="2"> <label class="regular h2" style="color: red;"> All boxes will be translated to box paths. Continue? </label> </td>
|
||||||
|
</tr>
|
||||||
|
</table>
|
||||||
|
<center style="margin-top: 10px;">
|
||||||
|
<button id="${reidCancelMergeId}" class="regular h2"> Cancel </button>
|
||||||
|
<button id="${reidSubmitMergeId}" class="regular h2"> Merge </button>
|
||||||
|
</center>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
`).appendTo('body');
|
||||||
|
|
||||||
|
$(`
|
||||||
|
<div class="modal hidden force-modal" id="${reidOverlay}">
|
||||||
|
<div class="modal-content" style="width: 300px; height: 70px;">
|
||||||
|
<center> <label class="regular h2"> ReID is processing the data </label></center>
|
||||||
|
<center style="margin-top: 5px;">
|
||||||
|
<button id="${reidCancelButtonId}" class="regular h2" style="width: 250px;"> Cancel ReID Merge (0%) </button>
|
||||||
|
</center>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
`).appendTo('body');
|
||||||
|
|
||||||
|
$(`#${reidCancelMergeId}`).on('click', () => {
|
||||||
|
$(`#${reidWindowId}`).addClass('hidden');
|
||||||
|
});
|
||||||
|
|
||||||
|
$(`#${reidCancelButtonId}`).on('click', () => {
|
||||||
|
userConfirm('ReID process will be canceld. Are you sure?', () => {
|
||||||
|
cancel($(`#${reidOverlay}`), $(`#${reidCancelButtonId}`));
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
$(`#${reidSubmitMergeId}`).on('click', () => {
|
||||||
|
$(`#${reidWindowId}`).addClass('hidden');
|
||||||
|
run($(`#${reidOverlay}`), $(`#${reidCancelButtonId}`),
|
||||||
|
$(`#${reidThresholdValueId}`), $(`#${reidDistanceValueId}`));
|
||||||
|
});
|
||||||
|
});
|
||||||
@ -0,0 +1,12 @@
|
|||||||
|
# Copyright (C) 2018 Intel Corporation
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
from django.urls import path
|
||||||
|
from . import views
|
||||||
|
|
||||||
|
urlpatterns = [
|
||||||
|
path('start/job/<int:jid>', views.start),
|
||||||
|
path('cancel/<int:jid>', views.cancel),
|
||||||
|
path('check/<int:jid>', views.check),
|
||||||
|
]
|
||||||
@ -0,0 +1,96 @@
|
|||||||
|
# Copyright (C) 2018 Intel Corporation
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
from django.http import HttpResponse, HttpResponseBadRequest, JsonResponse
|
||||||
|
from cvat.apps.authentication.decorators import login_required
|
||||||
|
from rules.contrib.views import permission_required, objectgetter
|
||||||
|
|
||||||
|
from cvat.apps.engine.models import Job
|
||||||
|
from cvat.apps.reid.reid import ReID
|
||||||
|
|
||||||
|
import django_rq
|
||||||
|
import json
|
||||||
|
import rq
|
||||||
|
|
||||||
|
|
||||||
|
def _create_thread(jid, data):
|
||||||
|
job = rq.get_current_job()
|
||||||
|
reid_obj = ReID(jid, data)
|
||||||
|
job.meta["result"] = json.dumps(reid_obj.run())
|
||||||
|
job.save_meta()
|
||||||
|
|
||||||
|
|
||||||
|
@login_required
|
||||||
|
@permission_required(perm=["engine.job.change"],
|
||||||
|
fn=objectgetter(Job, 'jid'), raise_exception=True)
|
||||||
|
def start(request, jid):
|
||||||
|
try:
|
||||||
|
data = json.loads(request.body.decode('utf-8'))
|
||||||
|
queue = django_rq.get_queue("low")
|
||||||
|
job_id = "reid.create.{}".format(jid)
|
||||||
|
job = queue.fetch_job(job_id)
|
||||||
|
if job is not None and (job.is_started or job.is_queued):
|
||||||
|
raise Exception('ReID process has been already started')
|
||||||
|
queue.enqueue_call(func=_create_thread, args=(jid, data), job_id=job_id, timeout=7200)
|
||||||
|
job = queue.fetch_job(job_id)
|
||||||
|
job.meta = {}
|
||||||
|
job.save_meta()
|
||||||
|
except Exception as e:
|
||||||
|
return HttpResponseBadRequest(str(e))
|
||||||
|
|
||||||
|
return HttpResponse()
|
||||||
|
|
||||||
|
|
||||||
|
@login_required
|
||||||
|
@permission_required(perm=["engine.job.change"],
|
||||||
|
fn=objectgetter(Job, 'jid'), raise_exception=True)
|
||||||
|
def check(request, jid):
|
||||||
|
try:
|
||||||
|
queue = django_rq.get_queue("low")
|
||||||
|
rq_id = "reid.create.{}".format(jid)
|
||||||
|
job = queue.fetch_job(rq_id)
|
||||||
|
if job is not None and "cancel" in job.meta:
|
||||||
|
return JsonResponse({"status": "finished"})
|
||||||
|
data = {}
|
||||||
|
if job is None:
|
||||||
|
data["status"] = "unknown"
|
||||||
|
elif job.is_queued:
|
||||||
|
data["status"] = "queued"
|
||||||
|
elif job.is_started:
|
||||||
|
data["status"] = "started"
|
||||||
|
if "progress" in job.meta:
|
||||||
|
data["progress"] = job.meta["progress"]
|
||||||
|
elif job.is_finished:
|
||||||
|
data["status"] = "finished"
|
||||||
|
data["result"] = job.meta["result"]
|
||||||
|
job.delete()
|
||||||
|
else:
|
||||||
|
data["status"] = "failed"
|
||||||
|
data["stderr"] = job.exc_info
|
||||||
|
job.delete()
|
||||||
|
|
||||||
|
except Exception as ex:
|
||||||
|
data["stderr"] = str(ex)
|
||||||
|
data["status"] = "unknown"
|
||||||
|
|
||||||
|
return JsonResponse(data)
|
||||||
|
|
||||||
|
|
||||||
|
@login_required
|
||||||
|
@permission_required(perm=["engine.job.change"],
|
||||||
|
fn=objectgetter(Job, 'jid'), raise_exception=True)
|
||||||
|
def cancel(request, jid):
|
||||||
|
try:
|
||||||
|
queue = django_rq.get_queue("low")
|
||||||
|
rq_id = "reid.create.{}".format(jid)
|
||||||
|
job = queue.fetch_job(rq_id)
|
||||||
|
if job is None or job.is_finished or job.is_failed:
|
||||||
|
raise Exception("Task is not being annotated currently")
|
||||||
|
elif "cancel" not in job.meta:
|
||||||
|
job.meta["cancel"] = True
|
||||||
|
job.save_meta()
|
||||||
|
except Exception as e:
|
||||||
|
return HttpResponseBadRequest(str(e))
|
||||||
|
|
||||||
|
return HttpResponse()
|
||||||
Loading…
Reference in New Issue