SDK: Add an adapter layer that presents a CVAT task as a torchvision dataset (#5417)

main
Roman Donchenko 3 years ago committed by GitHub
parent 82adde42aa
commit 487c60ce2b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -196,7 +196,7 @@ jobs:
- name: Running REST API and SDK tests
run: |
pip3 install --user /tmp/cvat_sdk/
pip3 install --user '/tmp/cvat_sdk/[pytorch]'
pip3 install --user cvat-cli/
pip3 install --user -r tests/python/requirements.txt
pytest tests/python -s -v

@ -164,7 +164,7 @@ jobs:
- name: Running REST API and SDK tests
run: |
pip3 install --user /tmp/cvat_sdk/
pip3 install --user '/tmp/cvat_sdk/[pytorch]'
pip3 install --user cvat-cli/
pip3 install --user -r tests/python/requirements.txt
pytest tests/python/ -s -v

@ -235,7 +235,7 @@ jobs:
gen/generate.sh
cd ..
pip3 install --user cvat-sdk/
pip3 install --user 'cvat-sdk/[pytorch]'
pip3 install --user cvat-cli/
pip3 install --user -r tests/python/requirements.txt
pytest tests/python/

@ -21,6 +21,8 @@ from online detectors & interactors) (<https://github.com/opencv/cvat/pull/4543>
- Authentication with social accounts google & github (<https://github.com/opencv/cvat/pull/5147>, <https://github.com/opencv/cvat/pull/5181>, <https://github.com/opencv/cvat/pull/5295>)
- REST API tests to export job datasets & annotations and validate their structure (<https://github.com/opencv/cvat/pull/5160>)
- Propagation backward on UI (<https://github.com/opencv/cvat/pull/5355>)
- A PyTorch dataset adapter layer in the SDK
(<https://github.com/opencv/cvat/pull/5417>)
### Changed
- `api/docs`, `api/swagger`, `api/schema`, `server/about` endpoints now allow unauthorized access (<https://github.com/opencv/cvat/pull/4928>, <https://github.com/opencv/cvat/pull/4935>)

@ -0,0 +1,359 @@
import base64
import collections
import json
import os
import shutil
import types
import zipfile
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import (
Callable,
Dict,
FrozenSet,
List,
Mapping,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
)
import appdirs
import attrs
import attrs.validators
import PIL.Image
import torchvision.datasets
from typing_extensions import TypedDict
import cvat_sdk.core
import cvat_sdk.core.exceptions
from cvat_sdk.api_client.model_utils import to_json
from cvat_sdk.core.utils import atomic_writer
from cvat_sdk.models import DataMetaRead, LabeledData, LabeledImage, LabeledShape, TaskRead
_ModelType = TypeVar("_ModelType")
_CACHE_DIR = Path(appdirs.user_cache_dir("cvat-sdk", "CVAT.ai"))
_NUM_DOWNLOAD_THREADS = 4
class UnsupportedDatasetError(cvat_sdk.core.exceptions.CvatSdkException):
pass
@attrs.frozen
class FrameAnnotations:
"""
Contains annotations that pertain to a single frame.
"""
tags: List[LabeledImage] = attrs.Factory(list)
shapes: List[LabeledShape] = attrs.Factory(list)
@attrs.frozen
class Target:
"""
Non-image data for a dataset sample.
"""
annotations: FrameAnnotations
"""Annotations for the frame corresponding to the sample."""
label_id_to_index: Mapping[int, int]
"""
A mapping from label_id values in `LabeledImage` and `LabeledShape` objects
to an index in the range [0, num_labels), where num_labels is the number of labels
defined in the task. This mapping is consistent across all samples for a given task.
"""
class TaskVisionDataset(torchvision.datasets.VisionDataset):
"""
Represents a task on a CVAT server as a PyTorch Dataset.
This dataset contains one sample for each frame in the task, in the same
order as the frames are in the task. Deleted frames are omitted.
Before transforms are applied, each sample is a tuple of
(image, target), where:
* image is a `PIL.Image.Image` object for the corresponding frame.
* target is a `Target` object containing annotations for the frame.
This class caches all data and annotations for the task on the local file system
during construction. If the task is updated on the server, the cache is updated.
Limitations:
* Only tasks with image (not video) data are supported at the moment.
* Track annotations are currently not accessible.
"""
def __init__(
self,
client: cvat_sdk.core.Client,
task_id: int,
*,
transforms: Optional[Callable] = None,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
) -> None:
"""
Creates a dataset corresponding to the task with ID `task_id` on the
server that `client` is connected to.
`transforms`, `transform` and `target_transforms` are optional transformation
functions; see the documentation for `torchvision.datasets.VisionDataset` for
more information.
"""
self._logger = client.logger
self._logger.info(f"Fetching task {task_id}...")
self._task = client.tasks.retrieve(task_id)
if not self._task.size or not self._task.data_chunk_size:
raise UnsupportedDatasetError("The task has no data")
if self._task.data_original_chunk_type != "imageset":
raise UnsupportedDatasetError(
f"{self.__class__.__name__} only supports tasks with image chunks;"
f" current chunk type is {self._task.data_original_chunk_type!r}"
)
# Base64-encode the name to avoid FS-unsafe characters (like slashes)
server_dir_name = (
base64.urlsafe_b64encode(client.api_map.host.encode()).rstrip(b"=").decode()
)
server_dir = _CACHE_DIR / f"servers/{server_dir_name}"
self._task_dir = server_dir / f"tasks/{self._task.id}"
self._initialize_task_dir()
super().__init__(
os.fspath(self._task_dir),
transforms=transforms,
transform=transform,
target_transform=target_transform,
)
data_meta = self._ensure_model(
"data_meta.json", DataMetaRead, self._task.get_meta, "data metadata"
)
self._active_frame_indexes = sorted(
set(range(self._task.size)) - set(data_meta.deleted_frames)
)
self._logger.info("Downloading chunks...")
self._chunk_dir = self._task_dir / "chunks"
self._chunk_dir.mkdir(exist_ok=True, parents=True)
needed_chunks = {
index // self._task.data_chunk_size for index in self._active_frame_indexes
}
with ThreadPoolExecutor(_NUM_DOWNLOAD_THREADS) as pool:
for _ in pool.map(self._ensure_chunk, sorted(needed_chunks)):
# just need to loop through all results so that any exceptions are propagated
pass
self._logger.info("All chunks downloaded")
self._label_id_to_index = types.MappingProxyType(
{
label["id"]: label_index
for label_index, label in enumerate(sorted(self._task.labels, key=lambda l: l.id))
}
)
annotations = self._ensure_model(
"annotations.json", LabeledData, self._task.get_annotations, "annotations"
)
self._frame_annotations: Dict[int, FrameAnnotations] = collections.defaultdict(
FrameAnnotations
)
for tag in annotations.tags:
self._frame_annotations[tag.frame].tags.append(tag)
for shape in annotations.shapes:
self._frame_annotations[shape.frame].shapes.append(shape)
# TODO: tracks?
def _initialize_task_dir(self) -> None:
task_json_path = self._task_dir / "task.json"
try:
with open(task_json_path, "rb") as task_json_file:
saved_task = TaskRead._new_from_openapi_data(**json.load(task_json_file))
except Exception:
self._logger.info("Task is not yet cached or the cache is corrupted")
# If the cache was corrupted, the directory might already be there; clear it.
if self._task_dir.exists():
shutil.rmtree(self._task_dir)
else:
if saved_task.updated_date < self._task.updated_date:
self._logger.info(
"Task has been updated on the server since it was cached; purging the cache"
)
shutil.rmtree(self._task_dir)
self._task_dir.mkdir(exist_ok=True, parents=True)
with atomic_writer(task_json_path, "w", encoding="UTF-8") as task_json_file:
json.dump(to_json(self._task._model), task_json_file, indent=4)
print(file=task_json_file) # add final newline
def _ensure_chunk(self, chunk_index: int) -> None:
chunk_path = self._chunk_dir / f"{chunk_index}.zip"
if chunk_path.exists():
return # already downloaded previously
self._logger.info(f"Downloading chunk #{chunk_index}...")
with atomic_writer(chunk_path, "wb") as chunk_file:
self._task.download_chunk(chunk_index, chunk_file, quality="original")
def _ensure_model(
self,
filename: str,
model_type: Type[_ModelType],
download: Callable[[], _ModelType],
model_description: str,
) -> _ModelType:
path = self._task_dir / filename
try:
with open(path, "rb") as f:
model = model_type._new_from_openapi_data(**json.load(f))
self._logger.info(f"Loaded {model_description} from cache")
return model
except FileNotFoundError:
pass
except Exception:
self._logger.warning(f"Failed to load {model_description} from cache", exc_info=True)
self._logger.info(f"Downloading {model_description}...")
model = download()
self._logger.info(f"Downloaded {model_description}")
with atomic_writer(path, "w", encoding="UTF-8") as f:
json.dump(to_json(model), f, indent=4)
print(file=f) # add final newline
return model
def __getitem__(self, sample_index: int):
"""
Returns the sample with index `sample_index`.
`sample_index` must satisfy the condition `0 <= sample_index < len(self)`.
"""
frame_index = self._active_frame_indexes[sample_index]
chunk_index = frame_index // self._task.data_chunk_size
member_index = frame_index % self._task.data_chunk_size
with zipfile.ZipFile(self._chunk_dir / f"{chunk_index}.zip", "r") as chunk_zip:
with chunk_zip.open(chunk_zip.infolist()[member_index]) as chunk_member:
sample_image = PIL.Image.open(chunk_member)
sample_image.load()
sample_target = Target(
annotations=self._frame_annotations[frame_index],
label_id_to_index=self._label_id_to_index,
)
if self.transforms:
sample_image, sample_target = self.transforms(sample_image, sample_target)
return sample_image, sample_target
def __len__(self) -> int:
"""Returns the number of samples in the dataset."""
return len(self._active_frame_indexes)
@attrs.frozen
class ExtractSingleLabelIndex:
"""
A target transform that takes a `Target` object and produces a single label index
based on the tag in that object.
This makes the dataset samples compatible with the image classification networks
in torchvision.
If the annotations contain no tags, or multiple tags, raises a `ValueError`.
"""
def __call__(self, target: Target) -> int:
tags = target.annotations.tags
if not tags:
raise ValueError("sample has no tags")
if len(tags) > 1:
raise ValueError("sample has multiple tags")
return target.label_id_to_index[tags[0].label_id]
class LabeledBoxes(TypedDict):
boxes: Sequence[Tuple[float, float, float, float]]
labels: Sequence[int]
_SUPPORTED_SHAPE_TYPES = frozenset(["rectangle", "polygon", "polyline", "points", "ellipse"])
@attrs.frozen
class ExtractBoundingBoxes:
"""
A target transform that takes a `Target` object and returns a dictionary compatible
with the object detection networks in torchvision.
The dictionary contains the following entries:
"boxes": a sequence of (xmin, ymin, xmax, ymax) tuples, one for each shape
in the annotations.
"labels": a sequence of corresponding label indices.
Limitations:
* Only the following shape types are supported: rectangle, polygon, polyline,
points, ellipse.
* Rotated shapes are not supported.
Unsupported shapes will cause a `UnsupportedDatasetError` exception to be
raised unless they are filtered out by `include_shape_types`.
"""
include_shape_types: FrozenSet[str] = attrs.field(
converter=frozenset,
validator=attrs.validators.deep_iterable(attrs.validators.in_(_SUPPORTED_SHAPE_TYPES)),
kw_only=True,
)
"""Shapes whose type is not in this set will be ignored."""
def __call__(self, target: Target) -> LabeledBoxes:
boxes = []
labels = []
for shape in target.annotations.shapes:
if shape.type.value not in self.include_shape_types:
continue
if shape.rotation != 0:
raise UnsupportedDatasetError("Rotated shapes are not supported")
x_coords = shape.points[0::2]
y_coords = shape.points[1::2]
boxes.append((min(x_coords), min(y_coords), max(x_coords), max(y_coords)))
labels.append(target.label_id_to_index[shape.label_id])
return LabeledBoxes(boxes=boxes, labels=labels)

@ -76,6 +76,9 @@ setup(
],
python_requires="{{{generatorLanguageVersion}}}",
install_requires=BASE_REQUIREMENTS,
extras_require={
"pytorch": ['appdirs', 'torch', 'torchvision'],
},
package_dir={"": "."},
packages=find_packages(include=["cvat_sdk*"]),
include_package_data=True,

@ -0,0 +1,207 @@
# Copyright (C) 2022 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
import io
import os
from logging import Logger
from pathlib import Path
from typing import Tuple
import pytest
from cvat_sdk import Client, models
from cvat_sdk.core.proxies.tasks import ResourceType
try:
import cvat_sdk.pytorch as cvatpt
import PIL.Image
import torch
import torchvision.transforms
import torchvision.transforms.functional as TF
from torch.utils.data import DataLoader
except ImportError:
cvatpt = None
from shared.utils.helpers import generate_image_files
@pytest.mark.skipif(cvatpt is None, reason="PyTorch dependencies are not installed")
class TestTaskVisionDataset:
@pytest.fixture(autouse=True)
def setup(
self,
monkeypatch: pytest.MonkeyPatch,
tmp_path: Path,
fxt_login: Tuple[Client, str],
fxt_logger: Tuple[Logger, io.StringIO],
fxt_stdout: io.StringIO,
):
self.tmp_path = tmp_path
logger, self.logger_stream = fxt_logger
self.stdout = fxt_stdout
self.client, self.user = fxt_login
self.client.logger = logger
api_client = self.client.api_client
for k in api_client.configuration.logger:
api_client.configuration.logger[k] = logger
monkeypatch.setattr(cvatpt, "_CACHE_DIR", self.tmp_path / "cache")
self._create_task()
yield
def _create_task(self):
self.images = generate_image_files(10)
image_dir = self.tmp_path / "images"
image_dir.mkdir()
image_paths = []
for image in self.images:
image_path = image_dir / image.name
image_path.write_bytes(image.getbuffer())
image_paths.append(image_path)
self.task = self.client.tasks.create_from_data(
models.TaskWriteRequest(
"PyTorch integration test task",
labels=[
models.PatchedLabelRequest(name="person"),
models.PatchedLabelRequest(name="car"),
],
),
ResourceType.LOCAL,
list(map(os.fspath, image_paths)),
data_params={"chunk_size": 3},
)
self.label_ids = sorted(l.id for l in self.task.labels)
self.task.update_annotations(
models.PatchedLabeledDataRequest(
tags=[
models.LabeledImageRequest(frame=5, label_id=self.label_ids[0]),
models.LabeledImageRequest(frame=6, label_id=self.label_ids[1]),
models.LabeledImageRequest(frame=8, label_id=self.label_ids[0]),
models.LabeledImageRequest(frame=8, label_id=self.label_ids[1]),
],
shapes=[
models.LabeledShapeRequest(
frame=6,
label_id=self.label_ids[1],
type=models.ShapeType("rectangle"),
points=[1.0, 2.0, 3.0, 4.0],
),
models.LabeledShapeRequest(
frame=7,
label_id=self.label_ids[0],
type=models.ShapeType("points"),
points=[1.1, 2.1, 3.1, 4.1],
),
],
)
)
def test_basic(self):
dataset = cvatpt.TaskVisionDataset(self.client, self.task.id)
assert len(dataset) == self.task.size
for index, (sample_image, sample_target) in enumerate(dataset):
sample_image_tensor = TF.pil_to_tensor(sample_image)
reference_tensor = TF.pil_to_tensor(PIL.Image.open(self.images[index]))
assert torch.equal(sample_image_tensor, reference_tensor)
for index, label_id in enumerate(self.label_ids):
assert sample_target.label_id_to_index[label_id] == index
assert not dataset[0][1].annotations.tags
assert not dataset[0][1].annotations.shapes
assert len(dataset[5][1].annotations.tags) == 1
assert dataset[5][1].annotations.tags[0].label_id == self.label_ids[0]
assert not dataset[5][1].annotations.shapes
assert len(dataset[6][1].annotations.tags) == 1
assert dataset[6][1].annotations.tags[0].label_id == self.label_ids[1]
assert len(dataset[6][1].annotations.shapes) == 1
assert dataset[6][1].annotations.shapes[0].type.value == "rectangle"
assert dataset[6][1].annotations.shapes[0].points == [1.0, 2.0, 3.0, 4.0]
assert not dataset[7][1].annotations.tags
assert len(dataset[7][1].annotations.shapes) == 1
assert dataset[7][1].annotations.shapes[0].type.value == "points"
assert dataset[7][1].annotations.shapes[0].points == [1.1, 2.1, 3.1, 4.1]
def test_deleted_frame(self):
self.task.remove_frames_by_ids([1])
dataset = cvatpt.TaskVisionDataset(self.client, self.task.id)
assert len(dataset) == self.task.size - 1
# sample #0 is still frame #0
assert torch.equal(
TF.pil_to_tensor(dataset[0][0]), TF.pil_to_tensor(PIL.Image.open(self.images[0]))
)
# sample #1 is now frame #2
assert torch.equal(
TF.pil_to_tensor(dataset[1][0]), TF.pil_to_tensor(PIL.Image.open(self.images[2]))
)
# sample #4 is now frame #5
assert len(dataset[4][1].annotations.tags) == 1
assert dataset[4][1].annotations.tags[0].label_id == self.label_ids[0]
assert not dataset[4][1].annotations.shapes
def test_extract_single_label_index(self):
dataset = cvatpt.TaskVisionDataset(
self.client,
self.task.id,
transform=torchvision.transforms.PILToTensor(),
target_transform=cvatpt.ExtractSingleLabelIndex(),
)
assert dataset[5][1] == 0
assert dataset[6][1] == 1
with pytest.raises(ValueError):
# no tags
_ = dataset[7]
with pytest.raises(ValueError):
# multiple tags
_ = dataset[8]
# make sure the samples can be batched with the default collater
loader = DataLoader(dataset, batch_size=2, sampler=[5, 6])
batch = next(iter(loader))
assert torch.equal(batch[0][0], TF.pil_to_tensor(PIL.Image.open(self.images[5])))
assert torch.equal(batch[0][1], TF.pil_to_tensor(PIL.Image.open(self.images[6])))
assert torch.equal(batch[1], torch.tensor([0, 1]))
def test_extract_bounding_boxes(self):
dataset = cvatpt.TaskVisionDataset(
self.client,
self.task.id,
transform=torchvision.transforms.PILToTensor(),
target_transform=cvatpt.ExtractBoundingBoxes(include_shape_types={"rectangle"}),
)
assert dataset[0][1] == {"boxes": [], "labels": []}
assert dataset[6][1] == {"boxes": [(1.0, 2.0, 3.0, 4.0)], "labels": [1]}
assert dataset[7][1] == {"boxes": [], "labels": []} # points are filtered out
def test_transforms(self):
dataset = cvatpt.TaskVisionDataset(
self.client,
self.task.id,
transforms=lambda x, y: (y, x),
)
assert isinstance(dataset[0][0], cvatpt.Target)
assert isinstance(dataset[0][1], PIL.Image.Image)

@ -8,9 +8,9 @@ from typing import List
from PIL import Image
def generate_image_file(filename="image.png", size=(50, 50)):
def generate_image_file(filename="image.png", size=(50, 50), color=(0, 0, 0)):
f = BytesIO()
image = Image.new("RGB", size=size)
image = Image.new("RGB", size=size, color=color)
image.save(f, "jpeg")
f.name = filename
f.seek(0)
@ -21,7 +21,7 @@ def generate_image_file(filename="image.png", size=(50, 50)):
def generate_image_files(count) -> List[BytesIO]:
images = []
for i in range(count):
image = generate_image_file(f"{i}.jpeg")
image = generate_image_file(f"{i}.jpeg", color=(i, i, i))
images.append(image)
return images

Loading…
Cancel
Save