Add tf import checks and tests (#1567)

* Add tf import checks and tests

* implement disabled formats on server

* python 3.5 compatibility

* add checks to dm tests

* fix tests

* Support for disabled formats in UI

* add sorting for formats, mark grey disabled items

* update changelog

* advance package versions
main
zhiltsov-max 6 years ago committed by GitHub
parent ac922e33a5
commit da3fa34ce0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added ### Added
- Throttling policy for unauthenticated users (<https://github.com/opencv/cvat/pull/1531>) - Throttling policy for unauthenticated users (<https://github.com/opencv/cvat/pull/1531>)
- Added default label color table for mask export (https://github.com/opencv/cvat/pull/1549) - Added default label color table for mask export (https://github.com/opencv/cvat/pull/1549)
- Added visual identification for unavailable formats (https://github.com/opencv/cvat/pull/1567)
### Changed ### Changed
- Removed information about e-mail from the basic user information (<https://github.com/opencv/cvat/pull/1627>) - Removed information about e-mail from the basic user information (<https://github.com/opencv/cvat/pull/1627>)
@ -19,7 +20,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- -
### Fixed ### Fixed
- - Fixed interpreter crash when trying to import `tensorflow` with no AVX instructions available (https://github.com/opencv/cvat/pull/1567)
### Security ### Security
- -

@ -1,6 +1,6 @@
{ {
"name": "cvat-core", "name": "cvat-core",
"version": "2.0.1", "version": "2.1.1",
"description": "Part of Computer Vision Tool which presents an interface for client-side integration", "description": "Part of Computer Vision Tool which presents an interface for client-side integration",
"main": "babel.config.js", "main": "babel.config.js",
"scripts": { "scripts": {

@ -15,6 +15,7 @@
name: initialData.name, name: initialData.name,
format: initialData.ext, format: initialData.ext,
version: initialData.version, version: initialData.version,
enabled: initialData.enabled,
}; };
Object.defineProperties(this, { Object.defineProperties(this, {
@ -48,6 +49,16 @@
*/ */
get: () => data.version, get: () => data.version,
}, },
enabled: {
/**
* @name enabled
* @type {string}
* @memberof module:API.cvat.classes.Loader
* @readonly
* @instance
*/
get: () => data.enabled,
},
}); });
} }
} }
@ -63,6 +74,7 @@
name: initialData.name, name: initialData.name,
format: initialData.ext, format: initialData.ext,
version: initialData.version, version: initialData.version,
enabled: initialData.enabled,
}; };
Object.defineProperties(this, { Object.defineProperties(this, {
@ -96,6 +108,16 @@
*/ */
get: () => data.version, get: () => data.version,
}, },
enabled: {
/**
* @name enabled
* @type {string}
* @memberof module:API.cvat.classes.Loader
* @readonly
* @instance
*/
get: () => data.enabled,
},
}); });
} }
} }

