Skip to content

Commit 9386701

Browse files
eqypytorchmergebot
authored andcommitted
[cuDNN][SDPA] cuDNN SDPA refactor/cleanup, nested tensor backward, test priority bump for sm90, sm100 (pytorch#149282)
cleanup tuple/tensor boilerplate in cuDNN SDPA, preparation for nested/ragged tensor backward Pull Request resolved: pytorch#149282 Approved by: https://github.com/drisspg
1 parent 8521a69 commit 9386701

File tree

11 files changed

+996
-422
lines changed

11 files changed

+996
-422
lines changed

aten/src/ATen/native/cudnn/MHA.cpp

Lines changed: 706 additions & 341 deletions
Large diffs are not rendered by default.

aten/src/ATen/native/cudnn/MHA.h

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,4 +70,31 @@ void run_cudnn_SDP_bprop(
7070
const Tensor& dropoutseed,
7171
const Tensor& dropoutoffset);
7272

73+
void run_cudnn_SDP_bprop_nestedtensor(
74+
int64_t b,
75+
int64_t h_q,
76+
int64_t h_k,
77+
int64_t h_v,
78+
int64_t s_q,
79+
int64_t s_kv,
80+
int64_t d_qk,
81+
int64_t d_v,
82+
float scaling_factor,
83+
bool is_causal,
84+
float dropout_probability,
85+
const Tensor& cum_seqlen_q,
86+
const Tensor& cum_seqlen_kv,
87+
const Tensor& q,
88+
const Tensor& k,
89+
const Tensor& v,
90+
const std::optional<Tensor>& attn_bias,
91+
const Tensor& o,
92+
const Tensor& dO,
93+
const Tensor& softmaxstats,
94+
Tensor& dQ,
95+
Tensor& dK,
96+
Tensor& dV,
97+
const Tensor& dropoutseed,
98+
const Tensor& dropoutoffset);
99+
73100
} // namespace at::native

aten/src/ATen/native/native_functions.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14958,6 +14958,7 @@
1495814958
- func: _scaled_dot_product_cudnn_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor attn_bias, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, *, float? scale=None) -> (Tensor, Tensor, Tensor)
1495914959
dispatch:
1496014960
CUDA: _scaled_dot_product_cudnn_attention_backward_cuda
14961+
NestedTensorCUDA: _scaled_dot_product_cudnn_attention_nestedtensor_backward_cuda
1496114962
tags: nondeterministic_seeded
1496214963

1496314964
- func: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None, SymInt? window_size_left=None, SymInt? window_size_right=None, Tensor? seqused_k=None, Tensor? alibi_slopes=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor rng_state, Tensor unused, Tensor debug_attn_mask)
@@ -14990,6 +14991,11 @@
1499014991
CUDA: _cudnn_attention_forward
1499114992
tags: nondeterministic_seeded
1499214993

14994+
- func: _cudnn_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor attn_bias, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, *, float? scale=None) -> (Tensor, Tensor, Tensor)
14995+
dispatch:
14996+
CUDA: _cudnn_attention_backward
14997+
tags: nondeterministic_seeded
14998+
1499314999
- func: _triton_scaled_dot_attention(Tensor q, Tensor k, Tensor v, float dropout_p=0.0) -> Tensor
1499415000
variants: function
1499515001
dispatch:

aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,63 @@ _scaled_dot_product_cudnn_attention_nestedtensor_cuda(
349349
return std::make_tuple(std::move(attention), std::move(log_sumexp), cumulative_sequence_length_q, cumulative_sequence_length_kv, max_seqlen_batch_q, max_seqlen_batch_kv, std::move(cudnn_seed), std::move(cudnn_offset), Tensor());
350350
}
351351

