SDK: make the dataset cache directory customizable (#5535)

This is useful for people whose home directory is too small/not fast
enough. It also lets us make the tests less hacky.
main
Roman Donchenko 3 years ago committed by GitHub
parent 72b612507a
commit 4d32c3c686
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -8,6 +8,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## \[2.4.0] - Unreleased
### Added
- Filename pattern to simplify uploading cloud storage data for a task (<https://github.com/opencv/cvat/pull/5498>)
- \[SDK\] Configuration setting to change the dataset cache directory
(<https://github.com/opencv/cvat/pull/5535>)
### Changed
- The Docker Compose files now use the Compose Specification version

@ -8,9 +8,11 @@ from __future__ import annotations
import logging
import urllib.parse
from contextlib import suppress
from pathlib import Path
from time import sleep
from typing import Any, Dict, Optional, Sequence, Tuple
import appdirs
import attrs
import packaging.version as pv
import urllib3
@ -27,6 +29,8 @@ from cvat_sdk.core.proxies.tasks import TasksRepo
from cvat_sdk.core.proxies.users import UsersRepo
from cvat_sdk.version import VERSION
_DEFAULT_CACHE_DIR = Path(appdirs.user_cache_dir("cvat-sdk", "CVAT.ai"))
@attrs.define
class Config:
@ -43,6 +47,9 @@ class Config:
verify_ssl: Optional[bool] = None
"""Whether to verify host SSL certificate or not"""
cache_dir: Path = attrs.field(converter=Path, default=_DEFAULT_CACHE_DIR)
"""Directory in which to store cached server data"""
class Client:
"""

@ -6,7 +6,6 @@ import shutil
import types
import zipfile
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import (
Callable,
Dict,
@ -20,7 +19,6 @@ from typing import (
TypeVar,
)
import appdirs
import attrs
import attrs.validators
import PIL.Image
@ -36,7 +34,6 @@ from cvat_sdk.models import DataMetaRead, LabeledData, LabeledImage, LabeledShap
_ModelType = TypeVar("_ModelType")
_CACHE_DIR = Path(appdirs.user_cache_dir("cvat-sdk", "CVAT.ai"))
_NUM_DOWNLOAD_THREADS = 4
@ -139,7 +136,7 @@ class TaskVisionDataset(torchvision.datasets.VisionDataset):
server_dir_name = (
base64.urlsafe_b64encode(client.api_map.host.encode()).rstrip(b"=").decode()
)
server_dir = _CACHE_DIR / f"servers/{server_dir_name}"
server_dir = client.config.cache_dir / f"servers/{server_dir_name}"
self._task_dir = server_dir / f"tasks/{self._task.id}"
self._initialize_task_dir()

@ -77,7 +77,7 @@ setup(
python_requires="{{{generatorLanguageVersion}}}",
install_requires=BASE_REQUIREMENTS,
extras_require={
"pytorch": ['appdirs', 'torch', 'torchvision'],
"pytorch": ['torch', 'torchvision'],
},
package_dir={"": "."},
packages=find_packages(include=["cvat_sdk*"]),

@ -1,5 +1,6 @@
-r api_client.txt
appdirs
attrs >= 21.4.0
packaging >= 21.3
Pillow >= 9.0.1

@ -30,7 +30,6 @@ class TestTaskVisionDataset:
@pytest.fixture(autouse=True)
def setup(
self,
monkeypatch: pytest.MonkeyPatch,
tmp_path: Path,
fxt_login: Tuple[Client, str],
fxt_logger: Tuple[Logger, io.StringIO],
@ -41,13 +40,12 @@ class TestTaskVisionDataset:
self.stdout = fxt_stdout
self.client, self.user = fxt_login
self.client.logger = logger
self.client.config.cache_dir = tmp_path / "cache"
api_client = self.client.api_client
for k in api_client.configuration.logger:
api_client.configuration.logger[k] = logger
monkeypatch.setattr(cvatpt, "_CACHE_DIR", self.tmp_path / "cache")
self._create_task()
yield
@ -107,6 +105,9 @@ class TestTaskVisionDataset:
def test_basic(self):
dataset = cvatpt.TaskVisionDataset(self.client, self.task.id)
# verify that the cache is not empty
assert list(self.client.config.cache_dir.iterdir())
assert len(dataset) == self.task.size
for index, (sample_image, sample_target) in enumerate(dataset):

Loading…
Cancel
Save