Skip to content

Commit 51b10ec

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] Support for decoding jpegs on GPU with nvjpeg (#3792)
Summary: Co-authored-by: James Thewlis <[email protected]> Reviewed By: datumbox Differential Revision: D28473331 fbshipit-source-id: d82d415e81876b660e599997860c737848d9afc0
1 parent 36347bb commit 51b10ec

File tree

9 files changed

+300
-12
lines changed

9 files changed

+300
-12
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ include(CMakePackageConfigHelpers)
6161

6262
set(TVCPP torchvision/csrc)
6363
list(APPEND ALLOW_LISTED ${TVCPP} ${TVCPP}/io/image ${TVCPP}/io/image/cpu ${TVCPP}/models ${TVCPP}/ops
64-
${TVCPP}/ops/autograd ${TVCPP}/ops/cpu)
64+
${TVCPP}/ops/autograd ${TVCPP}/ops/cpu ${TVCPP}/io/image/cuda)
6565
if(WITH_CUDA)
6666
list(APPEND ALLOW_LISTED ${TVCPP}/ops/cuda ${TVCPP}/ops/autocast)
6767
endif()

setup.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -315,8 +315,23 @@ def get_extensions():
315315
image_library += [jpeg_lib]
316316
image_include += [jpeg_include]
317317

318+
# Locating nvjpeg
319+
# Should be included in CUDA_HOME for CUDA >= 10.1, which is the minimum version we have in the CI
320+
nvjpeg_found = (
321+
extension is CUDAExtension and
322+
CUDA_HOME is not None and
323+
os.path.exists(os.path.join(CUDA_HOME, 'include', 'nvjpeg.h'))
324+
)
325+
326+
print('NVJPEG found: {0}'.format(nvjpeg_found))
327+
image_macros += [('NVJPEG_FOUND', str(int(nvjpeg_found)))]
328+
if nvjpeg_found:
329+
print('Building torchvision with NVJPEG image support')
330+
image_link_flags.append('nvjpeg')
331+
318332
image_path = os.path.join(extensions_dir, 'io', 'image')
319-
image_src = glob.glob(os.path.join(image_path, '*.cpp')) + glob.glob(os.path.join(image_path, 'cpu', '*.cpp'))
333+
image_src = (glob.glob(os.path.join(image_path, '*.cpp')) + glob.glob(os.path.join(image_path, 'cpu', '*.cpp'))
334+
+ glob.glob(os.path.join(image_path, 'cuda', '*.cpp')))
320335

