[Datumaro] Fix TFrecord converter constructor (#993)

main
zhiltsov-max 6 years ago committed by Nikita Manovich
parent 19b86b9731
commit 7c9043974b

@ -593,7 +593,6 @@ class TaskViewSet(auth.TaskGetQuerySetMixin, viewsets.ModelViewSet):
@action(detail=True, methods=['GET'], serializer_class=None,
url_path='dataset')
def dataset_export(self, request, pk):
db_task = self.get_object()
action = request.query_params.get("action", "")
@ -611,7 +610,7 @@ class TaskViewSet(auth.TaskGetQuerySetMixin, viewsets.ModelViewSet):
raise serializers.ValidationError(
"Unexpected parameter 'format' specified for the request")
rq_id = "task_dataset_export.{}.{}".format(pk, dst_format)
rq_id = "/api/v1/tasks/{}/dataset/{}".format(pk, dst_format)
queue = django_rq.get_queue("default")
rq_job = queue.fetch_job(rq_id)

@ -98,8 +98,27 @@ def _make_tf_example(item, get_label_id, get_label, save_images=False):
return tf_example
class DetectionApiConverter:
def __init__(self, save_images=True):
self.save_images = save_images
def __init__(self, save_images=False, cmdline_args=None):
super().__init__()
self._save_images = save_images
if cmdline_args is not None:
options = self._parse_cmdline(cmdline_args)
for k, v in options.items():
if hasattr(self, '_' + str(k)):
setattr(self, '_' + str(k), v)
@classmethod
def build_cmdline_parser(cls, parser=None):
import argparse
if not parser:
parser = argparse.ArgumentParser()
parser.add_argument('--save-images', action='store_true',
help="Save images (default: %(default)s)")
return parser
def __call__(self, extractor, save_dir):
tf = _import_tf()
@ -141,6 +160,6 @@ class DetectionApiConverter:
item,
get_label=get_label,
get_label_id=map_label_id,
save_images=self.save_images,
save_images=self._save_images,
)
writer.write(tf_example.SerializeToString())

@ -112,7 +112,8 @@ class TfrecordConverterTest(TestCase):
with TestDir() as test_dir:
self._test_can_save_and_load(
TestExtractor(), DetectionApiConverter(), test_dir)
TestExtractor(), DetectionApiConverter(save_images=True),
test_dir)
def test_labelmap_parsing(self):
text = """

Loading…
Cancel
Save