From fa92ccb987c72fd8ddc0aac8b7c718557068976f Mon Sep 17 00:00:00 2001 From: Roman Donchenko Date: Fri, 16 Dec 2022 16:55:31 +0300 Subject: [PATCH] SDK: Improve the PyTorch adapter layer (#5455) * Make the extractors return tensors instead of Python data structures. * Let the user specify custom label IDs. --- cvat-sdk/cvat_sdk/pytorch/__init__.py | 54 +++++++++++++++++++-------- tests/python/sdk/test_pytorch.py | 29 +++++++++++--- 2 files changed, 62 insertions(+), 21 deletions(-) diff --git a/cvat-sdk/cvat_sdk/pytorch/__init__.py b/cvat-sdk/cvat_sdk/pytorch/__init__.py index 55b88186..fa6b38a0 100644 --- a/cvat-sdk/cvat_sdk/pytorch/__init__.py +++ b/cvat-sdk/cvat_sdk/pytorch/__init__.py @@ -24,6 +24,7 @@ import appdirs import attrs import attrs.validators import PIL.Image +import torch import torchvision.datasets from typing_extensions import TypedDict @@ -65,8 +66,7 @@ class Target: label_id_to_index: Mapping[int, int] """ A mapping from label_id values in `LabeledImage` and `LabeledShape` objects - to an index in the range [0, num_labels), where num_labels is the number of labels - defined in the task. This mapping is consistent across all samples for a given task. + to an integer index. This mapping is consistent across all samples for a given task. """ @@ -99,6 +99,7 @@ class TaskVisionDataset(torchvision.datasets.VisionDataset): transforms: Optional[Callable] = None, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, + label_name_to_index: Mapping[str, int] = None, ) -> None: """ Creates a dataset corresponding to the task with ID `task_id` on the @@ -107,6 +108,17 @@ class TaskVisionDataset(torchvision.datasets.VisionDataset): `transforms`, `transform` and `target_transforms` are optional transformation functions; see the documentation for `torchvision.datasets.VisionDataset` for more information. + + `label_name_to_index` affects the `label_id_to_index` member in `Target` objects + returned by the dataset. If it is specified, then it must contain an entry for + each label name in the task. The `label_id_to_index` mapping will be constructed + so that each label will be mapped to the index corresponding to the label's name + in `label_name_to_index`. + + If `label_name_to_index` is unspecified or set to `None`, then `label_id_to_index` + will map each label ID to a distinct integer in the range [0, `num_labels`), where + `num_labels` is the number of labels defined in the task. This mapping will be + generally unpredictable, but consistent for a given task. """ self._logger = client.logger @@ -162,12 +174,19 @@ class TaskVisionDataset(torchvision.datasets.VisionDataset): self._logger.info("All chunks downloaded") - self._label_id_to_index = types.MappingProxyType( - { - label["id"]: label_index - for label_index, label in enumerate(sorted(self._task.labels, key=lambda l: l.id)) - } - ) + if label_name_to_index is None: + self._label_id_to_index = types.MappingProxyType( + { + label.id: label_index + for label_index, label in enumerate( + sorted(self._task.labels, key=lambda l: l.id) + ) + } + ) + else: + self._label_id_to_index = types.MappingProxyType( + {label.id: label_name_to_index[label.name] for label in self._task.labels} + ) annotations = self._ensure_model( "annotations.json", LabeledData, self._task.get_annotations, "annotations" @@ -283,7 +302,7 @@ class TaskVisionDataset(torchvision.datasets.VisionDataset): class ExtractSingleLabelIndex: """ A target transform that takes a `Target` object and produces a single label index - based on the tag in that object. + based on the tag in that object, as a 0-dimensional tensor. This makes the dataset samples compatible with the image classification networks in torchvision. @@ -299,12 +318,12 @@ class ExtractSingleLabelIndex: if len(tags) > 1: raise ValueError("sample has multiple tags") - return target.label_id_to_index[tags[0].label_id] + return torch.tensor(target.label_id_to_index[tags[0].label_id], dtype=torch.long) class LabeledBoxes(TypedDict): - boxes: Sequence[Tuple[float, float, float, float]] - labels: Sequence[int] + boxes: torch.Tensor + labels: torch.Tensor _SUPPORTED_SHAPE_TYPES = frozenset(["rectangle", "polygon", "polyline", "points", "ellipse"]) @@ -318,9 +337,9 @@ class ExtractBoundingBoxes: The dictionary contains the following entries: - "boxes": a sequence of (xmin, ymin, xmax, ymax) tuples, one for each shape - in the annotations. - "labels": a sequence of corresponding label indices. + "boxes": a tensor with shape [N, 4], where each row represents a bounding box of a shape + in the annotations in the (xmin, ymin, xmax, ymax) format. + "labels": a tensor with shape [N] containing corresponding label indices. Limitations: @@ -356,4 +375,7 @@ class ExtractBoundingBoxes: boxes.append((min(x_coords), min(y_coords), max(x_coords), max(y_coords))) labels.append(target.label_id_to_index[shape.label_id]) - return LabeledBoxes(boxes=boxes, labels=labels) + return LabeledBoxes( + boxes=torch.tensor(boxes, dtype=torch.float), + labels=torch.tensor(labels, dtype=torch.long), + ) diff --git a/tests/python/sdk/test_pytorch.py b/tests/python/sdk/test_pytorch.py index 69d329b7..1aa61174 100644 --- a/tests/python/sdk/test_pytorch.py +++ b/tests/python/sdk/test_pytorch.py @@ -165,8 +165,8 @@ class TestTaskVisionDataset: target_transform=cvatpt.ExtractSingleLabelIndex(), ) - assert dataset[5][1] == 0 - assert dataset[6][1] == 1 + assert torch.equal(dataset[5][1], torch.tensor(0)) + assert torch.equal(dataset[6][1], torch.tensor(1)) with pytest.raises(ValueError): # no tags @@ -192,9 +192,15 @@ class TestTaskVisionDataset: target_transform=cvatpt.ExtractBoundingBoxes(include_shape_types={"rectangle"}), ) - assert dataset[0][1] == {"boxes": [], "labels": []} - assert dataset[6][1] == {"boxes": [(1.0, 2.0, 3.0, 4.0)], "labels": [1]} - assert dataset[7][1] == {"boxes": [], "labels": []} # points are filtered out + assert torch.equal(dataset[0][1]["boxes"], torch.tensor([])) + assert torch.equal(dataset[0][1]["labels"], torch.tensor([])) + + assert torch.equal(dataset[6][1]["boxes"], torch.tensor([(1.0, 2.0, 3.0, 4.0)])) + assert torch.equal(dataset[6][1]["labels"], torch.tensor([1])) + + # points are filtered out + assert torch.equal(dataset[7][1]["boxes"], torch.tensor([])) + assert torch.equal(dataset[7][1]["labels"], torch.tensor([])) def test_transforms(self): dataset = cvatpt.TaskVisionDataset( @@ -205,3 +211,16 @@ class TestTaskVisionDataset: assert isinstance(dataset[0][0], cvatpt.Target) assert isinstance(dataset[0][1], PIL.Image.Image) + + def test_custom_label_mapping(self): + label_name_to_id = {label.name: label.id for label in self.task.labels} + + dataset = cvatpt.TaskVisionDataset( + self.client, + self.task.id, + label_name_to_index={"person": 123, "car": 456}, + ) + + _, target = dataset[5] + assert target.label_id_to_index[label_name_to_id["person"]] == 123 + assert target.label_id_to_index[label_name_to_id["car"]] == 456