Skip to content

[rocm6.4_internal_testing] Exposing Some MIOpen Symbols (#2176) (#154545) #2218

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 33 additions & 18 deletions aten/src/ATen/miopen/Descriptors.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <ATen/miopen/miopen-wrapper.h>
#include <ATen/core/Tensor.h>
#include <ATen/TensorUtils.h>
#include <c10/macros/Export.h>

namespace at { namespace native {

Expand Down Expand Up @@ -37,9 +38,9 @@ struct DescriptorDeleter {
// initialized the first time you call set() or any other initializing
// function.
template <typename T, miopenStatus_t (*ctor)(T**), miopenStatus_t (*dtor)(T*)>
class Descriptor
{
public:
// NOLINTNEXTLINE(bugprone-exception-escape)
class TORCH_CUDA_CPP_API Descriptor {
public:
// Use desc() to access the underlying descriptor pointer in
// a read-only fashion. Most client code should use this.
// If the descriptor was never initialized, this will return
Expand All @@ -55,7 +56,7 @@ class Descriptor
protected:
void init() {
if (desc_ == nullptr) {
T* raw_desc;
T* raw_desc = nullptr;
MIOPEN_CHECK(ctor(&raw_desc));
desc_.reset(raw_desc);
}
Expand All @@ -64,13 +65,12 @@ class Descriptor
std::unique_ptr<T, DescriptorDeleter<T, dtor>> desc_;
};

class TensorDescriptor
: public Descriptor<miopenTensorDescriptor,
&miopenCreateTensorDescriptor,
&miopenDestroyTensorDescriptor>
{
public:
TensorDescriptor() {}
class TORCH_CUDA_CPP_API TensorDescriptor : public Descriptor<
miopenTensorDescriptor,
&miopenCreateTensorDescriptor,
&miopenDestroyTensorDescriptor> {
public:
TensorDescriptor() = default;
explicit TensorDescriptor(const at::Tensor &t, size_t pad = 0) {
set(t, pad);
}
Expand All @@ -88,11 +88,10 @@ class TensorDescriptor

std::ostream& operator<<(std::ostream & out, const TensorDescriptor& d);

class FilterDescriptor
: public Descriptor<miopenTensorDescriptor,
&miopenCreateTensorDescriptor,
&miopenDestroyTensorDescriptor>
{
class TORCH_CUDA_CPP_API FilterDescriptor : public Descriptor<
miopenTensorDescriptor,
&miopenCreateTensorDescriptor,
&miopenDestroyTensorDescriptor> {
public:
void set(const at::Tensor &t, int64_t pad = 0) {
set(t, at::MemoryFormat::Contiguous, pad);
Expand All @@ -106,7 +105,7 @@ class FilterDescriptor
}
};

struct ConvolutionDescriptor
struct TORCH_CUDA_CPP_API ConvolutionDescriptor
: public Descriptor<miopenConvolutionDescriptor,
&miopenCreateConvolutionDescriptor,
&miopenDestroyConvolutionDescriptor>
Expand All @@ -118,8 +117,24 @@ struct ConvolutionDescriptor
}
};

// NOLINTNEXTLINE(bugprone-exception-escape)
struct TORCH_CUDA_CPP_API DropoutDescriptor
: public Descriptor<
miopenDropoutDescriptor,
&miopenCreateDropoutDescriptor,
&miopenDestroyDropoutDescriptor> {
void set(miopenHandle_t handle, float dropout, void* states, size_t stateSizeInBytes,
unsigned long long seed, bool use_mask, bool state_evo, miopenRNGType_t rng_mode) {
MIOPEN_CHECK(miopenSetDropoutDescriptor(mut_desc(), handle, dropout, states, stateSizeInBytes, seed, use_mask, state_evo, rng_mode));
}

void restore(miopenHandle_t handle, float dropout, void* states, size_t stateSizeInBytes,
unsigned long long seed, bool use_mask, bool state_evo, miopenRNGType_t rng_mode) {
MIOPEN_CHECK(miopenRestoreDropoutDescriptor(mut_desc(), handle, dropout, states, stateSizeInBytes, seed, use_mask, state_evo, rng_mode));
}
};

struct RNNDescriptor
struct TORCH_CUDA_CPP_API RNNDescriptor
: public Descriptor<miopenRNNDescriptor,
&miopenCreateRNNDescriptor,
&miopenDestroyRNNDescriptor>
Expand Down
8 changes: 4 additions & 4 deletions aten/src/ATen/miopen/Handle.h
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#pragma once

#include <ATen/miopen/miopen-wrapper.h>
#include <c10/macros/Export.h>

namespace at { namespace native {
namespace at::native {

miopenHandle_t getMiopenHandle();

}} // namespace
TORCH_CUDA_CPP_API miopenHandle_t getMiopenHandle();
} // namespace at::native
9 changes: 5 additions & 4 deletions aten/src/ATen/miopen/Types.h
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
#pragma once

#include <ATen/miopen/miopen-wrapper.h>
#include <ATen/Tensor.h>
#include <ATen/miopen/miopen-wrapper.h>
#include <c10/macros/Export.h>

namespace at { namespace native {
namespace at::native {

miopenDataType_t getMiopenDataType(const at::Tensor& tensor);
TORCH_CUDA_CPP_API miopenDataType_t getMiopenDataType(const at::Tensor& tensor);

int64_t miopen_version();

}} // namespace at::miopen
} // namespace at::native