352+
std::tuple<Tensor, Tensor, Tensor> _scaled_dot_product_cudnn_attention_nestedtensor_backward_cuda(
353+
const Tensor& grad_out,
354+
const Tensor& query,
355+
const Tensor& key,
356+
const Tensor& value,
357+
const Tensor& out,
358+
const Tensor& logsumexp,
359+
const Tensor& philox_seed,
360+
const Tensor& philox_offset,
361+
const Tensor& attn_bias,
362+
const Tensor& cum_seq_q,
363+
const Tensor& cum_seq_k,
364+
const int64_t max_q,
365+
const int64_t max_k,
366+
double dropout_p,
367+
bool is_causal,
368+
std::optional<double> scale) {
369+
if (!grad_out.defined()) {
370+
return std::make_tuple(Tensor{}, Tensor{}, Tensor{});
371+
}
372+
auto [
373+
grad_out_buffer_reshaped,
374+
query_buffer_reshaped,
375+
key_buffer_reshaped,
376+
value_buffer_reshaped,
377+
output_buffer_reshaped] =
378+
preprocessing::sdpa_nested_preprocessing_backward(
379+
grad_out,
380+
query,
381+
key,
382+
value,
383+
out,
384+
cum_seq_q,
385+
cum_seq_k,
386+
max_q,
387+
max_k);
388+
389+
auto [dq, dk, dv] = at::_cudnn_attention_backward(grad_out_buffer_reshaped,
390+
query_buffer_reshaped,
391+
key_buffer_reshaped,
392+
value_buffer_reshaped,
393+
output_buffer_reshaped,
394+
logsumexp,
395+
philox_seed,
396+
philox_offset,
397+
attn_bias,
398+
cum_seq_q,
399+
cum_seq_k,
400+
max_q,
401+
max_k,
402+
dropout_p,
403+
is_causal,
404+
scale);
405+
return std::make_tuple(dq, dk, dv);
406+
}
407+
408+
352409
std::tuple<at::Tensor, at::Tensor, at::Tensor> _scaled_dot_product_flash_attention_backward_nested(
353410
const at::Tensor& grad_out_,
354411
const at::Tensor& query,

aten/src/ATen/native/transformers/cuda/attention.cu

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -848,16 +848,6 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt, Tensor, Ten
848848
// TODO(eqy): support debug_attn_mask
849849
return std::make_tuple(std::move(attention), std::move(log_sumexp), Tensor(), Tensor(), max_seqlen_batch_q, max_seqlen_batch_kv, std::move(cudnn_seed), std::move(cudnn_offset), Tensor());
850850
} else {
851-
//auto [
852-
// query_buffer_reshaped,
853-
// key_buffer_reshaped,
854-
// value_buffer_reshaped,
855-
// cumulative_sequence_length_q,
856-
// cumulative_sequence_length_kv,
857-
// max_seqlen_batch_q,
858-
// max_seqlen_batch_kv,
859-
// output_shape] = preprocessing::sdpa_nested_preprocessing(query, key, value);
860-
// C10_LOG_API_USAGE_ONCE("torch.sdpa.flash_attention_cudnn");
861851
// TODO(eqy): debug mask support
862852
// BHSD ...
863853
const int64_t batch_size = cumulative_sequence_length_q.value().size(0) - 1;

aten/src/ATen/native/transformers/cuda/attention_backward.cu

Lines changed: 145 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
#include <ATen/Functions.h>
2525
#include <ATen/NativeFunctions.h>
2626
#else
27+
#include <ATen/ops/_cudnn_attention_backward.h>
28+
#include <ATen/ops/_cudnn_attention_backward_native.h>
2729
#include <ATen/ops/_flash_attention_backward.h>
2830
#include <ATen/ops/_flash_attention_backward_native.h>
2931
#include <ATen/ops/_efficient_attention_backward.h>
@@ -170,7 +172,7 @@ std::tuple<Tensor, Tensor, Tensor> _flash_attention_backward(
170172
return std::make_tuple(Tensor(), Tensor(), Tensor());
171173
}
172174

173-
std::tuple<Tensor, Tensor, Tensor> _scaled_dot_product_cudnn_attention_backward_cuda(
175+
std::tuple<Tensor, Tensor, Tensor> _cudnn_attention_backward(
174176
const Tensor& grad_out,
175177
const Tensor& query,
176178
const Tensor& key,
@@ -197,57 +199,117 @@ std::tuple<Tensor, Tensor, Tensor> _scaled_dot_product_cudnn_attention_backward_
197199
}
198200
}
199201

200-
const int64_t batch_size = query.size(0);
201-
const int64_t num_heads = query.size(1);
202-
const int64_t head_dim_qk = query.size(3);
203-
const int64_t head_dim_v = value.size(3);
202+
const bool is_nested = cum_seq_q.defined();
204203
const int64_t max_seqlen_batch_q = query.size(2);
205204
const int64_t max_seqlen_batch_k = key.size(2);
206205

207-
// This is needed because SaveVariable automatically converts
208-
// std::optional to undefined tensor
209-
std::optional<Tensor> attn_bias_;
210-
if (attn_bias.defined()) {
211-
attn_bias_ = attn_bias;
212-
}
213-
if (attn_bias_.has_value()) {
214-
const auto bias_dim = attn_bias_.value().dim();
215-
if (bias_dim == 2) {
216-
attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k});
217-
} else if (bias_dim == 3) {
218-
attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k});
219-
} else {
220-
TORCH_CHECK(bias_dim == 4, "cuDNN SDPA expects either a 2D, 3D, or 4D attn_bias but got ", attn_bias_.value().dim(), "D");
221-
attn_bias_ = attn_bias_.value().expand({batch_size, attn_bias_.value().size(1), max_seqlen_batch_q, max_seqlen_batch_k});
206+
if (!is_nested) {
207+
const int64_t batch_size = query.size(0);
208+
const int64_t num_heads = query.size(1);
209+
const int64_t head_dim_qk = query.size(3);
210+
const int64_t head_dim_v = value.size(3);
211+
212+
// This is needed because SaveVariable automatically converts
213+
// std::optional to undefined tensor
214+
std::optional<Tensor> attn_bias_;
215+
if (attn_bias.defined()) {
216+
attn_bias_ = attn_bias;
217+
}
218+
if (attn_bias_.has_value()) {
219+
const auto bias_dim = attn_bias_.value().dim();
220+
if (bias_dim == 2) {
221+
attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k});
222+
} else if (bias_dim == 3) {
223+
attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k});
224+
} else {
225+
TORCH_CHECK(bias_dim == 4, "cuDNN SDPA expects either a 2D, 3D, or 4D attn_bias but got ", attn_bias_.value().dim(), "D");
226+
attn_bias_ = attn_bias_.value().expand({batch_size, attn_bias_.value().size(1), max_seqlen_batch_q, max_seqlen_batch_k});
227+
}
222228
}
223-
}
224229

