You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
184 lines
7.0 KiB
Python
184 lines
7.0 KiB
Python
from collections import OrderedDict
|
|
from typing import List
|
|
|
|
from django_rq import job
|
|
|
|
from cvat.apps import dataset_manager as dm
|
|
from cvat.apps.engine.frame_provider import FrameProvider
|
|
from cvat.apps.engine.models import (
|
|
Project,
|
|
Task,
|
|
TrainingProjectImage,
|
|
Label,
|
|
Image,
|
|
TrainingProjectLabel,
|
|
Data,
|
|
Job,
|
|
ShapeType,
|
|
)
|
|
from cvat.apps.training.apis import TrainingServerAPI
|
|
|
|
|
|
@job
|
|
def save_prediction_server_status_to_cache_job(cache_key,
|
|
cvat_project_id,
|
|
timeout=60):
|
|
cvat_project = Project.objects.get(pk=cvat_project_id)
|
|
api = TrainingServerAPI(host=cvat_project.training_project.host, username=cvat_project.training_project.username,
|
|
password=cvat_project.training_project.password)
|
|
status = api.get_project_status(project_id=cvat_project.training_project.training_id)
|
|
|
|
resp = {
|
|
**status,
|
|
'status': 'done'
|
|
}
|
|
|
|
return resp # dummy code, need to delete training app in a separate PR
|
|
|
|
|
|
@job
|
|
def save_frame_prediction_to_cache_job(cache_key: str,
|
|
task_id: int,
|
|
frame: int,
|
|
timeout: int = 60):
|
|
task = Task.objects.get(pk=task_id)
|
|
training_project_image = TrainingProjectImage.objects.filter(idx=frame, task=task).first()
|
|
if not training_project_image:
|
|
return
|
|
|
|
cvat_labels = Label.objects.filter(project__id=task.project_id).all()
|
|
training_project = Project.objects.get(pk=task.project_id).training_project
|
|
api = TrainingServerAPI(host=training_project.host,
|
|
username=training_project.username,
|
|
password=training_project.password)
|
|
image = Image.objects.get(frame=frame, data=task.data)
|
|
labels_mapping = {
|
|
TrainingProjectLabel.objects.get(cvat_label=cvat_label).training_label_id: cvat_label.id
|
|
for cvat_label in cvat_labels
|
|
}
|
|
annotation = api.get_annotation(project_id=training_project.training_id,
|
|
image_id=training_project_image.training_image_id,
|
|
width=image.width,
|
|
height=image.height,
|
|
labels_mapping=labels_mapping,
|
|
frame=frame)
|
|
resp = {
|
|
'annotation': annotation,
|
|
'status': 'done'
|
|
}
|
|
|
|
return resp # dummy code, need to delete training app in a separate PR
|
|
|
|
|
|
@job
|
|
def upload_images_job(task_id: int):
|
|
if TrainingProjectImage.objects.filter(task_id=task_id).count() is 0:
|
|
task = Task.objects.get(pk=task_id)
|
|
frame_provider = FrameProvider(task.data)
|
|
frames = frame_provider.get_frames()
|
|
api = TrainingServerAPI(
|
|
host=task.project.training_project.host,
|
|
username=task.project.training_project.username,
|
|
password=task.project.training_project.password,
|
|
)
|
|
|
|
for i, (buffer, _) in enumerate(frames):
|
|
training_image_id = api.upload_image(training_id=task.project.training_project.training_id, buffer=buffer)
|
|
if training_image_id:
|
|
TrainingProjectImage.objects.create(task=task, idx=i,
|
|
training_image_id=training_image_id)
|
|
|
|
def __add_fields_to_shape(shape: dict, frame: int, data: Data, labels_mapping: dict) -> dict:
|
|
image = Image.objects.get(frame=frame, data=data)
|
|
return {
|
|
**shape,
|
|
'height': image.height,
|
|
'width': image.width,
|
|
'third_party_label_id': labels_mapping[shape['label_id']],
|
|
}
|
|
|
|
|
|
@job
|
|
def upload_annotation_to_training_project_job(job_id: int):
|
|
cvat_job = Job.objects.get(pk=job_id)
|
|
cvat_project = cvat_job.segment.task.project
|
|
training_project = cvat_project.training_project
|
|
start = cvat_job.segment.start_frame
|
|
stop = cvat_job.segment.stop_frame
|
|
data = dm.task.get_job_data(job_id)
|
|
shapes: List[OrderedDict] = data.get('shapes', [])
|
|
frames_data = []
|
|
api = TrainingServerAPI(
|
|
host=cvat_project.training_project.host,
|
|
username=cvat_project.training_project.username,
|
|
password=cvat_project.training_project.password,
|
|
)
|
|
cvat_labels = Label.objects.filter(project=cvat_project).all()
|
|
labels_mapping = {
|
|
cvat_label.id: TrainingProjectLabel.objects.get(cvat_label=cvat_label).training_label_id
|
|
for cvat_label in cvat_labels
|
|
}
|
|
|
|
for frame in range(start, stop + 1):
|
|
frame_shapes = list(
|
|
map(
|
|
lambda x: __add_fields_to_shape(x, frame, cvat_job.segment.task.data, labels_mapping),
|
|
filter(
|
|
lambda x: x['frame'] == frame and x['type'] == ShapeType.RECTANGLE,
|
|
shapes,
|
|
)
|
|
)
|
|
)
|
|
|
|
if frame_shapes:
|
|
training_project_image = TrainingProjectImage.objects.get(task=cvat_job.segment.task, idx=frame)
|
|
frames_data.append({
|
|
'third_party_id': training_project_image.training_image_id,
|
|
'shapes': frame_shapes
|
|
})
|
|
|
|
api.upload_annotations(project_id=training_project.training_id, frames_data=frames_data)
|
|
|
|
|
|
@job
|
|
def create_training_project_job(project_id: int):
|
|
cvat_project = Project.objects.get(pk=project_id)
|
|
training_project = cvat_project.training_project
|
|
api = TrainingServerAPI(
|
|
host=cvat_project.training_project.host,
|
|
username=cvat_project.training_project.username,
|
|
password=cvat_project.training_project.password,
|
|
)
|
|
create_training_project(cvat_project=cvat_project, training_project=training_project, api=api)
|
|
|
|
|
|
def create_training_project(cvat_project, training_project, api):
|
|
labels = cvat_project.label_set.all()
|
|
training_project_resp = api.create_project(
|
|
name=f'{cvat_project.name}_cvat',
|
|
project_class=training_project.project_class,
|
|
labels=[{'name': label.name} for label in labels]
|
|
)
|
|
if training_project_resp.get('id'):
|
|
training_project.training_id = training_project_resp['id']
|
|
training_project.save()
|
|
|
|
for cvat_label in labels:
|
|
training_label = list(filter(lambda x: x['name'] == cvat_label.name, training_project_resp.get('labels', [])))
|
|
if training_label:
|
|
TrainingProjectLabel.objects.create(cvat_label=cvat_label, training_label_id=training_label[0]['id'])
|
|
|
|
|
|
async def upload_images(cvat_project_id, training_id, api):
|
|
project = Project.objects.get(pk=cvat_project_id)
|
|
tasks: List[Task] = project.tasks.all()
|
|
for task in tasks:
|
|
frame_provider = FrameProvider(task)
|
|
frames = frame_provider.get_frames()
|
|
for i, (buffer, _) in enumerate(frames):
|
|
training_image_id = api.upload_image(training_id=training_id, buffer=buffer)
|
|
if training_image_id:
|
|
TrainingProjectImage.objects.create(project=project, task=task, idx=i,
|
|
training_image_id=training_image_id)
|
|
|