You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

223 lines
7.3 KiB
Python

# Copyright (C) 2023 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
import base64
import json
import shutil
from abc import ABCMeta, abstractmethod
from enum import Enum, auto
from pathlib import Path
from typing import Callable, Mapping, Type, TypeVar
import cvat_sdk.models as models
from cvat_sdk.api_client.model_utils import OpenApiModel, to_json
from cvat_sdk.core.client import Client
from cvat_sdk.core.proxies.projects import Project
from cvat_sdk.core.proxies.tasks import Task
from cvat_sdk.core.utils import atomic_writer
class UpdatePolicy(Enum):
"""
Defines policies for when the local cache is updated from the CVAT server.
"""
IF_MISSING_OR_STALE = auto()
"""
Update the cache whenever cached data is missing or the server has a newer version.
"""
NEVER = auto()
"""
Never update the cache. If an operation requires data that is not cached,
it will fail.
No network access will be performed if this policy is used.
"""
_ModelType = TypeVar("_ModelType", bound=OpenApiModel)
class CacheManager(metaclass=ABCMeta):
def __init__(self, client: Client) -> None:
self._client = client
self._logger = client.logger
self._server_dir = client.config.cache_dir / f"servers/{self.server_dir_name}"
@property
def server_dir_name(self) -> str:
# Base64-encode the name to avoid FS-unsafe characters (like slashes)
return base64.urlsafe_b64encode(self._client.api_map.host.encode()).rstrip(b"=").decode()
def task_dir(self, task_id: int) -> Path:
return self._server_dir / f"tasks/{task_id}"
def task_json_path(self, task_id: int) -> Path:
return self.task_dir(task_id) / "task.json"
def chunk_dir(self, task_id: int) -> Path:
return self.task_dir(task_id) / "chunks"
def project_dir(self, project_id: int) -> Path:
return self._server_dir / f"projects/{project_id}"
def project_json_path(self, project_id: int) -> Path:
return self.project_dir(project_id) / "project.json"
def load_model(self, path: Path, model_type: Type[_ModelType]) -> _ModelType:
with open(path, "rb") as f:
return model_type._new_from_openapi_data(**json.load(f))
def save_model(self, path: Path, model: OpenApiModel) -> None:
with atomic_writer(path, "w", encoding="UTF-8") as f:
json.dump(to_json(model), f, indent=4)
print(file=f) # add final newline
@abstractmethod
def retrieve_task(self, task_id: int) -> Task:
...
@abstractmethod
def ensure_task_model(
self,
task_id: int,
filename: str,
model_type: Type[_ModelType],
downloader: Callable[[], _ModelType],
model_description: str,
) -> _ModelType:
...
@abstractmethod
def ensure_chunk(self, task: Task, chunk_index: int) -> None:
...
@abstractmethod
def retrieve_project(self, project_id: int) -> Project:
...
class _CacheManagerOnline(CacheManager):
def retrieve_task(self, task_id: int) -> Task:
self._logger.info(f"Fetching task {task_id}...")
task = self._client.tasks.retrieve(task_id)
self._initialize_task_dir(task)
return task
def _initialize_task_dir(self, task: Task) -> None:
task_dir = self.task_dir(task.id)
task_json_path = self.task_json_path(task.id)
try:
saved_task = self.load_model(task_json_path, models.TaskRead)
except Exception:
self._logger.info(f"Task {task.id} is not yet cached or the cache is corrupted")
# If the cache was corrupted, the directory might already be there; clear it.
if task_dir.exists():
shutil.rmtree(task_dir)
else:
if saved_task.updated_date < task.updated_date:
self._logger.info(
f"Task {task.id} has been updated on the server since it was cached; purging the cache"
)
shutil.rmtree(task_dir)
task_dir.mkdir(exist_ok=True, parents=True)
self.save_model(task_json_path, task._model)
def ensure_task_model(
self,
task_id: int,
filename: str,
model_type: Type[_ModelType],
downloader: Callable[[], _ModelType],
model_description: str,
) -> _ModelType:
path = self.task_dir(task_id) / filename
try:
model = self.load_model(path, model_type)
self._logger.info(f"Loaded {model_description} from cache")
return model
except FileNotFoundError:
pass
except Exception:
self._logger.warning(f"Failed to load {model_description} from cache", exc_info=True)
self._logger.info(f"Downloading {model_description}...")
model = downloader()
self._logger.info(f"Downloaded {model_description}")
self.save_model(path, model)
return model
def ensure_chunk(self, task: Task, chunk_index: int) -> None:
chunk_path = self.chunk_dir(task.id) / f"{chunk_index}.zip"
if chunk_path.exists():
return # already downloaded previously
self._logger.info(f"Downloading chunk #{chunk_index}...")
with atomic_writer(chunk_path, "wb") as chunk_file:
task.download_chunk(chunk_index, chunk_file, quality="original")
def retrieve_project(self, project_id: int) -> Project:
self._logger.info(f"Fetching project {project_id}...")
project = self._client.projects.retrieve(project_id)
project_dir = self.project_dir(project_id)
project_dir.mkdir(parents=True, exist_ok=True)
project_json_path = self.project_json_path(project_id)
# There are currently no files cached alongside project.json,
# so we don't need to check if we need to purge them.
self.save_model(project_json_path, project._model)
return project
class _CacheManagerOffline(CacheManager):
def retrieve_task(self, task_id: int) -> Task:
self._logger.info(f"Retrieving task {task_id} from cache...")
return Task(self._client, self.load_model(self.task_json_path(task_id), models.TaskRead))
def ensure_task_model(
self,
task_id: int,
filename: str,
model_type: Type[_ModelType],
downloader: Callable[[], _ModelType],
model_description: str,
) -> _ModelType:
self._logger.info(f"Loading {model_description} from cache...")
return self.load_model(self.task_dir(task_id) / filename, model_type)
def ensure_chunk(self, task: Task, chunk_index: int) -> None:
chunk_path = self.chunk_dir(task.id) / f"{chunk_index}.zip"
if not chunk_path.exists():
raise FileNotFoundError(f"Chunk {chunk_index} of task {task.id} is not cached")
def retrieve_project(self, project_id: int) -> Project:
self._logger.info(f"Retrieving project {project_id} from cache...")
return Project(
self._client, self.load_model(self.project_json_path(project_id), models.ProjectRead)
)
_CACHE_MANAGER_CLASSES: Mapping[UpdatePolicy, Type[CacheManager]] = {
UpdatePolicy.IF_MISSING_OR_STALE: _CacheManagerOnline,
UpdatePolicy.NEVER: _CacheManagerOffline,
}
def make_cache_manager(client: Client, update_policy: UpdatePolicy) -> CacheManager:
return _CACHE_MANAGER_CLASSES[update_policy](client)