225-
const auto softmax_scale = sdp::calculate_scale(query, scale).expect_float();
226-
auto dq = at::empty_like(query);
227-
auto dk = at::empty_like(key);
228-
auto dv = at::empty_like(value);
229-
run_cudnn_SDP_bprop(batch_size /*int64_t b*/,
230-
num_heads /*int64_t h*/,
231-
max_q/*int64_t s_q*/,
232-
max_k/*int64_t s_kv*/,
233-
head_dim_qk /*int64_t d_qk*/,
234-
head_dim_v /*int64_t d_v*/,
235-
softmax_scale /*float scaling_factor*/,
236-
is_causal /*bool is_causal*/,
237-
dropout_p /*float dropout_probability*/,
238-
query /*const Tensor& q*/,
239-
key /*const Tensor& k*/,
240-
value /*const Tensor& v*/,
241-
attn_bias_ /*const std::optional<Tensor>& attn_bias*/,
242-
out /*const Tensor& o*/,
243-
grad_out/*const Tensor& dO*/,
244-
logsumexp.unsqueeze(-1)/*const Tensor& softmaxstats*/,
245-
dq/*Tensor& dQ*/,
246-
dk/*Tensor& dK*/,
247-
dv/*Tensor& dV*/,
248-
philox_seed/*Tensor& dropoutseed*/,
249-
philox_offset/*Tensor& dropoutoffset*/);
250-
return std::make_tuple(std::move(dq), std::move(dk), std::move(dv));
230+
const auto softmax_scale = sdp::calculate_scale(query, scale).expect_float();
231+
auto dq = at::empty_like(query);
232+
auto dk = at::empty_like(key);
233+
auto dv = at::empty_like(value);
234+
run_cudnn_SDP_bprop(batch_size /*int64_t b*/,
235+
num_heads /*int64_t h*/,
236+
max_q/*int64_t s_q*/,
237+
max_k/*int64_t s_kv*/,
238+
head_dim_qk /*int64_t d_qk*/,
239+
head_dim_v /*int64_t d_v*/,
240+
softmax_scale /*float scaling_factor*/,
241+
is_causal /*bool is_causal*/,
242+
dropout_p /*float dropout_probability*/,
243+
query /*const Tensor& q*/,
244+
key /*const Tensor& k*/,
245+
value /*const Tensor& v*/,
246+
attn_bias_ /*const std::optional<Tensor>& attn_bias*/,
247+
out /*const Tensor& o*/,
248+
grad_out/*const Tensor& dO*/,
249+
logsumexp.unsqueeze(-1)/*const Tensor& softmaxstats*/,
250+
dq/*Tensor& dQ*/,
251+
dk/*Tensor& dK*/,
252+
dv/*Tensor& dV*/,
253+
philox_seed/*Tensor& dropoutseed*/,
254+
philox_offset/*Tensor& dropoutoffset*/);
255+
return std::make_tuple(std::move(dq), std::move(dk), std::move(dv));
256+
} else {
257+
// BHSD ...
258+
const int64_t batch_size = cum_seq_q.size(0) - 1;
259+
const int64_t num_heads_q = query.size(-2);
260+
const int64_t num_heads_k = key.size(-2);
261+
const int64_t num_heads_v = value.size(-2);
262+
const int64_t head_dim_qk = query.size(-1);
263+
const int64_t head_dim_v = value.size(-1);
264+
std::optional<Tensor> attn_bias_;
265+
if (attn_bias.defined()) {
266+
attn_bias_ = attn_bias;
267+
}
268+
if (attn_bias_.has_value()) {
269+
const auto bias_dim = attn_bias_.value().dim();
270+
if (bias_dim == 2) {
271+
attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k});
272+
} else if (bias_dim == 3) {
273+
attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k});
274+
} else {
275+
attn_bias_ = attn_bias_.value().expand({batch_size, attn_bias_.value().size(1), max_seqlen_batch_q, max_seqlen_batch_k});
276+
TORCH_CHECK(bias_dim == 4, "cuDNN SDPA expects either a 2D, 3D, or 4D attn_bias but got ", attn_bias_.value().dim(), "D");
277+
}
278+
}
279+
280+
auto dq = at::empty_like(query);
281+
auto dk = at::empty_like(key);
282+
auto dv = at::empty_like(value);
283+
284+
const auto softmax_scale = sdp::calculate_scale(query, scale).as_float_unchecked();
285+
run_cudnn_SDP_bprop_nestedtensor(
286+
batch_size,
287+
num_heads_q,
288+
num_heads_k,
289+
num_heads_v,
290+
max_seqlen_batch_q,
291+
max_seqlen_batch_k,
292+
head_dim_qk,
293+
head_dim_v,
294+
softmax_scale,
295+
is_causal,
296+
dropout_p,
297+
cum_seq_q,
298+
cum_seq_k,
299+
query,
300+
key,
301+
value,
302+
attn_bias_,
303+
out,
304+
grad_out,
305+
logsumexp,
306+
dq,
307+
dk,
308+
dv,
309+
philox_seed,
310+
philox_offset);
311+
return std::make_tuple(std::move(dq), std::move(dk), std::move(dv));
312+
}
251313
}
252314

