Refactor frame provider (#1355)

* Refactor frame provider

* fix
main
zhiltsov-max 6 years ago committed by GitHub
parent c17303bbed
commit 457454821f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,21 +1,21 @@
# Copyright (C) 2019 Intel Corporation # Copyright (C) 2020 Intel Corporation
# #
# SPDX-License-Identifier: MIT # SPDX-License-Identifier: MIT
import itertools
import math import math
from io import BytesIO
from enum import Enum from enum import Enum
import itertools from io import BytesIO
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from cvat.apps.engine.media_extractors import VideoReader, ZipReader from cvat.apps.engine.media_extractors import VideoReader, ZipReader
from cvat.apps.engine.models import DataChoice
from cvat.apps.engine.mime_types import mimetypes from cvat.apps.engine.mime_types import mimetypes
from cvat.apps.engine.models import DataChoice
class FrameProvider(): class FrameProvider:
class Quality(Enum): class Quality(Enum):
COMPRESSED = 0 COMPRESSED = 0
ORIGINAL = 100 ORIGINAL = 100
@ -25,26 +25,33 @@ class FrameProvider():
PIL = 1 PIL = 1
NUMPY_ARRAY = 2 NUMPY_ARRAY = 2
def __init__(self, db_data): class ChunkLoader:
self._db_data = db_data def __init__(self, reader_class, path_getter):
if db_data.compressed_chunk_type == DataChoice.IMAGESET: self.chunk_id = None
self._compressed_chunk_reader_class = ZipReader self.chunk_reader = None
elif db_data.compressed_chunk_type == DataChoice.VIDEO: self.reader_class = reader_class
self._compressed_chunk_reader_class = VideoReader self.get_chunk_path = path_getter
else:
raise Exception('Unsupported chunk type')
if db_data.original_chunk_type == DataChoice.IMAGESET: def load(self, chunk_id):
self._original_chunk_reader_class = ZipReader if self.chunk_id != chunk_id:
elif db_data.original_chunk_type == DataChoice.VIDEO: self.chunk_id = chunk_id
self._original_chunk_reader_class = VideoReader self.chunk_reader = self.reader_class([self.get_chunk_path(chunk_id)])
else: return self.chunk_reader
raise Exception('Unsupported chunk type')
self._extracted_compressed_chunk = None def __init__(self, db_data):
self._compressed_chunk_reader = None self._db_data = db_data
self._extracted_original_chunk = None self._loaders = {}
self._original_chunk_reader = None
reader_class = {
DataChoice.IMAGESET: ZipReader,
DataChoice.VIDEO: VideoReader,
}
self._loaders[self.Quality.COMPRESSED] = self.ChunkLoader(
reader_class[db_data.compressed_chunk_type],
db_data.get_compressed_chunk_path)
self._loaders[self.Quality.ORIGINAL] = self.ChunkLoader(
reader_class[db_data.original_chunk_type],
db_data.get_original_chunk_path)
def __len__(self): def __len__(self):
return self._db_data.size return self._db_data.size
@ -74,77 +81,41 @@ class FrameProvider():
buf.seek(0) buf.seek(0)
return buf return buf
def _get_frame(self, frame_number, chunk_path_getter, extracted_chunk, chunk_reader, reader_class): def _convert_frame(self, frame, reader_class, out_type):
_, chunk_number, frame_offset = self._validate_frame_number(frame_number) if out_type == self.Type.BUFFER:
chunk_path = chunk_path_getter(chunk_number) return self._av_frame_to_png_bytes(frame) if reader_class is VideoReader else frame
if chunk_number != extracted_chunk: elif out_type == self.Type.PIL:
extracted_chunk = chunk_number return frame.to_image() if reader_class is VideoReader else Image.open(frame)
chunk_reader = reader_class([chunk_path]) elif out_type == self.Type.NUMPY_ARRAY:
if reader_class is VideoReader:
frame, frame_name, _ = next(itertools.islice(chunk_reader, frame_offset, None)) image = np.array(frame.to_image())
if reader_class is VideoReader: else:
return (self._av_frame_to_png_bytes(frame), 'image/png') image = np.array(Image.open(frame))
if len(image.shape) == 3 and image.shape[2] in {3, 4}:
return (frame, mimetypes.guess_type(frame_name)) image[:, :, :3] = image[:, :, 2::-1] # RGB to BGR
return image
def _get_frames(self, chunk_path_getter, reader_class, out_type): else:
for chunk_idx in range(math.ceil(self._db_data.size / self._db_data.chunk_size)): raise Exception('unsupported output type')
chunk_path = chunk_path_getter(chunk_idx)
chunk_reader = reader_class([chunk_path])
for frame, _, _ in chunk_reader:
if out_type == self.Type.BUFFER:
yield self._av_frame_to_png_bytes(frame) if reader_class is VideoReader else frame
elif out_type == self.Type.PIL:
yield frame.to_image() if reader_class is VideoReader else Image.open(frame)
elif out_type == self.Type.NUMPY_ARRAY:
if reader_class is VideoReader:
image = np.array(frame.to_image())
else:
image = np.array(Image.open(frame))
if len(image.shape) == 3 and image.shape[2] in {3, 4}:
image[:, :, :3] = image[:, :, 2::-1] # RGB to BGR
yield image
else:
raise Exception('unsupported output type')
def get_preview(self): def get_preview(self):
return self._db_data.get_preview_path() return self._db_data.get_preview_path()
def get_chunk(self, chunk_number, quality=Quality.ORIGINAL): def get_chunk(self, chunk_number, quality=Quality.ORIGINAL):
chunk_number = self._validate_chunk_number(chunk_number) chunk_number = self._validate_chunk_number(chunk_number)
if quality == self.Quality.ORIGINAL: return self._loaders[quality].get_chunk_path(chunk_number)
return self._db_data.get_original_chunk_path(chunk_number)
elif quality == self.Quality.COMPRESSED:
return self._db_data.get_compressed_chunk_path(chunk_number)
def get_frame(self, frame_number, quality=Quality.ORIGINAL): def get_frame(self, frame_number, quality=Quality.ORIGINAL):
if quality == self.Quality.ORIGINAL: _, chunk_number, frame_offset = self._validate_frame_number(frame_number)
return self._get_frame(
frame_number=frame_number, chunk_reader = self._loaders[quality].load(chunk_number)
chunk_path_getter=self._db_data.get_original_chunk_path,
extracted_chunk=self._extracted_original_chunk, frame, frame_name, _ = next(itertools.islice(chunk_reader, frame_offset, None))
chunk_reader=self._original_chunk_reader, if self._loaders[quality].reader_class is VideoReader:
reader_class=self._original_chunk_reader_class, return (self._av_frame_to_png_bytes(frame), 'image/png')
) return (frame, mimetypes.guess_type(frame_name))
elif quality == self.Quality.COMPRESSED:
return self._get_frame(
frame_number=frame_number,
chunk_path_getter=self._db_data.get_compressed_chunk_path,
extracted_chunk=self._extracted_compressed_chunk,
chunk_reader=self._compressed_chunk_reader,
reader_class=self._compressed_chunk_reader_class,
)
def get_frames(self, quality=Quality.ORIGINAL, out_type=Type.BUFFER): def get_frames(self, quality=Quality.ORIGINAL, out_type=Type.BUFFER):
if quality == self.Quality.ORIGINAL: loader = self._loaders[quality]
return self._get_frames( for chunk_idx in range(math.ceil(self._db_data.size / self._db_data.chunk_size)):
chunk_path_getter=self._db_data.get_original_chunk_path, for frame, _, _ in loader.load(chunk_idx):
reader_class=self._original_chunk_reader_class, yield self._convert_frame(frame, loader.reader_class, out_type)
out_type=out_type,
)
elif quality == self.Quality.COMPRESSED:
return self._get_frames(
chunk_path_getter=self._db_data.get_compressed_chunk_path,
reader_class=self._compressed_chunk_reader_class,
out_type=out_type,
)

Loading…
Cancel
Save