Skip to content

Commit 973d510

Browse files
ssnlfacebook-github-bot
authored andcommitted
Add device-specific cuFFT plan caches (pytorch#19300)
Summary: Fixes pytorch#19224 Pull Request resolved: pytorch#19300 Differential Revision: D14986967 Pulled By: soumith fbshipit-source-id: 8c31237db50d6924bba1472434c10326610d9255
1 parent b8fb6ea commit 973d510

File tree

11 files changed

+238
-92
lines changed

11 files changed

+238
-92
lines changed

aten/src/ATen/cuda/detail/CUDAHooks.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -196,33 +196,33 @@ double CUDAHooks::batchnormMinEpsilonCuDNN() const {
196196
#endif
197197
}
198198

199-
int64_t CUDAHooks::cuFFTGetPlanCacheMaxSize() const {
199+
int64_t CUDAHooks::cuFFTGetPlanCacheMaxSize(int64_t device_index) const {
200200
#ifndef __HIP_PLATFORM_HCC__
201-
return at::native::detail::cufft_get_plan_cache_max_size_impl();
201+
return at::native::detail::cufft_get_plan_cache_max_size_impl(device_index);
202202
#else
203203
AT_ERROR("cuFFT with HIP is not supported");
204204
#endif
205205
}
206206

207-
void CUDAHooks::cuFFTSetPlanCacheMaxSize(int64_t max_size) const {
207+
void CUDAHooks::cuFFTSetPlanCacheMaxSize(int64_t device_index, int64_t max_size) const {
208208
#ifndef __HIP_PLATFORM_HCC__
209-
at::native::detail::cufft_set_plan_cache_max_size_impl(max_size);
209+
at::native::detail::cufft_set_plan_cache_max_size_impl(device_index, max_size);
210210
#else
211211
AT_ERROR("cuFFT with HIP is not supported");
212212
#endif
213213
}
214214

215-
int64_t CUDAHooks::cuFFTGetPlanCacheSize() const {
215+
int64_t CUDAHooks::cuFFTGetPlanCacheSize(int64_t device_index) const {
216216
#ifndef __HIP_PLATFORM_HCC__
217-
return at::native::detail::cufft_get_plan_cache_size_impl();
217+
return at::native::detail::cufft_get_plan_cache_size_impl(device_index);
218218
#else
219219
AT_ERROR("cuFFT with HIP is not supported");
220220
#endif
221221
}
222222

223-
void CUDAHooks::cuFFTClearPlanCache() const {
223+
void CUDAHooks::cuFFTClearPlanCache(int64_t device_index) const {
224224
#ifndef __HIP_PLATFORM_HCC__
225-
at::native::detail::cufft_clear_plan_cache_impl();
225+
at::native::detail::cufft_clear_plan_cache_impl(device_index);
226226
#else
227227
AT_ERROR("cuFFT with HIP is not supported");
228228
#endif

aten/src/ATen/cuda/detail/CUDAHooks.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@ struct CUDAHooks : public at::CUDAHooksInterface {
2424
long versionCuDNN() const override;
2525
std::string showConfig() const override;
2626
double batchnormMinEpsilonCuDNN() const override;
27-
int64_t cuFFTGetPlanCacheMaxSize() const override;
28-
void cuFFTSetPlanCacheMaxSize(int64_t max_size) const override;
29-
int64_t cuFFTGetPlanCacheSize() const override;
30-
void cuFFTClearPlanCache() const override;
27+
int64_t cuFFTGetPlanCacheMaxSize(int64_t device_index) const override;
28+
void cuFFTSetPlanCacheMaxSize(int64_t device_index, int64_t max_size) const override;
29+
int64_t cuFFTGetPlanCacheSize(int64_t device_index) const override;
30+
void cuFFTClearPlanCache(int64_t device_index) const override;
3131
int getNumGPUs() const override;
3232
};
3333

aten/src/ATen/detail/CUDAHooksInterface.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -111,19 +111,19 @@ struct CAFFE2_API CUDAHooksInterface {
111111
"Cannot query batchnormMinEpsilonCuDNN() without ATen_cuda library. ", CUDA_HELP);
112112
}
113113

114-
virtual int64_t cuFFTGetPlanCacheMaxSize() const {
114+
virtual int64_t cuFFTGetPlanCacheMaxSize(int64_t device_index) const {
115115
AT_ERROR("Cannot access cuFFT plan cache without ATen_cuda library. ", CUDA_HELP);
116116
}
117117

118-
virtual void cuFFTSetPlanCacheMaxSize(int64_t max_size) const {
118+
virtual void cuFFTSetPlanCacheMaxSize(int64_t device_index, int64_t max_size) const {
119119
AT_ERROR("Cannot access cuFFT plan cache without ATen_cuda library. ", CUDA_HELP);
120120
}
121121

122-
virtual int64_t cuFFTGetPlanCacheSize() const {
122+
virtual int64_t cuFFTGetPlanCacheSize(int64_t device_index) const {
123123
AT_ERROR("Cannot access cuFFT plan cache without ATen_cuda library. ", CUDA_HELP);
124124
}
125125

126-
virtual void cuFFTClearPlanCache() const {
126+
virtual void cuFFTClearPlanCache(int64_t device_index) const {
127127
AT_ERROR("Cannot access cuFFT plan cache without ATen_cuda library. ", CUDA_HELP);
128128
}
129129

aten/src/ATen/native/SpectralOps.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -130,20 +130,20 @@ static inline Tensor _fft(const Tensor &self, const int64_t signal_ndim,
130130

131131
// We call the following methods via CUDA hooks because they are really only
132132
// valid when CUDA is available. See native/cuda/CuFFTPlanCache.h for more details.
133-
int64_t _cufft_get_plan_cache_max_size() {
134-
return detail::getCUDAHooks().cuFFTGetPlanCacheMaxSize();
133+
int64_t _cufft_get_plan_cache_max_size(int64_t device_index) {
134+
return detail::getCUDAHooks().cuFFTGetPlanCacheMaxSize(device_index);
135135
}
136136

137-
void _cufft_set_plan_cache_max_size(int64_t max_size) {
138-
detail::getCUDAHooks().cuFFTSetPlanCacheMaxSize(max_size);
137+
void _cufft_set_plan_cache_max_size(int64_t device_index, int64_t max_size) {
138+
detail::getCUDAHooks().cuFFTSetPlanCacheMaxSize(device_index, max_size);
139139
}
140140

141-
int64_t _cufft_get_plan_cache_size() {
142-
return detail::getCUDAHooks().cuFFTGetPlanCacheSize();
141+
int64_t _cufft_get_plan_cache_size(int64_t device_index) {
142+
return detail::getCUDAHooks().cuFFTGetPlanCacheSize(device_index);
143143
}
144144

145-
void _cufft_clear_plan_cache() {
146-
detail::getCUDAHooks().cuFFTClearPlanCache();
145+
void _cufft_clear_plan_cache(int64_t device_index) {
146+
detail::getCUDAHooks().cuFFTClearPlanCache(device_index);
147147
}
148148

149149
Tensor fft(const Tensor& self, const int64_t signal_ndim, const bool normalized) {

aten/src/ATen/native/cuda/CuFFTPlanCache.h

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,18 @@ class CuFFTParamsLRUCache {
373373
_set_max_size(max_size);
374374
}
375375

376+
CuFFTParamsLRUCache(CuFFTParamsLRUCache&& other) noexcept :
377+
_usage_list(std::move(other._usage_list)),
378+
_cache_map(std::move(other._cache_map)),
379+
_max_size(other._max_size) {}
380+
381+
CuFFTParamsLRUCache& operator=(CuFFTParamsLRUCache&& other) noexcept {
382+
_usage_list = std::move(other._usage_list);
383+
_cache_map = std::move(other._cache_map);
384+
_max_size = other._max_size;
385+
return *this;
386+
}
387+
376388
// If key is in this cache, return the cached config. Otherwise, emplace the
377389
// config in this cache using value_args and return it.
378390
// Return const reference because CuFFTConfig shouldn't be tampered with once
@@ -431,6 +443,8 @@ class CuFFTParamsLRUCache {
431443

432444
size_t max_size() const noexcept { return _max_size; }
433445

446+
std::mutex mutex;
447+
434448
private:
435449
// Only sets size and does value check. Does not resize the data structures.
436450
void _set_max_size(int64_t new_size) {
@@ -455,9 +469,9 @@ class CuFFTParamsLRUCache {
455469
// native function counterparts (at native/SpectralOps.cpp), i.e.,
456470
// _cufft_get_plan_cache_max_size, _cufft_set_plan_cache_max_size
457471
// _cufft_get_plan_cache_size, and _cufft_clear_plan_cache.
458-
int64_t cufft_get_plan_cache_max_size_impl();
459-
void cufft_set_plan_cache_max_size_impl(int64_t max_size);
460-
int64_t cufft_get_plan_cache_size_impl();
461-
void cufft_clear_plan_cache_impl();
472+
int64_t cufft_get_plan_cache_max_size_impl(int64_t device_index);
473+
void cufft_set_plan_cache_max_size_impl(int64_t device_index, int64_t max_size);
474+
int64_t cufft_get_plan_cache_size_impl(int64_t device_index);
475+
void cufft_clear_plan_cache_impl(int64_t device_index);
462476

463477
}}} // namespace at::native::detail

aten/src/ATen/native/cuda/SpectralOps.cu

Lines changed: 50 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <ATen/Dispatch.h>
55
#include <ATen/Utils.h>
66
#include <ATen/NativeFunctions.h>
7+
#include <ATen/detail/CUDAHooksInterface.h>
78
#include <ATen/native/SpectralOpsUtils.h>
89
#include <ATen/native/cuda/CuFFTUtils.h>
910
#include <ATen/native/cuda/CuFFTPlanCache.h>
@@ -14,6 +15,7 @@
1415
#include <thrust/unique.h>
1516
#include <cufft.h>
1617
#include <cufftXt.h>
18+
#include <vector>
1719
#include <cmath>
1820

1921
namespace at { namespace native {
@@ -260,29 +262,59 @@ static inline Tensor _run_cufft(
260262
}
261263

262264
// The cuFFT plan cache, defined in CuFFTUtils.h
263-
struct CuFFTParamsLRUCache plan_cache;
264-
std::mutex plan_cache_mutex;
265+
std::vector<optional<CuFFTParamsLRUCache>> plan_caches;
266+
std::mutex plan_caches_mutex;
267+
268+
static inline
269+
CuFFTParamsLRUCache &cufft_get_plan_cache(int64_t device_index) {
270+
std::lock_guard<std::mutex> guard(plan_caches_mutex);
271+
272+
AT_ASSERT(device_index >= 0);
273+
274+
if (device_index >= plan_caches.size()) {
275+
plan_caches.resize(device_index + 1);
276+
}
277+
278+
if (!plan_caches[device_index]) {
279+
plan_caches[device_index].emplace();
280+
}
281+
282+
return *plan_caches[device_index];
283+
}
284+
265285

266286
namespace detail {
267287

268-
int64_t cufft_get_plan_cache_max_size_impl() {
269-
std::lock_guard<std::mutex> guard(plan_cache_mutex);
270-
return plan_cache.max_size();
288+
int64_t cufft_get_plan_cache_max_size_impl(int64_t device_index) {
289+
AT_CHECK(0 <= device_index && device_index < at::detail::getCUDAHooks().getNumGPUs(),
290+
"cufft_get_plan_cache_max_size: expected 0 <= device_index < ",
291+
at::detail::getCUDAHooks().getNumGPUs(), "], but got device_index=",
292+
device_index);
293+
return cufft_get_plan_cache(device_index).max_size();
271294
}
272295

273-
void cufft_set_plan_cache_max_size_impl(int64_t max_size) {
274-
std::lock_guard<std::mutex> guard(plan_cache_mutex);
275-
plan_cache.resize(max_size);
296+
void cufft_set_plan_cache_max_size_impl(int64_t device_index, int64_t max_size) {
297+
AT_CHECK(0 <= device_index && device_index < at::detail::getCUDAHooks().getNumGPUs(),
298+
"cufft_set_plan_cache_max_size: expected 0 <= device_index < ",
299+
at::detail::getCUDAHooks().getNumGPUs(), "], but got device_index=",
300+
device_index);
301+
return cufft_get_plan_cache(device_index).resize(max_size);
276302
}
277303

278-
int64_t cufft_get_plan_cache_size_impl() {
279-
std::lock_guard<std::mutex> guard(plan_cache_mutex);
280-
return plan_cache.size();
304+
int64_t cufft_get_plan_cache_size_impl(int64_t device_index) {
305+
AT_CHECK(0 <= device_index && device_index < at::detail::getCUDAHooks().getNumGPUs(),
306+
"cufft_get_plan_cache_size: expected 0 <= device_index < ",
307+
at::detail::getCUDAHooks().getNumGPUs(), "], but got device_index=",
308+
device_index);
309+
return cufft_get_plan_cache(device_index).size();
281310
}
282311

283-
void cufft_clear_plan_cache_impl() {
284-
std::lock_guard<std::mutex> guard(plan_cache_mutex);
285-
return plan_cache.clear();
312+
void cufft_clear_plan_cache_impl(int64_t device_index) {
313+
AT_CHECK(0 <= device_index && device_index < at::detail::getCUDAHooks().getNumGPUs(),
314+
"cufft_clear_plan_cache: expected 0 <= device_index < ",
315+
at::detail::getCUDAHooks().getNumGPUs(), "], but got device_index=",
316+
device_index);
317+
return cufft_get_plan_cache(device_index).clear();
286318
}
287319

288320
} // namespace at::native::detail
@@ -293,6 +325,9 @@ Tensor _fft_cufft(const Tensor& self, int64_t signal_ndim,
293325
bool complex_input, bool complex_output, bool inverse,
294326
IntArrayRef checked_signal_sizes, bool normalized, bool onesided,
295327
IntArrayRef output_sizes) {
328+
329+
CuFFTParamsLRUCache& plan_cache = cufft_get_plan_cache(self.device().index());
330+
296331
Tensor input = self;
297332
bool input_was_cloned = false;
298333

@@ -334,7 +369,7 @@ Tensor _fft_cufft(const Tensor& self, int64_t signal_ndim,
334369
CuFFTParams params;
335370
setCuFFTParams(&params, input, signal_ndim, complex_input,
336371
complex_output, checked_signal_sizes, onesided);
337-
std::lock_guard<std::mutex> guard(plan_cache_mutex);
372+
std::lock_guard<std::mutex> guard(plan_cache.mutex);
338373
if (plan_cache.max_size() > 0) { // check again after acquiring the lock
339374
const CuFFTConfig &config = plan_cache.try_emplace_value(std::move(params),
340375
input, signal_ndim, complex_input,

aten/src/ATen/native/native_functions.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -917,13 +917,13 @@
917917
CPU: _fft_mkl
918918
CUDA: _fft_cufft
919919

920-
- func: _cufft_get_plan_cache_size() -> int
920+
- func: _cufft_get_plan_cache_size(int device_index) -> int
921921

922-
- func: _cufft_get_plan_cache_max_size() -> int
922+
- func: _cufft_get_plan_cache_max_size(int device_index) -> int
923923

924-
- func: _cufft_set_plan_cache_max_size(int max_size) -> void
924+
- func: _cufft_set_plan_cache_max_size(int device_index, int max_size) -> void
925925

926-
- func: _cufft_clear_plan_cache() -> void
926+
- func: _cufft_clear_plan_cache(int device_index) -> void
927927

928928
- func: index(Tensor self, Tensor?[] indices) -> Tensor
929929
variants: function, method

docs/source/notes/cuda.rst

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,3 +272,31 @@ There are significant caveats to using CUDA models with
272272
:mod:`~torch.multiprocessing`; unless care is taken to meet the data handling
273273
requirements exactly, it is likely that your program will have incorrect or
274274
undefined behavior.
275+
276+
.. _cufft-plan-cache:
277+
278+
cuFFT plan cache
279+
^^^^^^^^^^^^^^^^
280+
281+
For each CUDA device, an LRU cache of cuFFT plans is used to speed up repeatedly
282+
running FFT methods (e.g., :func:`torch.fft`) on CUDA tensors of same geometry
283+
with same configuration. Because some cuFFT plans may allocate GPU memory,
284+
these caches have a maximum capacity.
285+
286+
You may control and query the properties of the cache of current device with
287+
the following APIs:
288+
289+
* ``torch.backends.cuda.cufft_plan_cache.max_size`` gives the capacity of the
290+
cache (default is 4096 on CUDA 10 and newer, and 1023 on older CUDA versions).
291+
Setting this value directly modifies the capacity.
292+
293+
* ``torch.backends.cuda.cufft_plan_cache.size`` gives the number of plans
294+
currently residing in the cache.
295+
296+
* ``torch.backends.cuda.cufft_plan_cache.clear()`` clears the cache.
297+
298+
To control and query plan caches of a non-default device, you can index the
299+
``torch.backends.cuda.cufft_plan_cache`` object with either a :class:`torch.device`
300+
object or a device index, and access one of the above attributes. E.g., to set
301+
the capacity of the cache for device ``1``, one can write
302+
``torch.backends.cuda.cufft_plan_cache[1].max_size = 10``.

test/test_cuda.py

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2192,11 +2192,15 @@ def test_fft_ifft_rfft_irfft(self):
21922192
_TestTorchMixin._test_fft_ifft_rfft_irfft(self, device=torch.device('cuda'))
21932193

21942194
@contextmanager
2195-
def plan_cache_max_size(n):
2196-
original = torch.backends.cuda.cufft_plan_cache.max_size
2197-
torch.backends.cuda.cufft_plan_cache.max_size = n
2195+
def plan_cache_max_size(n, device=None):
2196+
if device is None:
2197+
plan_cache = torch.backends.cuda.cufft_plan_cache
2198+
else:
2199+
plan_cache = torch.backends.cuda.cufft_plan_cache[device]
2200+
original = plan_cache.max_size
2201+
plan_cache.max_size = n
21982202
yield
2199-
torch.backends.cuda.cufft_plan_cache.max_size = original
2203+
plan_cache.max_size = original
22002204

22012205
with plan_cache_max_size(max(1, torch.backends.cuda.cufft_plan_cache.size - 10)):
22022206
_TestTorchMixin._test_fft_ifft_rfft_irfft(self, device=torch.device('cuda'))
@@ -2216,6 +2220,44 @@ def plan_cache_max_size(n):
22162220
with self.assertRaisesRegex(RuntimeError, r"read-only property"):
22172221
torch.backends.cuda.cufft_plan_cache.size = -1
22182222

2223+
with self.assertRaisesRegex(RuntimeError, r"but got device with index"):
2224+
torch.backends.cuda.cufft_plan_cache[torch.cuda.device_count() + 10]
2225+
2226+
if TEST_MULTIGPU:
2227+
# Test that different GPU has different cache
2228+
x0 = torch.randn(2, 3, 3, device='cuda:0')
2229+
x1 = x0.cuda(1)
2230+
self.assertEqual(x0.rfft(2), x1.rfft(2))
2231+
# If a plan is used across different devices, the following line (or
2232+
# the assert above) would trigger illegal memory access. Other ways
2233+
# to trigger the error include
2234+
# (1) setting CUDA_LAUNCH_BLOCKING=1 (pytorch/pytorch#19224) and
2235+
# (2) printing a device 1 tensor.
2236+
x0.copy_(x1)
2237+
2238+
# Test that un-indexed `torch.backends.cuda.cufft_plan_cache` uses current device
2239+
with plan_cache_max_size(10, device='cuda:0'):
2240+
with plan_cache_max_size(11, device='cuda:1'):
2241+
self.assertEqual(torch.backends.cuda.cufft_plan_cache[0].max_size, 10)
2242+
self.assertEqual(torch.backends.cuda.cufft_plan_cache[1].max_size, 11)
2243+
2244+
self.assertEqual(torch.backends.cuda.cufft_plan_cache.max_size, 10) # default is cuda:0
2245+
with torch.cuda.device(1):
2246+
self.assertEqual(torch.backends.cuda.cufft_plan_cache.max_size, 11) # default is cuda:1
2247+
with torch.cuda.device(0):
2248+
self.assertEqual(torch.backends.cuda.cufft_plan_cache.max_size, 10) # default is cuda:0
2249+
2250+
self.assertEqual(torch.backends.cuda.cufft_plan_cache[0].max_size, 10)
2251+
with torch.cuda.device(1):
2252+
with plan_cache_max_size(11): # default is cuda:1
2253+
self.assertEqual(torch.backends.cuda.cufft_plan_cache[0].max_size, 10)
2254+
self.assertEqual(torch.backends.cuda.cufft_plan_cache[1].max_size, 11)
2255+
2256+
self.assertEqual(torch.backends.cuda.cufft_plan_cache.max_size, 11) # default is cuda:1
2257+
with torch.cuda.device(0):
2258+
self.assertEqual(torch.backends.cuda.cufft_plan_cache.max_size, 10) # default is cuda:0
2259+
self.assertEqual(torch.backends.cuda.cufft_plan_cache.max_size, 11) # default is cuda:1
2260+
22192261
def test_stft(self):
22202262
_TestTorchMixin._test_stft(self, device=torch.device('cuda'))
22212263

0 commit comments

Comments
 (0)