Integration with an internal training server (#2785)
Co-authored-by: Boris Sekachev <boris.sekachev@intel.com> Co-authored-by: Nikita Manovich <nikita.manovich@intel.com>main
parent
babf1a3f54
commit
d2a1d12fba
@ -0,0 +1,56 @@
|
||||
<?xml version="1.0" encoding="iso-8859-1"?>
|
||||
<!-- The icon received from: https://www.svgrepo.com/svg/25187/brain -->
|
||||
<!-- License: CC0 Creative Commons License -->
|
||||
<svg version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" x="0px" y="0px"
|
||||
viewBox="0 0 463 463" width="40px" height="40px" style="enable-background:new 0 0 463 463;" xml:space="preserve">
|
||||
<g>
|
||||
<path d="M151.245,222.446C148.054,237.039,135.036,248,119.5,248c-4.142,0-7.5,3.357-7.5,7.5s3.358,7.5,7.5,7.5
|
||||
c23.774,0,43.522-17.557,46.966-40.386c14.556-1.574,27.993-8.06,38.395-18.677c2.899-2.959,2.85-7.708-0.109-10.606
|
||||
c-2.958-2.897-7.707-2.851-10.606,0.108C184.947,202.829,172.643,208,159.5,208c-26.743,0-48.5-21.757-48.5-48.5
|
||||
c0-4.143-3.358-7.5-7.5-7.5s-7.5,3.357-7.5,7.5C96,191.715,120.119,218.384,151.245,222.446z"/>
|
||||
<path d="M183,287.5c0-4.143-3.358-7.5-7.5-7.5c-35.014,0-63.5,28.486-63.5,63.5c0,0.362,0.013,0.725,0.019,1.088
|
||||
C109.23,344.212,106.39,344,103.5,344c-4.142,0-7.5,3.357-7.5,7.5s3.358,7.5,7.5,7.5c26.743,0,48.5,21.757,48.5,48.5
|
||||
c0,4.143,3.358,7.5,7.5,7.5s7.5-3.357,7.5-7.5c0-26.611-16.462-49.437-39.731-58.867c-0.178-1.699-0.269-3.418-0.269-5.133
|
||||
c0-26.743,21.757-48.5,48.5-48.5C179.642,295,183,291.643,183,287.5z"/>
|
||||
<path d="M439,223.5c0-17.075-6.82-33.256-18.875-45.156c1.909-6.108,2.875-12.426,2.875-18.844
|
||||
c0-30.874-22.152-56.659-51.394-62.329C373.841,91.6,375,85.628,375,79.5c0-19.557-11.883-36.387-28.806-43.661
|
||||
C317.999,13.383,287.162,0,263.5,0c-13.153,0-24.817,6.468-32,16.384C224.317,6.468,212.653,0,199.5,0
|
||||
c-23.662,0-54.499,13.383-82.694,35.839C99.883,43.113,88,59.943,88,79.5c0,6.128,1.159,12.1,3.394,17.671
|
||||
C62.152,102.841,40,128.626,40,159.5c0,6.418,0.965,12.735,2.875,18.844C30.82,190.244,24,206.425,24,223.5
|
||||
c0,13.348,4.149,25.741,11.213,35.975C27.872,270.087,24,282.466,24,295.5c0,23.088,12.587,44.242,32.516,55.396
|
||||
C56.173,353.748,56,356.626,56,359.5c0,31.144,20.315,58.679,49.79,68.063C118.611,449.505,141.965,463,167.5,463
|
||||
c27.995,0,52.269-16.181,64-39.674c11.731,23.493,36.005,39.674,64,39.674c25.535,0,48.889-13.495,61.71-35.437
|
||||
c29.475-9.385,49.79-36.92,49.79-68.063c0-2.874-0.173-5.752-0.516-8.604C426.413,339.742,439,318.588,439,295.5
|
||||
c0-13.034-3.872-25.413-11.213-36.025C434.851,249.241,439,236.848,439,223.5z M167.5,448c-21.029,0-40.191-11.594-50.009-30.256
|
||||
c-0.973-1.849-2.671-3.208-4.688-3.751C88.19,407.369,71,384.961,71,359.5c0-3.81,0.384-7.626,1.141-11.344
|
||||
c0.702-3.447-1.087-6.92-4.302-8.35C50.32,332.018,39,314.626,39,295.5c0-8.699,2.256-17.014,6.561-24.379
|
||||
C56.757,280.992,71.436,287,87.5,287c4.142,0,7.5-3.357,7.5-7.5s-3.358-7.5-7.5-7.5C60.757,272,39,250.243,39,223.5
|
||||
c0-14.396,6.352-27.964,17.428-37.221c2.5-2.09,3.365-5.555,2.14-8.574C56.2,171.869,55,165.744,55,159.5
|
||||
c0-26.743,21.757-48.5,48.5-48.5s48.5,21.757,48.5,48.5c0,4.143,3.358,7.5,7.5,7.5s7.5-3.357,7.5-7.5
|
||||
c0-33.642-26.302-61.243-59.421-63.355C104.577,91.127,103,85.421,103,79.5c0-13.369,8.116-24.875,19.678-29.859
|
||||
c0.447-0.133,0.885-0.307,1.308-0.527C127.568,47.752,131.447,47,135.5,47c12.557,0,23.767,7.021,29.256,18.325
|
||||
c1.81,3.727,6.298,5.281,10.023,3.47c3.726-1.809,5.28-6.296,3.47-10.022c-6.266-12.903-18.125-22.177-31.782-25.462
|
||||
C165.609,21.631,184.454,15,199.5,15c13.509,0,24.5,10.99,24.5,24.5v97.051c-6.739-5.346-15.25-8.551-24.5-8.551
|
||||
c-4.142,0-7.5,3.357-7.5,7.5s3.358,7.5,7.5,7.5c13.509,0,24.5,10.99,24.5,24.5v180.279c-9.325-12.031-22.471-21.111-37.935-25.266
|
||||
c-3.999-1.071-8.114,1.297-9.189,5.297c-1.075,4.001,1.297,8.115,5.297,9.189C206.8,343.616,224,366.027,224,391.5
|
||||
C224,422.654,198.654,448,167.5,448z M395.161,339.807c-3.215,1.43-5.004,4.902-4.302,8.35c0.757,3.718,1.141,7.534,1.141,11.344
|
||||
c0,25.461-17.19,47.869-41.803,54.493c-2.017,0.543-3.716,1.902-4.688,3.751C335.691,436.406,316.529,448,295.5,448
|
||||
c-31.154,0-56.5-25.346-56.5-56.5c0-2.109-0.098-4.2-0.281-6.271c0.178-0.641,0.281-1.314,0.281-2.012V135.5
|
||||
c0-13.51,10.991-24.5,24.5-24.5c4.142,0,7.5-3.357,7.5-7.5s-3.358-7.5-7.5-7.5c-9.25,0-17.761,3.205-24.5,8.551V39.5
|
||||
c0-13.51,10.991-24.5,24.5-24.5c15.046,0,33.891,6.631,53.033,18.311c-13.657,3.284-25.516,12.559-31.782,25.462
|
||||
c-1.81,3.727-0.256,8.214,3.47,10.022c3.726,1.81,8.213,0.257,10.023-3.47C303.733,54.021,314.943,47,327.5,47
|
||||
c4.053,0,7.933,0.752,11.514,2.114c0.422,0.22,0.86,0.393,1.305,0.526C351.883,54.624,360,66.13,360,79.5
|
||||
c0,5.921-1.577,11.627-4.579,16.645C322.302,98.257,296,125.858,296,159.5c0,4.143,3.358,7.5,7.5,7.5s7.5-3.357,7.5-7.5
|
||||
c0-26.743,21.757-48.5,48.5-48.5s48.5,21.757,48.5,48.5c0,6.244-1.2,12.369-3.567,18.205c-1.225,3.02-0.36,6.484,2.14,8.574
|
||||
C417.648,195.536,424,209.104,424,223.5c0,26.743-21.757,48.5-48.5,48.5c-4.142,0-7.5,3.357-7.5,7.5s3.358,7.5,7.5,7.5
|
||||
c16.064,0,30.743-6.008,41.939-15.879c4.306,7.365,6.561,15.68,6.561,24.379C424,314.626,412.68,332.018,395.161,339.807z"/>
|
||||
<path d="M359.5,240c-15.536,0-28.554-10.961-31.745-25.554C358.881,210.384,383,183.715,383,151.5c0-4.143-3.358-7.5-7.5-7.5
|
||||
s-7.5,3.357-7.5,7.5c0,26.743-21.757,48.5-48.5,48.5c-13.143,0-25.447-5.171-34.646-14.561c-2.898-2.958-7.647-3.007-10.606-0.108
|
||||
s-3.008,7.647-0.109,10.606c10.402,10.617,23.839,17.103,38.395,18.677C315.978,237.443,335.726,255,359.5,255
|
||||
c4.142,0,7.5-3.357,7.5-7.5S363.642,240,359.5,240z"/>
|
||||
<path d="M335.5,328c-2.89,0-5.73,0.212-8.519,0.588c0.006-0.363,0.019-0.726,0.019-1.088c0-35.014-28.486-63.5-63.5-63.5
|
||||
c-4.142,0-7.5,3.357-7.5,7.5s3.358,7.5,7.5,7.5c26.743,0,48.5,21.757,48.5,48.5c0,1.714-0.091,3.434-0.269,5.133
|
||||
C288.462,342.063,272,364.889,272,391.5c0,4.143,3.358,7.5,7.5,7.5s7.5-3.357,7.5-7.5c0-26.743,21.757-48.5,48.5-48.5
|
||||
c4.142,0,7.5-3.357,7.5-7.5S339.642,328,335.5,328z"/>
|
||||
</g>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 5.5 KiB |
@ -1,21 +1,56 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// Copyright (C) 2020-2021 Intel Corporation
|
||||
//
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
import './styles.scss';
|
||||
import React from 'react';
|
||||
import React, { useState } from 'react';
|
||||
import { Row, Col } from 'antd/lib/grid';
|
||||
import Text from 'antd/lib/typography/Text';
|
||||
|
||||
import { connect } from 'react-redux';
|
||||
import CreateProjectContent from './create-project-content';
|
||||
import { CombinedState } from '../../reducers/interfaces';
|
||||
import CreateProjectContext, { ICreateProjectContext } from './create-project.context';
|
||||
|
||||
export default function CreateProjectPageComponent(): JSX.Element {
|
||||
function CreateProjectPageComponent(props: StateToProps): JSX.Element {
|
||||
const { isTrainingActive } = props;
|
||||
const [projectClass, setProjectClass] = useState('');
|
||||
const [trainingEnabled, setTrainingEnabled] = useState(false);
|
||||
const [isTrainingActiveState] = useState(isTrainingActive);
|
||||
|
||||
const defaultContext: ICreateProjectContext = {
|
||||
projectClass: {
|
||||
value: projectClass,
|
||||
set: setProjectClass,
|
||||
},
|
||||
trainingEnabled: {
|
||||
value: trainingEnabled,
|
||||
set: setTrainingEnabled,
|
||||
},
|
||||
isTrainingActive: {
|
||||
value: isTrainingActiveState,
|
||||
},
|
||||
};
|
||||
return (
|
||||
<Row justify='center' align='top' className='cvat-create-task-form-wrapper'>
|
||||
<Col md={20} lg={16} xl={14} xxl={9}>
|
||||
<Text className='cvat-title'>Create a new project</Text>
|
||||
<CreateProjectContent />
|
||||
</Col>
|
||||
</Row>
|
||||
<CreateProjectContext.Provider value={defaultContext}>
|
||||
<Row justify='center' align='top' className='cvat-create-task-form-wrapper'>
|
||||
<Col md={20} lg={16} xl={14} xxl={9}>
|
||||
<Text className='cvat-title'>Create a new project</Text>
|
||||
<CreateProjectContent />
|
||||
</Col>
|
||||
</Row>
|
||||
</CreateProjectContext.Provider>
|
||||
);
|
||||
}
|
||||
|
||||
interface StateToProps {
|
||||
isTrainingActive: boolean;
|
||||
}
|
||||
|
||||
function mapStateToProps(state: CombinedState): StateToProps {
|
||||
return {
|
||||
isTrainingActive: state.plugins.list.PREDICT,
|
||||
};
|
||||
}
|
||||
|
||||
export default connect(mapStateToProps)(CreateProjectPageComponent);
|
||||
|
||||
@ -0,0 +1,31 @@
|
||||
// Copyright (C) 2020-2021 Intel Corporation
|
||||
//
|
||||
// SPDX-License-Identifier: MIT
|
||||
import { createContext, Dispatch, SetStateAction } from 'react';
|
||||
|
||||
export interface IState<T> {
|
||||
value: T;
|
||||
set?: Dispatch<SetStateAction<T>>;
|
||||
}
|
||||
|
||||
export function getDefaultState<T>(v: T): IState<T> {
|
||||
return {
|
||||
value: v,
|
||||
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
set: (value: SetStateAction<T>): void => {},
|
||||
};
|
||||
}
|
||||
|
||||
export interface ICreateProjectContext {
|
||||
projectClass: IState<string>;
|
||||
trainingEnabled: IState<boolean>;
|
||||
isTrainingActive: IState<boolean>;
|
||||
}
|
||||
|
||||
export const defaultState: ICreateProjectContext = {
|
||||
projectClass: getDefaultState<string>(''),
|
||||
trainingEnabled: getDefaultState<boolean>(false),
|
||||
isTrainingActive: getDefaultState<boolean>(false),
|
||||
};
|
||||
|
||||
export default createContext<ICreateProjectContext>(defaultState);
|
||||
@ -0,0 +1,48 @@
|
||||
# Generated by Django 3.1.7 on 2021-04-02 13:17
|
||||
|
||||
from django.db import migrations, models
|
||||
import django.db.models.deletion
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('engine', '0038_manifest'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name='TrainingProject',
|
||||
fields=[
|
||||
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
|
||||
('host', models.CharField(max_length=256)),
|
||||
('username', models.CharField(max_length=256)),
|
||||
('password', models.CharField(max_length=256)),
|
||||
('training_id', models.CharField(max_length=64)),
|
||||
('enabled', models.BooleanField(null=True)),
|
||||
('project_class', models.CharField(blank=True, choices=[('OD', 'Object Detection')], max_length=2, null=True)),
|
||||
],
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name='TrainingProjectLabel',
|
||||
fields=[
|
||||
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
|
||||
('training_label_id', models.CharField(max_length=64)),
|
||||
('cvat_label', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='training_project_label', to='engine.label')),
|
||||
],
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name='TrainingProjectImage',
|
||||
fields=[
|
||||
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
|
||||
('idx', models.PositiveIntegerField()),
|
||||
('training_image_id', models.CharField(max_length=64)),
|
||||
('task', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='engine.task')),
|
||||
],
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name='project',
|
||||
name='training_project',
|
||||
field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, to='engine.trainingproject'),
|
||||
),
|
||||
]
|
||||
@ -0,0 +1 @@
|
||||
default_app_config = 'cvat.apps.training.apps.TrainingConfig'
|
||||
@ -0,0 +1,362 @@
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import OrderedDict
|
||||
from functools import wraps
|
||||
from typing import Callable, List, Union
|
||||
|
||||
import requests
|
||||
|
||||
from cacheops import cache, CacheMiss
|
||||
|
||||
from cvat.apps.engine.models import TrainingProject, ShapeType
|
||||
|
||||
|
||||
class TrainingServerAPIAbs(ABC):
|
||||
|
||||
def __init__(self, host, username, password):
|
||||
self.host = host
|
||||
self.username = username
|
||||
self.password = password
|
||||
|
||||
@abstractmethod
|
||||
def get_server_status(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def create_project(self, name: str, description: str = '', project_class: TrainingProject.ProjectClass = None,
|
||||
labels: List[dict] = None):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def upload_annotations(self, project_id: str, frames_data: List[dict]):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_project_status(self, project_id: str) -> dict:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_annotation(self, project_id: str, image_id: str, width: int, height: int, frame: int,
|
||||
labels_mapping: dict) -> dict:
|
||||
pass
|
||||
|
||||
|
||||
def retry(amount: int = 2) -> Callable:
|
||||
def dec(func: Callable) -> Callable:
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
__amount = amount
|
||||
while __amount > 0:
|
||||
__amount -= 1
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
return result
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return wrapper
|
||||
|
||||
return dec
|
||||
|
||||
|
||||
class TrainingServerAPI(TrainingServerAPIAbs):
|
||||
TRAINING_CLASS = {
|
||||
TrainingProject.ProjectClass.DETECTION: "DETECTION"
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def __convert_annotation_from_cvat(shapes):
|
||||
data = []
|
||||
for shape in shapes:
|
||||
x0, y0, x1, y1 = shape['points']
|
||||
x = x0 / shape['width']
|
||||
y = y0 / shape['height']
|
||||
width = (x1 - x0) / shape['width']
|
||||
height = (y1 - y0) / shape['height']
|
||||
data.append({
|
||||
"id": str(uuid.uuid4()),
|
||||
"shapes": [
|
||||
{
|
||||
"type": "rect",
|
||||
"geometry": {
|
||||
"x": x,
|
||||
"y": y,
|
||||
"width": width,
|
||||
"height": height,
|
||||
"points": None,
|
||||
}
|
||||
}
|
||||
],
|
||||
"editor": None,
|
||||
"labels": [
|
||||
{
|
||||
"id": shape['third_party_label_id'],
|
||||
"probability": 1.0,
|
||||
},
|
||||
],
|
||||
})
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
def __convert_annotation_to_cvat(annotation: dict, image_width: int, image_height: int, frame: int,
|
||||
labels_mapping: dict) -> List[OrderedDict]:
|
||||
shapes = []
|
||||
for i, annotation in enumerate(annotation.get('data', [])):
|
||||
label_id = annotation['labels'][0]['id']
|
||||
if not labels_mapping.get(label_id):
|
||||
continue
|
||||
shape = annotation['shapes'][0]
|
||||
if shape['type'] != 'rect':
|
||||
continue
|
||||
x = shape['geometry']['x']
|
||||
y = shape['geometry']['y']
|
||||
w = shape['geometry']['width']
|
||||
h = shape['geometry']['height']
|
||||
x0 = x * image_width
|
||||
y0 = y * image_height
|
||||
x1 = image_width * w + x0
|
||||
y1 = image_height * h + y0
|
||||
shapes.append(OrderedDict([
|
||||
('type', ShapeType.RECTANGLE),
|
||||
('occluded', False),
|
||||
('z_order', 0),
|
||||
('points', [x0, y0, x1, y1]),
|
||||
('id', i),
|
||||
('frame', int(frame)),
|
||||
('label', labels_mapping.get(label_id)),
|
||||
('group', 0),
|
||||
('source', 'auto'),
|
||||
('attributes', {})
|
||||
]))
|
||||
return shapes
|
||||
|
||||
@retry()
|
||||
def __create_project(self, name: str, description: str = None,
|
||||
labels: List[dict] = None, tasks: List[dict] = None) -> dict:
|
||||
url = f'{self.host}/v2/projects'
|
||||
headers = {
|
||||
'Context-Type': 'application/json',
|
||||
'Authorization': f'bearer_token {self.token}',
|
||||
}
|
||||
tasks[1]['properties'] = [
|
||||
{
|
||||
"id": "labels",
|
||||
"user_value": labels
|
||||
}
|
||||
]
|
||||
data = {
|
||||
'name': name,
|
||||
'description': description,
|
||||
"dimensions": [],
|
||||
"group_type": "normal",
|
||||
'pipeline': {
|
||||
'connections': [{
|
||||
'from': {
|
||||
**tasks[0]['output_ports'][0],
|
||||
'task_id': tasks[0]['temp_id'],
|
||||
},
|
||||
'to': {
|
||||
**tasks[1]['input_ports'][0],
|
||||
'task_id': tasks[1]['temp_id'],
|
||||
}
|
||||
}],
|
||||
'tasks': tasks,
|
||||
},
|
||||
"pipeline_representation": 'Detection',
|
||||
"type": "project",
|
||||
}
|
||||
response = self.request(method='POST', url=url, json=data, headers=headers)
|
||||
return response
|
||||
|
||||
@retry()
|
||||
def __get_annotation(self, project_id: str, image_id: str) -> dict:
|
||||
url = f'{self.host}/v2/projects/{project_id}/media/images/{image_id}/results/online'
|
||||
headers = {
|
||||
'Authorization': f'bearer_token {self.token}',
|
||||
}
|
||||
response = self.request(method='GET', url=url, headers=headers)
|
||||
return response
|
||||
|
||||
@retry()
|
||||
def __get_job_status(self, project_id: str) -> dict:
|
||||
url = f'{self.host}/v2/projects/{project_id}/jobs'
|
||||
headers = {
|
||||
'Authorization': f'bearer_token {self.token}',
|
||||
}
|
||||
response = self.request(method='GET', url=url, headers=headers)
|
||||
return response
|
||||
|
||||
@retry()
|
||||
def __get_project_summary(self, project_id: str) -> dict:
|
||||
url = f'{self.host}/v2/projects/{project_id}/statistics/summary'
|
||||
headers = {
|
||||
'Authorization': f'bearer_token {self.token}',
|
||||
}
|
||||
response = self.request(method='GET', url=url, headers=headers)
|
||||
return response
|
||||
|
||||
@retry()
|
||||
def __get_project(self, project_id: str) -> dict:
|
||||
url = f'{self.host}/v2/projects/{project_id}'
|
||||
headers = {
|
||||
'Authorization': f'bearer_token {self.token}',
|
||||
}
|
||||
response = self.request(method='GET', url=url, headers=headers)
|
||||
return response
|
||||
|
||||
@retry()
|
||||
def __get_server_status(self) -> dict:
|
||||
url = f'{self.host}/v2/status'
|
||||
headers = {
|
||||
'Authorization': f'bearer_token {self.token}',
|
||||
}
|
||||
response = self.request(method='GET', url=url, headers=headers)
|
||||
return response
|
||||
|
||||
@retry()
|
||||
def __get_tasks(self) -> List[dict]:
|
||||
url = f'{self.host}/v2/tasks'
|
||||
headers = {
|
||||
'Authorization': f'bearer_token {self.token}',
|
||||
}
|
||||
response = self.request(method='GET', url=url, headers=headers)
|
||||
return response
|
||||
|
||||
def __delete_token(self):
|
||||
cache.delete(self.token_key)
|
||||
|
||||
@retry()
|
||||
def __upload_annotation(self, project_id: str, image_id: str, annotation: List[dict]):
|
||||
url = f'{self.host}/v2/projects/{project_id}/media/images/{image_id}/annotations'
|
||||
headers = {
|
||||
'Authorization': f'bearer_token {self.token}',
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
data = {
|
||||
'image_id': image_id,
|
||||
'data': annotation
|
||||
}
|
||||
response = self.request(method='POST', url=url, headers=headers, json=data)
|
||||
return response
|
||||
|
||||
@retry()
|
||||
def __upload_image(self, project_id: str, buffer) -> dict:
|
||||
url = f'{self.host}/v2/projects/{project_id}/media/images'
|
||||
files = {'file': buffer}
|
||||
headers = {
|
||||
'Authorization': f'bearer_token {self.token}',
|
||||
}
|
||||
response = self.request(method='POST', url=url, headers=headers, files=files)
|
||||
return response
|
||||
|
||||
@property
|
||||
def project_id_key(self):
|
||||
return f'{self.host}_{self.username}_project_id'
|
||||
|
||||
@property
|
||||
def token(self) -> str:
|
||||
def get_token(host: str, username: str, password: str) -> dict:
|
||||
url = f'{host}/v2/authentication'
|
||||
data = {
|
||||
'username': (None, username),
|
||||
'password': (None, password),
|
||||
}
|
||||
r = requests.post(url=url, files=data, verify=False) # nosec
|
||||
return r.json()
|
||||
|
||||
try:
|
||||
token = cache.get(self.token_key)
|
||||
except CacheMiss:
|
||||
response = get_token(self.host, self.username, self.password)
|
||||
token = response.get('secure_token', '')
|
||||
expires_in = response.get('expires_in', 3600)
|
||||
cache.set(cache_key=self.token_key, data=token, timeout=expires_in)
|
||||
return token
|
||||
|
||||
@property
|
||||
def token_key(self):
|
||||
return f'{self.host}_{self.username}_token'
|
||||
|
||||
def request(self, method: str, url: str, **kwargs) -> Union[list, dict, str]:
|
||||
response = requests.request(method=method, url=url, verify=False, **kwargs)
|
||||
if response.status_code == 401:
|
||||
self.__delete_token()
|
||||
raise Exception("401")
|
||||
result = response.json()
|
||||
return result
|
||||
|
||||
def create_project(self, name: str, description: str = '', project_class: TrainingProject.ProjectClass = None,
|
||||
labels: List[dict] = None) -> dict:
|
||||
all_tasks = self.__get_tasks()
|
||||
task_type = self.TRAINING_CLASS.get(project_class)
|
||||
task_algo = 'Retinanet - TF2'
|
||||
tasks = [
|
||||
next(({'temp_id': '_1_', **task}
|
||||
for task in all_tasks
|
||||
if task['task_type'] == 'DATASET'), {}),
|
||||
next(({'temp_id': '_2_', **task}
|
||||
for task in all_tasks
|
||||
if task['task_type'] == task_type and
|
||||
task['algorithm_name'] == task_algo), {}),
|
||||
]
|
||||
labels = [{
|
||||
'name': label['name'],
|
||||
'temp_id': label['name']
|
||||
} for label in labels]
|
||||
r = self.__create_project(name=name, description=description, tasks=tasks, labels=labels)
|
||||
return r
|
||||
|
||||
def get_server_status(self) -> dict:
|
||||
return self.__get_server_status()
|
||||
|
||||
def upload_annotations(self, project_id: str, frames_data: List[dict]):
|
||||
for frame in frames_data:
|
||||
annotation = self.__convert_annotation_from_cvat(frame['shapes'])
|
||||
self.__upload_annotation(project_id=project_id, image_id=frame['third_party_id'], annotation=annotation)
|
||||
|
||||
def upload_image(self, training_id: str, buffer):
|
||||
response = self.__upload_image(project_id=training_id, buffer=buffer)
|
||||
return response.get('id')
|
||||
|
||||
def get_project_status(self, project_id) -> dict:
|
||||
summary = self.__get_project_summary(project_id=project_id)
|
||||
if not summary or not isinstance(summary, list):
|
||||
return {'message': 'Not available'}
|
||||
jobs = self.__get_job_status(project_id=project_id)
|
||||
media_amount = next(item.get('value', 0) for item in summary if item.get('key') == 'Media')
|
||||
annotation_amount = next(item.get('value', 0) for item in summary if item.get('key') == 'Annotation')
|
||||
score = next(item.get('value', 0) for item in summary if item.get('key') == 'Score')
|
||||
job_items = jobs.get('items', 0)
|
||||
if len(job_items) == 0 and score == 0:
|
||||
message = 'Not started'
|
||||
elif len(job_items) == 0 and score > 0:
|
||||
message = ''
|
||||
else:
|
||||
message = 'In progress'
|
||||
progress = 0 if len(job_items) == 0 else job_items[0]["status"]["progress"]
|
||||
time_remaining = 0 if len(job_items) == 0 else job_items[0]["status"]['time_remaining']
|
||||
result = {
|
||||
'media_amount': media_amount if media_amount else 0,
|
||||
'annotation_amount': annotation_amount,
|
||||
'score': score,
|
||||
'message': message,
|
||||
'progress': progress,
|
||||
'time_remaining': time_remaining,
|
||||
}
|
||||
return result
|
||||
|
||||
def get_annotation(self, project_id: str, image_id: str, width: int, height: int, frame: int,
|
||||
labels_mapping: dict) -> List[OrderedDict]:
|
||||
annotation = self.__get_annotation(project_id=project_id, image_id=image_id)
|
||||
cvat_annotation = self.__convert_annotation_to_cvat(annotation=annotation, image_width=width,
|
||||
image_height=height, frame=frame,
|
||||
labels_mapping=labels_mapping)
|
||||
return cvat_annotation
|
||||
|
||||
def get_labels(self, project_id: str) -> List[dict]:
|
||||
project = self.__get_project(project_id=project_id)
|
||||
labels = [{
|
||||
'id': label['id'],
|
||||
'name': label['name']
|
||||
} for label in project.get('labels')]
|
||||
return labels
|
||||
@ -0,0 +1,11 @@
|
||||
from django.apps import AppConfig
|
||||
|
||||
|
||||
class TrainingConfig(AppConfig):
|
||||
name = 'cvat.apps.training'
|
||||
|
||||
def ready(self):
|
||||
# Required to define signals in application
|
||||
import cvat.apps.training.signals
|
||||
# Required in order to silent "unused-import" in pyflake
|
||||
assert cvat.apps.training.signals
|
||||
@ -0,0 +1,186 @@
|
||||
from collections import OrderedDict
|
||||
from typing import List
|
||||
|
||||
from cacheops import cache
|
||||
from django_rq import job
|
||||
|
||||
from cvat.apps import dataset_manager as dm
|
||||
from cvat.apps.engine.frame_provider import FrameProvider
|
||||
from cvat.apps.engine.models import (
|
||||
Project,
|
||||
Task,
|
||||
TrainingProjectImage,
|
||||
Label,
|
||||
Image,
|
||||
TrainingProjectLabel,
|
||||
Data,
|
||||
Job,
|
||||
ShapeType,
|
||||
)
|
||||
from cvat.apps.training.apis import TrainingServerAPI
|
||||
|
||||
|
||||
@job
|
||||
def save_prediction_server_status_to_cache_job(cache_key,
|
||||
cvat_project_id,
|
||||
timeout=60):
|
||||
cvat_project = Project.objects.get(pk=cvat_project_id)
|
||||
api = TrainingServerAPI(host=cvat_project.training_project.host, username=cvat_project.training_project.username,
|
||||
password=cvat_project.training_project.password)
|
||||
status = api.get_project_status(project_id=cvat_project.training_project.training_id)
|
||||
|
||||
resp = {
|
||||
**status,
|
||||
'status': 'done'
|
||||
}
|
||||
cache.set(cache_key=cache_key, data=resp, timeout=timeout)
|
||||
|
||||
|
||||
@job
|
||||
def save_frame_prediction_to_cache_job(cache_key: str,
|
||||
task_id: int,
|
||||
frame: int,
|
||||
timeout: int = 60):
|
||||
task = Task.objects.get(pk=task_id)
|
||||
training_project_image = TrainingProjectImage.objects.filter(idx=frame, task=task).first()
|
||||
if not training_project_image:
|
||||
cache.set(cache_key=cache_key, data={
|
||||
'annotation': [],
|
||||
'status': 'done'
|
||||
}, timeout=timeout)
|
||||
return
|
||||
|
||||
cvat_labels = Label.objects.filter(project__id=task.project_id).all()
|
||||
training_project = Project.objects.get(pk=task.project_id).training_project
|
||||
api = TrainingServerAPI(host=training_project.host,
|
||||
username=training_project.username,
|
||||
password=training_project.password)
|
||||
image = Image.objects.get(frame=frame, data=task.data)
|
||||
labels_mapping = {
|
||||
TrainingProjectLabel.objects.get(cvat_label=cvat_label).training_label_id: cvat_label.id
|
||||
for cvat_label in cvat_labels
|
||||
}
|
||||
annotation = api.get_annotation(project_id=training_project.training_id,
|
||||
image_id=training_project_image.training_image_id,
|
||||
width=image.width,
|
||||
height=image.height,
|
||||
labels_mapping=labels_mapping,
|
||||
frame=frame)
|
||||
resp = {
|
||||
'annotation': annotation,
|
||||
'status': 'done'
|
||||
}
|
||||
cache.set(cache_key=cache_key, data=resp, timeout=timeout)
|
||||
|
||||
|
||||
@job
|
||||
def upload_images_job(task_id: int):
|
||||
if TrainingProjectImage.objects.filter(task_id=task_id).count() is 0:
|
||||
task = Task.objects.get(pk=task_id)
|
||||
frame_provider = FrameProvider(task.data)
|
||||
frames = frame_provider.get_frames()
|
||||
api = TrainingServerAPI(
|
||||
host=task.project.training_project.host,
|
||||
username=task.project.training_project.username,
|
||||
password=task.project.training_project.password,
|
||||
)
|
||||
|
||||
for i, (buffer, _) in enumerate(frames):
|
||||
training_image_id = api.upload_image(training_id=task.project.training_project.training_id, buffer=buffer)
|
||||
if training_image_id:
|
||||
TrainingProjectImage.objects.create(task=task, idx=i,
|
||||
training_image_id=training_image_id)
|
||||
|
||||
def __add_fields_to_shape(shape: dict, frame: int, data: Data, labels_mapping: dict) -> dict:
|
||||
image = Image.objects.get(frame=frame, data=data)
|
||||
return {
|
||||
**shape,
|
||||
'height': image.height,
|
||||
'width': image.width,
|
||||
'third_party_label_id': labels_mapping[shape['label_id']],
|
||||
}
|
||||
|
||||
|
||||
@job
|
||||
def upload_annotation_to_training_project_job(job_id: int):
|
||||
cvat_job = Job.objects.get(pk=job_id)
|
||||
cvat_project = cvat_job.segment.task.project
|
||||
training_project = cvat_project.training_project
|
||||
start = cvat_job.segment.start_frame
|
||||
stop = cvat_job.segment.stop_frame
|
||||
data = dm.task.get_job_data(job_id)
|
||||
shapes: List[OrderedDict] = data.get('shapes', [])
|
||||
frames_data = []
|
||||
api = TrainingServerAPI(
|
||||
host=cvat_project.training_project.host,
|
||||
username=cvat_project.training_project.username,
|
||||
password=cvat_project.training_project.password,
|
||||
)
|
||||
cvat_labels = Label.objects.filter(project=cvat_project).all()
|
||||
labels_mapping = {
|
||||
cvat_label.id: TrainingProjectLabel.objects.get(cvat_label=cvat_label).training_label_id
|
||||
for cvat_label in cvat_labels
|
||||
}
|
||||
|
||||
for frame in range(start, stop + 1):
|
||||
frame_shapes = list(
|
||||
map(
|
||||
lambda x: __add_fields_to_shape(x, frame, cvat_job.segment.task.data, labels_mapping),
|
||||
filter(
|
||||
lambda x: x['frame'] == frame and x['type'] == ShapeType.RECTANGLE,
|
||||
shapes,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if frame_shapes:
|
||||
training_project_image = TrainingProjectImage.objects.get(task=cvat_job.segment.task, idx=frame)
|
||||
frames_data.append({
|
||||
'third_party_id': training_project_image.training_image_id,
|
||||
'shapes': frame_shapes
|
||||
})
|
||||
|
||||
api.upload_annotations(project_id=training_project.training_id, frames_data=frames_data)
|
||||
|
||||
|
||||
@job
|
||||
def create_training_project_job(project_id: int):
|
||||
cvat_project = Project.objects.get(pk=project_id)
|
||||
training_project = cvat_project.training_project
|
||||
api = TrainingServerAPI(
|
||||
host=cvat_project.training_project.host,
|
||||
username=cvat_project.training_project.username,
|
||||
password=cvat_project.training_project.password,
|
||||
)
|
||||
create_training_project(cvat_project=cvat_project, training_project=training_project, api=api)
|
||||
|
||||
|
||||
def create_training_project(cvat_project, training_project, api):
|
||||
labels = cvat_project.label_set.all()
|
||||
training_project_resp = api.create_project(
|
||||
name=f'{cvat_project.name}_cvat',
|
||||
project_class=training_project.project_class,
|
||||
labels=[{'name': label.name} for label in labels]
|
||||
)
|
||||
if training_project_resp.get('id'):
|
||||
training_project.training_id = training_project_resp['id']
|
||||
training_project.save()
|
||||
|
||||
for cvat_label in labels:
|
||||
training_label = list(filter(lambda x: x['name'] == cvat_label.name, training_project_resp.get('labels', [])))
|
||||
if training_label:
|
||||
TrainingProjectLabel.objects.create(cvat_label=cvat_label, training_label_id=training_label[0]['id'])
|
||||
|
||||
|
||||
async def upload_images(cvat_project_id, training_id, api):
|
||||
project = Project.objects.get(pk=cvat_project_id)
|
||||
tasks: List[Task] = project.tasks.all()
|
||||
for task in tasks:
|
||||
frame_provider = FrameProvider(task)
|
||||
frames = frame_provider.get_frames()
|
||||
for i, (buffer, _) in enumerate(frames):
|
||||
training_image_id = api.upload_image(training_id=training_id, buffer=buffer)
|
||||
if training_image_id:
|
||||
TrainingProjectImage.objects.create(project=project, task=task, idx=i,
|
||||
training_image_id=training_image_id)
|
||||
|
||||
@ -0,0 +1,30 @@
|
||||
from django.db.models.signals import post_save
|
||||
from django.dispatch import receiver
|
||||
|
||||
from cvat.apps.engine.models import Job, StatusChoice, Project, Task
|
||||
from cvat.apps.training.jobs import (
|
||||
create_training_project_job,
|
||||
upload_images_job,
|
||||
upload_annotation_to_training_project_job,
|
||||
)
|
||||
|
||||
|
||||
@receiver(post_save, sender=Project, dispatch_uid="create_training_project")
|
||||
def create_training_project(instance: Project, **kwargs):
|
||||
if instance.training_project:
|
||||
create_training_project_job.delay(instance.id)
|
||||
|
||||
|
||||
@receiver(post_save, sender=Task, dispatch_uid='upload_images_to_training_project')
|
||||
def upload_images_to_training_project(instance: Task, **kwargs):
|
||||
if (instance.status == StatusChoice.ANNOTATION and
|
||||
instance.data and instance.data.size != 0 and \
|
||||
instance.project_id and instance.project.training_project):
|
||||
|
||||
upload_images_job.delay(instance.id)
|
||||
|
||||
|
||||
@receiver(post_save, sender=Job, dispatch_uid="upload_annotation_to_training_project")
|
||||
def upload_annotation_to_training_project(instance: Job, **kwargs):
|
||||
if instance.status == StatusChoice.COMPLETED:
|
||||
upload_annotation_to_training_project_job.delay(instance.id)
|
||||
@ -0,0 +1,11 @@
|
||||
from django.urls import path, include
|
||||
from rest_framework import routers
|
||||
|
||||
from cvat.apps.training.views import PredictView
|
||||
|
||||
router = routers.DefaultRouter(trailing_slash=False)
|
||||
router.register('', PredictView, basename='predict')
|
||||
|
||||
urlpatterns = [
|
||||
path('', include((router.urls, 'predict'), namespace='predict'))
|
||||
]
|
||||
@ -0,0 +1,68 @@
|
||||
from cacheops import cache, CacheMiss
|
||||
from drf_yasg.utils import swagger_auto_schema
|
||||
from rest_framework import viewsets, status
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.permissions import IsAuthenticated, SAFE_METHODS
|
||||
from rest_framework.response import Response
|
||||
|
||||
from cvat.apps.authentication import auth
|
||||
from cvat.apps.engine.models import Project
|
||||
from cvat.apps.training.jobs import save_frame_prediction_to_cache_job, save_prediction_server_status_to_cache_job
|
||||
|
||||
|
||||
class PredictView(viewsets.ViewSet):
|
||||
def get_permissions(self):
|
||||
http_method = self.request.method
|
||||
permissions = [IsAuthenticated]
|
||||
|
||||
if http_method in SAFE_METHODS:
|
||||
permissions.append(auth.ProjectAccessPermission)
|
||||
else:
|
||||
permissions.append(auth.AdminRolePermission)
|
||||
|
||||
return [perm() for perm in permissions]
|
||||
|
||||
@swagger_auto_schema(method='get', operation_summary='Returns prediction for image')
|
||||
@action(detail=False, methods=['GET'], url_path='frame')
|
||||
def predict_image(self, request):
|
||||
frame = self.request.query_params.get('frame')
|
||||
task_id = self.request.query_params.get('task')
|
||||
if not task_id:
|
||||
return Response(data='query param "task" empty or not provided', status=status.HTTP_400_BAD_REQUEST)
|
||||
if not frame:
|
||||
return Response(data='query param "frame" empty or not provided', status=status.HTTP_400_BAD_REQUEST)
|
||||
cache_key = f'predict_image_{task_id}_{frame}'
|
||||
try:
|
||||
resp = cache.get(cache_key)
|
||||
except CacheMiss:
|
||||
save_frame_prediction_to_cache_job.delay(cache_key, task_id=task_id,
|
||||
frame=frame)
|
||||
resp = {
|
||||
'status': 'queued',
|
||||
}
|
||||
cache.set(cache_key=cache_key, data=resp, timeout=60)
|
||||
|
||||
return Response(resp)
|
||||
|
||||
@swagger_auto_schema(method='get',
|
||||
operation_summary='Returns information of the tasks of the project with the selected id')
|
||||
@action(detail=False, methods=['GET'], url_path='status')
|
||||
def predict_status(self, request):
|
||||
project_id = self.request.query_params.get('project')
|
||||
if not project_id:
|
||||
return Response(data='query param "project" empty or not provided', status=status.HTTP_400_BAD_REQUEST)
|
||||
project = Project.objects.get(pk=project_id)
|
||||
if not project.training_project:
|
||||
Response({'status': 'done'})
|
||||
|
||||
cache_key = f'predict_status_{project_id}'
|
||||
try:
|
||||
resp = cache.get(cache_key)
|
||||
except CacheMiss:
|
||||
save_prediction_server_status_to_cache_job.delay(cache_key, cvat_project_id=project_id)
|
||||
resp = {
|
||||
'status': 'queued',
|
||||
}
|
||||
cache.set(cache_key=cache_key, data=resp, timeout=60)
|
||||
|
||||
return Response(resp)
|
||||
Loading…
Reference in New Issue