2424#include < ATen/Functions.h>
2525#include < ATen/NativeFunctions.h>
2626#else
27+ #include < ATen/ops/_cudnn_attention_backward.h>
28+ #include < ATen/ops/_cudnn_attention_backward_native.h>
2729#include < ATen/ops/_flash_attention_backward.h>
2830#include < ATen/ops/_flash_attention_backward_native.h>
2931#include < ATen/ops/_efficient_attention_backward.h>
@@ -170,7 +172,7 @@ std::tuple<Tensor, Tensor, Tensor> _flash_attention_backward(
170172 return std::make_tuple (Tensor (), Tensor (), Tensor ());
171173}
172174
173- std::tuple<Tensor, Tensor, Tensor> _scaled_dot_product_cudnn_attention_backward_cuda (
175+ std::tuple<Tensor, Tensor, Tensor> _cudnn_attention_backward (
174176 const Tensor& grad_out,
175177 const Tensor& query,
176178 const Tensor& key,
@@ -197,57 +199,117 @@ std::tuple<Tensor, Tensor, Tensor> _scaled_dot_product_cudnn_attention_backward_
197199 }
198200 }
199201
200- const int64_t batch_size = query.size (0 );
201- const int64_t num_heads = query.size (1 );
202- const int64_t head_dim_qk = query.size (3 );
203- const int64_t head_dim_v = value.size (3 );
202+ const bool is_nested = cum_seq_q.defined ();
204203 const int64_t max_seqlen_batch_q = query.size (2 );
205204 const int64_t max_seqlen_batch_k = key.size (2 );
206205
207- // This is needed because SaveVariable automatically converts
208- // std::optional to undefined tensor
209- std::optional<Tensor> attn_bias_;
210- if (attn_bias.defined ()) {
211- attn_bias_ = attn_bias;
212- }
213- if (attn_bias_.has_value ()) {
214- const auto bias_dim = attn_bias_.value ().dim ();
215- if (bias_dim == 2 ) {
216- attn_bias_ = attn_bias_.value ().expand ({batch_size, 1 , max_seqlen_batch_q, max_seqlen_batch_k});
217- } else if (bias_dim == 3 ) {
218- attn_bias_ = attn_bias_.value ().expand ({batch_size, 1 , max_seqlen_batch_q, max_seqlen_batch_k});
219- } else {
220- TORCH_CHECK (bias_dim == 4 , " cuDNN SDPA expects either a 2D, 3D, or 4D attn_bias but got " , attn_bias_.value ().dim (), " D" );
221- attn_bias_ = attn_bias_.value ().expand ({batch_size, attn_bias_.value ().size (1 ), max_seqlen_batch_q, max_seqlen_batch_k});
206+ if (!is_nested) {
207+ const int64_t batch_size = query.size (0 );
208+ const int64_t num_heads = query.size (1 );
209+ const int64_t head_dim_qk = query.size (3 );
210+ const int64_t head_dim_v = value.size (3 );
211+
212+ // This is needed because SaveVariable automatically converts
213+ // std::optional to undefined tensor
214+ std::optional<Tensor> attn_bias_;
215+ if (attn_bias.defined ()) {
216+ attn_bias_ = attn_bias;
217+ }
218+ if (attn_bias_.has_value ()) {
219+ const auto bias_dim = attn_bias_.value ().dim ();
220+ if (bias_dim == 2 ) {
221+ attn_bias_ = attn_bias_.value ().expand ({batch_size, 1 , max_seqlen_batch_q, max_seqlen_batch_k});
222+ } else if (bias_dim == 3 ) {
223+ attn_bias_ = attn_bias_.value ().expand ({batch_size, 1 , max_seqlen_batch_q, max_seqlen_batch_k});
224+ } else {
225+ TORCH_CHECK (bias_dim == 4 , " cuDNN SDPA expects either a 2D, 3D, or 4D attn_bias but got " , attn_bias_.value ().dim (), " D" );
226+ attn_bias_ = attn_bias_.value ().expand ({batch_size, attn_bias_.value ().size (1 ), max_seqlen_batch_q, max_seqlen_batch_k});
227+ }
222228 }
223- }
224229
225- const auto softmax_scale = sdp::calculate_scale (query, scale).expect_float ();
226- auto dq = at::empty_like (query);
227- auto dk = at::empty_like (key);
228- auto dv = at::empty_like (value);
229- run_cudnn_SDP_bprop (batch_size /* int64_t b*/ ,
230- num_heads /* int64_t h*/ ,
231- max_q/* int64_t s_q*/ ,
232- max_k/* int64_t s_kv*/ ,
233- head_dim_qk /* int64_t d_qk*/ ,
234- head_dim_v /* int64_t d_v*/ ,
235- softmax_scale /* float scaling_factor*/ ,
236- is_causal /* bool is_causal*/ ,
237- dropout_p /* float dropout_probability*/ ,
238- query /* const Tensor& q*/ ,
239- key /* const Tensor& k*/ ,
240- value /* const Tensor& v*/ ,
241- attn_bias_ /* const std::optional<Tensor>& attn_bias*/ ,
242- out /* const Tensor& o*/ ,
243- grad_out/* const Tensor& dO*/ ,
244- logsumexp.unsqueeze (-1 )/* const Tensor& softmaxstats*/ ,
245- dq/* Tensor& dQ*/ ,
246- dk/* Tensor& dK*/ ,
247- dv/* Tensor& dV*/ ,
248- philox_seed/* Tensor& dropoutseed*/ ,
249- philox_offset/* Tensor& dropoutoffset*/ );
250- return std::make_tuple (std::move (dq), std::move (dk), std::move (dv));
230+ const auto softmax_scale = sdp::calculate_scale (query, scale).expect_float ();
231+ auto dq = at::empty_like (query);
232+ auto dk = at::empty_like (key);
233+ auto dv = at::empty_like (value);
234+ run_cudnn_SDP_bprop (batch_size /* int64_t b*/ ,
235+ num_heads /* int64_t h*/ ,
236+ max_q/* int64_t s_q*/ ,
237+ max_k/* int64_t s_kv*/ ,
238+ head_dim_qk /* int64_t d_qk*/ ,
239+ head_dim_v /* int64_t d_v*/ ,
240+ softmax_scale /* float scaling_factor*/ ,
241+ is_causal /* bool is_causal*/ ,
242+ dropout_p /* float dropout_probability*/ ,
243+ query /* const Tensor& q*/ ,
244+ key /* const Tensor& k*/ ,
245+ value /* const Tensor& v*/ ,
246+ attn_bias_ /* const std::optional<Tensor>& attn_bias*/ ,
247+ out /* const Tensor& o*/ ,
248+ grad_out/* const Tensor& dO*/ ,
249+ logsumexp.unsqueeze (-1 )/* const Tensor& softmaxstats*/ ,
250+ dq/* Tensor& dQ*/ ,
251+ dk/* Tensor& dK*/ ,
252+ dv/* Tensor& dV*/ ,
253+ philox_seed/* Tensor& dropoutseed*/ ,
254+ philox_offset/* Tensor& dropoutoffset*/ );
255+ return std::make_tuple (std::move (dq), std::move (dk), std::move (dv));
256+ } else {
257+ // BHSD ...
258+ const int64_t batch_size = cum_seq_q.size (0 ) - 1 ;
259+ const int64_t num_heads_q = query.size (-2 );
260+ const int64_t num_heads_k = key.size (-2 );
261+ const int64_t num_heads_v = value.size (-2 );
262+ const int64_t head_dim_qk = query.size (-1 );
263+ const int64_t head_dim_v = value.size (-1 );
264+ std::optional<Tensor> attn_bias_;
265+ if (attn_bias.defined ()) {
266+ attn_bias_ = attn_bias;
267+ }
268+ if (attn_bias_.has_value ()) {
269+ const auto bias_dim = attn_bias_.value ().dim ();
270+ if (bias_dim == 2 ) {
271+ attn_bias_ = attn_bias_.value ().expand ({batch_size, 1 , max_seqlen_batch_q, max_seqlen_batch_k});
272+ } else if (bias_dim == 3 ) {
273+ attn_bias_ = attn_bias_.value ().expand ({batch_size, 1 , max_seqlen_batch_q, max_seqlen_batch_k});
274+ } else {
275+ attn_bias_ = attn_bias_.value ().expand ({batch_size, attn_bias_.value ().size (1 ), max_seqlen_batch_q, max_seqlen_batch_k});
276+ TORCH_CHECK (bias_dim == 4 , " cuDNN SDPA expects either a 2D, 3D, or 4D attn_bias but got " , attn_bias_.value ().dim (), " D" );
277+ }
278+ }
279+
280+ auto dq = at::empty_like (query);
281+ auto dk = at::empty_like (key);
282+ auto dv = at::empty_like (value);
283+
284+ const auto softmax_scale = sdp::calculate_scale (query, scale).as_float_unchecked ();
285+ run_cudnn_SDP_bprop_nestedtensor (
286+ batch_size,
287+ num_heads_q,
288+ num_heads_k,
289+ num_heads_v,
290+ max_seqlen_batch_q,
291+ max_seqlen_batch_k,
292+ head_dim_qk,
293+ head_dim_v,
294+ softmax_scale,
295+ is_causal,
296+ dropout_p,
297+ cum_seq_q,
298+ cum_seq_k,
299+ query,
300+ key,
301+ value,
302+ attn_bias_,
303+ out,
304+ grad_out,
305+ logsumexp,
306+ dq,
307+ dk,
308+ dv,
309+ philox_seed,
310+ philox_offset);
311+ return std::make_tuple (std::move (dq), std::move (dk), std::move (dv));
312+ }
251313}
252314
253315std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
@@ -950,4 +1012,40 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> _scaled_dot_product_e
9501012 grad_q.transpose (1 , 2 ), grad_k.transpose (1 , 2 ), grad_v.transpose (1 , 2 ), grad_bias);
9511013}
9521014
1015+ std::tuple<Tensor, Tensor, Tensor> _scaled_dot_product_cudnn_attention_backward_cuda (
1016+ const Tensor& grad_out,
1017+ const Tensor& query,
1018+ const Tensor& key,
1019+ const Tensor& value,
1020+ const Tensor& out,
1021+ const Tensor& logsumexp,
1022+ const Tensor& philox_seed,
1023+ const Tensor& philox_offset,
1024+ const Tensor& attn_bias,
1025+ const Tensor& cum_seq_q,
1026+ const Tensor& cum_seq_k,
1027+ const int64_t max_q,
1028+ const int64_t max_k,
1029+ double dropout_p,
1030+ bool is_causal,
1031+ std::optional<double > scale) {
1032+ return at::_cudnn_attention_backward (
1033+ grad_out,
1034+ query,
1035+ key,
1036+ value,
1037+ out,
1038+ logsumexp,
1039+ philox_seed,
1040+ philox_offset,
1041+ attn_bias,
1042+ cum_seq_q,
1043+ cum_seq_k,
1044+ max_q,
1045+ max_k,
1046+ dropout_p,
1047+ is_causal,
1048+ scale);
1049+ }
1050+
9531051} // namespace at::native
0 commit comments