Skip to content

Commit b084768

Browse files
authored
Add set_cuda_backend Context Manager to publicly expose the BETA CUDA Interface (#959)
1 parent 61202b9 commit b084768

File tree

4 files changed

+122
-13
lines changed

4 files changed

+122
-13
lines changed

src/torchcodec/decoders/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from .._core import AudioStreamMetadata, VideoStreamMetadata
88
from ._audio_decoder import AudioDecoder # noqa
9+
from ._decoder_utils import set_cuda_backend # noqa
910
from ._video_decoder import VideoDecoder # noqa
1011

1112
SimpleVideoDecoder = VideoDecoder

src/torchcodec/decoders/_decoder_utils.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import contextvars
78
import io
9+
from contextlib import contextmanager
810
from pathlib import Path
911

10-
from typing import Union
12+
from typing import Generator, Union
1113

1214
from torch import Tensor
1315
from torchcodec import _core as core
@@ -50,3 +52,52 @@ def create_decoder(
5052
"read(self, size: int) -> bytes and "
5153
"seek(self, offset: int, whence: int) -> int methods."
5254
)
55+
56+
57+
# Thread-local and async-safe storage for the current CUDA backend
58+
_CUDA_BACKEND: contextvars.ContextVar[str] = contextvars.ContextVar(
59+
"_CUDA_BACKEND", default="ffmpeg"
60+
)
61+
62+
63+
@contextmanager
64+
def set_cuda_backend(backend: str) -> Generator[None, None, None]:
65+
"""Context Manager to set the CUDA backend for :class:`~torchcodec.decoders.VideoDecoder`.
66+
67+
This context manager allows you to specify which CUDA backend implementation
68+
to use when creating :class:`~torchcodec.decoders.VideoDecoder` instances
69+
with CUDA devices. This is thread-safe and async-safe.
70+
71+
Note that you still need to pass `device="cuda"` when creating the
72+
:class:`~torchcodec.decoders.VideoDecoder` instance. If a CUDA device isn't
73+
specified, this context manager will have no effect.
74+
75+
Only the creation of the decoder needs to be inside the context manager, the
76+
decoding methods can be called outside of it.
77+
78+
Args:
79+
backend (str): The CUDA backend to use. Can be "ffmpeg" or "beta". Default is "ffmpeg".
80+
81+
Example:
82+
>>> with torchcodec.set_cuda_backend("beta"):
83+
... decoder = VideoDecoder("video.mp4", device="cuda")
84+
...
85+
... # Only the decoder creation needs to be part of the context manager.
86+
... # Decoder will now the beta CUDA implementation:
87+
... decoder.get_frame_at(0)
88+
"""
89+
backend = backend.lower()
90+
if backend not in ("ffmpeg", "beta"):
91+
raise ValueError(
92+
f"Invalid CUDA backend ({backend}). Supported values are 'ffmpeg' and 'beta'."
93+
)
94+
95+
previous_state = _CUDA_BACKEND.set(backend)
96+
try:
97+
yield
98+
finally:
99+
_CUDA_BACKEND.reset(previous_state)
100+
101+
102+
def _get_cuda_backend() -> str:
103+
return _CUDA_BACKEND.get()

src/torchcodec/decoders/_video_decoder.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from torchcodec import _core as core, Frame, FrameBatch
1717
from torchcodec.decoders._decoder_utils import (
18+
_get_cuda_backend,
1819
create_decoder,
1920
ERROR_REPORTING_INSTRUCTIONS,
2021
)
@@ -143,17 +144,17 @@ def __init__(
143144
if isinstance(device, torch_device):
144145
device = str(device)
145146

146-
# If device looks like "cuda:0:beta", make it "cuda:0" and set
147-
# device_variant to "beta"
148-
# TODONVDEC P2 Consider alternative ways of exposing custom device
149-
# variants, and if we want this new decoder backend to be a "device
150-
# variant" at all.
151-
device_variant = "default"
152-
if device is not None:
153-
device_split = device.split(":")
154-
if len(device_split) == 3:
155-
device_variant = device_split[2]
156-
device = ":".join(device_split[0:2])
147+
device_variant = _get_cuda_backend()
148+
if device_variant == "ffmpeg":
149+
# TODONVDEC P2 rename 'default' into 'ffmpeg' everywhere.
150+
device_variant = "default"
151+
152+
# Legacy support for device="cuda:0:beta" syntax
153+
# TODONVDEC P2: remove support for this everywhere. This will require
154+
# updating our tests.
155+
if device == "cuda:0:beta":
156+
device = "cuda:0"
157+
device_variant = "beta"
157158

158159
core.add_video_stream(
159160
self._decoder,

test/test_decoders.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,11 @@
1818
from torchcodec.decoders import (
1919
AudioDecoder,
2020
AudioStreamMetadata,
21+
set_cuda_backend,
2122
VideoDecoder,
2223
VideoStreamMetadata,
2324
)
25+
from torchcodec.decoders._decoder_utils import _get_cuda_backend
2426

2527
from .utils import (
2628
all_supported_devices,
@@ -1702,9 +1704,63 @@ def test_beta_cuda_interface_small_h265(self):
17021704

17031705
@needs_cuda
17041706
def test_beta_cuda_interface_error(self):
1705-
with pytest.raises(RuntimeError, match="Unsupported device"):
1707+
with pytest.raises(RuntimeError, match="Invalid device string"):
17061708
VideoDecoder(NASA_VIDEO.path, device="cuda:0:bad_variant")
17071709

1710+
@needs_cuda
1711+
def test_set_cuda_backend(self):
1712+
# Tests for the set_cuda_backend() context manager.
1713+
1714+
with pytest.raises(ValueError, match="Invalid CUDA backend"):
1715+
with set_cuda_backend("bad_backend"):
1716+
pass
1717+
1718+
# set_cuda_backend() is meant to be used as a context manager. Using it
1719+
# as a global call does nothing because the "context" is exited right
1720+
# away. This is a good thing, we prefer users to use it as a CM only.
1721+
set_cuda_backend("beta")
1722+
assert _get_cuda_backend() == "ffmpeg" # Not changed to "beta".
1723+
1724+
# Case insensitive
1725+
with set_cuda_backend("BETA"):
1726+
assert _get_cuda_backend() == "beta"
1727+
1728+
def assert_decoder_uses(decoder, *, expected_backend):
1729+
# Assert that a decoder instance is using a given backend.
1730+
#
1731+
# We know H265_VIDEO fails on the BETA backend while it works on the
1732+
# ffmpeg one.
1733+
if expected_backend == "ffmpeg":
1734+
decoder.get_frame_at(0) # this would fail if this was BETA
1735+
else:
1736+
with pytest.raises(RuntimeError, match="Video is too small"):
1737+
decoder.get_frame_at(0)
1738+
1739+
# Check that the default is the ffmpeg backend
1740+
assert _get_cuda_backend() == "ffmpeg"
1741+
dec = VideoDecoder(H265_VIDEO.path, device="cuda")
1742+
assert_decoder_uses(dec, expected_backend="ffmpeg")
1743+
1744+
# Check the setting "beta" effectively uses the BETA backend.
1745+
# We also show that the affects decoder creation only. When the decoder
1746+
# is created with a given backend, it stays in this backend for the rest
1747+
# of its life. This is normal and intended.
1748+
with set_cuda_backend("beta"):
1749+
dec = VideoDecoder(H265_VIDEO.path, device="cuda")
1750+
assert _get_cuda_backend() == "ffmpeg"
1751+
assert_decoder_uses(dec, expected_backend="beta")
1752+
with set_cuda_backend("ffmpeg"):
1753+
assert_decoder_uses(dec, expected_backend="beta")
1754+
1755+
# Hacky way to ensure passing "cuda:1" is supported by both backends. We
1756+
# just check that there's an error when passing cuda:N where N is too
1757+
# high.
1758+
bad_device_number = torch.cuda.device_count() + 1
1759+
for backend in ("ffmpeg", "beta"):
1760+
with pytest.raises(RuntimeError, match="invalid device ordinal"):
1761+
with set_cuda_backend(backend):
1762+
VideoDecoder(H265_VIDEO.path, device=f"cuda:{bad_device_number}")
1763+
17081764

17091765
class TestAudioDecoder:
17101766
@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3, SINE_MONO_S32))

0 commit comments

Comments
 (0)