Auto segmentation using Mask_RCNN (#767)
parent
310e208229
commit
d99140a0fb
@ -0,0 +1,38 @@
|
|||||||
|
## [Keras+Tensorflow Mask R-CNN Segmentation](https://github.com/matterport/Mask_RCNN)
|
||||||
|
|
||||||
|
### What is it?
|
||||||
|
- This application allows you automatically to segment many various objects on images.
|
||||||
|
- It's based on Feature Pyramid Network (FPN) and a ResNet101 backbone.
|
||||||
|
|
||||||
|
- It uses a pre-trained model on MS COCO dataset
|
||||||
|
- It supports next classes (use them in "labels" row):
|
||||||
|
```python
|
||||||
|
'BG', 'person', 'bicycle', 'car', 'motorcycle', 'airplane',
|
||||||
|
'bus', 'train', 'truck', 'boat', 'traffic light',
|
||||||
|
'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird',
|
||||||
|
'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear',
|
||||||
|
'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
|
||||||
|
'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
|
||||||
|
'kite', 'baseball bat', 'baseball glove', 'skateboard',
|
||||||
|
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
|
||||||
|
'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
|
||||||
|
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
|
||||||
|
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
|
||||||
|
'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
|
||||||
|
'keyboard', 'cell phone', 'microwave', 'oven', 'toaster',
|
||||||
|
'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors',
|
||||||
|
'teddy bear', 'hair drier', 'toothbrush'.
|
||||||
|
```
|
||||||
|
- Component adds "Run Auto Segmentation" button into dashboard.
|
||||||
|
|
||||||
|
### Build docker image
|
||||||
|
```bash
|
||||||
|
# From project root directory
|
||||||
|
docker-compose -f docker-compose.yml -f components/auto_segmentation/docker-compose.auto_segmentation.yml build
|
||||||
|
```
|
||||||
|
|
||||||
|
### Run docker container
|
||||||
|
```bash
|
||||||
|
# From project root directory
|
||||||
|
docker-compose -f docker-compose.yml -f components/auto_segmentation/docker-compose.auto_segmentation.yml up -d
|
||||||
|
```
|
||||||
@ -0,0 +1,13 @@
|
|||||||
|
#
|
||||||
|
# Copyright (C) 2018 Intel Corporation
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
#
|
||||||
|
version: "2.3"
|
||||||
|
|
||||||
|
services:
|
||||||
|
cvat:
|
||||||
|
build:
|
||||||
|
context: .
|
||||||
|
args:
|
||||||
|
AUTO_SEGMENTATION: "yes"
|
||||||
@ -0,0 +1,12 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
#
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
cd ${HOME} && \
|
||||||
|
git clone https://github.com/matterport/Mask_RCNN.git && \
|
||||||
|
wget https://github.com/matterport/Mask_RCNN/releases/download/v2.0/mask_rcnn_coco.h5 && \
|
||||||
|
mv mask_rcnn_coco.h5 Mask_RCNN/mask_rcnn_coco.h5
|
||||||
|
|
||||||
|
# TODO remove useless files
|
||||||
|
# tensorflow and Keras are installed globally
|
||||||
@ -0,0 +1,9 @@
|
|||||||
|
|
||||||
|
# Copyright (C) 2018 Intel Corporation
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
from cvat.settings.base import JS_3RDPARTY
|
||||||
|
|
||||||
|
JS_3RDPARTY['dashboard'] = JS_3RDPARTY.get('dashboard', []) + ['auto_segmentation/js/dashboardPlugin.js']
|
||||||
|
|
||||||
@ -0,0 +1,8 @@
|
|||||||
|
|
||||||
|
# Copyright (C) 2018 Intel Corporation
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
|
||||||
|
# Register your models here.
|
||||||
|
|
||||||
@ -0,0 +1,11 @@
|
|||||||
|
|
||||||
|
# Copyright (C) 2018 Intel Corporation
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
from django.apps import AppConfig
|
||||||
|
|
||||||
|
|
||||||
|
class AutoSegmentationConfig(AppConfig):
|
||||||
|
name = 'auto_segmentation'
|
||||||
|
|
||||||
@ -0,0 +1,5 @@
|
|||||||
|
|
||||||
|
# Copyright (C) 2018 Intel Corporation
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
@ -0,0 +1,8 @@
|
|||||||
|
|
||||||
|
# Copyright (C) 2018 Intel Corporation
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
|
||||||
|
# Create your models here.
|
||||||
|
|
||||||
@ -0,0 +1,112 @@
|
|||||||
|
/*
|
||||||
|
* Copyright (C) 2018 Intel Corporation
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: MIT
|
||||||
|
*/
|
||||||
|
|
||||||
|
/* global
|
||||||
|
userConfirm:false
|
||||||
|
showMessage:false
|
||||||
|
*/
|
||||||
|
|
||||||
|
window.addEventListener('dashboardReady', () => {
|
||||||
|
function checkProcess(tid, button) {
|
||||||
|
function checkCallback() {
|
||||||
|
$.get(`/tensorflow/segmentation/check/task/${tid}`).done((statusData) => {
|
||||||
|
if (['started', 'queued'].includes(statusData.status)) {
|
||||||
|
const progress = Math.round(statusData.progress) || '0';
|
||||||
|
button.text(`Cancel Auto Segmentation (${progress}%)`);
|
||||||
|
setTimeout(checkCallback, 5000);
|
||||||
|
} else {
|
||||||
|
button.text('Run Auto Segmentation');
|
||||||
|
button.removeClass('tfAnnotationProcess');
|
||||||
|
button.prop('disabled', false);
|
||||||
|
|
||||||
|
if (statusData.status === 'failed') {
|
||||||
|
const message = `Tensorflow Segmentation failed. Error: ${statusData.stderr}`;
|
||||||
|
showMessage(message);
|
||||||
|
} else if (statusData.status !== 'finished') {
|
||||||
|
const message = `Tensorflow segmentation check request returned status "${statusData.status}"`;
|
||||||
|
showMessage(message);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}).fail((errorData) => {
|
||||||
|
const message = `Can not sent tensorflow segmentation check request. Code: ${errorData.status}. `
|
||||||
|
+ `Message: ${errorData.responseText || errorData.statusText}`;
|
||||||
|
showMessage(message);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
setTimeout(checkCallback, 5000);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
function runProcess(tid, button) {
|
||||||
|
$.get(`/tensorflow/segmentation/create/task/${tid}`).done(() => {
|
||||||
|
showMessage('Process has started');
|
||||||
|
button.text('Cancel Auto Segmentation (0%)');
|
||||||
|
button.addClass('tfAnnotationProcess');
|
||||||
|
checkProcess(tid, button);
|
||||||
|
}).fail((errorData) => {
|
||||||
|
const message = `Can not run Auto Segmentation. Code: ${errorData.status}. `
|
||||||
|
+ `Message: ${errorData.responseText || errorData.statusText}`;
|
||||||
|
showMessage(message);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
function cancelProcess(tid, button) {
|
||||||
|
$.get(`/tensorflow/segmentation/cancel/task/${tid}`).done(() => {
|
||||||
|
button.prop('disabled', true);
|
||||||
|
}).fail((errorData) => {
|
||||||
|
const message = `Can not cancel Auto Segmentation. Code: ${errorData.status}. `
|
||||||
|
+ `Message: ${errorData.responseText || errorData.statusText}`;
|
||||||
|
showMessage(message);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
function setupDashboardItem(item, metaData) {
|
||||||
|
const tid = +item.attr('tid');
|
||||||
|
const button = $('<button> Run Auto Segmentation </button>');
|
||||||
|
|
||||||
|
button.on('click', () => {
|
||||||
|
if (button.hasClass('tfAnnotationProcess')) {
|
||||||
|
userConfirm('The process will be canceled. Continue?', () => {
|
||||||
|
cancelProcess(tid, button);
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
userConfirm('The current annotation will be lost. Are you sure?', () => {
|
||||||
|
runProcess(tid, button);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
button.addClass('dashboardTFAnnotationButton regular dashboardButtonUI');
|
||||||
|
button.appendTo(item.find('div.dashboardButtonsUI'));
|
||||||
|
|
||||||
|
if ((tid in metaData) && (metaData[tid].active)) {
|
||||||
|
button.text('Cancel Auto Segmentation');
|
||||||
|
button.addClass('tfAnnotationProcess');
|
||||||
|
checkProcess(tid, button);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const elements = $('.dashboardItem');
|
||||||
|
const tids = Array.from(elements, el => +el.getAttribute('tid'));
|
||||||
|
|
||||||
|
$.ajax({
|
||||||
|
type: 'POST',
|
||||||
|
url: '/tensorflow/segmentation/meta/get',
|
||||||
|
data: JSON.stringify(tids),
|
||||||
|
contentType: 'application/json; charset=utf-8',
|
||||||
|
}).done((metaData) => {
|
||||||
|
elements.each(function setupDashboardItemWrapper() {
|
||||||
|
setupDashboardItem($(this), metaData);
|
||||||
|
});
|
||||||
|
}).fail((errorData) => {
|
||||||
|
const message = `Can not get Auto Segmentation meta info. Code: ${errorData.status}. `
|
||||||
|
+ `Message: ${errorData.responseText || errorData.statusText}`;
|
||||||
|
showMessage(message);
|
||||||
|
});
|
||||||
|
});
|
||||||
@ -0,0 +1,8 @@
|
|||||||
|
|
||||||
|
# Copyright (C) 2018 Intel Corporation
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
|
||||||
|
# Create your tests here.
|
||||||
|
|
||||||
@ -0,0 +1,14 @@
|
|||||||
|
|
||||||
|
# Copyright (C) 2018 Intel Corporation
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
from django.urls import path
|
||||||
|
from . import views
|
||||||
|
|
||||||
|
urlpatterns = [
|
||||||
|
path('create/task/<int:tid>', views.create),
|
||||||
|
path('check/task/<int:tid>', views.check),
|
||||||
|
path('cancel/task/<int:tid>', views.cancel),
|
||||||
|
path('meta/get', views.get_meta_info),
|
||||||
|
]
|
||||||
@ -0,0 +1,322 @@
|
|||||||
|
|
||||||
|
# Copyright (C) 2018 Intel Corporation
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
from django.http import HttpResponse, JsonResponse, HttpResponseBadRequest
|
||||||
|
from rules.contrib.views import permission_required, objectgetter
|
||||||
|
from cvat.apps.authentication.decorators import login_required
|
||||||
|
from cvat.apps.engine.models import Task as TaskModel
|
||||||
|
from cvat.apps.engine.serializers import LabeledDataSerializer
|
||||||
|
from cvat.apps.engine.annotation import put_task_data
|
||||||
|
|
||||||
|
import django_rq
|
||||||
|
import fnmatch
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import rq
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from cvat.apps.engine.log import slogger
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import skimage.io
|
||||||
|
from skimage.measure import find_contours, approximate_polygon
|
||||||
|
|
||||||
|
|
||||||
|
def load_image_into_numpy(image):
|
||||||
|
(im_width, im_height) = image.size
|
||||||
|
return np.array(image.getdata()).reshape((im_height, im_width, 3)).astype(np.uint8)
|
||||||
|
|
||||||
|
|
||||||
|
def run_tensorflow_auto_segmentation(image_list, labels_mapping, treshold):
|
||||||
|
def _convert_to_int(boolean_mask):
|
||||||
|
return boolean_mask.astype(np.uint8)
|
||||||
|
|
||||||
|
def _convert_to_segmentation(mask):
|
||||||
|
contours = find_contours(mask, 0.5)
|
||||||
|
# only one contour exist in our case
|
||||||
|
contour = contours[0]
|
||||||
|
contour = np.flip(contour, axis=1)
|
||||||
|
# Approximate the contour and reduce the number of points
|
||||||
|
contour = approximate_polygon(contour, tolerance=2.5)
|
||||||
|
segmentation = contour.ravel().tolist()
|
||||||
|
return segmentation
|
||||||
|
|
||||||
|
## INITIALIZATION
|
||||||
|
|
||||||
|
# Root directory of the project
|
||||||
|
ROOT_DIR = os.environ.get('AUTO_SEGMENTATION_PATH')
|
||||||
|
# Import Mask RCNN
|
||||||
|
sys.path.append(ROOT_DIR) # To find local version of the library
|
||||||
|
import mrcnn.model as modellib
|
||||||
|
|
||||||
|
# Import COCO config
|
||||||
|
sys.path.append(os.path.join(ROOT_DIR, "samples/coco/")) # To find local version
|
||||||
|
import coco
|
||||||
|
|
||||||
|
# Directory to save logs and trained model
|
||||||
|
MODEL_DIR = os.path.join(ROOT_DIR, "logs")
|
||||||
|
|
||||||
|
# Local path to trained weights file
|
||||||
|
COCO_MODEL_PATH = os.path.join(ROOT_DIR, "mask_rcnn_coco.h5")
|
||||||
|
if COCO_MODEL_PATH is None:
|
||||||
|
raise OSError('Model path env not found in the system.')
|
||||||
|
job = rq.get_current_job()
|
||||||
|
|
||||||
|
## CONFIGURATION
|
||||||
|
|
||||||
|
class InferenceConfig(coco.CocoConfig):
|
||||||
|
# Set batch size to 1 since we'll be running inference on
|
||||||
|
# one image at a time. Batch size = GPU_COUNT * IMAGES_PER_GPU
|
||||||
|
GPU_COUNT = 1
|
||||||
|
IMAGES_PER_GPU = 1
|
||||||
|
|
||||||
|
# Print config details
|
||||||
|
config = InferenceConfig()
|
||||||
|
config.display()
|
||||||
|
|
||||||
|
## CREATE MODEL AND LOAD TRAINED WEIGHTS
|
||||||
|
|
||||||
|
# Create model object in inference mode.
|
||||||
|
model = modellib.MaskRCNN(mode="inference", model_dir=MODEL_DIR, config=config)
|
||||||
|
# Load weights trained on MS-COCO
|
||||||
|
model.load_weights(COCO_MODEL_PATH, by_name=True)
|
||||||
|
|
||||||
|
## RUN OBJECT DETECTION
|
||||||
|
result = {}
|
||||||
|
for image_num, image_path in enumerate(image_list):
|
||||||
|
job.refresh()
|
||||||
|
if 'cancel' in job.meta:
|
||||||
|
del job.meta['cancel']
|
||||||
|
job.save()
|
||||||
|
return None
|
||||||
|
job.meta['progress'] = image_num * 100 / len(image_list)
|
||||||
|
job.save_meta()
|
||||||
|
|
||||||
|
image = skimage.io.imread(image_path)
|
||||||
|
|
||||||
|
# for multiple image detection, "batch size" must be equal to number of images
|
||||||
|
r = model.detect([image], verbose=1)
|
||||||
|
|
||||||
|
r = r[0]
|
||||||
|
# "r['rois'][index]" gives bounding box around the object
|
||||||
|
for index, c_id in enumerate(r['class_ids']):
|
||||||
|
if c_id in labels_mapping.keys():
|
||||||
|
if r['scores'][index] >= treshold:
|
||||||
|
mask = _convert_to_int(r['masks'][:,:,index])
|
||||||
|
segmentation = _convert_to_segmentation(mask)
|
||||||
|
label = labels_mapping[c_id]
|
||||||
|
if label not in result:
|
||||||
|
result[label] = []
|
||||||
|
result[label].append(
|
||||||
|
[image_num, segmentation])
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def make_image_list(path_to_data):
|
||||||
|
def get_image_key(item):
|
||||||
|
return int(os.path.splitext(os.path.basename(item))[0])
|
||||||
|
|
||||||
|
image_list = []
|
||||||
|
for root, _, filenames in os.walk(path_to_data):
|
||||||
|
for filename in fnmatch.filter(filenames, '*.jpg'):
|
||||||
|
image_list.append(os.path.join(root, filename))
|
||||||
|
|
||||||
|
image_list.sort(key=get_image_key)
|
||||||
|
return image_list
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_cvat_format(data):
|
||||||
|
result = {
|
||||||
|
"tracks": [],
|
||||||
|
"shapes": [],
|
||||||
|
"tags": [],
|
||||||
|
"version": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
for label in data:
|
||||||
|
segments = data[label]
|
||||||
|
for segment in segments:
|
||||||
|
result['shapes'].append({
|
||||||
|
"type": "polygon",
|
||||||
|
"label_id": label,
|
||||||
|
"frame": segment[0],
|
||||||
|
"points": segment[1],
|
||||||
|
"z_order": 0,
|
||||||
|
"group": None,
|
||||||
|
"occluded": False,
|
||||||
|
"attributes": [],
|
||||||
|
})
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def create_thread(tid, labels_mapping, user):
|
||||||
|
try:
|
||||||
|
# If detected object accuracy bigger than threshold it will returend
|
||||||
|
TRESHOLD = 0.5
|
||||||
|
# Init rq job
|
||||||
|
job = rq.get_current_job()
|
||||||
|
job.meta['progress'] = 0
|
||||||
|
job.save_meta()
|
||||||
|
# Get job indexes and segment length
|
||||||
|
db_task = TaskModel.objects.get(pk=tid)
|
||||||
|
# Get image list
|
||||||
|
image_list = make_image_list(db_task.get_data_dirname())
|
||||||
|
|
||||||
|
# Run auto segmentation by tf
|
||||||
|
result = None
|
||||||
|
slogger.glob.info("auto segmentation with tensorflow framework for task {}".format(tid))
|
||||||
|
result = run_tensorflow_auto_segmentation(image_list, labels_mapping, TRESHOLD)
|
||||||
|
|
||||||
|
if result is None:
|
||||||
|
slogger.glob.info('auto segmentation for task {} canceled by user'.format(tid))
|
||||||
|
return
|
||||||
|
|
||||||
|
# Modify data format and save
|
||||||
|
result = convert_to_cvat_format(result)
|
||||||
|
serializer = LabeledDataSerializer(data = result)
|
||||||
|
if serializer.is_valid(raise_exception=True):
|
||||||
|
put_task_data(tid, user, result)
|
||||||
|
slogger.glob.info('auto segmentation for task {} done'.format(tid))
|
||||||
|
except Exception as ex:
|
||||||
|
try:
|
||||||
|
slogger.task[tid].exception('exception was occured during auto segmentation of the task', exc_info=True)
|
||||||
|
except Exception:
|
||||||
|
slogger.glob.exception('exception was occured during auto segmentation of the task {}'.format(tid), exc_into=True)
|
||||||
|
raise ex
|
||||||
|
|
||||||
|
@login_required
|
||||||
|
def get_meta_info(request):
|
||||||
|
try:
|
||||||
|
queue = django_rq.get_queue('low')
|
||||||
|
tids = json.loads(request.body.decode('utf-8'))
|
||||||
|
result = {}
|
||||||
|
for tid in tids:
|
||||||
|
job = queue.fetch_job('auto_segmentation.create/{}'.format(tid))
|
||||||
|
if job is not None:
|
||||||
|
result[tid] = {
|
||||||
|
"active": job.is_queued or job.is_started,
|
||||||
|
"success": not job.is_failed
|
||||||
|
}
|
||||||
|
|
||||||
|
return JsonResponse(result)
|
||||||
|
except Exception as ex:
|
||||||
|
slogger.glob.exception('exception was occured during tf meta request', exc_into=True)
|
||||||
|
return HttpResponseBadRequest(str(ex))
|
||||||
|
|
||||||
|
|
||||||
|
@login_required
|
||||||
|
@permission_required(perm=['engine.task.change'],
|
||||||
|
fn=objectgetter(TaskModel, 'tid'), raise_exception=True)
|
||||||
|
def create(request, tid):
|
||||||
|
slogger.glob.info('auto segmentation create request for task {}'.format(tid))
|
||||||
|
try:
|
||||||
|
db_task = TaskModel.objects.get(pk=tid)
|
||||||
|
queue = django_rq.get_queue('low')
|
||||||
|
job = queue.fetch_job('auto_segmentation.create/{}'.format(tid))
|
||||||
|
if job is not None and (job.is_started or job.is_queued):
|
||||||
|
raise Exception("The process is already running")
|
||||||
|
|
||||||
|
db_labels = db_task.label_set.prefetch_related('attributespec_set').all()
|
||||||
|
db_labels = {db_label.id:db_label.name for db_label in db_labels}
|
||||||
|
|
||||||
|
# COCO Labels
|
||||||
|
auto_segmentation_labels = { "BG": 0,
|
||||||
|
"person": 1, "bicycle": 2, "car": 3, "motorcycle": 4, "airplane": 5,
|
||||||
|
"bus": 6, "train": 7, "truck": 8, "boat": 9, "traffic_light": 10,
|
||||||
|
"fire_hydrant": 11, "stop_sign": 13, "parking_meter": 14, "bench": 15,
|
||||||
|
"bird": 16, "cat": 17, "dog": 18, "horse": 19, "sheep": 20, "cow": 21,
|
||||||
|
"elephant": 22, "bear": 23, "zebra": 24, "giraffe": 25, "backpack": 27,
|
||||||
|
"umbrella": 28, "handbag": 31, "tie": 32, "suitcase": 33, "frisbee": 34,
|
||||||
|
"skis": 35, "snowboard": 36, "sports_ball": 37, "kite": 38, "baseball_bat": 39,
|
||||||
|
"baseball_glove": 40, "skateboard": 41, "surfboard": 42, "tennis_racket": 43,
|
||||||
|
"bottle": 44, "wine_glass": 46, "cup": 47, "fork": 48, "knife": 49, "spoon": 50,
|
||||||
|
"bowl": 51, "banana": 52, "apple": 53, "sandwich": 54, "orange": 55, "broccoli": 56,
|
||||||
|
"carrot": 57, "hot_dog": 58, "pizza": 59, "donut": 60, "cake": 61, "chair": 62,
|
||||||
|
"couch": 63, "potted_plant": 64, "bed": 65, "dining_table": 67, "toilet": 70,
|
||||||
|
"tv": 72, "laptop": 73, "mouse": 74, "remote": 75, "keyboard": 76, "cell_phone": 77,
|
||||||
|
"microwave": 78, "oven": 79, "toaster": 80, "sink": 81, "refrigerator": 83,
|
||||||
|
"book": 84, "clock": 85, "vase": 86, "scissors": 87, "teddy_bear": 88, "hair_drier": 89,
|
||||||
|
"toothbrush": 90
|
||||||
|
}
|
||||||
|
|
||||||
|
labels_mapping = {}
|
||||||
|
for key, labels in db_labels.items():
|
||||||
|
if labels in auto_segmentation_labels.keys():
|
||||||
|
labels_mapping[auto_segmentation_labels[labels]] = key
|
||||||
|
|
||||||
|
if not len(labels_mapping.values()):
|
||||||
|
raise Exception('No labels found for auto segmentation')
|
||||||
|
|
||||||
|
# Run auto segmentation job
|
||||||
|
queue.enqueue_call(func=create_thread,
|
||||||
|
args=(tid, labels_mapping, request.user),
|
||||||
|
job_id='auto_segmentation.create/{}'.format(tid),
|
||||||
|
timeout=604800) # 7 days
|
||||||
|
|
||||||
|
slogger.task[tid].info('tensorflow segmentation job enqueued with labels {}'.format(labels_mapping))
|
||||||
|
|
||||||
|
except Exception as ex:
|
||||||
|
try:
|
||||||
|
slogger.task[tid].exception("exception was occured during tensorflow segmentation request", exc_info=True)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return HttpResponseBadRequest(str(ex))
|
||||||
|
|
||||||
|
return HttpResponse()
|
||||||
|
|
||||||
|
@login_required
|
||||||
|
@permission_required(perm=['engine.task.access'],
|
||||||
|
fn=objectgetter(TaskModel, 'tid'), raise_exception=True)
|
||||||
|
def check(request, tid):
|
||||||
|
try:
|
||||||
|
queue = django_rq.get_queue('low')
|
||||||
|
job = queue.fetch_job('auto_segmentation.create/{}'.format(tid))
|
||||||
|
if job is not None and 'cancel' in job.meta:
|
||||||
|
return JsonResponse({'status': 'finished'})
|
||||||
|
data = {}
|
||||||
|
if job is None:
|
||||||
|
data['status'] = 'unknown'
|
||||||
|
elif job.is_queued:
|
||||||
|
data['status'] = 'queued'
|
||||||
|
elif job.is_started:
|
||||||
|
data['status'] = 'started'
|
||||||
|
data['progress'] = job.meta['progress']
|
||||||
|
elif job.is_finished:
|
||||||
|
data['status'] = 'finished'
|
||||||
|
job.delete()
|
||||||
|
else:
|
||||||
|
data['status'] = 'failed'
|
||||||
|
data['stderr'] = job.exc_info
|
||||||
|
job.delete()
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
data['status'] = 'unknown'
|
||||||
|
|
||||||
|
return JsonResponse(data)
|
||||||
|
|
||||||
|
|
||||||
|
@login_required
|
||||||
|
@permission_required(perm=['engine.task.change'],
|
||||||
|
fn=objectgetter(TaskModel, 'tid'), raise_exception=True)
|
||||||
|
def cancel(request, tid):
|
||||||
|
try:
|
||||||
|
queue = django_rq.get_queue('low')
|
||||||
|
job = queue.fetch_job('auto_segmentation.create/{}'.format(tid))
|
||||||
|
if job is None or job.is_finished or job.is_failed:
|
||||||
|
raise Exception('Task is not being segmented currently')
|
||||||
|
elif 'cancel' not in job.meta:
|
||||||
|
job.meta['cancel'] = True
|
||||||
|
job.save()
|
||||||
|
|
||||||
|
except Exception as ex:
|
||||||
|
try:
|
||||||
|
slogger.task[tid].exception("cannot cancel tensorflow segmentation for task #{}".format(tid), exc_info=True)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return HttpResponseBadRequest(str(ex))
|
||||||
|
|
||||||
|
return HttpResponse()
|
||||||
Loading…
Reference in New Issue