diff --git a/CHANGELOG.md b/CHANGELOG.md index de92930e..8e560d05 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Ability to dump/load annotations in several formats from UI (CVAT, Pascal VOC, YOLO, MS COCO, png mask, TFRecord) - Auth for REST API (api/v1/auth/): login, logout, register, ... - Preview for the new CVAT UI (dashboard only) is available: http://localhost:9080/ +- Added command line tool for performing common task operations (/utils/cli/) ### Changed - Outside and keyframe buttons in the side panel for all interpolation shapes (they were only for boxes before) diff --git a/utils/cli/cli.py b/utils/cli/cli.py new file mode 100755 index 00000000..3407b23b --- /dev/null +++ b/utils/cli/cli.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +# +# SPDX-License-Identifier: MIT +import logging +import requests +import sys +from http.client import HTTPConnection +from core.core import CLI, CVAT_API_V1 +from core.definition import parser +log = logging.getLogger(__name__) + + +def config_log(level): + log = logging.getLogger('core') + log.addHandler(logging.StreamHandler(sys.stdout)) + log.setLevel(level) + if level <= logging.DEBUG: + HTTPConnection.debuglevel = 1 + + +def main(): + actions = {'create': CLI.tasks_create, + 'delete': CLI.tasks_delete, + 'ls': CLI.tasks_list, + 'frames': CLI.tasks_frame, + 'dump': CLI.tasks_dump} + args = parser.parse_args() + config_log(args.loglevel) + with requests.Session() as session: + session.auth = args.auth + api = CVAT_API_V1(args.server_host, args.server_port) + cli = CLI(session, api) + try: + actions[args.action](cli, **args.__dict__) + except (requests.exceptions.HTTPError, + requests.exceptions.ConnectionError, + requests.exceptions.RequestException) as e: + log.info(e) + + +if __name__ == '__main__': + main() diff --git a/utils/cli/core/__init__.py b/utils/cli/core/__init__.py new file mode 100644 index 00000000..c5319fc6 --- /dev/null +++ b/utils/cli/core/__init__.py @@ -0,0 +1,3 @@ +# SPDX-License-Identifier: MIT +from .definition import parser, ResourceType # noqa +from .core import CLI, CVAT_API_V1 # noqa diff --git a/utils/cli/core/core.py b/utils/cli/core/core.py new file mode 100644 index 00000000..157e30d1 --- /dev/null +++ b/utils/cli/core/core.py @@ -0,0 +1,140 @@ +# SPDX-License-Identifier: MIT +import json +import logging +import os +import requests +from io import BytesIO +from PIL import Image +from .definition import ResourceType +log = logging.getLogger(__name__) + + +class CLI(): + + def __init__(self, session, api): + self.api = api + self.session = session + + def tasks_data(self, task_id, resource_type, resources): + """ Add local, remote, or shared files to an existing task. """ + url = self.api.tasks_id_data(task_id) + data = None + files = None + if resource_type == ResourceType.LOCAL: + files = {f'client_files[{i}]': open(f, 'rb') for i, f in enumerate(resources)} + elif resource_type == ResourceType.REMOTE: + data = {f'remote_files[{i}]': f for i, f in enumerate(resources)} + elif resource_type == ResourceType.SHARE: + data = {f'server_files[{i}]': f for i, f in enumerate(resources)} + response = self.session.post(url, data=data, files=files) + response.raise_for_status() + + def tasks_list(self, use_json_output, **kwargs): + """ List all tasks in either basic or JSON format. """ + url = self.api.tasks + response = self.session.get(url) + response.raise_for_status() + page = 1 + while True: + response_json = response.json() + for r in response_json['results']: + if use_json_output: + log.info(json.dumps(r, indent=4)) + else: + log.info(f'{r["id"]},{r["name"]},{r["status"]}') + if not response_json['next']: + return + page += 1 + url = self.api.tasks_page(page) + response = self.session.get(url) + response.raise_for_status() + + def tasks_create(self, name, labels, bug, resource_type, resources, **kwargs): + """ Create a new task with the given name and labels JSON and + add the files to it. """ + url = self.api.tasks + data = {'name': name, + 'labels': labels, + 'bug_tracker': bug, + 'image_quality': 50} + response = self.session.post(url, json=data) + response.raise_for_status() + response_json = response.json() + log.info(f'Created task ID: {response_json["id"]} ' + f'NAME: {response_json["name"]}') + self.tasks_data(response_json['id'], resource_type, resources) + + def tasks_delete(self, task_ids, **kwargs): + """ Delete a list of tasks, ignoring those which don't exist. """ + for task_id in task_ids: + url = self.api.tasks_id(task_id) + response = self.session.delete(url) + try: + response.raise_for_status() + log.info(f'Task ID {task_id} deleted') + except requests.exceptions.HTTPError as e: + if response.status_code == 404: + log.info(f'Task ID {task_id} not found') + else: + raise e + + def tasks_frame(self, task_id, frame_ids, outdir='', **kwargs): + """ Download the requested frame numbers for a task and save images as + task__frame_.jpg.""" + for frame_id in frame_ids: + url = self.api.tasks_id_frame_id(task_id, frame_id) + response = self.session.get(url) + response.raise_for_status() + im = Image.open(BytesIO(response.content)) + outfile = f'task_{task_id}_frame_{frame_id:06d}.jpg' + im.save(os.path.join(outdir, outfile)) + + def tasks_dump(self, task_id, fileformat, filename, **kwargs): + """ Download annotations for a task in the specified format + (e.g. 'YOLO ZIP 1.0').""" + url = self.api.tasks_id(task_id) + response = self.session.get(url) + response.raise_for_status() + response_json = response.json() + + url = self.api.tasks_id_annotations_filename(task_id, + response_json['name'], + fileformat) + while True: + response = self.session.get(url) + response.raise_for_status() + log.info(f'STATUS {response.status_code}') + if response.status_code == 201: + break + + response = self.session.get(url + '&action=download') + response.raise_for_status() + + with open(filename, 'wb') as fp: + fp.write(response.content) + + +class CVAT_API_V1(): + """ Build parameterized API URLs """ + + def __init__(self, host, port): + self.base = f'http://{host}:{port}/api/v1/' + + @property + def tasks(self): + return f'{self.base}tasks' + + def tasks_page(self, page_id): + return f'{self.tasks}?page={page_id}' + + def tasks_id(self, task_id): + return f'{self.tasks}/{task_id}' + + def tasks_id_data(self, task_id): + return f'{self.tasks}/{task_id}/data' + + def tasks_id_frame_id(self, task_id, frame_id): + return f'{self.tasks}/{task_id}/frames/{frame_id}' + + def tasks_id_annotations_filename(self, task_id, name, fileformat): + return f'{self.tasks}/{task_id}/annotations/{name}?format={fileformat}' diff --git a/utils/cli/core/definition.py b/utils/cli/core/definition.py new file mode 100644 index 00000000..c9e32e25 --- /dev/null +++ b/utils/cli/core/definition.py @@ -0,0 +1,210 @@ +# SPDX-License-Identifier: MIT +import argparse +import getpass +import json +import logging +import os +from enum import Enum + + +def get_auth(s): + """ Parse USER[:PASS] strings and prompt for password if none was + supplied. """ + user, _, password = s.partition(':') + password = password or os.environ.get('PASS') or getpass.getpass() + return user, password + + +def parse_label_arg(s): + """ If s is a file load it as JSON, otherwise parse s as JSON.""" + if os.path.exists(s): + fp = open(s, 'r') + return json.load(fp) + else: + return json.loads(s) + + +class ResourceType(Enum): + + LOCAL = 0 + SHARE = 1 + REMOTE = 2 + + def __str__(self): + return self.name.lower() + + def __repr__(self): + return str(self) + + @staticmethod + def argparse(s): + try: + return ResourceType[s.upper()] + except KeyError: + return s + + +####################################################################### +# Command line interface definition +####################################################################### + +parser = argparse.ArgumentParser( + description='Perform common operations related to CVAT tasks.\n\n' +) +task_subparser = parser.add_subparsers(dest='action') + +####################################################################### +# Positional arguments +####################################################################### + +parser.add_argument( + '--auth', + type=get_auth, + metavar='USER:[PASS]', + default=getpass.getuser(), + help='''defaults to the current user and supports the PASS + environment variable or password prompt + (default user: %(default)s).''' +) +parser.add_argument( + '--server-host', + type=str, + default='localhost', + help='host (default: %(default)s)' +) +parser.add_argument( + '--server-port', + type=int, + default='8080', + help='port (default: %(default)s)' +) +parser.add_argument( + '--debug', + action='store_const', + dest='loglevel', + const=logging.DEBUG, + default=logging.INFO, + help='show debug output' +) + +####################################################################### +# Create +####################################################################### + +task_create_parser = task_subparser.add_parser( + 'create', + description='Create a new CVAT task.' +) +task_create_parser.add_argument( + 'name', + type=str, + help='name of the task' +) +task_create_parser.add_argument( + '--labels', + default='[]', + type=parse_label_arg, + help='string or file containing JSON labels specification' +) +task_create_parser.add_argument( + '--bug', + default='', + type=str, + help='bug tracker URL' +) +task_create_parser.add_argument( + 'resource_type', + default='local', + choices=list(ResourceType), + type=ResourceType.argparse, + help='type of files specified' +) +task_create_parser.add_argument( + 'resources', + type=str, + help='list of paths or URLs', + nargs='+' +) + +####################################################################### +# Delete +####################################################################### + +delete_parser = task_subparser.add_parser( + 'delete', + description='Delete a CVAT task.' +) +delete_parser.add_argument( + 'task_ids', + type=int, + help='list of task IDs', + nargs='+' +) + +####################################################################### +# List +####################################################################### + +ls_parser = task_subparser.add_parser( + 'ls', + description='List all CVAT tasks in simple or JSON format.' +) +ls_parser.add_argument( + '--json', + dest='use_json_output', + default=False, + action='store_true', + help='output JSON data' +) + +####################################################################### +# Frames +####################################################################### + +frames_parser = task_subparser.add_parser( + 'frames', + description='Download all frame images for a CVAT task.' +) +frames_parser.add_argument( + 'task_id', + type=int, + help='task ID' +) +frames_parser.add_argument( + 'frame_ids', + type=int, + help='list of frame IDs to download', + nargs='+' +) +frames_parser.add_argument( + '--outdir', + type=str, + default='', + help='directory to save images' +) + +####################################################################### +# Dump +####################################################################### + +dump_parser = task_subparser.add_parser( + 'dump', + description='Download annotations for a CVAT task.' +) +dump_parser.add_argument( + 'task_id', + type=int, + help='task ID' +) +dump_parser.add_argument( + 'filename', + type=str, + help='output file' +) +dump_parser.add_argument( + '--format', + dest='fileformat', + type=str, + default='CVAT XML 1.1 for images', + help='annotation format (default: %(default)s)' +) diff --git a/utils/cli/requirements.txt b/utils/cli/requirements.txt new file mode 100644 index 00000000..65a8f149 --- /dev/null +++ b/utils/cli/requirements.txt @@ -0,0 +1,2 @@ +Pillow==5.3.0 +requests==2.20.1 diff --git a/utils/cli/tests/__init__.py b/utils/cli/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/utils/cli/tests/test_cli.py b/utils/cli/tests/test_cli.py new file mode 100644 index 00000000..b09e7371 --- /dev/null +++ b/utils/cli/tests/test_cli.py @@ -0,0 +1,71 @@ +# SPDX-License-Identifier: MIT +import logging +import io +import os +import sys +import unittest +from django.conf import settings +from requests.auth import HTTPBasicAuth +from utils.cli.core import CLI, CVAT_API_V1, ResourceType +from rest_framework.test import APITestCase, RequestsClient +from cvat.apps.engine.tests.test_rest_api import create_db_users +from cvat.apps.engine.tests.test_rest_api import generate_image_file + + +class TestCLI(APITestCase): + + @unittest.mock.patch('sys.stdout', new_callable=io.StringIO) + def setUp(self, mock_stdout): + self.client = RequestsClient() + self.client.auth = HTTPBasicAuth('admin', 'admin') + self.api = CVAT_API_V1('testserver', '') + self.cli = CLI(self.client, self.api) + self.taskname = 'test_task' + self.cli.tasks_create(self.taskname, + [], + '', + ResourceType.LOCAL, + [self.img_file]) + # redirect logging to mocked stdout to test program output + self.mock_stdout = mock_stdout + log = logging.getLogger('utils.cli.core') + log.setLevel(logging.INFO) + log.addHandler(logging.StreamHandler(sys.stdout)) + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.img_file = os.path.join(settings.SHARE_ROOT, 'test_cli.jpg') + data = generate_image_file(cls.img_file) + with open(cls.img_file, 'wb') as image: + image.write(data.read()) + + @classmethod + def tearDownClass(cls): + super().tearDownClass() + os.remove(cls.img_file) + + @classmethod + def setUpTestData(cls): + create_db_users(cls) + + def test_tasks_list(self): + self.cli.tasks_list(False) + self.assertRegex(self.mock_stdout.getvalue(), f'.*{self.taskname}.*') + + def test_tasks_delete(self): + self.cli.tasks_delete([1]) + self.cli.tasks_list(False) + self.assertNotRegex(self.mock_stdout.getvalue(), f'.*{self.taskname}.*') + + def test_tasks_dump(self): + path = os.path.join(settings.SHARE_ROOT, 'test_cli.xml') + self.cli.tasks_dump(1, 'CVAT XML 1.1 for images', path) + self.assertTrue(os.path.exists(path)) + os.remove(path) + + def test_tasks_frame(self): + path = os.path.join(settings.SHARE_ROOT, 'task_1_frame_000000.jpg') + self.cli.tasks_frame(1, [0], outdir=settings.SHARE_ROOT) + self.assertTrue(os.path.exists(path)) + os.remove(path)