You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

184 lines
7.0 KiB
Python

from collections import OrderedDict
from typing import List
from django_rq import job
from cvat.apps import dataset_manager as dm
from cvat.apps.engine.frame_provider import FrameProvider
from cvat.apps.engine.models import (
Project,
Task,
TrainingProjectImage,
Label,
Image,
TrainingProjectLabel,
Data,
Job,
ShapeType,
)
from cvat.apps.training.apis import TrainingServerAPI
@job
def save_prediction_server_status_to_cache_job(cache_key,
cvat_project_id,
timeout=60):
cvat_project = Project.objects.get(pk=cvat_project_id)
api = TrainingServerAPI(host=cvat_project.training_project.host, username=cvat_project.training_project.username,
password=cvat_project.training_project.password)
status = api.get_project_status(project_id=cvat_project.training_project.training_id)
resp = {
**status,
'status': 'done'
}
return resp # dummy code, need to delete training app in a separate PR
@job
def save_frame_prediction_to_cache_job(cache_key: str,
task_id: int,
frame: int,
timeout: int = 60):
task = Task.objects.get(pk=task_id)
training_project_image = TrainingProjectImage.objects.filter(idx=frame, task=task).first()
if not training_project_image:
return
cvat_labels = Label.objects.filter(project__id=task.project_id).all()
training_project = Project.objects.get(pk=task.project_id).training_project
api = TrainingServerAPI(host=training_project.host,
username=training_project.username,
password=training_project.password)
image = Image.objects.get(frame=frame, data=task.data)
labels_mapping = {
TrainingProjectLabel.objects.get(cvat_label=cvat_label).training_label_id: cvat_label.id
for cvat_label in cvat_labels
}
annotation = api.get_annotation(project_id=training_project.training_id,
image_id=training_project_image.training_image_id,
width=image.width,
height=image.height,
labels_mapping=labels_mapping,
frame=frame)
resp = {
'annotation': annotation,
'status': 'done'
}
return resp # dummy code, need to delete training app in a separate PR
@job
def upload_images_job(task_id: int):
if TrainingProjectImage.objects.filter(task_id=task_id).count() is 0:
task = Task.objects.get(pk=task_id)
frame_provider = FrameProvider(task.data)
frames = frame_provider.get_frames()
api = TrainingServerAPI(
host=task.project.training_project.host,
username=task.project.training_project.username,
password=task.project.training_project.password,
)
for i, (buffer, _) in enumerate(frames):
training_image_id = api.upload_image(training_id=task.project.training_project.training_id, buffer=buffer)
if training_image_id:
TrainingProjectImage.objects.create(task=task, idx=i,
training_image_id=training_image_id)
def __add_fields_to_shape(shape: dict, frame: int, data: Data, labels_mapping: dict) -> dict:
image = Image.objects.get(frame=frame, data=data)
return {
**shape,
'height': image.height,
'width': image.width,
'third_party_label_id': labels_mapping[shape['label_id']],
}
@job
def upload_annotation_to_training_project_job(job_id: int):
cvat_job = Job.objects.get(pk=job_id)
cvat_project = cvat_job.segment.task.project
training_project = cvat_project.training_project
start = cvat_job.segment.start_frame
stop = cvat_job.segment.stop_frame
data = dm.task.get_job_data(job_id)
shapes: List[OrderedDict] = data.get('shapes', [])
frames_data = []
api = TrainingServerAPI(
host=cvat_project.training_project.host,
username=cvat_project.training_project.username,
password=cvat_project.training_project.password,
)
cvat_labels = Label.objects.filter(project=cvat_project).all()
labels_mapping = {
cvat_label.id: TrainingProjectLabel.objects.get(cvat_label=cvat_label).training_label_id
for cvat_label in cvat_labels
}
for frame in range(start, stop + 1):
frame_shapes = list(
map(
lambda x: __add_fields_to_shape(x, frame, cvat_job.segment.task.data, labels_mapping),
filter(
lambda x: x['frame'] == frame and x['type'] == ShapeType.RECTANGLE,
shapes,
)
)
)
if frame_shapes:
training_project_image = TrainingProjectImage.objects.get(task=cvat_job.segment.task, idx=frame)
frames_data.append({
'third_party_id': training_project_image.training_image_id,
'shapes': frame_shapes
})
api.upload_annotations(project_id=training_project.training_id, frames_data=frames_data)
@job
def create_training_project_job(project_id: int):
cvat_project = Project.objects.get(pk=project_id)
training_project = cvat_project.training_project
api = TrainingServerAPI(
host=cvat_project.training_project.host,
username=cvat_project.training_project.username,
password=cvat_project.training_project.password,
)
create_training_project(cvat_project=cvat_project, training_project=training_project, api=api)
def create_training_project(cvat_project, training_project, api):
labels = cvat_project.label_set.all()
training_project_resp = api.create_project(
name=f'{cvat_project.name}_cvat',
project_class=training_project.project_class,
labels=[{'name': label.name} for label in labels]
)
if training_project_resp.get('id'):
training_project.training_id = training_project_resp['id']
training_project.save()
for cvat_label in labels:
training_label = list(filter(lambda x: x['name'] == cvat_label.name, training_project_resp.get('labels', [])))
if training_label:
TrainingProjectLabel.objects.create(cvat_label=cvat_label, training_label_id=training_label[0]['id'])
async def upload_images(cvat_project_id, training_id, api):
project = Project.objects.get(pk=cvat_project_id)
tasks: List[Task] = project.tasks.all()
for task in tasks:
frame_provider = FrameProvider(task)
frames = frame_provider.get_frames()
for i, (buffer, _) in enumerate(frames):
training_image_id = api.upload_image(training_id=training_id, buffer=buffer)
if training_image_id:
TrainingProjectImage.objects.create(project=project, task=task, idx=i,
training_image_id=training_image_id)