diff --git a/src/ATen/native/transformers/Attention.cpp b/src/ATen/native/transformers/Attention.cpp index df0a2c9bc0..97c6468147 100644 --- a/src/ATen/native/transformers/Attention.cpp +++ b/src/ATen/native/transformers/Attention.cpp @@ -3,6 +3,8 @@ #include #include #include +#include +#include #ifndef AT_PER_OPERATOR_HEADERS #include @@ -294,5 +296,19 @@ std::tuple native_multi_head_attention_xpu( return std::make_tuple(std::move(proj), std::move(qkt)); } +/** + * get the mask for dropout. only used for testing, not much + * attention is paid to performance + */ +at::Tensor& _fill_mem_eff_dropout_mask_( + Tensor& self, + double dropout_p, + const int64_t seed, + const int64_t offset) { + auto mask = std::get<1>(xpu::dropout_kernel(self, dropout_p, true)); + self.copy_(mask); + return self; +} + } // namespace native } // namespace at diff --git a/yaml/native/native_functions.yaml b/yaml/native/native_functions.yaml index a3281791de..7bf88e5eb8 100644 --- a/yaml/native/native_functions.yaml +++ b/yaml/native/native_functions.yaml @@ -7922,6 +7922,12 @@ SparseCsrXPU: angle_sparse_csr_out tags: pointwise +- func: _fill_mem_eff_dropout_mask_(Tensor(a!) self, float dropout_p, int seed, int offset) -> Tensor(a!) + variants: function + dispatch: + XPU: _fill_mem_eff_dropout_mask_ + tags: nondeterministic_seeded + - func: special_airy_ai(Tensor x) -> Tensor python_module: special structured_delegate: special_airy_ai.out