@ -1,6 +1,6 @@
{ {
"name": "cvat-ui", "name": "cvat-ui",
"version": "1.2.0", "version": "1.2.1",
"lockfileVersion": 1, "lockfileVersion": 1,
"requires": true, "requires": true,
"dependencies": { "dependencies": {

@ -1,6 +1,6 @@
{ {
"name": "cvat-ui", "name": "cvat-ui",
"version": "1.2.0", "version": "1.2.1",
"description": "CVAT single-page application", "description": "CVAT single-page application",
"main": "src/index.tsx", "main": "src/index.tsx",
"scripts": { "scripts": {

@ -16,8 +16,8 @@ interface Props {
taskMode: string; taskMode: string;
bugTracker: string; bugTracker: string;
loaders: string[]; loaders: any[];
dumpers: string[]; dumpers: any[];
loadActivity: string | null; loadActivity: string | null;
dumpActivities: string[] | null; dumpActivities: string[] | null;
exportActivities: string[] | null; exportActivities: string[] | null;

@ -15,7 +15,7 @@ function isDefaultFormat(dumperName: string, taskMode: string): boolean {
interface Props { interface Props {
taskMode: string; taskMode: string;
menuKey: string; menuKey: string;
dumpers: string[]; dumpers: any[];
dumpActivities: string[] | null; dumpActivities: string[] | null;
} }
@ -30,17 +30,21 @@ export default function DumpSubmenu(props: Props): JSX.Element {
return ( return (
<Menu.SubMenu key={menuKey} title='Dump annotations'> <Menu.SubMenu key={menuKey} title='Dump annotations'>
{ {
dumpers.map((dumper: string): JSX.Element => { dumpers
const pending = (dumpActivities || []).includes(dumper); .sort((a: any, b: any) => a.name.localeCompare(b.name))
const isDefault = isDefaultFormat(dumper, taskMode); .map((dumper: any): JSX.Element =>
{
const pending = (dumpActivities || []).includes(dumper.name);
const disabled = !dumper.enabled || pending;
const isDefault = isDefaultFormat(dumper.name, taskMode);
return ( return (
<Menu.Item <Menu.Item
key={dumper} key={dumper.name}
disabled={pending} disabled={disabled}
className='cvat-menu-dump-submenu-item' className='cvat-menu-dump-submenu-item'
> >
<Icon type='download' /> <Icon type='download' />
<Text strong={isDefault}>{dumper}</Text> <Text strong={isDefault} disabled={disabled}>{dumper.name}</Text>
{pending && <Icon style={{ marginLeft: 10 }} type='loading' />} {pending && <Icon style={{ marginLeft: 10 }} type='loading' />}
</Menu.Item> </Menu.Item>
); );

@ -9,7 +9,7 @@ import Text from 'antd/lib/typography/Text';
interface Props { interface Props {
menuKey: string; menuKey: string;
exporters: string[]; exporters: any[];
exportActivities: string[] | null; exportActivities: string[] | null;
} }
@ -23,16 +23,20 @@ export default function ExportSubmenu(props: Props): JSX.Element {
return ( return (
<Menu.SubMenu key={menuKey} title='Export as a dataset'> <Menu.SubMenu key={menuKey} title='Export as a dataset'>
{ {
exporters.map((exporter: string): JSX.Element => { exporters
const pending = (exportActivities || []).includes(exporter); .sort((a: any, b: any) => a.name.localeCompare(b.name))
.map((exporter: any): JSX.Element =>
{
const pending = (exportActivities || []).includes(exporter.name);
const disabled = !exporter.enabled || pending;
return ( return (
<Menu.Item <Menu.Item
key={exporter} key={exporter.name}
disabled={pending} disabled={disabled}
className='cvat-menu-export-submenu-item' className='cvat-menu-export-submenu-item'
> >
<Icon type='export' /> <Icon type='export' />
<Text>{exporter}</Text> <Text disabled={disabled}>{exporter.name}</Text>
{pending && <Icon style={{ marginLeft: 10 }} type='loading' />} {pending && <Icon style={{ marginLeft: 10 }} type='loading' />}
</Menu.Item> </Menu.Item>
); );

@ -11,7 +11,7 @@ import Text from 'antd/lib/typography/Text';
interface Props { interface Props {
menuKey: string; menuKey: string;
loaders: string[]; loaders: any[];
loadActivity: string | null; loadActivity: string | null;
onFileUpload(file: File): void; onFileUpload(file: File): void;
} }
@ -27,13 +27,20 @@ export default function LoadSubmenu(props: Props): JSX.Element {
return ( return (
<Menu.SubMenu key={menuKey} title='Upload annotations'> <Menu.SubMenu key={menuKey} title='Upload annotations'>
{ {
loaders.map((_loader: string): JSX.Element => { loaders
const [loader, accept] = _loader.split('::'); .sort((a: any, b: any) => a.name.localeCompare(b.name))
const pending = loadActivity === loader; .map((loader: any): JSX.Element =>
{
const accept = loader.format
.split(',')
.map((x: string) => '.' + x.trimStart())
.join(', '); // add '.' to each extension in a list
const pending = loadActivity === loader.name;
const disabled = !loader.enabled || !!loadActivity;
return ( return (
<Menu.Item <Menu.Item
key={loader} key={loader.name}
disabled={!!loadActivity} disabled={disabled}
className='cvat-menu-load-submenu-item' className='cvat-menu-load-submenu-item'
> >
<Upload <Upload
@ -45,9 +52,9 @@ export default function LoadSubmenu(props: Props): JSX.Element {
return false; return false;
}} }}
> >
<Button block type='link' disabled={!!loadActivity}> <Button block type='link' disabled={disabled}>
<Icon type='upload' /> <Icon type='upload' />
<Text>{loader}</Text> <Text>{loader.name}</Text>
{pending && <Icon style={{ marginLeft: 10 }} type='loading' />} {pending && <Icon style={{ marginLeft: 10 }} type='loading' />}
</Button> </Button>
</Upload> </Upload>

@ -13,8 +13,8 @@ import ReIDPlugin from './reid-plugin';
interface Props { interface Props {
taskMode: string; taskMode: string;
loaders: string[]; loaders: any[];
dumpers: string[]; dumpers: any[];
loadActivity: string | null; loadActivity: string | null;
dumpActivities: string[] | null; dumpActivities: string[] | null;
exportActivities: string[] | null; exportActivities: string[] | null;

@ -134,7 +134,7 @@ function ActionsMenuContainer(props: OwnProps & StateToProps & DispatchToProps):
dumpAnnotations(taskInstance, dumper); dumpAnnotations(taskInstance, dumper);
} }
} else if (action === Actions.LOAD_TASK_ANNO) { } else if (action === Actions.LOAD_TASK_ANNO) {
const [format] = additionalKey.split('::'); const format = additionalKey;
const [loader] = loaders const [loader] = loaders
.filter((_loader: any): boolean => _loader.name === format); .filter((_loader: any): boolean => _loader.name === format);
if (loader && file) { if (loader && file) {
@ -166,8 +166,8 @@ function ActionsMenuContainer(props: OwnProps & StateToProps & DispatchToProps):
taskID={taskInstance.id} taskID={taskInstance.id}
taskMode={taskInstance.mode} taskMode={taskInstance.mode}
bugTracker={taskInstance.bugTracker} bugTracker={taskInstance.bugTracker}
loaders={loaders.map((loader: any): string => `${loader.name}::${loader.format}`)} loaders={loaders}
dumpers={dumpers.map((dumper: any): string => dumper.name)} dumpers={dumpers}
loadActivity={loadActivity} loadActivity={loadActivity}
dumpActivities={dumpActivities} dumpActivities={dumpActivities}
exportActivities={exportActivities} exportActivities={exportActivities}

@ -123,7 +123,7 @@ function AnnotationMenuContainer(props: Props): JSX.Element {
dumpAnnotations(jobInstance.task, dumper); dumpAnnotations(jobInstance.task, dumper);
} }
} else if (action === Actions.LOAD_JOB_ANNO) { } else if (action === Actions.LOAD_JOB_ANNO) {
const [format] = additionalKey.split('::'); const format = additionalKey;
const [loader] = loaders const [loader] = loaders
.filter((_loader: any): boolean => _loader.name === format); .filter((_loader: any): boolean => _loader.name === format);
if (loader && file) { if (loader && file) {
@ -150,8 +150,8 @@ function AnnotationMenuContainer(props: Props): JSX.Element {
return ( return (
<AnnotationMenuComponent <AnnotationMenuComponent
taskMode={jobInstance.task.mode} taskMode={jobInstance.task.mode}
loaders={loaders.map((loader: any): string => loader.name)} loaders={loaders}
dumpers={dumpers.map((dumper: any): string => dumper.name)} dumpers={dumpers}
loadActivity={loadActivity} loadActivity={loadActivity}
dumpActivities={dumpActivities} dumpActivities={dumpActivities}
exportActivities={exportActivities} exportActivities={exportActivities}

@ -13,6 +13,7 @@ class _Format:
EXT = '' EXT = ''
VERSION = '' VERSION = ''
DISPLAY_NAME = '{NAME} {VERSION}' DISPLAY_NAME = '{NAME} {VERSION}'
ENABLED = True
class Exporter(_Format): class Exporter(_Format):
def __call__(self, dst_file, task_data, **options): def __call__(self, dst_file, task_data, **options):
@ -22,7 +23,7 @@ class Importer(_Format):
def __call__(self, src_file, task_data, **options): def __call__(self, src_file, task_data, **options):
raise NotImplementedError() raise NotImplementedError()
def _wrap_format(f_or_cls, klass, name, version, ext, display_name): def _wrap_format(f_or_cls, klass, name, version, ext, display_name, enabled):
import inspect import inspect
assert inspect.isclass(f_or_cls) or inspect.isfunction(f_or_cls) assert inspect.isclass(f_or_cls) or inspect.isfunction(f_or_cls)
if inspect.isclass(f_or_cls): if inspect.isclass(f_or_cls):
@ -44,14 +45,17 @@ def _wrap_format(f_or_cls, klass, name, version, ext, display_name):
target.DISPLAY_NAME = (display_name or klass.DISPLAY_NAME).format( target.DISPLAY_NAME = (display_name or klass.DISPLAY_NAME).format(
NAME=name, VERSION=version, EXT=ext) NAME=name, VERSION=version, EXT=ext)
assert all([target.NAME, target.VERSION, target.EXT, target.DISPLAY_NAME]) assert all([target.NAME, target.VERSION, target.EXT, target.DISPLAY_NAME])
target.ENABLED = enabled
return target return target
EXPORT_FORMATS = {} EXPORT_FORMATS = {}
def exporter(name, version, ext, display_name=None): def exporter(name, version, ext, display_name=None, enabled=True):
assert name not in EXPORT_FORMATS, "Export format '%s' already registered" % name assert name not in EXPORT_FORMATS, "Export format '%s' already registered" % name
def wrap_with_params(f_or_cls): def wrap_with_params(f_or_cls):
t = _wrap_format(f_or_cls, Exporter, t = _wrap_format(f_or_cls, Exporter,
name=name, ext=ext, version=version, display_name=display_name) name=name, ext=ext, version=version, display_name=display_name,
enabled=enabled)
key = t.DISPLAY_NAME key = t.DISPLAY_NAME
assert key not in EXPORT_FORMATS, "Export format '%s' already registered" % name assert key not in EXPORT_FORMATS, "Export format '%s' already registered" % name
EXPORT_FORMATS[key] = t EXPORT_FORMATS[key] = t
@ -59,10 +63,11 @@ def exporter(name, version, ext, display_name=None):
return wrap_with_params return wrap_with_params
IMPORT_FORMATS = {} IMPORT_FORMATS = {}
def importer(name, version, ext, display_name=None): def importer(name, version, ext, display_name=None, enabled=True):
def wrap_with_params(f_or_cls): def wrap_with_params(f_or_cls):
t = _wrap_format(f_or_cls, Importer, t = _wrap_format(f_or_cls, Importer,
name=name, ext=ext, version=version, display_name=display_name) name=name, ext=ext, version=version, display_name=display_name,
enabled=enabled)
key = t.DISPLAY_NAME key = t.DISPLAY_NAME
assert key not in IMPORT_FORMATS, "Import format '%s' already registered" % name assert key not in IMPORT_FORMATS, "Import format '%s' already registered" % name
IMPORT_FORMATS[key] = t IMPORT_FORMATS[key] = t

@ -14,7 +14,15 @@ from datumaro.components.project import Dataset
from .registry import dm_env, exporter, importer from .registry import dm_env, exporter, importer
@exporter(name='TFRecord', ext='ZIP', version='1.0') from datumaro.util.tf_util import import_tf
try:
import_tf()
tf_available = True
except ImportError:
tf_available = False
@exporter(name='TFRecord', ext='ZIP', version='1.0', enabled=tf_available)
def _export(dst_file, task_data, save_images=False): def _export(dst_file, task_data, save_images=False):
extractor = CvatTaskDataExtractor(task_data, include_images=save_images) extractor = CvatTaskDataExtractor(task_data, include_images=save_images)
extractor = Dataset.from_extractors(extractor) # apply lazy transforms extractor = Dataset.from_extractors(extractor) # apply lazy transforms
@ -25,7 +33,7 @@ def _export(dst_file, task_data, save_images=False):
make_zip_archive(temp_dir, dst_file) make_zip_archive(temp_dir, dst_file)
@importer(name='TFRecord', ext='ZIP', version='1.0') @importer(name='TFRecord', ext='ZIP', version='1.0', enabled=tf_available)
def _import(src_file, task_data): def _import(src_file, task_data):
with TemporaryDirectory() as tmp_dir: with TemporaryDirectory() as tmp_dir:
Archive(src_file.name).extractall(tmp_dir) Archive(src_file.name).extractall(tmp_dir)

@ -9,6 +9,7 @@ class DatasetFormatSerializer(serializers.Serializer):
name = serializers.CharField(max_length=64, source='DISPLAY_NAME') name = serializers.CharField(max_length=64, source='DISPLAY_NAME')
ext = serializers.CharField(max_length=64, source='EXT') ext = serializers.CharField(max_length=64, source='EXT')
version = serializers.CharField(max_length=64, source='VERSION') version = serializers.CharField(max_length=64, source='VERSION')
enabled = serializers.BooleanField(source='ENABLED')
class DatasetFormatsSerializer(serializers.Serializer): class DatasetFormatsSerializer(serializers.Serializer):
importers = DatasetFormatSerializer(many=True) importers = DatasetFormatSerializer(many=True)

@ -335,6 +335,9 @@ class TaskExportTest(APITestCase):
self.assertTrue(len(f.read()) != 0) self.assertTrue(len(f.read()) != 0)
for f in dm.views.get_export_formats(): for f in dm.views.get_export_formats():
if not f.ENABLED:
self.skipTest("Format is disabled")
format_name = f.DISPLAY_NAME format_name = f.DISPLAY_NAME
for save_images in { True, False }: for save_images in { True, False }:
with self.subTest(format=format_name, save_images=save_images): with self.subTest(format=format_name, save_images=save_images):
@ -359,6 +362,9 @@ class TaskExportTest(APITestCase):
('YOLO 1.1', 'yolo'), ('YOLO 1.1', 'yolo'),
]: ]:
with self.subTest(format=format_name): with self.subTest(format=format_name):
if not dm.formats.registry.EXPORT_FORMATS[format_name].ENABLED:
self.skipTest("Format is disabled")
task = self._generate_task() task = self._generate_task()
def check(file_path): def check(file_path):

@ -3149,8 +3149,8 @@ class TaskAnnotationAPITestCase(JobAnnotationAPITestCase):
export_formats = data['exporters'] export_formats = data['exporters']
self.assertTrue(isinstance(import_formats, list) and import_formats) self.assertTrue(isinstance(import_formats, list) and import_formats)
self.assertTrue(isinstance(export_formats, list) and export_formats) self.assertTrue(isinstance(export_formats, list) and export_formats)
import_formats = { v['name'] for v in import_formats } import_formats = { v['name']: v for v in import_formats }
export_formats = { v['name'] for v in export_formats } export_formats = { v['name']: v for v in export_formats }
formats = { exp: exp if exp in import_formats else None formats = { exp: exp if exp in import_formats else None
for exp in export_formats } for exp in export_formats }
@ -3159,12 +3159,12 @@ class TaskAnnotationAPITestCase(JobAnnotationAPITestCase):
formats['CVAT for video 1.1'] = 'CVAT 1.1' formats['CVAT for video 1.1'] = 'CVAT 1.1'
if 'CVAT for images 1.1' in export_formats: if 'CVAT for images 1.1' in export_formats:
formats['CVAT for images 1.1'] = 'CVAT 1.1' formats['CVAT for images 1.1'] = 'CVAT 1.1'
if import_formats ^ export_formats: if set(import_formats) ^ set(export_formats):
# NOTE: this may not be an error, so we should not fail # NOTE: this may not be an error, so we should not fail
print("The following import formats have no pair:", print("The following import formats have no pair:",
import_formats - export_formats) set(import_formats) - set(export_formats))
print("The following export formats have no pair:", print("The following export formats have no pair:",
export_formats - import_formats) set(export_formats) - set(import_formats))
for export_format, import_format in formats.items(): for export_format, import_format in formats.items():
with self.subTest(export_format=export_format, with self.subTest(export_format=export_format,
@ -3183,7 +3183,12 @@ class TaskAnnotationAPITestCase(JobAnnotationAPITestCase):
# 3. download annotation # 3. download annotation
response = self._dump_api_v1_tasks_id_annotations(task["id"], annotator, response = self._dump_api_v1_tasks_id_annotations(task["id"], annotator,
"?format={}".format(export_format)) "?format={}".format(export_format))
self.assertEqual(response.status_code, HTTP_202_ACCEPTED) if annotator and not export_formats[export_format]['enabled']:
self.assertEqual(response.status_code,
status.HTTP_405_METHOD_NOT_ALLOWED)
continue
else:
self.assertEqual(response.status_code, HTTP_202_ACCEPTED)
response = self._dump_api_v1_tasks_id_annotations(task["id"], annotator, response = self._dump_api_v1_tasks_id_annotations(task["id"], annotator,
"?format={}".format(export_format)) "?format={}".format(export_format))

@ -482,7 +482,8 @@ class TaskViewSet(auth.TaskGetQuerySetMixin, viewsets.ModelViewSet):
responses={ responses={
'202': openapi.Response(description='Dump of annotations has been started'), '202': openapi.Response(description='Dump of annotations has been started'),
'201': openapi.Response(description='Annotations file is ready to download'), '201': openapi.Response(description='Annotations file is ready to download'),
'200': openapi.Response(description='Download of file started') '200': openapi.Response(description='Download of file started'),
'405': openapi.Response(description='Format is not available'),
} }
) )
@swagger_auto_schema(method='put', operation_summary='Method allows to upload task annotations', @swagger_auto_schema(method='put', operation_summary='Method allows to upload task annotations',
@ -494,6 +495,7 @@ class TaskViewSet(auth.TaskGetQuerySetMixin, viewsets.ModelViewSet):
responses={ responses={
'202': openapi.Response(description='Uploading has been started'), '202': openapi.Response(description='Uploading has been started'),
'201': openapi.Response(description='Uploading has finished'), '201': openapi.Response(description='Uploading has finished'),
'405': openapi.Response(description='Format is not available'),
} }
) )
@swagger_auto_schema(method='patch', operation_summary='Method performs a partial update of annotations in a specific task', @swagger_auto_schema(method='patch', operation_summary='Method performs a partial update of annotations in a specific task',
@ -619,7 +621,8 @@ class TaskViewSet(auth.TaskGetQuerySetMixin, viewsets.ModelViewSet):
], ],
responses={'202': openapi.Response(description='Exporting has been started'), responses={'202': openapi.Response(description='Exporting has been started'),
'201': openapi.Response(description='Output file is ready for downloading'), '201': openapi.Response(description='Output file is ready for downloading'),
'200': openapi.Response(description='Download of file started') '200': openapi.Response(description='Download of file started'),
'405': openapi.Response(description='Format is not available'),
} }
) )
@action(detail=True, methods=['GET'], serializer_class=None, @action(detail=True, methods=['GET'], serializer_class=None,
@ -799,17 +802,20 @@ def rq_handler(job, exc_type, exc_value, tb):
# tags=['tasks']) # tags=['tasks'])
# @api_view(['PUT']) # @api_view(['PUT'])
def _import_annotations(request, rq_id, rq_func, pk, format_name): def _import_annotations(request, rq_id, rq_func, pk, format_name):
format_desc = {f.DISPLAY_NAME: f
for f in dm.views.get_import_formats()}.get(format_name)
if format_desc is None:
raise serializers.ValidationError(
"Unknown input format '{}'".format(format_name))
elif not format_desc.ENABLED:
return Response(status=status.HTTP_405_METHOD_NOT_ALLOWED)
queue = django_rq.get_queue("default") queue = django_rq.get_queue("default")
rq_job = queue.fetch_job(rq_id) rq_job = queue.fetch_job(rq_id)
if not rq_job: if not rq_job:
serializer = AnnotationFileSerializer(data=request.data) serializer = AnnotationFileSerializer(data=request.data)
if serializer.is_valid(raise_exception=True): if serializer.is_valid(raise_exception=True):
if format_name not in \
[f.DISPLAY_NAME for f in dm.views.get_import_formats()]:
raise serializers.ValidationError(
"Unknown input format '{}'".format(format_name))
anno_file = serializer.validated_data['annotation_file'] anno_file = serializer.validated_data['annotation_file']
fd, filename = mkstemp(prefix='cvat_{}'.format(pk)) fd, filename = mkstemp(prefix='cvat_{}'.format(pk))
with open(filename, 'wb+') as f: with open(filename, 'wb+') as f:
@ -843,9 +849,13 @@ def _export_annotations(db_task, rq_id, request, format_name, action, callback,
raise serializers.ValidationError( raise serializers.ValidationError(
"Unexpected action specified for the request") "Unexpected action specified for the request")
if format_name not in [f.DISPLAY_NAME for f in dm.views.get_export_formats()]: format_desc = {f.DISPLAY_NAME: f
for f in dm.views.get_export_formats()}.get(format_name)
if format_desc is None:
raise serializers.ValidationError( raise serializers.ValidationError(
"Unknown format specified for the request") "Unknown format specified for the request")
elif not format_desc.ENABLED:
return Response(status=status.HTTP_405_METHOD_NOT_ALLOWED)
queue = django_rq.get_queue("default") queue = django_rq.get_queue("default")

@ -235,7 +235,17 @@ class Environment:
exports = cls._import_module(module_dir, module_name, types, exports = cls._import_module(module_dir, module_name, types,
package) package)
except Exception as e: except Exception as e:
log.debug("Failed to import module '%s': %s" % (module_name, e)) module_search_error = ImportError
try:
module_search_error = ModuleNotFoundError # python 3.6+
except NameError:
pass
message = ["Failed to import module '%s': %s", module_name, e]
if isinstance(e, module_search_error):
log.debug(*message)
else:
log.warning(*message)
continue continue
log.debug("Imported the following symbols from %s: %s" % \ log.debug("Imported the following symbols from %s: %s" % \

@ -10,11 +10,11 @@ import numpy as np
import os import os
import os.path as osp import os.path as osp
import platform import platform
import subprocess
from openvino.inference_engine import IENetwork, IEPlugin from openvino.inference_engine import IENetwork, IEPlugin
from datumaro.components.launcher import Launcher from datumaro.components.launcher import Launcher
from datumaro.util.os_util import check_instruction_set
class InterpreterScript: class InterpreterScript:
@ -45,17 +45,6 @@ class OpenVinoLauncher(Launcher):
_DEFAULT_IE_PLUGINS_PATH = "/opt/intel/openvino_2019.1.144/deployment_tools/inference_engine/lib/intel64" _DEFAULT_IE_PLUGINS_PATH = "/opt/intel/openvino_2019.1.144/deployment_tools/inference_engine/lib/intel64"
_IE_PLUGINS_PATH = os.getenv("IE_PLUGINS_PATH", _DEFAULT_IE_PLUGINS_PATH) _IE_PLUGINS_PATH = os.getenv("IE_PLUGINS_PATH", _DEFAULT_IE_PLUGINS_PATH)
@staticmethod
def _check_instruction_set(instruction):
return instruction == str.strip(
# Let's ignore a warning from bandit about using shell=True.
# In this case it isn't a security issue and we use some
# shell features like pipes.
subprocess.check_output(
'lscpu | grep -o "{}" | head -1'.format(instruction),
shell=True).decode('utf-8') # nosec
)
@staticmethod @staticmethod
def make_plugin(device='cpu', plugins_path=_IE_PLUGINS_PATH): def make_plugin(device='cpu', plugins_path=_IE_PLUGINS_PATH):
if plugins_path is None or not osp.isdir(plugins_path): if plugins_path is None or not osp.isdir(plugins_path):
@ -63,10 +52,10 @@ class OpenVinoLauncher(Launcher):
(plugins_path)) (plugins_path))
plugin = IEPlugin(device='CPU', plugin_dirs=[plugins_path]) plugin = IEPlugin(device='CPU', plugin_dirs=[plugins_path])
if (OpenVinoLauncher._check_instruction_set('avx2')): if (check_instruction_set('avx2')):
plugin.add_cpu_extension(os.path.join(plugins_path, plugin.add_cpu_extension(os.path.join(plugins_path,
'libcpu_extension_avx2.so')) 'libcpu_extension_avx2.so'))
elif (OpenVinoLauncher._check_instruction_set('sse4')): elif (check_instruction_set('sse4')):
plugin.add_cpu_extension(os.path.join(plugins_path, plugin.add_cpu_extension(os.path.join(plugins_path,
'libcpu_extension_sse4.so')) 'libcpu_extension_sse4.so'))
elif platform.system() == 'Darwin': elif platform.system() == 'Darwin':

@ -0,0 +1,17 @@
# Copyright (C) 2020 Intel Corporation
#
# SPDX-License-Identifier: MIT
import subprocess
def check_instruction_set(instruction):
return instruction == str.strip(
# Let's ignore a warning from bandit about using shell=True.
# In this case it isn't a security issue and we use some
# shell features like pipes.
subprocess.check_output(
'lscpu | grep -o "%s" | head -1' % instruction,
shell=True).decode('utf-8') # nosec
)

@ -3,7 +3,36 @@
# #
# SPDX-License-Identifier: MIT # SPDX-License-Identifier: MIT
def import_tf():
def check_import():
# Workaround for checking import availability:
# Official TF builds include AVX instructions. Once we try to import,
# the program crashes. We raise an exception instead.
import subprocess
import sys
from .os_util import check_instruction_set
result = subprocess.run([sys.executable, '-c', 'import tensorflow'],
timeout=60,
universal_newlines=True, # use text mode for output stream
stdout=subprocess.PIPE, stderr=subprocess.PIPE) # capture output
if result.returncode != 0:
message = result.stderr
if not message:
message = "Can't import tensorflow. " \
"Test process exit code: %s." % result.returncode
if not check_instruction_set('avx'):
# The process has probably crashed for AVX unavalability
message += " This is likely because your CPU does not " \
"support AVX instructions, " \
"which are required for tensorflow."
raise ImportError(message)
def import_tf(check=True):
import sys import sys
tf = sys.modules.get('tensorflow', None) tf = sys.modules.get('tensorflow', None)
@ -14,6 +43,9 @@ def import_tf():
import os import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
if check:
check_import()
import tensorflow as tf import tensorflow as tf
try: try:

@ -1,17 +1,34 @@
import numpy as np import numpy as np
from unittest import TestCase from unittest import TestCase, skipIf
from datumaro.components.extractor import (Extractor, DatasetItem, from datumaro.components.extractor import (Extractor, DatasetItem,
AnnotationType, Bbox, Mask, LabelCategories AnnotationType, Bbox, Mask, LabelCategories
) )
from datumaro.plugins.tf_detection_api_format.importer import TfDetectionApiImporter
from datumaro.plugins.tf_detection_api_format.extractor import TfDetectionApiExtractor
from datumaro.plugins.tf_detection_api_format.converter import TfDetectionApiConverter
from datumaro.util.image import Image from datumaro.util.image import Image
from datumaro.util.test_utils import TestDir, compare_datasets from datumaro.util.test_utils import TestDir, compare_datasets
from datumaro.util.tf_util import check_import
try:
from datumaro.plugins.tf_detection_api_format.importer import TfDetectionApiImporter
from datumaro.plugins.tf_detection_api_format.extractor import TfDetectionApiExtractor
from datumaro.plugins.tf_detection_api_format.converter import TfDetectionApiConverter
import_failed = False
except ImportError:
import_failed = True
import importlib
module_found = importlib.util.find_spec('tensorflow') is not None
@skipIf(not module_found, "Tensorflow package is not found")
class TfImportTest(TestCase):
def test_raises_when_crashes_on_import(self):
# Should fire if import can't be done for any reason except
# module unavailability and import crash
with self.assertRaisesRegex(ImportError, 'Test process exit code'):
check_import()
@skipIf(import_failed, "Failed to import tensorflow")
class TfrecordConverterTest(TestCase): class TfrecordConverterTest(TestCase):
def _test_save_and_load(self, source_dataset, converter, test_dir, def _test_save_and_load(self, source_dataset, converter, test_dir,
target_dataset=None, importer_args=None): target_dataset=None, importer_args=None):
@ -171,6 +188,7 @@ class TfrecordConverterTest(TestCase):
self.assertEqual(expected, parsed) self.assertEqual(expected, parsed)
@skipIf(import_failed, "Failed to import tensorflow")
class TfrecordImporterTest(TestCase): class TfrecordImporterTest(TestCase):
def test_can_detect(self): def test_can_detect(self):
class TestExtractor(Extractor): class TestExtractor(Extractor):

Loading…
Cancel
Save