44#include < ATen/Dispatch.h>
55#include < ATen/Utils.h>
66#include < ATen/NativeFunctions.h>
7+ #include < ATen/detail/CUDAHooksInterface.h>
78#include < ATen/native/SpectralOpsUtils.h>
89#include < ATen/native/cuda/CuFFTUtils.h>
910#include < ATen/native/cuda/CuFFTPlanCache.h>
1415#include < thrust/unique.h>
1516#include < cufft.h>
1617#include < cufftXt.h>
18+ #include < vector>
1719#include < cmath>
1820
1921namespace at { namespace native {
@@ -260,29 +262,59 @@ static inline Tensor _run_cufft(
260262}
261263
262264// The cuFFT plan cache, defined in CuFFTUtils.h
263- struct CuFFTParamsLRUCache plan_cache;
264- std::mutex plan_cache_mutex;
265+ std::vector<optional<CuFFTParamsLRUCache>> plan_caches;
266+ std::mutex plan_caches_mutex;
267+
268+ static inline
269+ CuFFTParamsLRUCache &cufft_get_plan_cache (int64_t device_index) {
270+ std::lock_guard<std::mutex> guard (plan_caches_mutex);
271+
272+ AT_ASSERT (device_index >= 0 );
273+
274+ if (device_index >= plan_caches.size ()) {
275+ plan_caches.resize (device_index + 1 );
276+ }
277+
278+ if (!plan_caches[device_index]) {
279+ plan_caches[device_index].emplace ();
280+ }
281+
282+ return *plan_caches[device_index];
283+ }
284+
265285
266286namespace detail {
267287
268- int64_t cufft_get_plan_cache_max_size_impl () {
269- std::lock_guard<std::mutex> guard (plan_cache_mutex);
270- return plan_cache.max_size ();
288+ int64_t cufft_get_plan_cache_max_size_impl (int64_t device_index) {
289+ AT_CHECK (0 <= device_index && device_index < at::detail::getCUDAHooks ().getNumGPUs (),
290+ " cufft_get_plan_cache_max_size: expected 0 <= device_index < " ,
291+ at::detail::getCUDAHooks ().getNumGPUs (), " ], but got device_index=" ,
292+ device_index);
293+ return cufft_get_plan_cache (device_index).max_size ();
271294}
272295
273- void cufft_set_plan_cache_max_size_impl (int64_t max_size) {
274- std::lock_guard<std::mutex> guard (plan_cache_mutex);
275- plan_cache.resize (max_size);
296+ void cufft_set_plan_cache_max_size_impl (int64_t device_index, int64_t max_size) {
297+ AT_CHECK (0 <= device_index && device_index < at::detail::getCUDAHooks ().getNumGPUs (),
298+ " cufft_set_plan_cache_max_size: expected 0 <= device_index < " ,
299+ at::detail::getCUDAHooks ().getNumGPUs (), " ], but got device_index=" ,
300+ device_index);
301+ return cufft_get_plan_cache (device_index).resize (max_size);
276302}
277303
278- int64_t cufft_get_plan_cache_size_impl () {
279- std::lock_guard<std::mutex> guard (plan_cache_mutex);
280- return plan_cache.size ();
304+ int64_t cufft_get_plan_cache_size_impl (int64_t device_index) {
305+ AT_CHECK (0 <= device_index && device_index < at::detail::getCUDAHooks ().getNumGPUs (),
306+ " cufft_get_plan_cache_size: expected 0 <= device_index < " ,
307+ at::detail::getCUDAHooks ().getNumGPUs (), " ], but got device_index=" ,
308+ device_index);
309+ return cufft_get_plan_cache (device_index).size ();
281310}
282311
283- void cufft_clear_plan_cache_impl () {
284- std::lock_guard<std::mutex> guard (plan_cache_mutex);
285- return plan_cache.clear ();
312+ void cufft_clear_plan_cache_impl (int64_t device_index) {
313+ AT_CHECK (0 <= device_index && device_index < at::detail::getCUDAHooks ().getNumGPUs (),
314+ " cufft_clear_plan_cache: expected 0 <= device_index < " ,
315+ at::detail::getCUDAHooks ().getNumGPUs (), " ], but got device_index=" ,
316+ device_index);
317+ return cufft_get_plan_cache (device_index).clear ();
286318}
287319
288320} // namespace at::native::detail
@@ -293,6 +325,9 @@ Tensor _fft_cufft(const Tensor& self, int64_t signal_ndim,
293325 bool complex_input, bool complex_output, bool inverse,
294326 IntArrayRef checked_signal_sizes, bool normalized, bool onesided,
295327 IntArrayRef output_sizes) {
328+
329+ CuFFTParamsLRUCache& plan_cache = cufft_get_plan_cache (self.device ().index ());
330+
296331 Tensor input = self;
297332 bool input_was_cloned = false ;
298333
@@ -334,7 +369,7 @@ Tensor _fft_cufft(const Tensor& self, int64_t signal_ndim,
334369 CuFFTParams params;
335370 setCuFFTParams (¶ms, input, signal_ndim, complex_input,
336371 complex_output, checked_signal_sizes, onesided);
337- std::lock_guard<std::mutex> guard (plan_cache_mutex );
372+ std::lock_guard<std::mutex> guard (plan_cache. mutex );
338373 if (plan_cache.max_size () > 0 ) { // check again after acquiring the lock
339374 const CuFFTConfig &config = plan_cache.try_emplace_value (std::move (params),
340375 input, signal_ndim, complex_input,
0 commit comments