|
42 | 42 | # Custom autograd Functions for vLLM operations |
43 | 43 | # ============================================================================ |
44 | 44 |
|
45 | | -class FlashAttn3Function(Function): |
46 | | - """ |
47 | | - Autograd function for Flash Attention 3 with proper backward support. |
48 | | - """ |
49 | | - |
50 | | - @staticmethod |
51 | | - def forward( |
52 | | - ctx, |
53 | | - q, k, v, |
54 | | - cu_seqlens_q, |
55 | | - cu_seqlens_k, |
56 | | - max_seqlen_q, |
57 | | - max_seqlen_k, |
58 | | - softmax_scale, |
59 | | - causal, |
60 | | - window_left, |
61 | | - window_right, |
62 | | - softcap, |
63 | | - scheduler_metadata, |
64 | | - num_splits, |
65 | | - ): |
66 | | - """ |
67 | | - Forward pass using vLLM's FA3 CUDA kernel. |
68 | | - """ |
69 | | - out, softmax_lse, _, _ = torch.ops._vllm_fa3_C.fwd( |
70 | | - q, k, v, |
71 | | - None, None, # k_new, v_new |
72 | | - None, # q_v |
73 | | - None, # out |
74 | | - cu_seqlens_q, |
75 | | - cu_seqlens_k, |
76 | | - None, # cu_seqlens_k_new |
77 | | - None, None, # seqused_q, seqused_k |
78 | | - max_seqlen_q, max_seqlen_k, |
79 | | - None, # block_table |
80 | | - None, # kv_batch_idx |
81 | | - None, # leftpad_k |
82 | | - None, None, None, # rotary_cos, rotary_sin, seqlens_rotary |
83 | | - None, None, None, # q_descale, k_descale, v_descale |
84 | | - softmax_scale, |
85 | | - causal, |
86 | | - window_left, window_right, |
87 | | - softcap, |
88 | | - True, # rotary_interleaved |
89 | | - scheduler_metadata, |
90 | | - ) |
91 | | - |
92 | | - # Save tensors needed for backward |
93 | | - ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k) |
94 | | - ctx.softmax_scale = softmax_scale |
95 | | - ctx.causal = causal |
96 | | - ctx.window_left = window_left |
97 | | - ctx.window_right = window_right |
98 | | - ctx.softcap = softcap |
99 | | - ctx.max_seqlen_q = max_seqlen_q |
100 | | - ctx.max_seqlen_k = max_seqlen_k |
101 | | - ctx.scheduler_metadata = scheduler_metadata |
102 | | - |
103 | | - return out |
104 | | - |
105 | | - @staticmethod |
106 | | - def backward(ctx, grad_output): |
107 | | - """ |
108 | | - Backward pass using vLLM's FA3 CUDA backward kernel. |
109 | | - """ |
110 | | - q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors |
111 | | - |
112 | | - # Allocate gradient tensors |
113 | | - grad_q = torch.empty_like(q) |
114 | | - grad_k = torch.empty_like(k) |
115 | | - grad_v = torch.empty_like(v) |
116 | | - |
117 | | - # Call FA3 backward kernel |
118 | | - torch.ops._vllm_fa3_C.bwd( |
119 | | - grad_output, |
120 | | - q, k, v, |
121 | | - out, |
122 | | - softmax_lse, |
123 | | - grad_q, grad_k, grad_v, |
124 | | - cu_seqlens_q, |
125 | | - cu_seqlens_k, |
126 | | - None, # cu_seqlens_k_new |
127 | | - None, None, # seqused_q, seqused_k |
128 | | - ctx.max_seqlen_q, |
129 | | - ctx.max_seqlen_k, |
130 | | - None, # block_table |
131 | | - None, # kv_batch_idx |
132 | | - None, # leftpad_k |
133 | | - None, None, None, # rotary_cos, rotary_sin, seqlens_rotary |
134 | | - None, None, None, # dq_accum, q_descale, k_descale, v_descale |
135 | | - ctx.softmax_scale, |
136 | | - ctx.causal, |
137 | | - ctx.window_left, ctx.window_right, |
138 | | - ctx.softcap, |
139 | | - False, # deterministic |
140 | | - True, # rotary_interleaved |
141 | | - ctx.scheduler_metadata, |
142 | | - ) |
143 | | - |
144 | | - # Return gradients for all forward inputs (None for non-tensor args) |
145 | | - return grad_q, grad_k, grad_v, None, None, None, None, None, None, None, None, None, None, None |
146 | | - |
147 | | - |
148 | 45 | class SiluAndMulFunction(Function): |
149 | 46 | """ |
150 | 47 | Autograd function for vLLM's SiluAndMul activation. |
@@ -459,48 +356,6 @@ def patch_batch_invariant_with_gradients(): |
459 | 356 | _batch_invariant_backward_LIB.impl("aten::matmul_backward", matmul_backward_impl, "CUDA") |
460 | 357 | _batch_invariant_backward_LIB.impl("aten::linear_backward", linear_backward_impl, "CUDA") |
461 | 358 |
|
462 | | - # Monkey-patch vLLM's flash_attn_varlen_func to use our autograd wrapper for FA3 |
463 | | - import vllm.vllm_flash_attn.flash_attn_interface as fa_interface |
464 | | - _original_flash_attn_varlen_func = fa_interface.flash_attn_varlen_func |
465 | | - |
466 | | - def patched_flash_attn_varlen_func(*args, **kwargs): |
467 | | - # Only patch FA3 calls |
468 | | - fa_version = kwargs.get('fa_version', fa_interface.DEFAULT_FA_VERSION) |
469 | | - if fa_version == 3: |
470 | | - # Extract the args needed for our autograd function |
471 | | - q = args[0] |
472 | | - k = args[1] |
473 | | - v = args[2] |
474 | | - max_seqlen_q = args[3] |
475 | | - cu_seqlens_q = args[4] |
476 | | - max_seqlen_k = args[5] |
477 | | - cu_seqlens_k = args[6] if len(args) > 6 else kwargs.get('cu_seqlens_k') |
478 | | - softmax_scale = kwargs.get('softmax_scale') |
479 | | - causal = kwargs.get('causal', False) |
480 | | - window_size = kwargs.get('window_size', (-1, -1)) |
481 | | - softcap = kwargs.get('softcap', 0.0) |
482 | | - scheduler_metadata = kwargs.get('scheduler_metadata') |
483 | | - num_splits = kwargs.get('num_splits', 0) |
484 | | - |
485 | | - if window_size is None: |
486 | | - window_size = (-1, -1) |
487 | | - window_left, window_right = window_size |
488 | | - |
489 | | - # Use our autograd wrapper |
490 | | - return FlashAttn3Function.apply( |
491 | | - q, k, v, |
492 | | - cu_seqlens_q, cu_seqlens_k, |
493 | | - max_seqlen_q, max_seqlen_k, |
494 | | - softmax_scale, causal, |
495 | | - window_left, window_right, |
496 | | - softcap, scheduler_metadata, num_splits |
497 | | - ) |
498 | | - else: |
499 | | - # Fall through to original implementation for FA2 |
500 | | - return _original_flash_attn_varlen_func(*args, **kwargs) |
501 | | - |
502 | | - fa_interface.flash_attn_varlen_func = patched_flash_attn_varlen_func |
503 | | - |
504 | 359 | _batch_invariant_backward_MODE = True |
505 | 360 |
|
506 | 361 |
|
|
0 commit comments