You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

105 lines
2.6 KiB
Python

# Copyright (C) 2022 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
from __future__ import annotations
import io
import json
from typing import Any, Dict, Iterable, List, Optional, Union
import tqdm
import urllib3
from cvat_sdk import exceptions
from cvat_sdk.api_client.api_client import Endpoint
from cvat_sdk.core.progress import ProgressReporter
def get_paginated_collection(
endpoint: Endpoint, *, return_json: bool = False, **kwargs
) -> Union[List, List[Dict[str, Any]]]:
"""
Accumulates results from all the pages
"""
results = []
page = 1
while True:
(page_contents, response) = endpoint.call_with_http_info(**kwargs, page=page)
expect_status(200, response)
if return_json:
results.extend(json.loads(response.data).get("results", []))
else:
results.extend(page_contents.results)
if not page_contents.next:
break
page += 1
return results
class TqdmProgressReporter(ProgressReporter):
def __init__(self, instance: tqdm.tqdm) -> None:
super().__init__()
self.tqdm = instance
@property
def period(self) -> float:
return 0
def start(self, total: int, *, desc: Optional[str] = None):
self.tqdm.reset(total)
self.tqdm.set_description_str(desc)
def report_status(self, progress: int):
self.tqdm.update(progress - self.tqdm.n)
def advance(self, delta: int):
self.tqdm.update(delta)
class StreamWithProgress:
def __init__(self, stream: io.RawIOBase, pbar: ProgressReporter, length: Optional[int] = None):
self.stream = stream
self.pbar = pbar
if hasattr(stream, "__len__"):
length = len(stream)
self.length = length
pbar.start(length)
def read(self, size=-1):
chunk = self.stream.read(size)
if chunk is not None:
self.pbar.advance(len(chunk))
return chunk
def __len__(self):
return self.length
def seek(self, pos, start=0):
self.stream.seek(pos, start)
self.pbar.report_status(pos)
def tell(self):
return self.stream.tell()
def expect_status(codes: Union[int, Iterable[int]], response: urllib3.HTTPResponse) -> None:
if not hasattr(codes, "__iter__"):
codes = [codes]
if response.status in codes:
return
if 300 <= response.status <= 500:
raise exceptions.ApiException(response.status, reason=response.msg, http_resp=response)
else:
raise exceptions.ApiException(
response.status, reason="Unexpected status code received", http_resp=response
)