SDK: Add an adapter layer that presents a CVAT task as a torchvision dataset (#5417)
parent
82adde42aa
commit
487c60ce2b
@ -0,0 +1,359 @@
|
|||||||
|
import base64
|
||||||
|
import collections
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import types
|
||||||
|
import zipfile
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import (
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
FrozenSet,
|
||||||
|
List,
|
||||||
|
Mapping,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Tuple,
|
||||||
|
Type,
|
||||||
|
TypeVar,
|
||||||
|
)
|
||||||
|
|
||||||
|
import appdirs
|
||||||
|
import attrs
|
||||||
|
import attrs.validators
|
||||||
|
import PIL.Image
|
||||||
|
import torchvision.datasets
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
|
import cvat_sdk.core
|
||||||
|
import cvat_sdk.core.exceptions
|
||||||
|
from cvat_sdk.api_client.model_utils import to_json
|
||||||
|
from cvat_sdk.core.utils import atomic_writer
|
||||||
|
from cvat_sdk.models import DataMetaRead, LabeledData, LabeledImage, LabeledShape, TaskRead
|
||||||
|
|
||||||
|
_ModelType = TypeVar("_ModelType")
|
||||||
|
|
||||||
|
_CACHE_DIR = Path(appdirs.user_cache_dir("cvat-sdk", "CVAT.ai"))
|
||||||
|
_NUM_DOWNLOAD_THREADS = 4
|
||||||
|
|
||||||
|
|
||||||
|
class UnsupportedDatasetError(cvat_sdk.core.exceptions.CvatSdkException):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@attrs.frozen
|
||||||
|
class FrameAnnotations:
|
||||||
|
"""
|
||||||
|
Contains annotations that pertain to a single frame.
|
||||||
|
"""
|
||||||
|
|
||||||
|
tags: List[LabeledImage] = attrs.Factory(list)
|
||||||
|
shapes: List[LabeledShape] = attrs.Factory(list)
|
||||||
|
|
||||||
|
|
||||||
|
@attrs.frozen
|
||||||
|
class Target:
|
||||||
|
"""
|
||||||
|
Non-image data for a dataset sample.
|
||||||
|
"""
|
||||||
|
|
||||||
|
annotations: FrameAnnotations
|
||||||
|
"""Annotations for the frame corresponding to the sample."""
|
||||||
|
|
||||||
|
label_id_to_index: Mapping[int, int]
|
||||||
|
"""
|
||||||
|
A mapping from label_id values in `LabeledImage` and `LabeledShape` objects
|
||||||
|
to an index in the range [0, num_labels), where num_labels is the number of labels
|
||||||
|
defined in the task. This mapping is consistent across all samples for a given task.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class TaskVisionDataset(torchvision.datasets.VisionDataset):
|
||||||
|
"""
|
||||||
|
Represents a task on a CVAT server as a PyTorch Dataset.
|
||||||
|
|
||||||
|
This dataset contains one sample for each frame in the task, in the same
|
||||||
|
order as the frames are in the task. Deleted frames are omitted.
|
||||||
|
Before transforms are applied, each sample is a tuple of
|
||||||
|
(image, target), where:
|
||||||
|
|
||||||
|
* image is a `PIL.Image.Image` object for the corresponding frame.
|
||||||
|
* target is a `Target` object containing annotations for the frame.
|
||||||
|
|
||||||
|
This class caches all data and annotations for the task on the local file system
|
||||||
|
during construction. If the task is updated on the server, the cache is updated.
|
||||||
|
|
||||||
|
Limitations:
|
||||||
|
|
||||||
|
* Only tasks with image (not video) data are supported at the moment.
|
||||||
|
* Track annotations are currently not accessible.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
client: cvat_sdk.core.Client,
|
||||||
|
task_id: int,
|
||||||
|
*,
|
||||||
|
transforms: Optional[Callable] = None,
|
||||||
|
transform: Optional[Callable] = None,
|
||||||
|
target_transform: Optional[Callable] = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Creates a dataset corresponding to the task with ID `task_id` on the
|
||||||
|
server that `client` is connected to.
|
||||||
|
|
||||||
|
`transforms`, `transform` and `target_transforms` are optional transformation
|
||||||
|
functions; see the documentation for `torchvision.datasets.VisionDataset` for
|
||||||
|
more information.
|
||||||
|
"""
|
||||||
|
|
||||||
|
self._logger = client.logger
|
||||||
|
|
||||||
|
self._logger.info(f"Fetching task {task_id}...")
|
||||||
|
self._task = client.tasks.retrieve(task_id)
|
||||||
|
|
||||||
|
if not self._task.size or not self._task.data_chunk_size:
|
||||||
|
raise UnsupportedDatasetError("The task has no data")
|
||||||
|
|
||||||
|
if self._task.data_original_chunk_type != "imageset":
|
||||||
|
raise UnsupportedDatasetError(
|
||||||
|
f"{self.__class__.__name__} only supports tasks with image chunks;"
|
||||||
|
f" current chunk type is {self._task.data_original_chunk_type!r}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Base64-encode the name to avoid FS-unsafe characters (like slashes)
|
||||||
|
server_dir_name = (
|
||||||
|
base64.urlsafe_b64encode(client.api_map.host.encode()).rstrip(b"=").decode()
|
||||||
|
)
|
||||||
|
server_dir = _CACHE_DIR / f"servers/{server_dir_name}"
|
||||||
|
|
||||||
|
self._task_dir = server_dir / f"tasks/{self._task.id}"
|
||||||
|
self._initialize_task_dir()
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
os.fspath(self._task_dir),
|
||||||
|
transforms=transforms,
|
||||||
|
transform=transform,
|
||||||
|
target_transform=target_transform,
|
||||||
|
)
|
||||||
|
|
||||||
|
data_meta = self._ensure_model(
|
||||||
|
"data_meta.json", DataMetaRead, self._task.get_meta, "data metadata"
|
||||||
|
)
|
||||||
|
self._active_frame_indexes = sorted(
|
||||||
|
set(range(self._task.size)) - set(data_meta.deleted_frames)
|
||||||
|
)
|
||||||
|
|
||||||
|
self._logger.info("Downloading chunks...")
|
||||||
|
|
||||||
|
self._chunk_dir = self._task_dir / "chunks"
|
||||||
|
self._chunk_dir.mkdir(exist_ok=True, parents=True)
|
||||||
|
|
||||||
|
needed_chunks = {
|
||||||
|
index // self._task.data_chunk_size for index in self._active_frame_indexes
|
||||||
|
}
|
||||||
|
|
||||||
|
with ThreadPoolExecutor(_NUM_DOWNLOAD_THREADS) as pool:
|
||||||
|
for _ in pool.map(self._ensure_chunk, sorted(needed_chunks)):
|
||||||
|
# just need to loop through all results so that any exceptions are propagated
|
||||||
|
pass
|
||||||
|
|
||||||
|
self._logger.info("All chunks downloaded")
|
||||||
|
|
||||||
|
self._label_id_to_index = types.MappingProxyType(
|
||||||
|
{
|
||||||
|
label["id"]: label_index
|
||||||
|
for label_index, label in enumerate(sorted(self._task.labels, key=lambda l: l.id))
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
annotations = self._ensure_model(
|
||||||
|
"annotations.json", LabeledData, self._task.get_annotations, "annotations"
|
||||||
|
)
|
||||||
|
|
||||||
|
self._frame_annotations: Dict[int, FrameAnnotations] = collections.defaultdict(
|
||||||
|
FrameAnnotations
|
||||||
|
)
|
||||||
|
|
||||||
|
for tag in annotations.tags:
|
||||||
|
self._frame_annotations[tag.frame].tags.append(tag)
|
||||||
|
|
||||||
|
for shape in annotations.shapes:
|
||||||
|
self._frame_annotations[shape.frame].shapes.append(shape)
|
||||||
|
|
||||||
|
# TODO: tracks?
|
||||||
|
|
||||||
|
def _initialize_task_dir(self) -> None:
|
||||||
|
task_json_path = self._task_dir / "task.json"
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(task_json_path, "rb") as task_json_file:
|
||||||
|
saved_task = TaskRead._new_from_openapi_data(**json.load(task_json_file))
|
||||||
|
except Exception:
|
||||||
|
self._logger.info("Task is not yet cached or the cache is corrupted")
|
||||||
|
|
||||||
|
# If the cache was corrupted, the directory might already be there; clear it.
|
||||||
|
if self._task_dir.exists():
|
||||||
|
shutil.rmtree(self._task_dir)
|
||||||
|
else:
|
||||||
|
if saved_task.updated_date < self._task.updated_date:
|
||||||
|
self._logger.info(
|
||||||
|
"Task has been updated on the server since it was cached; purging the cache"
|
||||||
|
)
|
||||||
|
shutil.rmtree(self._task_dir)
|
||||||
|
|
||||||
|
self._task_dir.mkdir(exist_ok=True, parents=True)
|
||||||
|
|
||||||
|
with atomic_writer(task_json_path, "w", encoding="UTF-8") as task_json_file:
|
||||||
|
json.dump(to_json(self._task._model), task_json_file, indent=4)
|
||||||
|
print(file=task_json_file) # add final newline
|
||||||
|
|
||||||
|
def _ensure_chunk(self, chunk_index: int) -> None:
|
||||||
|
chunk_path = self._chunk_dir / f"{chunk_index}.zip"
|
||||||
|
if chunk_path.exists():
|
||||||
|
return # already downloaded previously
|
||||||
|
|
||||||
|
self._logger.info(f"Downloading chunk #{chunk_index}...")
|
||||||
|
|
||||||
|
with atomic_writer(chunk_path, "wb") as chunk_file:
|
||||||
|
self._task.download_chunk(chunk_index, chunk_file, quality="original")
|
||||||
|
|
||||||
|
def _ensure_model(
|
||||||
|
self,
|
||||||
|
filename: str,
|
||||||
|
model_type: Type[_ModelType],
|
||||||
|
download: Callable[[], _ModelType],
|
||||||
|
model_description: str,
|
||||||
|
) -> _ModelType:
|
||||||
|
path = self._task_dir / filename
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(path, "rb") as f:
|
||||||
|
model = model_type._new_from_openapi_data(**json.load(f))
|
||||||
|
self._logger.info(f"Loaded {model_description} from cache")
|
||||||
|
return model
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
|
except Exception:
|
||||||
|
self._logger.warning(f"Failed to load {model_description} from cache", exc_info=True)
|
||||||
|
|
||||||
|
self._logger.info(f"Downloading {model_description}...")
|
||||||
|
model = download()
|
||||||
|
self._logger.info(f"Downloaded {model_description}")
|
||||||
|
|
||||||
|
with atomic_writer(path, "w", encoding="UTF-8") as f:
|
||||||
|
json.dump(to_json(model), f, indent=4)
|
||||||
|
print(file=f) # add final newline
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def __getitem__(self, sample_index: int):
|
||||||
|
"""
|
||||||
|
Returns the sample with index `sample_index`.
|
||||||
|
|
||||||
|
`sample_index` must satisfy the condition `0 <= sample_index < len(self)`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
frame_index = self._active_frame_indexes[sample_index]
|
||||||
|
chunk_index = frame_index // self._task.data_chunk_size
|
||||||
|
member_index = frame_index % self._task.data_chunk_size
|
||||||
|
|
||||||
|
with zipfile.ZipFile(self._chunk_dir / f"{chunk_index}.zip", "r") as chunk_zip:
|
||||||
|
with chunk_zip.open(chunk_zip.infolist()[member_index]) as chunk_member:
|
||||||
|
sample_image = PIL.Image.open(chunk_member)
|
||||||
|
sample_image.load()
|
||||||
|
|
||||||
|
sample_target = Target(
|
||||||
|
annotations=self._frame_annotations[frame_index],
|
||||||
|
label_id_to_index=self._label_id_to_index,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.transforms:
|
||||||
|
sample_image, sample_target = self.transforms(sample_image, sample_target)
|
||||||
|
return sample_image, sample_target
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""Returns the number of samples in the dataset."""
|
||||||
|
return len(self._active_frame_indexes)
|
||||||
|
|
||||||
|
|
||||||
|
@attrs.frozen
|
||||||
|
class ExtractSingleLabelIndex:
|
||||||
|
"""
|
||||||
|
A target transform that takes a `Target` object and produces a single label index
|
||||||
|
based on the tag in that object.
|
||||||
|
|
||||||
|
This makes the dataset samples compatible with the image classification networks
|
||||||
|
in torchvision.
|
||||||
|
|
||||||
|
If the annotations contain no tags, or multiple tags, raises a `ValueError`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __call__(self, target: Target) -> int:
|
||||||
|
tags = target.annotations.tags
|
||||||
|
if not tags:
|
||||||
|
raise ValueError("sample has no tags")
|
||||||
|
|
||||||
|
if len(tags) > 1:
|
||||||
|
raise ValueError("sample has multiple tags")
|
||||||
|
|
||||||
|
return target.label_id_to_index[tags[0].label_id]
|
||||||
|
|
||||||
|
|
||||||
|
class LabeledBoxes(TypedDict):
|
||||||
|
boxes: Sequence[Tuple[float, float, float, float]]
|
||||||
|
labels: Sequence[int]
|
||||||
|
|
||||||
|
|
||||||
|
_SUPPORTED_SHAPE_TYPES = frozenset(["rectangle", "polygon", "polyline", "points", "ellipse"])
|
||||||
|
|
||||||
|
|
||||||
|
@attrs.frozen
|
||||||
|
class ExtractBoundingBoxes:
|
||||||
|
"""
|
||||||
|
A target transform that takes a `Target` object and returns a dictionary compatible
|
||||||
|
with the object detection networks in torchvision.
|
||||||
|
|
||||||
|
The dictionary contains the following entries:
|
||||||
|
|
||||||
|
"boxes": a sequence of (xmin, ymin, xmax, ymax) tuples, one for each shape
|
||||||
|
in the annotations.
|
||||||
|
"labels": a sequence of corresponding label indices.
|
||||||
|
|
||||||
|
Limitations:
|
||||||
|
|
||||||
|
* Only the following shape types are supported: rectangle, polygon, polyline,
|
||||||
|
points, ellipse.
|
||||||
|
* Rotated shapes are not supported.
|
||||||
|
|
||||||
|
Unsupported shapes will cause a `UnsupportedDatasetError` exception to be
|
||||||
|
raised unless they are filtered out by `include_shape_types`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
include_shape_types: FrozenSet[str] = attrs.field(
|
||||||
|
converter=frozenset,
|
||||||
|
validator=attrs.validators.deep_iterable(attrs.validators.in_(_SUPPORTED_SHAPE_TYPES)),
|
||||||
|
kw_only=True,
|
||||||
|
)
|
||||||
|
"""Shapes whose type is not in this set will be ignored."""
|
||||||
|
|
||||||
|
def __call__(self, target: Target) -> LabeledBoxes:
|
||||||
|
boxes = []
|
||||||
|
labels = []
|
||||||
|
|
||||||
|
for shape in target.annotations.shapes:
|
||||||
|
if shape.type.value not in self.include_shape_types:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if shape.rotation != 0:
|
||||||
|
raise UnsupportedDatasetError("Rotated shapes are not supported")
|
||||||
|
|
||||||
|
x_coords = shape.points[0::2]
|
||||||
|
y_coords = shape.points[1::2]
|
||||||
|
|
||||||
|
boxes.append((min(x_coords), min(y_coords), max(x_coords), max(y_coords)))
|
||||||
|
labels.append(target.label_id_to_index[shape.label_id])
|
||||||
|
|
||||||
|
return LabeledBoxes(boxes=boxes, labels=labels)
|
||||||
@ -0,0 +1,207 @@
|
|||||||
|
# Copyright (C) 2022 CVAT.ai Corporation
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
import io
|
||||||
|
import os
|
||||||
|
from logging import Logger
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from cvat_sdk import Client, models
|
||||||
|
from cvat_sdk.core.proxies.tasks import ResourceType
|
||||||
|
|
||||||
|
try:
|
||||||
|
import cvat_sdk.pytorch as cvatpt
|
||||||
|
import PIL.Image
|
||||||
|
import torch
|
||||||
|
import torchvision.transforms
|
||||||
|
import torchvision.transforms.functional as TF
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
except ImportError:
|
||||||
|
cvatpt = None
|
||||||
|
|
||||||
|
from shared.utils.helpers import generate_image_files
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(cvatpt is None, reason="PyTorch dependencies are not installed")
|
||||||
|
class TestTaskVisionDataset:
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def setup(
|
||||||
|
self,
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
tmp_path: Path,
|
||||||
|
fxt_login: Tuple[Client, str],
|
||||||
|
fxt_logger: Tuple[Logger, io.StringIO],
|
||||||
|
fxt_stdout: io.StringIO,
|
||||||
|
):
|
||||||
|
self.tmp_path = tmp_path
|
||||||
|
logger, self.logger_stream = fxt_logger
|
||||||
|
self.stdout = fxt_stdout
|
||||||
|
self.client, self.user = fxt_login
|
||||||
|
self.client.logger = logger
|
||||||
|
|
||||||
|
api_client = self.client.api_client
|
||||||
|
for k in api_client.configuration.logger:
|
||||||
|
api_client.configuration.logger[k] = logger
|
||||||
|
|
||||||
|
monkeypatch.setattr(cvatpt, "_CACHE_DIR", self.tmp_path / "cache")
|
||||||
|
|
||||||
|
self._create_task()
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
def _create_task(self):
|
||||||
|
self.images = generate_image_files(10)
|
||||||
|
|
||||||
|
image_dir = self.tmp_path / "images"
|
||||||
|
image_dir.mkdir()
|
||||||
|
|
||||||
|
image_paths = []
|
||||||
|
for image in self.images:
|
||||||
|
image_path = image_dir / image.name
|
||||||
|
image_path.write_bytes(image.getbuffer())
|
||||||
|
image_paths.append(image_path)
|
||||||
|
|
||||||
|
self.task = self.client.tasks.create_from_data(
|
||||||
|
models.TaskWriteRequest(
|
||||||
|
"PyTorch integration test task",
|
||||||
|
labels=[
|
||||||
|
models.PatchedLabelRequest(name="person"),
|
||||||
|
models.PatchedLabelRequest(name="car"),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
ResourceType.LOCAL,
|
||||||
|
list(map(os.fspath, image_paths)),
|
||||||
|
data_params={"chunk_size": 3},
|
||||||
|
)
|
||||||
|
|
||||||
|
self.label_ids = sorted(l.id for l in self.task.labels)
|
||||||
|
|
||||||
|
self.task.update_annotations(
|
||||||
|
models.PatchedLabeledDataRequest(
|
||||||
|
tags=[
|
||||||
|
models.LabeledImageRequest(frame=5, label_id=self.label_ids[0]),
|
||||||
|
models.LabeledImageRequest(frame=6, label_id=self.label_ids[1]),
|
||||||
|
models.LabeledImageRequest(frame=8, label_id=self.label_ids[0]),
|
||||||
|
models.LabeledImageRequest(frame=8, label_id=self.label_ids[1]),
|
||||||
|
],
|
||||||
|
shapes=[
|
||||||
|
models.LabeledShapeRequest(
|
||||||
|
frame=6,
|
||||||
|
label_id=self.label_ids[1],
|
||||||
|
type=models.ShapeType("rectangle"),
|
||||||
|
points=[1.0, 2.0, 3.0, 4.0],
|
||||||
|
),
|
||||||
|
models.LabeledShapeRequest(
|
||||||
|
frame=7,
|
||||||
|
label_id=self.label_ids[0],
|
||||||
|
type=models.ShapeType("points"),
|
||||||
|
points=[1.1, 2.1, 3.1, 4.1],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_basic(self):
|
||||||
|
dataset = cvatpt.TaskVisionDataset(self.client, self.task.id)
|
||||||
|
|
||||||
|
assert len(dataset) == self.task.size
|
||||||
|
|
||||||
|
for index, (sample_image, sample_target) in enumerate(dataset):
|
||||||
|
sample_image_tensor = TF.pil_to_tensor(sample_image)
|
||||||
|
reference_tensor = TF.pil_to_tensor(PIL.Image.open(self.images[index]))
|
||||||
|
assert torch.equal(sample_image_tensor, reference_tensor)
|
||||||
|
|
||||||
|
for index, label_id in enumerate(self.label_ids):
|
||||||
|
assert sample_target.label_id_to_index[label_id] == index
|
||||||
|
|
||||||
|
assert not dataset[0][1].annotations.tags
|
||||||
|
assert not dataset[0][1].annotations.shapes
|
||||||
|
|
||||||
|
assert len(dataset[5][1].annotations.tags) == 1
|
||||||
|
assert dataset[5][1].annotations.tags[0].label_id == self.label_ids[0]
|
||||||
|
assert not dataset[5][1].annotations.shapes
|
||||||
|
|
||||||
|
assert len(dataset[6][1].annotations.tags) == 1
|
||||||
|
assert dataset[6][1].annotations.tags[0].label_id == self.label_ids[1]
|
||||||
|
assert len(dataset[6][1].annotations.shapes) == 1
|
||||||
|
assert dataset[6][1].annotations.shapes[0].type.value == "rectangle"
|
||||||
|
assert dataset[6][1].annotations.shapes[0].points == [1.0, 2.0, 3.0, 4.0]
|
||||||
|
|
||||||
|
assert not dataset[7][1].annotations.tags
|
||||||
|
assert len(dataset[7][1].annotations.shapes) == 1
|
||||||
|
assert dataset[7][1].annotations.shapes[0].type.value == "points"
|
||||||
|
assert dataset[7][1].annotations.shapes[0].points == [1.1, 2.1, 3.1, 4.1]
|
||||||
|
|
||||||
|
def test_deleted_frame(self):
|
||||||
|
self.task.remove_frames_by_ids([1])
|
||||||
|
|
||||||
|
dataset = cvatpt.TaskVisionDataset(self.client, self.task.id)
|
||||||
|
|
||||||
|
assert len(dataset) == self.task.size - 1
|
||||||
|
|
||||||
|
# sample #0 is still frame #0
|
||||||
|
assert torch.equal(
|
||||||
|
TF.pil_to_tensor(dataset[0][0]), TF.pil_to_tensor(PIL.Image.open(self.images[0]))
|
||||||
|
)
|
||||||
|
|
||||||
|
# sample #1 is now frame #2
|
||||||
|
assert torch.equal(
|
||||||
|
TF.pil_to_tensor(dataset[1][0]), TF.pil_to_tensor(PIL.Image.open(self.images[2]))
|
||||||
|
)
|
||||||
|
|
||||||
|
# sample #4 is now frame #5
|
||||||
|
assert len(dataset[4][1].annotations.tags) == 1
|
||||||
|
assert dataset[4][1].annotations.tags[0].label_id == self.label_ids[0]
|
||||||
|
assert not dataset[4][1].annotations.shapes
|
||||||
|
|
||||||
|
def test_extract_single_label_index(self):
|
||||||
|
dataset = cvatpt.TaskVisionDataset(
|
||||||
|
self.client,
|
||||||
|
self.task.id,
|
||||||
|
transform=torchvision.transforms.PILToTensor(),
|
||||||
|
target_transform=cvatpt.ExtractSingleLabelIndex(),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert dataset[5][1] == 0
|
||||||
|
assert dataset[6][1] == 1
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
# no tags
|
||||||
|
_ = dataset[7]
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
# multiple tags
|
||||||
|
_ = dataset[8]
|
||||||
|
|
||||||
|
# make sure the samples can be batched with the default collater
|
||||||
|
loader = DataLoader(dataset, batch_size=2, sampler=[5, 6])
|
||||||
|
|
||||||
|
batch = next(iter(loader))
|
||||||
|
assert torch.equal(batch[0][0], TF.pil_to_tensor(PIL.Image.open(self.images[5])))
|
||||||
|
assert torch.equal(batch[0][1], TF.pil_to_tensor(PIL.Image.open(self.images[6])))
|
||||||
|
assert torch.equal(batch[1], torch.tensor([0, 1]))
|
||||||
|
|
||||||
|
def test_extract_bounding_boxes(self):
|
||||||
|
dataset = cvatpt.TaskVisionDataset(
|
||||||
|
self.client,
|
||||||
|
self.task.id,
|
||||||
|
transform=torchvision.transforms.PILToTensor(),
|
||||||
|
target_transform=cvatpt.ExtractBoundingBoxes(include_shape_types={"rectangle"}),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert dataset[0][1] == {"boxes": [], "labels": []}
|
||||||
|
assert dataset[6][1] == {"boxes": [(1.0, 2.0, 3.0, 4.0)], "labels": [1]}
|
||||||
|
assert dataset[7][1] == {"boxes": [], "labels": []} # points are filtered out
|
||||||
|
|
||||||
|
def test_transforms(self):
|
||||||
|
dataset = cvatpt.TaskVisionDataset(
|
||||||
|
self.client,
|
||||||
|
self.task.id,
|
||||||
|
transforms=lambda x, y: (y, x),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(dataset[0][0], cvatpt.Target)
|
||||||
|
assert isinstance(dataset[0][1], PIL.Image.Image)
|
||||||
Loading…
Reference in New Issue