diff --git a/CHANGELOG.md b/CHANGELOG.md index 09009e6b..cb76509d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## \[2.4.0] - Unreleased ### Added +- \[SDK\] An arg to wait for data processing in the task data uploading function + () - Filename pattern to simplify uploading cloud storage data for a task (, ) - \[SDK\] Configuration setting to change the dataset cache directory () @@ -17,6 +19,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - The Docker Compose files now use the Compose Specification version of the format. This version is supported by Docker Compose 1.27.0+ (). +- \[SDK\] The `resource_type` args now have the default value of `local` in task creation functions. + The corresponding arguments are keyword-only now. + () ### Deprecated - TDB diff --git a/cvat-cli/src/cvat_cli/cli.py b/cvat-cli/src/cvat_cli/cli.py index 00122341..d8edf8ce 100644 --- a/cvat-cli/src/cvat_cli/cli.py +++ b/cvat-cli/src/cvat_cli/cli.py @@ -38,9 +38,9 @@ class CLI: self, name: str, labels: List[Dict[str, str]], - resource_type: ResourceType, resources: Sequence[str], *, + resource_type: ResourceType = ResourceType.LOCAL, annotation_path: str = "", annotation_format: str = "CVAT XML 1.1", status_check_period: int = 2, diff --git a/cvat-sdk/cvat_sdk/core/proxies/tasks.py b/cvat-sdk/cvat_sdk/core/proxies/tasks.py index 330a6ab9..97dcbdbc 100644 --- a/cvat-sdk/cvat_sdk/core/proxies/tasks.py +++ b/cvat-sdk/cvat_sdk/core/proxies/tasks.py @@ -65,11 +65,13 @@ class Task( def upload_data( self, - resource_type: ResourceType, resources: Sequence[StrPath], *, + resource_type: ResourceType = ResourceType.LOCAL, pbar: Optional[ProgressReporter] = None, params: Optional[Dict[str, Any]] = None, + wait_for_completion: bool = True, + status_check_period: Optional[int] = None, ) -> None: """ Add local, remote, or shared files to an existing task. @@ -121,6 +123,37 @@ class Task( url, list(map(Path, resources)), pbar=pbar, **data ) + if wait_for_completion: + if status_check_period is None: + status_check_period = self._client.config.status_check_period + + self._client.logger.info("Awaiting for task %s creation...", self.id) + while True: + sleep(status_check_period) + (status, response) = self.api.retrieve_status(self.id) + + self._client.logger.info( + "Task %s creation status: %s (message=%s)", + self.id, + status.state.value, + status.message, + ) + + if ( + status.state.value + == models.RqStatusStateEnum.allowed_values[("value",)]["FINISHED"] + ): + break + elif ( + status.state.value + == models.RqStatusStateEnum.allowed_values[("value",)]["FAILED"] + ): + raise exceptions.ApiException( + status=status.state.value, reason=status.message, http_resp=response + ) + + self.fetch() + def import_annotations( self, format_name: str, @@ -296,9 +329,9 @@ class TasksRepo( def create_from_data( self, spec: models.ITaskWriteRequest, - resource_type: ResourceType, resources: Sequence[str], *, + resource_type: ResourceType = ResourceType.LOCAL, data_params: Optional[Dict[str, Any]] = None, annotation_path: str = "", annotation_format: str = "CVAT XML 1.1", @@ -313,9 +346,6 @@ class TasksRepo( Returns: id of the created task """ - if status_check_period is None: - status_check_period = self._client.config.status_check_period - if getattr(spec, "project_id", None) and getattr(spec, "labels", None): raise exceptions.ApiValueError( "Can't set labels to a task inside a project. " @@ -326,27 +356,14 @@ class TasksRepo( task = self.create(spec=spec) self._client.logger.info("Created task ID: %s NAME: %s", task.id, task.name) - task.upload_data(resource_type, resources, pbar=pbar, params=data_params) - - self._client.logger.info("Awaiting for task %s creation...", task.id) - status: models.RqStatus = None - while status != models.RqStatusStateEnum.allowed_values[("value",)]["FINISHED"]: - sleep(status_check_period) - (status, response) = self.api.retrieve_status(task.id) - - self._client.logger.info( - "Task %s creation status=%s, message=%s", - task.id, - status.state.value, - status.message, - ) - - if status.state.value == models.RqStatusStateEnum.allowed_values[("value",)]["FAILED"]: - raise exceptions.ApiException( - status=status.state.value, reason=status.message, http_resp=response - ) - - status = status.state.value + task.upload_data( + resource_type=resource_type, + resources=resources, + pbar=pbar, + params=data_params, + wait_for_completion=True, + status_check_period=status_check_period, + ) if annotation_path: task.import_annotations(annotation_format, annotation_path, pbar=pbar) diff --git a/cvat-sdk/cvat_sdk/core/uploading.py b/cvat-sdk/cvat_sdk/core/uploading.py index c6f592f7..95d129a0 100644 --- a/cvat-sdk/cvat_sdk/core/uploading.py +++ b/cvat-sdk/cvat_sdk/core/uploading.py @@ -5,7 +5,6 @@ from __future__ import annotations import os -from contextlib import ExitStack, closing from pathlib import Path from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple @@ -206,40 +205,6 @@ class Uploader: positive_statuses=positive_statuses, ) - def _split_files_by_requests( - self, filenames: List[Path] - ) -> Tuple[List[Tuple[List[Path], int]], List[Path], int]: - bulk_files: Dict[str, int] = {} - separate_files: Dict[str, int] = {} - - # sort by size - for filename in filenames: - filename = filename.resolve() - file_size = filename.stat().st_size - if MAX_REQUEST_SIZE < file_size: - separate_files[filename] = file_size - else: - bulk_files[filename] = file_size - - total_size = sum(bulk_files.values()) + sum(separate_files.values()) - - # group small files by requests - bulk_file_groups: List[Tuple[List[str], int]] = [] - current_group_size: int = 0 - current_group: List[str] = [] - for filename, file_size in bulk_files.items(): - if MAX_REQUEST_SIZE < current_group_size + file_size: - bulk_file_groups.append((current_group, current_group_size)) - current_group_size = 0 - current_group = [] - - current_group.append(filename) - current_group_size += file_size - if current_group: - bulk_file_groups.append((current_group, current_group_size)) - - return bulk_file_groups, separate_files, total_size - @staticmethod def _make_tus_uploader(api_client: ApiClient, url: str, **kwargs): # Add headers required by CVAT server @@ -353,6 +318,10 @@ class DatasetUploader(Uploader): class DataUploader(Uploader): + def __init__(self, client: Client, *, max_request_size: int = MAX_REQUEST_SIZE): + super().__init__(client) + self.max_request_size = max_request_size + def upload_files( self, url: str, @@ -369,22 +338,21 @@ class DataUploader(Uploader): self._tus_start_upload(url) for group, group_size in bulk_file_groups: - with ExitStack() as es: - files = {} - for i, filename in enumerate(group): - files[f"client_files[{i}]"] = ( - os.fspath(filename), - es.enter_context(closing(open(filename, "rb"))).read(), - ) - response = self._client.api_client.rest_client.POST( - url, - post_params=dict(**kwargs, **files), - headers={ - "Content-Type": "multipart/form-data", - "Upload-Multiple": "", - **self._client.api_client.get_common_headers(), - }, + files = {} + for i, filename in enumerate(group): + files[f"client_files[{i}]"] = ( + os.fspath(filename), + filename.read_bytes(), ) + response = self._client.api_client.rest_client.POST( + url, + post_params=dict(**kwargs, **files), + headers={ + "Content-Type": "multipart/form-data", + "Upload-Multiple": "", + **self._client.api_client.get_common_headers(), + }, + ) expect_status(200, response) if pbar is not None: @@ -401,3 +369,38 @@ class DataUploader(Uploader): ) self._tus_finish_upload(url, fields=kwargs) + + def _split_files_by_requests( + self, filenames: List[Path] + ) -> Tuple[List[Tuple[List[Path], int]], List[Path], int]: + bulk_files: Dict[str, int] = {} + separate_files: Dict[str, int] = {} + max_request_size = self.max_request_size + + # sort by size + for filename in filenames: + filename = filename.resolve() + file_size = filename.stat().st_size + if max_request_size < file_size: + separate_files[filename] = file_size + else: + bulk_files[filename] = file_size + + total_size = sum(bulk_files.values()) + sum(separate_files.values()) + + # group small files by requests + bulk_file_groups: List[Tuple[List[str], int]] = [] + current_group_size: int = 0 + current_group: List[str] = [] + for filename, file_size in bulk_files.items(): + if max_request_size < current_group_size + file_size: + bulk_file_groups.append((current_group, current_group_size)) + current_group_size = 0 + current_group = [] + + current_group.append(filename) + current_group_size += file_size + if current_group: + bulk_file_groups.append((current_group, current_group_size)) + + return bulk_file_groups, separate_files, total_size diff --git a/tests/python/sdk/test_pytorch.py b/tests/python/sdk/test_pytorch.py index 35c5ad2d..c6e12ba5 100644 --- a/tests/python/sdk/test_pytorch.py +++ b/tests/python/sdk/test_pytorch.py @@ -70,8 +70,8 @@ class TestTaskVisionDataset: models.PatchedLabelRequest(name="car"), ], ), - ResourceType.LOCAL, - list(map(os.fspath, image_paths)), + resource_type=ResourceType.LOCAL, + resources=list(map(os.fspath, image_paths)), data_params={"chunk_size": 3}, ) @@ -274,8 +274,8 @@ class TestProjectVisionDataset: project_id=self.project.id, subset=subset, ), - ResourceType.LOCAL, - image_paths, + resource_type=ResourceType.LOCAL, + resources=image_paths, data_params={"image_quality": 70}, ) for subset, image_paths in zip(subsets, image_paths_per_task) diff --git a/tests/python/sdk/test_tasks.py b/tests/python/sdk/test_tasks.py index 4398d74d..24c37783 100644 --- a/tests/python/sdk/test_tasks.py +++ b/tests/python/sdk/test_tasks.py @@ -58,7 +58,6 @@ class TestTaskUsecases: "name": "test_task", "labels": [{"name": "car"}, {"name": "person"}], }, - resource_type=ResourceType.LOCAL, resources=[fxt_image_file], data_params={"image_quality": 80}, ) @@ -202,6 +201,38 @@ class TestTaskUsecases: assert response_json["format"] == "CVAT for images 1.1" assert response_json["lfs"] is False + def test_can_upload_data_to_empty_task(self): + pbar_out = io.StringIO() + pbar = make_pbar(file=pbar_out) + + task = self.client.tasks.create( + { + "name": f"test task", + "labels": [{"name": "car"}], + } + ) + + data_params = { + "image_quality": 75, + } + + task_files = generate_image_files(7) + for i, f in enumerate(task_files): + fname = self.tmp_path / f.name + fname.write_bytes(f.getvalue()) + task_files[i] = fname + + task.upload_data( + resources=task_files, + resource_type=ResourceType.LOCAL, + params=data_params, + pbar=pbar, + ) + + assert task.size == 7 + assert "100%" in pbar_out.getvalue().strip("\r").split("\r")[-1] + assert self.stdout.getvalue() == "" + def test_can_retrieve_task(self, fxt_new_task: Task): task_id = fxt_new_task.id