diff --git a/cvat/apps/auto_annotation/import_modules.py b/cvat/apps/auto_annotation/import_modules.py new file mode 100644 index 00000000..6f764b88 --- /dev/null +++ b/cvat/apps/auto_annotation/import_modules.py @@ -0,0 +1,36 @@ +import ast +from collections import namedtuple +import importlib + +Import = namedtuple("Import", ["module", "name", "alias"]) + +def import_modules(source_code: str): + results = {} + imports = parse_imports(source_code) + for import_ in imports: + module = import_.module if import_.module else import_.name + loaded_module = importlib.import_module(module) + + if not import_.name == module: + loaded_module = getattr(loaded_module, import_.name) + + if import_.alias: + results[import_.alias] = loaded_module + else: + results[import_.name] = loaded_module + + return results + +def parse_imports(source_code: str): + root = ast.parse(source_code) + + for node in ast.iter_child_nodes(root): + if isinstance(node, ast.Import): + module = [] + elif isinstance(node, ast.ImportFrom): + module = node.module + else: + continue + + for n in node.names: + yield Import(module, n.name, n.asname) diff --git a/cvat/apps/auto_annotation/model_manager.py b/cvat/apps/auto_annotation/model_manager.py index ee57ab6f..4f17d4d7 100644 --- a/cvat/apps/auto_annotation/model_manager.py +++ b/cvat/apps/auto_annotation/model_manager.py @@ -24,6 +24,8 @@ from cvat.apps.engine.annotation import put_task_data, patch_task_data from .models import AnnotationModel, FrameworkChoice from .model_loader import ModelLoader from .image_loader import ImageLoader +from .import_modules import import_modules + def _remove_old_file(model_file_field): if model_file_field and os.path.exists(model_file_field.name): @@ -270,6 +272,7 @@ def _process_detections(detections, path_to_conv_script, restricted=True): "detections": detections, "results": results, } + source_code = open(path_to_conv_script).read() if restricted: global_vars = { @@ -284,8 +287,10 @@ def _process_detections(detections, path_to_conv_script, restricted=True): } else: global_vars = globals() + imports = import_modules(source_code) + global_vars.update(imports) - exec (open(path_to_conv_script).read(), global_vars, local_vars) + exec(source_code, global_vars, local_vars) return results