321336
if png_found or jpeg_found:
322337
ext_modules.append(extension(

test/common_utils.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@
2424
PY39_SEGFAULT_SKIP_MSG = "Segmentation fault with Python 3.9, see https://github.com/pytorch/vision/issues/3367"
2525
PY39_SKIP = unittest.skipIf(IS_PY39, PY39_SEGFAULT_SKIP_MSG)
2626
IN_CIRCLE_CI = os.getenv("CIRCLECI", False) == 'true'
27+
IN_RE_WORKER = os.environ.get("INSIDE_RE_WORKER") is not None
28+
IN_FBCODE = os.environ.get("IN_FBCODE_TORCHVISION") == "1"
29+
CUDA_NOT_AVAILABLE_MSG = 'CUDA device not available'
2730

2831

2932
@contextlib.contextmanager
@@ -407,11 +410,8 @@ def call_args_to_kwargs_only(call_args, *callable_or_arg_names):
407410

408411
def cpu_and_gpu():
409412
import pytest # noqa
410-
# ignore CPU tests in RE as they're already covered by another contbuild
411-
IN_RE_WORKER = os.environ.get("INSIDE_RE_WORKER") is not None
412-
IN_FBCODE = os.environ.get("IN_FBCODE_TORCHVISION") == "1"
413-
CUDA_NOT_AVAILABLE_MSG = 'CUDA device not available'
414413

414+
# ignore CPU tests in RE as they're already covered by another contbuild
415415
devices = [] if IN_RE_WORKER else ['cpu']
416416

417417
if torch.cuda.is_available():
@@ -427,3 +427,17 @@ def cpu_and_gpu():
427427
devices.append(pytest.param('cuda', marks=cuda_marks))
428428

429429
return devices
430+
431+
432+
def needs_cuda(test_func):
433+
import pytest # noqa
434+
435+
if IN_FBCODE and not IN_RE_WORKER:
436+
# We don't want to skip in fbcode, so we just don't collect
437+
# TODO: slightly more robust way would be to detect if we're in a sandcastle instance
438+
# so that the test will still be collected (and skipped) in the devvms.
439+
return pytest.mark.dont_collect(test_func)
440+
elif torch.cuda.is_available():
441+
return test_func
442+
else:
443+
return pytest.mark.skip(reason=CUDA_NOT_AVAILABLE_MSG)(test_func)

test/test_image.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
import os
44
import unittest
55

6+
import pytest
67
import numpy as np
78
import torch
89
from PIL import Image
9-
from common_utils import get_tmp_dir
10+
from common_utils import get_tmp_dir, needs_cuda
1011

1112
from torchvision.io.image import (
1213
decode_png, decode_jpeg, encode_jpeg, write_jpeg, decode_image, read_file,
@@ -278,5 +279,51 @@ def test_write_file_non_ascii(self):
278279
os.unlink(fpath)
279280

280281

282+
@needs_cuda
283+
@pytest.mark.parametrize('img_path', [
284+
# We need to change the "id" for that parameter.
285+
# If we don't, the test id (i.e. its name) will contain the whole path to the image which is machine-specific,
286+
# and this creates issues when the test is running in a different machine than where it was collected
287+
# (typically, in fb internal infra)
288+
pytest.param(jpeg_path, id=jpeg_path.split('/')[-1])
289+
for jpeg_path in get_images(IMAGE_ROOT, ".jpg")
290+
])
291+
@pytest.mark.parametrize('mode', [ImageReadMode.UNCHANGED, ImageReadMode.GRAY, ImageReadMode.RGB])
292+
@pytest.mark.parametrize('scripted', (False, True))
293+
def test_decode_jpeg_cuda(mode, img_path, scripted):
294+
if 'cmyk' in img_path:
295+
pytest.xfail("Decoding a CMYK jpeg isn't supported")
296+
tester = ImageTester()
297+
data = read_file(img_path)
298+
img = decode_image(data, mode=mode)
299+
f = torch.jit.script(decode_jpeg) if scripted else decode_jpeg
300+
img_nvjpeg = f(data, mode=mode, device='cuda')
301+
302+
# Some difference expected between jpeg implementations
303+
tester.assertTrue((img.float() - img_nvjpeg.cpu().float()).abs().mean() < 2)
304+
305+
306+
@needs_cuda
307+
@pytest.mark.parametrize('cuda_device', ('cuda', 'cuda:0', torch.device('cuda')))
308+
def test_decode_jpeg_cuda_device_param(cuda_device):
309+
"""Make sure we can pass a string or a torch.device as device param"""
310+
torch.randint(0, 10, (10,), device=cuda_device)
311+
data = read_file(next(get_images(IMAGE_ROOT, ".jpg")))
312+
decode_jpeg(data, device=cuda_device)
313+
314+
315+
@needs_cuda
316+
def test_decode_jpeg_cuda_errors():
317+
data = read_file(next(get_images(IMAGE_ROOT, ".jpg")))
318+
with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"):
319+
decode_jpeg(data.reshape(-1, 1), device='cuda')
320+
with pytest.raises(RuntimeError, match="input tensor must be on CPU"):
321+
decode_jpeg(data.to('cuda'), device='cuda')
322+
with pytest.raises(RuntimeError, match="Expected a torch.uint8 tensor"):
323+
decode_jpeg(data.to(torch.float), device='cuda')
324+
with pytest.raises(RuntimeError, match="Expected a cuda device"):
325+
torch.ops.image.decode_jpeg_cuda(data, ImageReadMode.UNCHANGED.value, 'cpu')
326+
327+
281328
if __name__ == '__main__':
282329
unittest.main()
Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
#include "decode_jpeg_cuda.h"
2+
3+
#include <ATen/ATen.h>
4+
5+
#if NVJPEG_FOUND
6+
#include <ATen/cuda/CUDAContext.h>
7+
#include <c10/cuda/CUDAGuard.h>
8+
#include <nvjpeg.h>
9+
#endif
10+
11+
#include <string>
12+
13+
namespace vision {
14+
namespace image {
15+
16+
#if !NVJPEG_FOUND
17+
18+
torch::Tensor decode_jpeg_cuda(
19+
const torch::Tensor& data,
20+
ImageReadMode mode,
21+
torch::Device device) {
22+
TORCH_CHECK(
23+
false, "decode_jpeg_cuda: torchvision not compiled with nvJPEG support");
24+
}
25+
26+
#else
27+
28+
namespace {
29+
static nvjpegHandle_t nvjpeg_handle = nullptr;
30+
}
31+
32+
torch::Tensor decode_jpeg_cuda(
33+
const torch::Tensor& data,
34+
ImageReadMode mode,
35+
torch::Device device) {
36+
TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor");
37+
38+
TORCH_CHECK(
39+
!data.is_cuda(),
40+
"The input tensor must be on CPU when decoding with nvjpeg")
41+
42+
TORCH_CHECK(
43+
data.dim() == 1 && data.numel() > 0,
44+
"Expected a non empty 1-dimensional tensor");
45+
46+
TORCH_CHECK(device.is_cuda(), "Expected a cuda device")
47+
48+
at::cuda::CUDAGuard device_guard(device);
49+
50+
// Create global nvJPEG handle
51+
std::once_flag nvjpeg_handle_creation_flag;
52+
std::call_once(nvjpeg_handle_creation_flag, []() {
53+
if (nvjpeg_handle == nullptr) {
54+
nvjpegStatus_t create_status = nvjpegCreateSimple(&nvjpeg_handle);
55+
56+
if (create_status != NVJPEG_STATUS_SUCCESS) {
57+
// Reset handle so that one can still call the function again in the
58+
// same process if there was a failure
59+
free(nvjpeg_handle);
60+
nvjpeg_handle = nullptr;
61+
}
62+
TORCH_CHECK(
63+
create_status == NVJPEG_STATUS_SUCCESS,
64+
"nvjpegCreateSimple failed: ",
65+
create_status);
66+
}
67+
});
68+
69+
// Create the jpeg state
70+
nvjpegJpegState_t jpeg_state;
71+
nvjpegStatus_t state_status =
72+
nvjpegJpegStateCreate(nvjpeg_handle, &jpeg_state);
73+
74+
TORCH_CHECK(
75+
state_status == NVJPEG_STATUS_SUCCESS,
76+
"nvjpegJpegStateCreate failed: ",
77+
state_status);
78+
79+
auto datap = data.data_ptr<uint8_t>();
80+
81+
// Get the image information
82+
int num_channels;
83+
nvjpegChromaSubsampling_t subsampling;
84+
int widths[NVJPEG_MAX_COMPONENT];
85+
int heights[NVJPEG_MAX_COMPONENT];
86+
nvjpegStatus_t info_status = nvjpegGetImageInfo(
87+
nvjpeg_handle,
88+
datap,
89+
data.numel(),
90+
&num_channels,
91+
&subsampling,
92+
widths,
93+
heights);
94+
95+
if (info_status != NVJPEG_STATUS_SUCCESS) {
96+
nvjpegJpegStateDestroy(jpeg_state);
97+
TORCH_CHECK(false, "nvjpegGetImageInfo failed: ", info_status);
98+
}
99+
100+
if (subsampling == NVJPEG_CSS_UNKNOWN) {
101+
nvjpegJpegStateDestroy(jpeg_state);
102+
TORCH_CHECK(false, "Unknown NVJPEG chroma subsampling");
103+
}
104+
105+
int width = widths[0];
106+
int height = heights[0];
107+
108+
nvjpegOutputFormat_t ouput_format;
109+
int num_channels_output;
110+
111+
switch (mode) {
112+
case IMAGE_READ_MODE_UNCHANGED:
113+
num_channels_output = num_channels;
114+
// For some reason, setting output_format to NVJPEG_OUTPUT_UNCHANGED will
115+
// not properly decode RGB images (it's fine for grayscale), so we set
116+
// output_format manually here
117+
if (num_channels == 1) {
118+
ouput_format = NVJPEG_OUTPUT_Y;
119+
} else if (num_channels == 3) {
120+
ouput_format = NVJPEG_OUTPUT_RGB;
121+
} else {
122+
nvjpegJpegStateDestroy(jpeg_state);
123+
TORCH_CHECK(
124+
false,
125+
"When mode is UNCHANGED, only 1 or 3 input channels are allowed.");
126+
}
127+
break;
128+
case IMAGE_READ_MODE_GRAY:
129+
ouput_format = NVJPEG_OUTPUT_Y;
130+
num_channels_output = 1;
131+
break;
132+
case IMAGE_READ_MODE_RGB:
133+
ouput_format = NVJPEG_OUTPUT_RGB;
134+
num_channels_output = 3;
135+
break;
136+
default:
137+
nvjpegJpegStateDestroy(jpeg_state);
138+
TORCH_CHECK(
139+
false, "The provided mode is not supported for JPEG decoding on GPU");
140+
}
141+
142+
auto out_tensor = torch::empty(
143+
{int64_t(num_channels_output), int64_t(height), int64_t(width)},
144+
torch::dtype(torch::kU8).device(device));
145+
146+
// nvjpegImage_t is a struct with
147+
// - an array of pointers to each channel
148+
// - the pitch for each channel
149+
// which must be filled in manually
150+
nvjpegImage_t out_image;
151+
152+
for (int c = 0; c < num_channels_output; c++) {
153+
out_image.channel[c] = out_tensor[c].data_ptr<uint8_t>();
154+
out_image.pitch[c] = width;
155+
}
156+
for (int c = num_channels_output; c < NVJPEG_MAX_COMPONENT; c++) {
157+
out_image.channel[c] = nullptr;
158+
out_image.pitch[c] = 0;
159+
}
160+
161+
cudaStream_t stream = at::cuda::getCurrentCUDAStream(device.index());
162+
163+
nvjpegStatus_t decode_status = nvjpegDecode(
164+
nvjpeg_handle,
165+
jpeg_state,
166+
datap,
167+
data.numel(),
168+
ouput_format,
169+
&out_image,
170+
stream);
171+
172+
nvjpegJpegStateDestroy(jpeg_state);
173+
174+
TORCH_CHECK(
175+
decode_status == NVJPEG_STATUS_SUCCESS,
176+
"nvjpegDecode failed: ",
177+
decode_status);
178+
179+
return out_tensor;
180+
}
181+
182+
#endif // NVJPEG_FOUND
183+
184+
} // namespace image
185+
} // namespace vision
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#pragma once
2+
3+
#include <torch/types.h>
4+
#include "../image_read_mode.h"
5+
6+
namespace vision {
7+
namespace image {
8+
9+
C10_EXPORT torch::Tensor decode_jpeg_cuda(
10+
const torch::Tensor& data,
11+
ImageReadMode mode,
12+
torch::Device device);
13+
14+
} // namespace image
15+
} // namespace vision

torchvision/csrc/io/image/image.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ static auto registry = torch::RegisterOperators()
2121
.op("image::encode_jpeg", &encode_jpeg)
2222
.op("image::read_file", &read_file)
2323
.op("image::write_file", &write_file)
24-
.op("image::decode_image", &decode_image);
24+
.op("image::decode_image", &decode_image)
25+
.op("image::decode_jpeg_cuda", &decode_jpeg_cuda);
2526

2627
} // namespace image
2728
} // namespace vision

torchvision/csrc/io/image/image.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@
66
#include "cpu/encode_jpeg.h"
77
#include "cpu/encode_png.h"
88
#include "cpu/read_write_file.h"
9+
#include "cuda/decode_jpeg_cuda.h"

0 commit comments

Comments
 (0)