You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

83 lines
2.1 KiB
Python

# Copyright (C) 2022 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
from __future__ import annotations
import contextlib
import itertools
import os
from typing import (
IO,
Any,
BinaryIO,
ContextManager,
Dict,
Iterator,
Sequence,
TextIO,
Union,
overload,
)
from typing_extensions import Literal
def filter_dict(
d: Dict[str, Any], *, keep: Sequence[str] = None, drop: Sequence[str] = None
) -> Dict[str, Any]:
return {k: v for k, v in d.items() if (not keep or k in keep) and (not drop or k not in drop)}
@overload
def atomic_writer(path: Union[os.PathLike, str], mode: Literal["wb"]) -> ContextManager[BinaryIO]:
...
@overload
def atomic_writer(
path: Union[os.PathLike, str], mode: Literal["w"], encoding: str = "UTF-8"
) -> ContextManager[TextIO]:
...
@contextlib.contextmanager
def atomic_writer(
path: Union[os.PathLike, str], mode: Literal["w", "wb"], encoding: str = "UTF-8"
) -> Iterator[IO]:
"""
Returns a context manager that, when entered, returns a handle to a temporary
file opened with the specified `mode` and `encoding`. If the context manager
is exited via an exception, the temporary file is deleted. If the context manager
is exited normally, the file is renamed to `path`.
In other words, this function works like `open()`, but the file does not appear
at the specified path until and unless the context manager is exited
normally.
"""
path_str = os.fspath(path)
for counter in itertools.count():
tmp_path = f"{path_str}.tmp{counter}"
try:
if mode == "w":
tmp_file = open(tmp_path, "xt", encoding=encoding)
elif mode == "wb":
tmp_file = open(tmp_path, "xb")
else:
raise ValueError(f"Unsupported mode: {mode!r}")
break
except FileExistsError:
pass # try next counter value
try:
with tmp_file:
yield tmp_file
os.rename(tmp_path, path)
except:
os.unlink(tmp_path)
raise