From b31afbd82190fefa964f7fbca54dc88b5a1f2504 Mon Sep 17 00:00:00 2001 From: Arash Pakbin Date: Thu, 29 May 2025 21:10:45 +0000 Subject: [PATCH] [ROCm] Exposing Some MIOpen Symbols (#2176) (#154545) This PR exposes some MIOpen symbols, namely: 1. `miopenDataType_t getMiopenDataType(const at::Tensor& tensor)` 2. `miopenHandle_t getMiopenHandle()` 3. `class TensorDescriptor` 4. `class Descriptor` 5. `class FilterDescriptor` 6. `struct ConvolutionDescriptor` 7. `struct DropoutDescriptor` 8. `struct RNNDescriptor` to enable adding extensions that make use of them. Pull Request resolved: https://github.com/pytorch/pytorch/pull/154545 Approved by: https://github.com/jeffdaily, https://github.com/Skylion007 Co-authored-by: Jeff Daily --- aten/src/ATen/miopen/Descriptors.h | 51 +++++++++++++++++++----------- aten/src/ATen/miopen/Handle.h | 8 ++--- aten/src/ATen/miopen/Types.h | 9 +++--- 3 files changed, 42 insertions(+), 26 deletions(-) diff --git a/aten/src/ATen/miopen/Descriptors.h b/aten/src/ATen/miopen/Descriptors.h index 2e67ff49d183..352efc391142 100644 --- a/aten/src/ATen/miopen/Descriptors.h +++ b/aten/src/ATen/miopen/Descriptors.h @@ -5,6 +5,7 @@ #include #include #include +#include namespace at { namespace native { @@ -37,9 +38,9 @@ struct DescriptorDeleter { // initialized the first time you call set() or any other initializing // function. template -class Descriptor -{ -public: +// NOLINTNEXTLINE(bugprone-exception-escape) +class TORCH_CUDA_CPP_API Descriptor { + public: // Use desc() to access the underlying descriptor pointer in // a read-only fashion. Most client code should use this. // If the descriptor was never initialized, this will return @@ -55,7 +56,7 @@ class Descriptor protected: void init() { if (desc_ == nullptr) { - T* raw_desc; + T* raw_desc = nullptr; MIOPEN_CHECK(ctor(&raw_desc)); desc_.reset(raw_desc); } @@ -64,13 +65,12 @@ class Descriptor std::unique_ptr> desc_; }; -class TensorDescriptor - : public Descriptor -{ -public: - TensorDescriptor() {} +class TORCH_CUDA_CPP_API TensorDescriptor : public Descriptor< + miopenTensorDescriptor, + &miopenCreateTensorDescriptor, + &miopenDestroyTensorDescriptor> { + public: + TensorDescriptor() = default; explicit TensorDescriptor(const at::Tensor &t, size_t pad = 0) { set(t, pad); } @@ -88,11 +88,10 @@ class TensorDescriptor std::ostream& operator<<(std::ostream & out, const TensorDescriptor& d); -class FilterDescriptor - : public Descriptor -{ +class TORCH_CUDA_CPP_API FilterDescriptor : public Descriptor< + miopenTensorDescriptor, + &miopenCreateTensorDescriptor, + &miopenDestroyTensorDescriptor> { public: void set(const at::Tensor &t, int64_t pad = 0) { set(t, at::MemoryFormat::Contiguous, pad); @@ -106,7 +105,7 @@ class FilterDescriptor } }; -struct ConvolutionDescriptor +struct TORCH_CUDA_CPP_API ConvolutionDescriptor : public Descriptor @@ -118,8 +117,24 @@ struct ConvolutionDescriptor } }; +// NOLINTNEXTLINE(bugprone-exception-escape) +struct TORCH_CUDA_CPP_API DropoutDescriptor + : public Descriptor< + miopenDropoutDescriptor, + &miopenCreateDropoutDescriptor, + &miopenDestroyDropoutDescriptor> { + void set(miopenHandle_t handle, float dropout, void* states, size_t stateSizeInBytes, + unsigned long long seed, bool use_mask, bool state_evo, miopenRNGType_t rng_mode) { + MIOPEN_CHECK(miopenSetDropoutDescriptor(mut_desc(), handle, dropout, states, stateSizeInBytes, seed, use_mask, state_evo, rng_mode)); + } + + void restore(miopenHandle_t handle, float dropout, void* states, size_t stateSizeInBytes, + unsigned long long seed, bool use_mask, bool state_evo, miopenRNGType_t rng_mode) { + MIOPEN_CHECK(miopenRestoreDropoutDescriptor(mut_desc(), handle, dropout, states, stateSizeInBytes, seed, use_mask, state_evo, rng_mode)); + } +}; -struct RNNDescriptor +struct TORCH_CUDA_CPP_API RNNDescriptor : public Descriptor diff --git a/aten/src/ATen/miopen/Handle.h b/aten/src/ATen/miopen/Handle.h index 9d537d809112..4c80c3aea65b 100644 --- a/aten/src/ATen/miopen/Handle.h +++ b/aten/src/ATen/miopen/Handle.h @@ -1,9 +1,9 @@ #pragma once #include +#include -namespace at { namespace native { +namespace at::native { -miopenHandle_t getMiopenHandle(); - -}} // namespace +TORCH_CUDA_CPP_API miopenHandle_t getMiopenHandle(); +} // namespace at::native diff --git a/aten/src/ATen/miopen/Types.h b/aten/src/ATen/miopen/Types.h index 5a207c83d387..0a8a1a952e2e 100644 --- a/aten/src/ATen/miopen/Types.h +++ b/aten/src/ATen/miopen/Types.h @@ -1,12 +1,13 @@ #pragma once -#include #include +#include +#include -namespace at { namespace native { +namespace at::native { -miopenDataType_t getMiopenDataType(const at::Tensor& tensor); +TORCH_CUDA_CPP_API miopenDataType_t getMiopenDataType(const at::Tensor& tensor); int64_t miopen_version(); -}} // namespace at::miopen +} // namespace at::native