253315
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
@@ -950,4 +1012,40 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> _scaled_dot_product_e
9501012
grad_q.transpose(1, 2), grad_k.transpose(1, 2), grad_v.transpose(1, 2), grad_bias);
9511013
}
9521014

1015+
std::tuple<Tensor, Tensor, Tensor> _scaled_dot_product_cudnn_attention_backward_cuda(
1016+
const Tensor& grad_out,
1017+
const Tensor& query,
1018+
const Tensor& key,
1019+
const Tensor& value,
1020+
const Tensor& out,
1021+
const Tensor& logsumexp,
1022+
const Tensor& philox_seed,
1023+
const Tensor& philox_offset,
1024+
const Tensor& attn_bias,
1025+
const Tensor& cum_seq_q,
1026+
const Tensor& cum_seq_k,
1027+
const int64_t max_q,
1028+
const int64_t max_k,
1029+
double dropout_p,
1030+
bool is_causal,
1031+
std::optional<double> scale) {
1032+
return at::_cudnn_attention_backward(
1033+
grad_out,
1034+
query,
1035+
key,
1036+
value,
1037+
out,
1038+
logsumexp,
1039+
philox_seed,
1040+
philox_offset,
1041+
attn_bias,
1042+
cum_seq_q,
1043+
cum_seq_k,
1044+
max_q,
1045+
max_k,
1046+
dropout_p,
1047+
is_causal,
1048+
scale);
1049+
}
1050+
9531051
} // namespace at::native

0 commit comments

Comments
 (0)