From ce37be1f60f502b68993955cf559e7a68ac2dd59 Mon Sep 17 00:00:00 2001 From: Roman Donchenko Date: Fri, 30 Dec 2022 15:30:12 +0300 Subject: [PATCH] SDK: add a ProjectVisionDataset class (#5523) --- CHANGELOG.md | 2 + cvat-sdk/cvat_sdk/pytorch/__init__.py | 144 +++++++++++++++++---- tests/python/sdk/test_pytorch.py | 172 +++++++++++++++++++++++--- 3 files changed, 274 insertions(+), 44 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2e517700..49b82e3d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Filename pattern to simplify uploading cloud storage data for a task () - \[SDK\] Configuration setting to change the dataset cache directory () +- \[SDK\] Class to represent a project as a PyTorch dataset + () ### Changed - The Docker Compose files now use the Compose Specification version diff --git a/cvat-sdk/cvat_sdk/pytorch/__init__.py b/cvat-sdk/cvat_sdk/pytorch/__init__.py index 9bd24201..37002283 100644 --- a/cvat-sdk/cvat_sdk/pytorch/__init__.py +++ b/cvat-sdk/cvat_sdk/pytorch/__init__.py @@ -6,31 +6,22 @@ import shutil import types import zipfile from concurrent.futures import ThreadPoolExecutor -from typing import ( - Callable, - Dict, - FrozenSet, - List, - Mapping, - Optional, - Sequence, - Tuple, - Type, - TypeVar, -) +from pathlib import Path +from typing import Callable, Container, Dict, FrozenSet, List, Mapping, Optional, Type, TypeVar import attrs import attrs.validators import PIL.Image import torch +import torch.utils.data import torchvision.datasets from typing_extensions import TypedDict import cvat_sdk.core import cvat_sdk.core.exceptions +import cvat_sdk.models as models from cvat_sdk.api_client.model_utils import to_json from cvat_sdk.core.utils import atomic_writer -from cvat_sdk.models import DataMetaRead, LabeledData, LabeledImage, LabeledShape, TaskRead _ModelType = TypeVar("_ModelType") @@ -47,8 +38,8 @@ class FrameAnnotations: Contains annotations that pertain to a single frame. """ - tags: List[LabeledImage] = attrs.Factory(list) - shapes: List[LabeledShape] = attrs.Factory(list) + tags: List[models.LabeledImage] = attrs.Factory(list) + shapes: List[models.LabeledShape] = attrs.Factory(list) @attrs.frozen @@ -67,6 +58,12 @@ class Target: """ +def _get_server_dir(client: cvat_sdk.core.Client) -> Path: + # Base64-encode the name to avoid FS-unsafe characters (like slashes) + server_dir_name = base64.urlsafe_b64encode(client.api_map.host.encode()).rstrip(b"=").decode() + return client.config.cache_dir / f"servers/{server_dir_name}" + + class TaskVisionDataset(torchvision.datasets.VisionDataset): """ Represents a task on a CVAT server as a PyTorch Dataset. @@ -132,13 +129,7 @@ class TaskVisionDataset(torchvision.datasets.VisionDataset): f" current chunk type is {self._task.data_original_chunk_type!r}" ) - # Base64-encode the name to avoid FS-unsafe characters (like slashes) - server_dir_name = ( - base64.urlsafe_b64encode(client.api_map.host.encode()).rstrip(b"=").decode() - ) - server_dir = client.config.cache_dir / f"servers/{server_dir_name}" - - self._task_dir = server_dir / f"tasks/{self._task.id}" + self._task_dir = _get_server_dir(client) / f"tasks/{self._task.id}" self._initialize_task_dir() super().__init__( @@ -149,7 +140,7 @@ class TaskVisionDataset(torchvision.datasets.VisionDataset): ) data_meta = self._ensure_model( - "data_meta.json", DataMetaRead, self._task.get_meta, "data metadata" + "data_meta.json", models.DataMetaRead, self._task.get_meta, "data metadata" ) self._active_frame_indexes = sorted( set(range(self._task.size)) - set(data_meta.deleted_frames) @@ -186,7 +177,7 @@ class TaskVisionDataset(torchvision.datasets.VisionDataset): ) annotations = self._ensure_model( - "annotations.json", LabeledData, self._task.get_annotations, "annotations" + "annotations.json", models.LabeledData, self._task.get_annotations, "annotations" ) self._frame_annotations: Dict[int, FrameAnnotations] = collections.defaultdict( @@ -206,7 +197,7 @@ class TaskVisionDataset(torchvision.datasets.VisionDataset): try: with open(task_json_path, "rb") as task_json_file: - saved_task = TaskRead._new_from_openapi_data(**json.load(task_json_file)) + saved_task = models.TaskRead._new_from_openapi_data(**json.load(task_json_file)) except Exception: self._logger.info("Task is not yet cached or the cache is corrupted") @@ -295,6 +286,109 @@ class TaskVisionDataset(torchvision.datasets.VisionDataset): return len(self._active_frame_indexes) +class ProjectVisionDataset(torchvision.datasets.VisionDataset): + """ + Represents a project on a CVAT server as a PyTorch Dataset. + + The dataset contains one sample for each frame of each task in the project + (except for tasks that are filtered out - see the description of `task_filter` + in the constructor). The sequence of samples is formed by concatening sequences + of samples from all included tasks in an arbitrary order that's consistent + between executions. Each task's sequence of samples corresponds to the sequence + of frames on the server. + + See `TaskVisionDataset` for information on sample format, caching, and + current limitations. + """ + + def __init__( + self, + client: cvat_sdk.core.Client, + project_id: int, + *, + transforms: Optional[Callable] = None, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + label_name_to_index: Mapping[str, int] = None, + task_filter: Optional[Callable[[models.ITaskRead], bool]] = None, + include_subsets: Optional[Container[str]] = None, + ) -> None: + """ + Creates a dataset corresponding to the project with ID `project_id` on the + server that `client` is connected to. + + `transforms`, `transform` and `target_transforms` are optional transformation + functions; see the documentation for `torchvision.datasets.VisionDataset` for + more information. + + See `TaskVisionDataset.__init__` for information on `label_name_to_index`. + + By default, all of the project's tasks will be included in the dataset. + The following parameters can be specified to exclude some tasks: + + * If `task_filter` is set to a callable object, it will be applied to every task. + Tasks for which it returns a false value will be excluded. + + * If `include_subsets` is set to a container, then tasks whose subset is + not a member of this container will be excluded. + """ + + self._logger = client.logger + + self._logger.info(f"Fetching project {project_id}...") + project = client.projects.retrieve(project_id) + + # We don't actually need to save anything to this directory (yet), + # but VisionDataset.__init__ requires a root, so make one. + # It could be useful in the future to store the project data for + # offline-only mode. + project_dir = _get_server_dir(client) / f"projects/{project_id}" + project_dir.mkdir(parents=True, exist_ok=True) + + super().__init__( + os.fspath(project_dir), + transforms=transforms, + transform=transform, + target_transform=target_transform, + ) + + self._logger.info("Fetching project tasks...") + tasks = project.get_tasks() + + if task_filter is not None: + tasks = list(filter(task_filter, tasks)) + + if include_subsets is not None: + tasks = [task for task in tasks if task.subset in include_subsets] + + tasks.sort(key=lambda t: t.id) # ensure consistent order between executions + + self._underlying = torch.utils.data.ConcatDataset( + [ + TaskVisionDataset(client, task.id, label_name_to_index=label_name_to_index) + for task in tasks + ] + ) + + def __getitem__(self, sample_index: int): + """ + Returns the sample with index `sample_index`. + + `sample_index` must satisfy the condition `0 <= sample_index < len(self)`. + """ + + sample_image, sample_target = self._underlying[sample_index] + + if self.transforms: + sample_image, sample_target = self.transforms(sample_image, sample_target) + + return sample_image, sample_target + + def __len__(self) -> int: + """Returns the number of samples in the dataset.""" + return len(self._underlying) + + @attrs.frozen class ExtractSingleLabelIndex: """ diff --git a/tests/python/sdk/test_pytorch.py b/tests/python/sdk/test_pytorch.py index 77cd6ecd..35c5ad2d 100644 --- a/tests/python/sdk/test_pytorch.py +++ b/tests/python/sdk/test_pytorch.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: MIT import io +import itertools import os from logging import Logger from pathlib import Path @@ -25,6 +26,22 @@ except ImportError: from shared.utils.helpers import generate_image_files +@pytest.fixture(autouse=True) +def _common_setup( + tmp_path: Path, + fxt_login: Tuple[Client, str], + fxt_logger: Tuple[Logger, io.StringIO], +): + logger = fxt_logger[0] + client = fxt_login[0] + client.logger = logger + client.config.cache_dir = tmp_path / "cache" + + api_client = client.api_client + for k in api_client.configuration.logger: + api_client.configuration.logger[k] = logger + + @pytest.mark.skipif(cvatpt is None, reason="PyTorch dependencies are not installed") class TestTaskVisionDataset: @pytest.fixture(autouse=True) @@ -32,28 +49,11 @@ class TestTaskVisionDataset: self, tmp_path: Path, fxt_login: Tuple[Client, str], - fxt_logger: Tuple[Logger, io.StringIO], - fxt_stdout: io.StringIO, ): - self.tmp_path = tmp_path - logger, self.logger_stream = fxt_logger - self.stdout = fxt_stdout - self.client, self.user = fxt_login - self.client.logger = logger - self.client.config.cache_dir = tmp_path / "cache" - - api_client = self.client.api_client - for k in api_client.configuration.logger: - api_client.configuration.logger[k] = logger - - self._create_task() - - yield - - def _create_task(self): + self.client = fxt_login[0] self.images = generate_image_files(10) - image_dir = self.tmp_path / "images" + image_dir = tmp_path / "images" image_dir.mkdir() image_paths = [] @@ -225,3 +225,137 @@ class TestTaskVisionDataset: _, target = dataset[5] assert target.label_id_to_index[label_name_to_id["person"]] == 123 assert target.label_id_to_index[label_name_to_id["car"]] == 456 + + +@pytest.mark.skipif(cvatpt is None, reason="PyTorch dependencies are not installed") +class TestProjectVisionDataset: + @pytest.fixture(autouse=True) + def setup( + self, + tmp_path: Path, + fxt_login: Tuple[Client, str], + ): + self.client = fxt_login[0] + + self.project = self.client.projects.create( + models.ProjectWriteRequest( + "PyTorch integration test project", + labels=[ + models.PatchedLabelRequest(name="person"), + models.PatchedLabelRequest(name="car"), + ], + ) + ) + self.label_ids = sorted(l.id for l in self.project.labels) + + subsets = ["Train", "Test", "Val"] + num_images_per_task = 3 + + all_images = generate_image_files(num_images_per_task * len(subsets)) + + self.images_per_task = list(zip(*[iter(all_images)] * num_images_per_task)) + + image_dir = tmp_path / "images" + image_dir.mkdir() + + image_paths_per_task = [] + for images in self.images_per_task: + image_paths = [] + for image in images: + image_path = image_dir / image.name + image_path.write_bytes(image.getbuffer()) + image_paths.append(image_path) + image_paths_per_task.append(image_paths) + + self.tasks = [ + self.client.tasks.create_from_data( + models.TaskWriteRequest( + "PyTorch integration test task", + project_id=self.project.id, + subset=subset, + ), + ResourceType.LOCAL, + image_paths, + data_params={"image_quality": 70}, + ) + for subset, image_paths in zip(subsets, image_paths_per_task) + ] + + # sort both self.tasks and self.images_per_task in the order that ProjectVisionDataset uses + self.tasks, self.images_per_task = zip( + *sorted(zip(self.tasks, self.images_per_task), key=lambda t: t[0].id) + ) + + for task_id, label_index in ((0, 0), (1, 1), (2, 0)): + self.tasks[task_id].update_annotations( + models.PatchedLabeledDataRequest( + tags=[ + models.LabeledImageRequest( + frame=task_id, label_id=self.label_ids[label_index] + ), + ], + ) + ) + + def test_basic(self): + dataset = cvatpt.ProjectVisionDataset(self.client, self.project.id) + + assert len(dataset) == sum(task.size for task in self.tasks) + + for sample, image in zip(dataset, itertools.chain.from_iterable(self.images_per_task)): + assert torch.equal(TF.pil_to_tensor(sample[0]), TF.pil_to_tensor(PIL.Image.open(image))) + + assert dataset[0][1].annotations.tags[0].label_id == self.label_ids[0] + assert dataset[4][1].annotations.tags[0].label_id == self.label_ids[1] + assert dataset[8][1].annotations.tags[0].label_id == self.label_ids[0] + + def _test_filtering(self, **kwargs): + dataset = cvatpt.ProjectVisionDataset(self.client, self.project.id, **kwargs) + + assert len(dataset) == sum(task.size for task in self.tasks[1:]) + + for sample, image in zip(dataset, itertools.chain.from_iterable(self.images_per_task[1:])): + assert torch.equal(TF.pil_to_tensor(sample[0]), TF.pil_to_tensor(PIL.Image.open(image))) + + assert dataset[1][1].annotations.tags[0].label_id == self.label_ids[1] + assert dataset[5][1].annotations.tags[0].label_id == self.label_ids[0] + + def test_task_filter(self): + self._test_filtering(task_filter=lambda t: t.subset != self.tasks[0].subset) + + def test_include_subsets(self): + self._test_filtering(include_subsets={self.tasks[1].subset, self.tasks[2].subset}) + + def test_custom_label_mapping(self): + label_name_to_id = {label.name: label.id for label in self.project.labels} + + dataset = cvatpt.ProjectVisionDataset( + self.client, self.project.id, label_name_to_index={"person": 123, "car": 456} + ) + + _, target = dataset[5] + assert target.label_id_to_index[label_name_to_id["person"]] == 123 + assert target.label_id_to_index[label_name_to_id["car"]] == 456 + + def test_separate_transforms(self): + dataset = cvatpt.ProjectVisionDataset( + self.client, + self.project.id, + transform=torchvision.transforms.ToTensor(), + target_transform=cvatpt.ExtractSingleLabelIndex(), + ) + + assert torch.equal( + dataset[0][0], TF.pil_to_tensor(PIL.Image.open(self.images_per_task[0][0])) + ) + assert torch.equal(dataset[0][1], torch.tensor(0)) + + def test_combined_transforms(self): + dataset = cvatpt.ProjectVisionDataset( + self.client, + self.project.id, + transforms=lambda x, y: (y, x), + ) + + assert isinstance(dataset[0][0], cvatpt.Target) + assert isinstance(dataset[0][1], PIL.Image.Image)