You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
774 lines
29 KiB
Python
774 lines
29 KiB
Python
# Copyright (C) 2022 Intel Corporation
|
|
#
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
import base64
|
|
import json
|
|
from functools import wraps
|
|
from enum import Enum
|
|
from copy import deepcopy
|
|
|
|
import django_rq
|
|
import requests
|
|
import rq
|
|
import os
|
|
from django.conf import settings
|
|
from django.core.exceptions import ObjectDoesNotExist, ValidationError
|
|
from rest_framework import status, viewsets
|
|
from rest_framework.response import Response
|
|
|
|
import cvat.apps.dataset_manager as dm
|
|
from cvat.apps.engine.frame_provider import FrameProvider
|
|
from cvat.apps.engine.models import Task as TaskModel
|
|
from cvat.apps.engine.serializers import LabeledDataSerializer
|
|
from cvat.apps.engine.models import ShapeType, SourceType
|
|
|
|
from drf_spectacular.utils import extend_schema, extend_schema_view, OpenApiResponse, OpenApiParameter
|
|
from drf_spectacular.types import OpenApiTypes
|
|
|
|
class LambdaType(Enum):
|
|
DETECTOR = "detector"
|
|
INTERACTOR = "interactor"
|
|
REID = "reid"
|
|
TRACKER = "tracker"
|
|
UNKNOWN = "unknown"
|
|
|
|
def __str__(self):
|
|
return self.value
|
|
|
|
class LambdaGateway:
|
|
NUCLIO_ROOT_URL = '/api/functions'
|
|
|
|
def _http(self, method="get", scheme=None, host=None, port=None,
|
|
url=None, headers=None, data=None):
|
|
NUCLIO_GATEWAY = '{}://{}:{}'.format(
|
|
scheme or settings.NUCLIO['SCHEME'],
|
|
host or settings.NUCLIO['HOST'],
|
|
port or settings.NUCLIO['PORT'])
|
|
extra_headers = {
|
|
'x-nuclio-project-name': 'cvat',
|
|
'x-nuclio-function-namespace': 'nuclio',
|
|
}
|
|
if headers:
|
|
extra_headers.update(headers)
|
|
NUCLIO_TIMEOUT = settings.NUCLIO['DEFAULT_TIMEOUT']
|
|
|
|
if url:
|
|
url = "{}{}".format(NUCLIO_GATEWAY, url)
|
|
else:
|
|
url = NUCLIO_GATEWAY
|
|
|
|
reply = getattr(requests, method)(url, headers=extra_headers,
|
|
timeout=NUCLIO_TIMEOUT, json=data)
|
|
reply.raise_for_status()
|
|
response = reply.json()
|
|
|
|
return response
|
|
|
|
def list(self):
|
|
data = self._http(url=self.NUCLIO_ROOT_URL)
|
|
response = [LambdaFunction(self, item) for item in data.values()]
|
|
return response
|
|
|
|
def get(self, func_id):
|
|
data = self._http(url=self.NUCLIO_ROOT_URL + '/' + func_id)
|
|
response = LambdaFunction(self, data)
|
|
return response
|
|
|
|
def invoke(self, func, payload):
|
|
# Note: call the function directly without the nuclio dashboard
|
|
# host.docker.internal for Linux will work only with Docker 20.10+
|
|
NUCLIO_TIMEOUT = settings.NUCLIO['DEFAULT_TIMEOUT']
|
|
if os.path.exists('/.dockerenv'): # inside a docker container
|
|
url = f'http://host.docker.internal:{func.port}'
|
|
else:
|
|
url = f'http://localhost:{func.port}'
|
|
reply = requests.post(url, timeout=NUCLIO_TIMEOUT, json=payload)
|
|
reply.raise_for_status()
|
|
response = reply.json()
|
|
|
|
return response
|
|
|
|
class LambdaFunction:
|
|
def __init__(self, gateway, data):
|
|
# ID of the function (e.g. omz.public.yolo-v3)
|
|
self.id = data['metadata']['name']
|
|
# type of the function (e.g. detector, interactor)
|
|
meta_anno = data['metadata']['annotations']
|
|
kind = meta_anno.get('type')
|
|
try:
|
|
self.kind = LambdaType(kind)
|
|
except ValueError:
|
|
self.kind = LambdaType.UNKNOWN
|
|
# dictionary of labels for the function (e.g. car, person)
|
|
spec = json.loads(meta_anno.get('spec') or '[]')
|
|
labels = [item['name'] for item in spec]
|
|
if len(labels) != len(set(labels)):
|
|
raise ValidationError(
|
|
"`{}` lambda function has non-unique labels".format(self.id),
|
|
code=status.HTTP_404_NOT_FOUND)
|
|
self.labels = labels
|
|
# mapping of labels and corresponding supported attributes
|
|
self.func_attributes = {item['name']: item.get('attributes', []) for item in spec}
|
|
for label, attributes in self.func_attributes.items():
|
|
if len([attr['name'] for attr in attributes]) != len(set([attr['name'] for attr in attributes])):
|
|
raise ValidationError(
|
|
"`{}` lambda function has non-unique attributes for label {}".format(self.id, label),
|
|
code=status.HTTP_404_NOT_FOUND)
|
|
# state of the function
|
|
self.state = data['status']['state']
|
|
# description of the function
|
|
self.description = data['spec']['description']
|
|
# http port to access the serverless function
|
|
self.port = data["status"].get("httpPort")
|
|
# framework which is used for the function (e.g. tensorflow, openvino)
|
|
self.framework = meta_anno.get('framework')
|
|
# display name for the function
|
|
self.name = meta_anno.get('name', self.id)
|
|
self.min_pos_points = int(meta_anno.get('min_pos_points', 1))
|
|
self.min_neg_points = int(meta_anno.get('min_neg_points', -1))
|
|
self.startswith_box = bool(meta_anno.get('startswith_box', False))
|
|
self.animated_gif = meta_anno.get('animated_gif', '')
|
|
self.help_message = meta_anno.get('help_message', '')
|
|
self.gateway = gateway
|
|
|
|
def to_dict(self):
|
|
response = {
|
|
'id': self.id,
|
|
'kind': str(self.kind),
|
|
'labels': self.labels,
|
|
'description': self.description,
|
|
'framework': self.framework,
|
|
'name': self.name
|
|
}
|
|
|
|
if self.kind is LambdaType.INTERACTOR:
|
|
response.update({
|
|
'min_pos_points': self.min_pos_points,
|
|
'min_neg_points': self.min_neg_points,
|
|
'startswith_box': self.startswith_box,
|
|
'help_message': self.help_message,
|
|
'animated_gif': self.animated_gif
|
|
})
|
|
|
|
if self.kind is LambdaType.TRACKER:
|
|
response.update({
|
|
'state': self.state
|
|
})
|
|
if self.kind is LambdaType.DETECTOR:
|
|
response.update({
|
|
'attributes': self.func_attributes
|
|
})
|
|
|
|
return response
|
|
|
|
def invoke(self, db_task, data):
|
|
try:
|
|
payload = {}
|
|
data = {k: v for k,v in data.items() if v is not None}
|
|
threshold = data.get("threshold")
|
|
if threshold:
|
|
payload.update({ "threshold": threshold })
|
|
quality = data.get("quality")
|
|
mapping = data.get("mapping", {})
|
|
|
|
task_attributes = {}
|
|
mapping_by_default = {}
|
|
for db_label in (db_task.project.label_set if db_task.project_id else db_task.label_set).prefetch_related("attributespec_set").all():
|
|
mapping_by_default[db_label.name] = {
|
|
'name': db_label.name,
|
|
'attributes': {}
|
|
}
|
|
task_attributes[db_label.name] = {}
|
|
for attribute in db_label.attributespec_set.all():
|
|
task_attributes[db_label.name][attribute.name] = {
|
|
'input_type': attribute.input_type,
|
|
'values': attribute.values.split('\n')
|
|
}
|
|
if not mapping:
|
|
# use mapping by default to avoid labels in mapping which
|
|
# don't exist in the task
|
|
mapping = mapping_by_default
|
|
else:
|
|
# filter labels in mapping which don't exist in the task
|
|
mapping = {k:v for k,v in mapping.items() if v['name'] in mapping_by_default}
|
|
|
|
attr_mapping = { label: mapping[label]['attributes'] if 'attributes' in mapping[label] else {} for label in mapping }
|
|
mapping = { modelLabel: mapping[modelLabel]['name'] for modelLabel in mapping }
|
|
|
|
supported_attrs = {}
|
|
for func_label, func_attrs in self.func_attributes.items():
|
|
if func_label not in mapping:
|
|
continue
|
|
|
|
mapped_label = mapping[func_label]
|
|
mapped_attributes = attr_mapping.get(func_label, {})
|
|
supported_attrs[func_label] = {}
|
|
|
|
if mapped_attributes:
|
|
task_attr_names = [task_attr for task_attr in task_attributes[mapped_label]]
|
|
for attr in func_attrs:
|
|
mapped_attr = mapped_attributes.get(attr["name"])
|
|
if mapped_attr in task_attr_names:
|
|
supported_attrs[func_label].update({ attr["name"]: task_attributes[mapped_label][mapped_attr] })
|
|
|
|
if self.kind == LambdaType.DETECTOR:
|
|
payload.update({
|
|
"image": self._get_image(db_task, data["frame"], quality)
|
|
})
|
|
elif self.kind == LambdaType.INTERACTOR:
|
|
payload.update({
|
|
"image": self._get_image(db_task, data["frame"], quality),
|
|
"pos_points": data["pos_points"][2:] if self.startswith_box else data["pos_points"],
|
|
"neg_points": data["neg_points"],
|
|
"obj_bbox": data["pos_points"][0:2] if self.startswith_box else None
|
|
})
|
|
elif self.kind == LambdaType.REID:
|
|
payload.update({
|
|
"image0": self._get_image(db_task, data["frame0"], quality),
|
|
"image1": self._get_image(db_task, data["frame1"], quality),
|
|
"boxes0": data["boxes0"],
|
|
"boxes1": data["boxes1"]
|
|
})
|
|
max_distance = data.get("max_distance")
|
|
if max_distance:
|
|
payload.update({
|
|
"max_distance": max_distance
|
|
})
|
|
elif self.kind == LambdaType.TRACKER:
|
|
payload.update({
|
|
"image": self._get_image(db_task, data["frame"], quality),
|
|
"shapes": data.get("shapes", []),
|
|
"states": data.get("states", [])
|
|
})
|
|
else:
|
|
raise ValidationError(
|
|
'`{}` lambda function has incorrect type: {}'
|
|
.format(self.id, self.kind),
|
|
code=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
|
except KeyError as err:
|
|
raise ValidationError(
|
|
"`{}` lambda function was called without mandatory argument: {}"
|
|
.format(self.id, str(err)),
|
|
code=status.HTTP_400_BAD_REQUEST)
|
|
|
|
response = self.gateway.invoke(self, payload)
|
|
response_filtered = []
|
|
def check_attr_value(value, func_attr, db_attr):
|
|
if db_attr is None:
|
|
return False
|
|
func_attr_type = func_attr["input_type"]
|
|
db_attr_type = db_attr["input_type"]
|
|
# Check if attribute types are equal for function configuration and db spec
|
|
if func_attr_type == db_attr_type:
|
|
if func_attr_type == "number":
|
|
return value.isnumeric()
|
|
elif func_attr_type == "checkbox":
|
|
return value in ["true", "false"]
|
|
elif func_attr_type in ["select", "radio", "text"]:
|
|
return True
|
|
else:
|
|
return False
|
|
else:
|
|
if func_attr_type == "number":
|
|
return db_attr_type in ["select", "radio", "text"] and value.isnumeric()
|
|
elif func_attr_type == "text":
|
|
return db_attr_type == "text" or \
|
|
(db_attr_type in ["select", "radio"] and len(value.split(" ")) == 1)
|
|
elif func_attr_type == "select":
|
|
return db_attr_type in ["radio", "text"]
|
|
elif func_attr_type == "radio":
|
|
return db_attr_type in ["select", "text"]
|
|
elif func_attr_type == "checkbox":
|
|
return value in ["true", "false"]
|
|
else:
|
|
return False
|
|
if self.kind == LambdaType.DETECTOR:
|
|
for item in response:
|
|
item_label = item['label']
|
|
|
|
if item_label not in mapping:
|
|
continue
|
|
|
|
attributes = deepcopy(item.get("attributes", []))
|
|
item["attributes"] = []
|
|
mapped_attributes = attr_mapping[item_label]
|
|
|
|
for attr in attributes:
|
|
if attr['name'] not in mapped_attributes:
|
|
continue
|
|
|
|
func_attr = [func_attr for func_attr in self.func_attributes.get(item_label, []) if func_attr['name'] == attr["name"]]
|
|
# Skip current attribute if it was not declared as supported in function config
|
|
if not func_attr:
|
|
continue
|
|
|
|
db_attr = supported_attrs.get(item_label, {}).get(attr["name"])
|
|
|
|
if check_attr_value(attr["value"], func_attr[0], db_attr):
|
|
attr["name"] = mapped_attributes[attr['name']]
|
|
item["attributes"].append(attr)
|
|
|
|
item['label'] = mapping[item['label']]
|
|
response_filtered.append(item)
|
|
response = response_filtered
|
|
|
|
return response
|
|
|
|
def _get_image(self, db_task, frame, quality):
|
|
if quality is None or quality == "original":
|
|
quality = FrameProvider.Quality.ORIGINAL
|
|
elif quality == "compressed":
|
|
quality = FrameProvider.Quality.COMPRESSED
|
|
else:
|
|
raise ValidationError(
|
|
'`{}` lambda function was run '.format(self.id) +
|
|
'with wrong arguments (quality={})'.format(quality),
|
|
code=status.HTTP_400_BAD_REQUEST)
|
|
|
|
frame_provider = FrameProvider(db_task.data)
|
|
image = frame_provider.get_frame(frame, quality=quality)
|
|
|
|
return base64.b64encode(image[0].getvalue()).decode('utf-8')
|
|
|
|
class LambdaQueue:
|
|
def _get_queue(self):
|
|
QUEUE_NAME = "low"
|
|
return django_rq.get_queue(QUEUE_NAME)
|
|
|
|
def get_jobs(self):
|
|
queue = self._get_queue()
|
|
# Only failed jobs are not included in the list below.
|
|
job_ids = set(queue.get_job_ids() +
|
|
queue.started_job_registry.get_job_ids() +
|
|
queue.finished_job_registry.get_job_ids() +
|
|
queue.scheduled_job_registry.get_job_ids() +
|
|
queue.deferred_job_registry.get_job_ids())
|
|
jobs = queue.job_class.fetch_many(job_ids, queue.connection)
|
|
|
|
return [LambdaJob(job) for job in jobs if job.meta.get("lambda")]
|
|
|
|
def enqueue(self, lambda_func, threshold, task, quality, mapping, cleanup, max_distance):
|
|
jobs = self.get_jobs()
|
|
# It is still possible to run several concurrent jobs for the same task.
|
|
# But the race isn't critical. The filtration is just a light-weight
|
|
# protection.
|
|
if list(filter(lambda job: job.get_task() == task and not job.is_finished, jobs)):
|
|
raise ValidationError(
|
|
"Only one running request is allowed for the same task #{}".format(task),
|
|
code=status.HTTP_409_CONFLICT)
|
|
|
|
queue = self._get_queue()
|
|
# LambdaJob(None) is a workaround for python-rq. It has multiple issues
|
|
# with invocation of non-trivial functions. For example, it cannot run
|
|
# staticmethod, it cannot run a callable class. Thus I provide an object
|
|
# which has __call__ function.
|
|
job = queue.create_job(LambdaJob(None),
|
|
meta = { "lambda": True },
|
|
kwargs = {
|
|
"function": lambda_func,
|
|
"threshold": threshold,
|
|
"task": task,
|
|
"quality": quality,
|
|
"cleanup": cleanup,
|
|
"mapping": mapping,
|
|
"max_distance": max_distance
|
|
})
|
|
|
|
queue.enqueue_job(job)
|
|
|
|
return LambdaJob(job)
|
|
|
|
def fetch_job(self, pk):
|
|
queue = self._get_queue()
|
|
job = queue.fetch_job(pk)
|
|
if job is None or not job.meta.get("lambda"):
|
|
raise ValidationError("{} lambda job is not found".format(pk),
|
|
code=status.HTTP_404_NOT_FOUND)
|
|
|
|
return LambdaJob(job)
|
|
|
|
class LambdaJob:
|
|
def __init__(self, job):
|
|
self.job = job
|
|
|
|
def to_dict(self):
|
|
lambda_func = self.job.kwargs.get("function")
|
|
return {
|
|
"id": self.job.id,
|
|
"function": {
|
|
"id": lambda_func.id if lambda_func else None,
|
|
"threshold": self.job.kwargs.get("threshold"),
|
|
"task": self.job.kwargs.get("task")
|
|
},
|
|
"status": self.job.get_status(),
|
|
"progress": self.job.meta.get('progress', 0),
|
|
"enqueued": self.job.enqueued_at,
|
|
"started": self.job.started_at,
|
|
"ended": self.job.ended_at,
|
|
"exc_info": self.job.exc_info
|
|
}
|
|
|
|
def get_task(self):
|
|
return self.job.kwargs.get("task")
|
|
|
|
def get_status(self):
|
|
return self.job.get_status()
|
|
|
|
@property
|
|
def is_finished(self):
|
|
return self.get_status() == rq.job.JobStatus.FINISHED
|
|
|
|
@property
|
|
def is_queued(self):
|
|
return self.get_status() == rq.job.JobStatus.QUEUED
|
|
|
|
@property
|
|
def is_failed(self):
|
|
return self.get_status() == rq.job.JobStatus.FAILED
|
|
|
|
@property
|
|
def is_started(self):
|
|
return self.get_status() == rq.job.JobStatus.STARTED
|
|
|
|
@property
|
|
def is_deferred(self):
|
|
return self.get_status() == rq.job.JobStatus.DEFERRED
|
|
|
|
@property
|
|
def is_scheduled(self):
|
|
return self.get_status() == rq.job.JobStatus.SCHEDULED
|
|
|
|
def delete(self):
|
|
self.job.delete()
|
|
|
|
@staticmethod
|
|
def _call_detector(function, db_task, labels, quality, threshold, mapping):
|
|
class Results:
|
|
def __init__(self, task_id):
|
|
self.task_id = task_id
|
|
self.reset()
|
|
|
|
def append_shape(self, shape):
|
|
self.data["shapes"].append(shape)
|
|
|
|
def append_tag(self, tag):
|
|
self.data["tags"].append(tag)
|
|
|
|
def submit(self):
|
|
if not self.is_empty():
|
|
serializer = LabeledDataSerializer(data=self.data)
|
|
if serializer.is_valid(raise_exception=True):
|
|
dm.task.patch_task_data(self.task_id, serializer.data, "create")
|
|
self.reset()
|
|
|
|
def is_empty(self):
|
|
return not (self.data["tags"] or self.data["shapes"] or self.data["tracks"])
|
|
|
|
def reset(self):
|
|
# TODO: need to make "tags" and "tracks" are optional
|
|
# FIXME: need to provide the correct version here
|
|
self.data = {"version": 0, "tags": [], "shapes": [], "tracks": []}
|
|
|
|
results = Results(db_task.id)
|
|
|
|
for frame in range(db_task.data.size):
|
|
if frame in db_task.data.deleted_frames:
|
|
continue
|
|
annotations = function.invoke(db_task, data={
|
|
"frame": frame, "quality": quality, "mapping": mapping,
|
|
"threshold": threshold })
|
|
progress = (frame + 1) / db_task.data.size
|
|
if not LambdaJob._update_progress(progress):
|
|
break
|
|
|
|
for anno in annotations:
|
|
label = labels.get(anno["label"])
|
|
if label is None:
|
|
continue # Invalid label provided
|
|
if anno.get('attributes'):
|
|
attrs = [{'spec_id': label['attributes'][attr['name']], 'value': attr['value']} for attr in anno.get('attributes') if attr['name'] in label['attributes']]
|
|
else:
|
|
attrs = []
|
|
if anno["type"].lower() == "tag":
|
|
results.append_tag({
|
|
"frame": frame,
|
|
"label_id": label['id'],
|
|
"source": "auto",
|
|
"attributes": attrs,
|
|
"group": None,
|
|
})
|
|
else:
|
|
results.append_shape({
|
|
"frame": frame,
|
|
"label_id": label['id'],
|
|
"type": anno["type"],
|
|
"occluded": False,
|
|
"points": anno["points"],
|
|
"z_order": 0,
|
|
"group": None,
|
|
"attributes": attrs,
|
|
"source": "auto"
|
|
})
|
|
|
|
# Accumulate data during 100 frames before sumbitting results.
|
|
# It is optimization to make fewer calls to our server. Also
|
|
# it isn't possible to keep all results in memory.
|
|
if frame % 100 == 0:
|
|
results.submit()
|
|
|
|
results.submit()
|
|
|
|
@staticmethod
|
|
# progress is in [0, 1] range
|
|
def _update_progress(progress):
|
|
job = rq.get_current_job()
|
|
# If the job has been deleted, get_status will return None. Thus it will
|
|
# exist the loop.
|
|
job.meta["progress"] = int(progress * 100)
|
|
job.save_meta()
|
|
|
|
return job.get_status()
|
|
|
|
|
|
@staticmethod
|
|
def _call_reid(function, db_task, quality, threshold, max_distance):
|
|
data = dm.task.get_task_data(db_task.id)
|
|
boxes_by_frame = [[] for _ in range(db_task.data.size)]
|
|
shapes_without_boxes = []
|
|
for shape in data["shapes"]:
|
|
if shape["type"] == str(ShapeType.RECTANGLE):
|
|
boxes_by_frame[shape["frame"]].append(shape)
|
|
else:
|
|
shapes_without_boxes.append(shape)
|
|
|
|
paths = {}
|
|
for frame in range(db_task.data.size - 1):
|
|
boxes0 = boxes_by_frame[frame]
|
|
for box in boxes0:
|
|
if "path_id" not in box:
|
|
path_id = len(paths)
|
|
paths[path_id] = [box]
|
|
box["path_id"] = path_id
|
|
|
|
boxes1 = boxes_by_frame[frame + 1]
|
|
if boxes0 and boxes1:
|
|
matching = function.invoke(db_task, data={
|
|
"frame0": frame, "frame1": frame + 1, "quality": quality,
|
|
"boxes0": boxes0, "boxes1": boxes1, "threshold": threshold,
|
|
"max_distance": max_distance})
|
|
|
|
for idx0, idx1 in enumerate(matching):
|
|
if idx1 >= 0:
|
|
path_id = boxes0[idx0]["path_id"]
|
|
boxes1[idx1]["path_id"] = path_id
|
|
paths[path_id].append(boxes1[idx1])
|
|
|
|
progress = (frame + 2) / db_task.data.size
|
|
if not LambdaJob._update_progress(progress):
|
|
break
|
|
|
|
|
|
for box in boxes_by_frame[db_task.data.size - 1]:
|
|
if "path_id" not in box:
|
|
path_id = len(paths)
|
|
paths[path_id] = [box]
|
|
box["path_id"] = path_id
|
|
|
|
tracks = []
|
|
for path_id in paths:
|
|
box0 = paths[path_id][0]
|
|
tracks.append({
|
|
"label_id": box0["label_id"],
|
|
"group": None,
|
|
"attributes": [],
|
|
"frame": box0["frame"],
|
|
"shapes": paths[path_id],
|
|
"source": str(SourceType.AUTO)
|
|
})
|
|
|
|
for box in tracks[-1]["shapes"]:
|
|
box.pop("id", None)
|
|
box.pop("path_id")
|
|
box.pop("group")
|
|
box.pop("label_id")
|
|
box.pop("source")
|
|
box["outside"] = False
|
|
box["attributes"] = []
|
|
|
|
for track in tracks:
|
|
if track["shapes"][-1]["frame"] != db_task.data.size - 1:
|
|
box = track["shapes"][-1].copy()
|
|
box["outside"] = True
|
|
box["frame"] += 1
|
|
track["shapes"].append(box)
|
|
|
|
if tracks:
|
|
data["shapes"] = shapes_without_boxes
|
|
data["tracks"].extend(tracks)
|
|
|
|
serializer = LabeledDataSerializer(data=data)
|
|
if serializer.is_valid(raise_exception=True):
|
|
dm.task.put_task_data(db_task.id, serializer.data)
|
|
|
|
@staticmethod
|
|
def __call__(function, task, quality, cleanup, **kwargs):
|
|
# TODO: need logging
|
|
db_task = TaskModel.objects.get(pk=task)
|
|
if cleanup:
|
|
dm.task.delete_task_data(db_task.id)
|
|
db_labels = (db_task.project.label_set if db_task.project_id else db_task.label_set).prefetch_related("attributespec_set").all()
|
|
labels = {}
|
|
for label in db_labels:
|
|
labels[label.name] = {'id':label.id, 'attributes': {}}
|
|
for attr in label.attributespec_set.values():
|
|
labels[label.name]['attributes'][attr['name']] = attr['id']
|
|
|
|
if function.kind == LambdaType.DETECTOR:
|
|
LambdaJob._call_detector(function, db_task, labels, quality,
|
|
kwargs.get("threshold"), kwargs.get("mapping"))
|
|
elif function.kind == LambdaType.REID:
|
|
LambdaJob._call_reid(function, db_task, quality,
|
|
kwargs.get("threshold"), kwargs.get("max_distance"))
|
|
|
|
def return_response(success_code=status.HTTP_200_OK):
|
|
def wrap_response(func):
|
|
@wraps(func)
|
|
def func_wrapper(*args, **kwargs):
|
|
data = None
|
|
status_code = success_code
|
|
try:
|
|
data = func(*args, **kwargs)
|
|
except requests.ConnectionError as err:
|
|
status_code = status.HTTP_503_SERVICE_UNAVAILABLE
|
|
data = str(err)
|
|
except requests.HTTPError as err:
|
|
status_code = err.response.status_code
|
|
data = str(err)
|
|
except requests.Timeout as err:
|
|
status_code = status.HTTP_504_GATEWAY_TIMEOUT
|
|
data = str(err)
|
|
except requests.RequestException as err:
|
|
status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
|
|
data = str(err)
|
|
except ValidationError as err:
|
|
status_code = err.code
|
|
data = err.message
|
|
except ObjectDoesNotExist as err:
|
|
status_code = status.HTTP_400_BAD_REQUEST
|
|
data = str(err)
|
|
|
|
return Response(data=data, status=status_code)
|
|
|
|
return func_wrapper
|
|
return wrap_response
|
|
|
|
@extend_schema(tags=['lambda'])
|
|
@extend_schema_view(
|
|
retrieve=extend_schema(
|
|
summary='Method returns the information about the function',
|
|
responses={
|
|
'200': OpenApiResponse(response=OpenApiTypes.OBJECT, description='Information about the function'),
|
|
}),
|
|
list=extend_schema(
|
|
summary='Method returns a list of functions')
|
|
)
|
|
class FunctionViewSet(viewsets.ViewSet):
|
|
lookup_value_regex = '[a-zA-Z0-9_.-]+'
|
|
lookup_field = 'func_id'
|
|
iam_organization_field = None
|
|
serializer_class = None
|
|
|
|
@return_response()
|
|
def list(self, request):
|
|
gateway = LambdaGateway()
|
|
return [f.to_dict() for f in gateway.list()]
|
|
|
|
@return_response()
|
|
def retrieve(self, request, func_id):
|
|
self.check_object_permissions(request, func_id)
|
|
gateway = LambdaGateway()
|
|
return gateway.get(func_id).to_dict()
|
|
|
|
@return_response()
|
|
def call(self, request, func_id):
|
|
self.check_object_permissions(request, func_id)
|
|
try:
|
|
task_id = request.data['task']
|
|
db_task = TaskModel.objects.get(pk=task_id)
|
|
except (KeyError, ObjectDoesNotExist) as err:
|
|
raise ValidationError(
|
|
'`{}` lambda function was run '.format(func_id) +
|
|
'with wrong arguments ({})'.format(str(err)),
|
|
code=status.HTTP_400_BAD_REQUEST)
|
|
|
|
gateway = LambdaGateway()
|
|
lambda_func = gateway.get(func_id)
|
|
|
|
return lambda_func.invoke(db_task, request.data)
|
|
|
|
@extend_schema(tags=['lambda'])
|
|
@extend_schema_view(
|
|
retrieve=extend_schema(
|
|
summary='Method returns the status of the request',
|
|
parameters=[
|
|
# specify correct type
|
|
OpenApiParameter('id', location=OpenApiParameter.PATH, type=OpenApiTypes.INT,
|
|
description='Request id'),
|
|
]),
|
|
list=extend_schema(
|
|
summary='Method returns a list of requests'),
|
|
#TODO
|
|
create=extend_schema(
|
|
summary='Method calls the function'),
|
|
delete=extend_schema(
|
|
summary='Method cancels the request')
|
|
)
|
|
class RequestViewSet(viewsets.ViewSet):
|
|
iam_organization_field = None
|
|
serializer_class = None
|
|
|
|
@return_response()
|
|
def list(self, request):
|
|
queue = LambdaQueue()
|
|
return [job.to_dict() for job in queue.get_jobs()]
|
|
|
|
@return_response()
|
|
def create(self, request):
|
|
try:
|
|
function = request.data['function']
|
|
threshold = request.data.get('threshold')
|
|
task = request.data['task']
|
|
quality = request.data.get("quality")
|
|
cleanup = request.data.get('cleanup', False)
|
|
mapping = request.data.get('mapping')
|
|
max_distance = request.data.get('max_distance')
|
|
except KeyError as err:
|
|
raise ValidationError(
|
|
'`{}` lambda function was run '.format(request.data.get('function', 'undefined')) +
|
|
'with wrong arguments ({})'.format(str(err)),
|
|
code=status.HTTP_400_BAD_REQUEST)
|
|
|
|
gateway = LambdaGateway()
|
|
queue = LambdaQueue()
|
|
lambda_func = gateway.get(function)
|
|
job = queue.enqueue(lambda_func, threshold, task, quality,
|
|
mapping, cleanup, max_distance)
|
|
|
|
return job.to_dict()
|
|
|
|
@return_response()
|
|
def retrieve(self, request, pk):
|
|
self.check_object_permissions(request, pk)
|
|
queue = LambdaQueue()
|
|
job = queue.fetch_job(pk)
|
|
|
|
return job.to_dict()
|
|
|
|
@return_response(status.HTTP_204_NO_CONTENT)
|
|
def delete(self, request, pk):
|
|
self.check_object_permissions(request, pk)
|
|
queue = LambdaQueue()
|
|
job = queue.fetch_job(pk)
|
|
job.delete()
|