Skip to content

Commit 3e97d89

Browse files
author
Daniel Flores
committed
to_filelike, update test
1 parent e4af1df commit 3e97d89

File tree

4 files changed

+87
-3
lines changed

4 files changed

+87
-3
lines changed

src/torchcodec/_core/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
encode_audio_to_file_like,
2727
encode_audio_to_tensor,
2828
encode_video_to_file,
29+
encode_video_to_file_like,
2930
encode_video_to_tensor,
3031
get_ffmpeg_library_versions,
3132
get_frame_at_index,

src/torchcodec/_core/custom_ops.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ TORCH_LIBRARY(torchcodec_ns, m) {
4040
"encode_video_to_file(Tensor frames, int frame_rate, str filename, int? crf=None) -> ()");
4141
m.def(
4242
"encode_video_to_tensor(Tensor frames, int frame_rate, str format, int? crf=None) -> Tensor");
43+
m.def(
44+
"_encode_video_to_file_like(Tensor frames, int frame_rate, str format, int file_like_context, int? crf=None) -> ()");
4345
m.def(
4446
"create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor");
4547
m.def(
@@ -606,6 +608,30 @@ at::Tensor encode_video_to_tensor(
606608
.encodeToTensor();
607609
}
608610

611+
void _encode_video_to_file_like(
612+
const at::Tensor& frames,
613+
int64_t frame_rate,
614+
std::string_view format,
615+
int64_t file_like_context,
616+
std::optional<int64_t> crf = std::nullopt) {
617+
auto fileLikeContext =
618+
reinterpret_cast<AVIOFileLikeContext*>(file_like_context);
619+
TORCH_CHECK(
620+
fileLikeContext != nullptr, "file_like_context must be a valid pointer");
621+
std::unique_ptr<AVIOFileLikeContext> avioContextHolder(fileLikeContext);
622+
623+
VideoStreamOptions videoStreamOptions;
624+
videoStreamOptions.crf = crf;
625+
626+
VideoEncoder encoder(
627+
frames,
628+
validateInt64ToInt(frame_rate, "frame_rate"),
629+
format,
630+
std::move(avioContextHolder),
631+
videoStreamOptions);
632+
encoder.encode();
633+
}
634+
609635
// For testing only. We need to implement this operation as a core library
610636
// function because what we're testing is round-tripping pts values as
611637
// double-precision floating point numbers from C++ to Python and back to C++.
@@ -870,6 +896,7 @@ TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) {
870896
m.impl("_encode_audio_to_file_like", &_encode_audio_to_file_like);
871897
m.impl("encode_video_to_file", &encode_video_to_file);
872898
m.impl("encode_video_to_tensor", &encode_video_to_tensor);
899+
m.impl("_encode_video_to_file_like", &_encode_video_to_file_like);
873900
m.impl("seek_to_pts", &seek_to_pts);
874901
m.impl("add_video_stream", &add_video_stream);
875902
m.impl("_add_video_stream", &_add_video_stream);

src/torchcodec/_core/ops.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,9 @@ def load_torchcodec_shared_libraries():
104104
encode_video_to_tensor = torch._dynamo.disallow_in_graph(
105105
torch.ops.torchcodec_ns.encode_video_to_tensor.default
106106
)
107+
_encode_video_to_file_like = torch._dynamo.disallow_in_graph(
108+
torch.ops.torchcodec_ns._encode_video_to_file_like.default
109+
)
107110
create_from_tensor = torch._dynamo.disallow_in_graph(
108111
torch.ops.torchcodec_ns.create_from_tensor.default
109112
)
@@ -203,6 +206,33 @@ def encode_audio_to_file_like(
203206
)
204207

205208

209+
def encode_video_to_file_like(
210+
frames: torch.Tensor,
211+
frame_rate: int,
212+
format: str,
213+
file_like: Union[io.RawIOBase, io.BufferedIOBase],
214+
crf: Optional[int] = None,
215+
) -> None:
216+
"""Encode video frames to a file-like object.
217+
218+
Args:
219+
frames: Video frames tensor
220+
frame_rate: Frame rate in frames per second
221+
format: Video format (e.g., "mp4", "mov", "mkv")
222+
file_like: File-like object that supports write() and seek() methods
223+
crf: Optional constant rate factor for encoding quality
224+
"""
225+
assert _pybind_ops is not None
226+
227+
_encode_video_to_file_like(
228+
frames,
229+
frame_rate,
230+
format,
231+
_pybind_ops.create_file_like_context(file_like, True), # True means for writing
232+
crf,
233+
)
234+
235+
206236
def get_frames_at_indices(
207237
decoder: torch.Tensor, *, frame_indices: Union[torch.Tensor, list[int]]
208238
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
@@ -302,6 +332,17 @@ def encode_video_to_tensor_abstract(
302332
return torch.empty([], dtype=torch.long)
303333

304334

335+
@register_fake("torchcodec_ns::_encode_video_to_file_like")
336+
def _encode_video_to_file_like_abstract(
337+
frames: torch.Tensor,
338+
frame_rate: int,
339+
format: str,
340+
file_like_context: int,
341+
crf: Optional[int] = None,
342+
) -> None:
343+
return
344+
345+
305346
@register_fake("torchcodec_ns::create_from_tensor")
306347
def create_from_tensor_abstract(
307348
video_tensor: torch.Tensor, seek_mode: Optional[str]

test/test_ops.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
create_from_tensor,
3030
encode_audio_to_file,
3131
encode_video_to_file,
32+
encode_video_to_file_like,
3233
encode_video_to_tensor,
3334
get_ffmpeg_library_versions,
3435
get_frame_at_index,
@@ -1379,17 +1380,19 @@ def test_bad_input(self, tmp_path):
13791380
filename="./bad/path.mp3",
13801381
)
13811382

1382-
def decode(self, file_path=None, tensor=None) -> torch.Tensor:
1383+
def decode(self, file_path=None, tensor=None, file_like=None) -> torch.Tensor:
13831384
if file_path is not None:
13841385
decoder = create_from_file(str(file_path), seek_mode="approximate")
13851386
elif tensor is not None:
13861387
decoder = create_from_tensor(tensor, seek_mode="approximate")
1388+
elif file_like is not None:
1389+
decoder = create_from_file_like(file_like, seek_mode="approximate")
13871390
add_video_stream(decoder)
13881391
frames, *_ = get_frames_in_range(decoder, start=0, stop=60)
13891392
return frames
13901393

13911394
@pytest.mark.parametrize("format", ("mov", "mp4", "mkv", "webm"))
1392-
@pytest.mark.parametrize("output_method", ("to_file", "to_tensor"))
1395+
@pytest.mark.parametrize("output_method", ("to_file", "to_tensor", "to_file_like"))
13931396
def test_video_encoder_round_trip(self, tmp_path, format, output_method):
13941397
# Test that decode(encode(decode(asset))) == decode(asset)
13951398
ffmpeg_version = get_ffmpeg_major_version()
@@ -1416,12 +1419,24 @@ def test_video_encoder_round_trip(self, tmp_path, format, output_method):
14161419
crf=0,
14171420
)
14181421
round_trip_frames = self.decode(file_path=encoded_path).data
1419-
else: # to_tensor
1422+
elif output_method == "to_tensor":
14201423
format = "matroska" if format == "mkv" else format
14211424
encoded_tensor = encode_video_to_tensor(
14221425
source_frames, frame_rate, format, crf=0
14231426
)
14241427
round_trip_frames = self.decode(tensor=encoded_tensor).data
1428+
else: # to_file_like
1429+
format = "matroska" if format == "mkv" else format
1430+
file_like = io.BytesIO()
1431+
encode_video_to_file_like(
1432+
frames=source_frames,
1433+
frame_rate=frame_rate,
1434+
format=format,
1435+
file_like=file_like,
1436+
crf=0,
1437+
)
1438+
file_like.seek(0)
1439+
round_trip_frames = self.decode(file_like=file_like).data
14251440

14261441
assert source_frames.shape == round_trip_frames.shape
14271442
assert source_frames.dtype == round_trip_frames.dtype

0 commit comments

Comments
 (0)