diff --git a/aten/src/ATen/Context.h b/aten/src/ATen/Context.h index d3ffb51d496596..5060bf7f6aabdf 100644 --- a/aten/src/ATen/Context.h +++ b/aten/src/ATen/Context.h @@ -425,12 +425,37 @@ class TORCH_API Context { c10::utils::check_env("TORCH_LINALG_PREFER_CUSOLVER") == true ? at::LinalgBackend::Cusolver : at::LinalgBackend::Default; - at::BlasBackend blas_preferred_backend = #ifdef USE_ROCM - (c10::utils::check_env("TORCH_BLAS_PREFER_HIPBLASLT") != false) + // AMD Instinct targets prefer hipblaslt + const bool _hipblaslt_preferred_default = []() { + const std::vector archs = { + "gfx90a", "gfx942" + }; + for (auto index: c10::irange(detail::getCUDAHooks().deviceCount())) { + if (!detail::getCUDAHooks().isGPUArch(index, archs)) { + return false; + } + } + return true; + }(); #else - (c10::utils::check_env("TORCH_BLAS_PREFER_CUBLASLT") == true) + const bool _hipblaslt_preferred_default = false; +#endif + const bool _blaslt_preferred = [&]() { + auto env = c10::utils::check_env("TORCH_BLAS_PREFER_CUBLASLT"); + if (env.has_value()) { + return env.value(); + } + env = c10::utils::check_env("TORCH_BLAS_PREFER_HIPBLASLT"); + if (env.has_value()) { + return env.value(); + } +#ifdef USE_ROCM + return _hipblaslt_preferred_default; #endif + return false; + }(); + at::BlasBackend blas_preferred_backend = _blaslt_preferred ? at::BlasBackend::Cublaslt : at::BlasBackend::Cublas; at::ROCmFABackend rocm_fa_preferred_backend =