Skip to content

Commit

Permalink
Read video from memory newapi (#6771)
Browse files Browse the repository at this point in the history
* add tensor as optional param

* add init from memory

* fix bug

* fix bug

* first working version

* apply formatting and add tests

* simplify tests

* fix tests

* fix wrong variable name

* add path as optional parameter

* add src as optional

* address pr comments

* Fix warning messages

* address pr comments

* make tests stricter

* Revert "make tests stricter"

This reverts commit 6c92e94.
  • Loading branch information
jdsgomes authored Oct 21, 2022
1 parent 246de07 commit 06ad05f
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 19 deletions.
41 changes: 41 additions & 0 deletions test/test_videoapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def test_frame_reading(self, test_video):
# compare the frames and ptss
for i in range(len(vr_frames)):
assert float(av_pts[i]) == approx(vr_pts[i], abs=0.1)

mean_delta = torch.mean(torch.abs(av_frames[i].float() - vr_frames[i].float()))
# on average the difference is very small and caused
# by decoding (around 1%)
Expand Down Expand Up @@ -114,6 +115,46 @@ def test_frame_reading(self, test_video):
# we assure that there is never more than 1% difference in signal
assert max_delta.item() < 0.001

@pytest.mark.parametrize("stream", ["video", "audio"])
@pytest.mark.parametrize("test_video", test_videos.keys())
def test_frame_reading_mem_vs_file(self, test_video, stream):
full_path = os.path.join(VIDEO_DIR, test_video)

# Test video reading from file vs from memory
vr_frames, vr_frames_mem = [], []
vr_pts, vr_pts_mem = [], []
# get vr frames
video_reader = VideoReader(full_path, stream)
for vr_frame in video_reader:
vr_frames.append(vr_frame["data"])
vr_pts.append(vr_frame["pts"])

# get vr frames = read from memory
f = open(full_path, "rb")
fbytes = f.read()
f.close()
video_reader_from_mem = VideoReader(fbytes, stream)

for vr_frame_from_mem in video_reader_from_mem:
vr_frames_mem.append(vr_frame_from_mem["data"])
vr_pts_mem.append(vr_frame_from_mem["pts"])

# same number of frames
assert len(vr_frames) == len(vr_frames_mem)
assert len(vr_pts) == len(vr_pts_mem)

# compare the frames and ptss
for i in range(len(vr_frames)):
assert vr_pts[i] == vr_pts_mem[i]
mean_delta = torch.mean(torch.abs(vr_frames[i].float() - vr_frames_mem[i].float()))
# on average the difference is very small and caused
# by decoding (around 1%)
# TODO: asses empirically how to set this? atm it's 1%
# averaged over all frames
assert mean_delta.item() < 2.55

del vr_frames, vr_pts, vr_frames_mem, vr_pts_mem

@pytest.mark.parametrize("test_video,config", test_videos.items())
def test_metadata(self, test_video, config):
"""
Expand Down
2 changes: 1 addition & 1 deletion torchvision/csrc/io/decoder/defs.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ struct MediaFormat {
struct DecoderParameters {
// local file, remote file, http url, rtmp stream uri, etc. anything that
// ffmpeg can recognize
std::string uri;
std::string uri{std::string()};
// timeout on getting bytes for decoding
size_t timeoutMs{1000};
// logging level, default AV_LOG_PANIC
Expand Down
57 changes: 45 additions & 12 deletions torchvision/csrc/io/video/video.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,14 +156,34 @@ void Video::_getDecoderParams(

} // _get decoder params

Video::Video(std::string videoPath, std::string stream, int64_t numThreads) {
C10_LOG_API_USAGE_ONCE("torchvision.csrc.io.video.video.Video");
void Video::initFromFile(
std::string videoPath,
std::string stream,
int64_t numThreads) {
TORCH_CHECK(!initialized, "Video object can only be initialized once");
initialized = true;
params.uri = videoPath;
_init(stream, numThreads);
}

void Video::initFromMemory(
torch::Tensor videoTensor,
std::string stream,
int64_t numThreads) {
TORCH_CHECK(!initialized, "Video object can only be initialized once");
initialized = true;
callback = MemoryBuffer::getCallback(
videoTensor.data_ptr<uint8_t>(), videoTensor.size(0));
_init(stream, numThreads);
}

void Video::_init(std::string stream, int64_t numThreads) {
// set number of threads global
numThreads_ = numThreads;
// parse stream information
current_stream = _parseStream(stream);
// note that in the initial call we want to get all streams
Video::_getDecoderParams(
_getDecoderParams(
0, // video start
0, // headerOnly
std::get<0>(current_stream), // stream info - remove that
Expand All @@ -175,11 +195,6 @@ Video::Video(std::string videoPath, std::string stream, int64_t numThreads) {

std::string logMessage, logType;

// TODO: add read from memory option
params.uri = videoPath;
logType = "file";
logMessage = videoPath;

// locals
std::vector<double> audioFPS, videoFPS;
std::vector<double> audioDuration, videoDuration, ccDuration, subsDuration;
Expand All @@ -190,7 +205,8 @@ Video::Video(std::string videoPath, std::string stream, int64_t numThreads) {
c10::Dict<std::string, std::vector<double>> subsMetadata;

// callback and metadata defined in struct
succeeded = decoder.init(params, std::move(callback), &metadata);
DecoderInCallback tmp_callback = callback;
succeeded = decoder.init(params, std::move(tmp_callback), &metadata);
if (succeeded) {
for (const auto& header : metadata) {
double fps = double(header.fps);
Expand Down Expand Up @@ -225,16 +241,24 @@ Video::Video(std::string videoPath, std::string stream, int64_t numThreads) {
streamsMetadata.insert("subtitles", subsMetadata);
streamsMetadata.insert("cc", ccMetadata);

succeeded = Video::setCurrentStream(stream);
succeeded = setCurrentStream(stream);
LOG(INFO) << "\nDecoder inited with: " << succeeded << "\n";
if (std::get<1>(current_stream) != -1) {
LOG(INFO)
<< "Stream index set to " << std::get<1>(current_stream)
<< ". If you encounter trouble, consider switching it to automatic stream discovery. \n";
}
}

Video::Video(std::string videoPath, std::string stream, int64_t numThreads) {
C10_LOG_API_USAGE_ONCE("torchvision.csrc.io.video.video.Video");
if (!videoPath.empty()) {
initFromFile(videoPath, stream, numThreads);
}
} // video

bool Video::setCurrentStream(std::string stream = "video") {
TORCH_CHECK(initialized, "Video object has to be initialized first");
if ((!stream.empty()) && (_parseStream(stream) != current_stream)) {
current_stream = _parseStream(stream);
}
Expand All @@ -256,19 +280,23 @@ bool Video::setCurrentStream(std::string stream = "video") {
);

// callback and metadata defined in Video.h
return (decoder.init(params, std::move(callback), &metadata));
DecoderInCallback tmp_callback = callback;
return (decoder.init(params, std::move(tmp_callback), &metadata));
}

std::tuple<std::string, int64_t> Video::getCurrentStream() const {
TORCH_CHECK(initialized, "Video object has to be initialized first");
return current_stream;
}

c10::Dict<std::string, c10::Dict<std::string, std::vector<double>>> Video::
getStreamMetadata() const {
TORCH_CHECK(initialized, "Video object has to be initialized first");
return streamsMetadata;
}

void Video::Seek(double ts, bool fastSeek = false) {
TORCH_CHECK(initialized, "Video object has to be initialized first");
// initialize the class variables used for seeking and retrurn
_getDecoderParams(
ts, // video start
Expand All @@ -282,11 +310,14 @@ void Video::Seek(double ts, bool fastSeek = false) {
);

// callback and metadata defined in Video.h
succeeded = decoder.init(params, std::move(callback), &metadata);
DecoderInCallback tmp_callback = callback;
succeeded = decoder.init(params, std::move(tmp_callback), &metadata);

LOG(INFO) << "Decoder init at seek " << succeeded << "\n";
}

std::tuple<torch::Tensor, double> Video::Next() {
TORCH_CHECK(initialized, "Video object has to be initialized first");
// if failing to decode simply return a null tensor (note, should we
// raise an exeption?)
double frame_pts_s;
Expand Down Expand Up @@ -345,6 +376,8 @@ std::tuple<torch::Tensor, double> Video::Next() {
static auto registerVideo =
torch::class_<Video>("torchvision", "Video")
.def(torch::init<std::string, std::string, int64_t>())
.def("init_from_file", &Video::initFromFile)
.def("init_from_memory", &Video::initFromMemory)
.def("get_current_stream", &Video::getCurrentStream)
.def("set_current_stream", &Video::setCurrentStream)
.def("get_metadata", &Video::getStreamMetadata)
Expand Down
20 changes: 19 additions & 1 deletion torchvision/csrc/io/video/video.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,19 @@ struct Video : torch::CustomClassHolder {
int64_t numThreads_{0};

public:
Video(std::string videoPath, std::string stream, int64_t numThreads);
Video(
std::string videoPath = std::string(),
std::string stream = std::string("video"),
int64_t numThreads = 0);
void initFromFile(
std::string videoPath,
std::string stream,
int64_t numThreads);
void initFromMemory(
torch::Tensor videoTensor,
std::string stream,
int64_t numThreads);

std::tuple<std::string, int64_t> getCurrentStream() const;
c10::Dict<std::string, c10::Dict<std::string, std::vector<double>>>
getStreamMetadata() const;
Expand All @@ -34,6 +46,12 @@ struct Video : torch::CustomClassHolder {
// time in comination with any_frame settings
double seekTS = -1;

bool initialized = false;

void _init(
std::string stream,
int64_t numThreads); // expects params.uri OR callback to be set

void _getDecoderParams(
double videoStartS,
int64_t getPtsOnly,
Expand Down
47 changes: 42 additions & 5 deletions torchvision/io/video_reader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, Dict, Iterator
import warnings
from typing import Any, Dict, Iterator, Optional

import torch

Expand Down Expand Up @@ -71,8 +72,13 @@ class VideoReader:
If only stream type is passed, the decoder auto-detects first stream of that type.
Args:
src (string, bytes object, or tensor): The media source.
If string-type, it must be a file path supported by FFMPEG.
If bytes shoud be an in memory representatin of a file supported by FFMPEG.
If Tensor, it is interpreted internally as byte buffer.
It must be one-dimensional, of type ``torch.uint8``.
path (string): Path to the video file in supported format
stream (string, optional): descriptor of the required stream, followed by the stream id,
in the format ``{stream_type}:{stream_id}``. Defaults to ``"video:0"``.
Expand All @@ -85,17 +91,31 @@ class VideoReader:
device (str, optional): Device to be used for decoding. Defaults to ``"cpu"``.
To use GPU decoding, pass ``device="cuda"``.
path (str, optional):
.. warning:
This parameter was deprecated in ``0.15`` and will be removed in ``0.17``.
Please use ``src`` instead.
"""

def __init__(self, path: str, stream: str = "video", num_threads: int = 0, device: str = "cpu") -> None:
def __init__(
self,
src: str = "",
stream: str = "video",
num_threads: int = 0,
device: str = "cpu",
path: Optional[str] = None,
) -> None:
_log_api_usage_once(self)
self.is_cuda = False
device = torch.device(device)
if device.type == "cuda":
if not _HAS_GPU_VIDEO_DECODER:
raise RuntimeError("Not compiled with GPU decoder support.")
self.is_cuda = True
self._c = torch.classes.torchvision.GPUDecoder(path, device)
self._c = torch.classes.torchvision.GPUDecoder(src, device)
return
if not _has_video_opt():
raise RuntimeError(
Expand All @@ -105,7 +125,24 @@ def __init__(self, path: str, stream: str = "video", num_threads: int = 0, devic
+ "build torchvision from source."
)

self._c = torch.classes.torchvision.Video(path, stream, num_threads)
if src == "":
if path is None:
raise TypeError("src cannot be empty")
src = path
warnings.warn("path is deprecated and will be removed in 0.17. Please use src instead")

elif isinstance(src, bytes):
src = torch.frombuffer(src, dtype=torch.uint8)

if isinstance(src, str):
self._c = torch.classes.torchvision.Video(src, stream, num_threads)
elif isinstance(src, torch.Tensor):
if self.is_cuda:
raise RuntimeError("GPU VideoReader cannot be initialized from Tensor or bytes object.")
self._c = torch.classes.torchvision.Video("", "", 0)
self._c.init_from_memory(src, stream, num_threads)
else:
raise TypeError("`src` must be either string, Tensor or bytes object.")

def __next__(self) -> Dict[str, Any]:
"""Decodes and returns the next frame of the current stream.
Expand Down

0 comments on commit 06ad05f

Please sign in to comment.