Skip to content

Commit 90bc444

Browse files
committed
Cleanup readme and remove unneeded flashv3 backward
1 parent 53b56a5 commit 90bc444

File tree

2 files changed

+4
-145
lines changed

2 files changed

+4
-145
lines changed

torchtitan/experiments/deterministic_vllm_rl/README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,10 @@ deterministic_vllm_rl/
226226
│ └── qwen3/
227227
│ ├── __init__.py
228228
│ └── model_vllm_compat.py # vLLM-compatible Qwen3 model
229+
├── weights/
230+
│ ├── __init__.py
231+
│ ├── converter.py # Weight conversion script
232+
│ └── README.md # Weight conversion documentation
229233
└── tests/
230234
├── __init__.py
231235
├── test_batch_invariant_backward.py # Test backward passes

torchtitan/experiments/deterministic_vllm_rl/batch_invariant_backward.py

Lines changed: 0 additions & 145 deletions
Original file line numberDiff line numberDiff line change
@@ -42,109 +42,6 @@
4242
# Custom autograd Functions for vLLM operations
4343
# ============================================================================
4444

45-
class FlashAttn3Function(Function):
46-
"""
47-
Autograd function for Flash Attention 3 with proper backward support.
48-
"""
49-
50-
@staticmethod
51-
def forward(
52-
ctx,
53-
q, k, v,
54-
cu_seqlens_q,
55-
cu_seqlens_k,
56-
max_seqlen_q,
57-
max_seqlen_k,
58-
softmax_scale,
59-
causal,
60-
window_left,
61-
window_right,
62-
softcap,
63-
scheduler_metadata,
64-
num_splits,
65-
):
66-
"""
67-
Forward pass using vLLM's FA3 CUDA kernel.
68-
"""
69-
out, softmax_lse, _, _ = torch.ops._vllm_fa3_C.fwd(
70-
q, k, v,
71-
None, None, # k_new, v_new
72-
None, # q_v
73-
None, # out
74-
cu_seqlens_q,
75-
cu_seqlens_k,
76-
None, # cu_seqlens_k_new
77-
None, None, # seqused_q, seqused_k
78-
max_seqlen_q, max_seqlen_k,
79-
None, # block_table
80-
None, # kv_batch_idx
81-
None, # leftpad_k
82-
None, None, None, # rotary_cos, rotary_sin, seqlens_rotary
83-
None, None, None, # q_descale, k_descale, v_descale
84-
softmax_scale,
85-
causal,
86-
window_left, window_right,
87-
softcap,
88-
True, # rotary_interleaved
89-
scheduler_metadata,
90-
)
91-
92-
# Save tensors needed for backward
93-
ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k)
94-
ctx.softmax_scale = softmax_scale
95-
ctx.causal = causal
96-
ctx.window_left = window_left
97-
ctx.window_right = window_right
98-
ctx.softcap = softcap
99-
ctx.max_seqlen_q = max_seqlen_q
100-
ctx.max_seqlen_k = max_seqlen_k
101-
ctx.scheduler_metadata = scheduler_metadata
102-
103-
return out
104-
105-
@staticmethod
106-
def backward(ctx, grad_output):
107-
"""
108-
Backward pass using vLLM's FA3 CUDA backward kernel.
109-
"""
110-
q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors
111-
112-
# Allocate gradient tensors
113-
grad_q = torch.empty_like(q)
114-
grad_k = torch.empty_like(k)
115-
grad_v = torch.empty_like(v)
116-
117-
# Call FA3 backward kernel
118-
torch.ops._vllm_fa3_C.bwd(
119-
grad_output,
120-
q, k, v,
121-
out,
122-
softmax_lse,
123-
grad_q, grad_k, grad_v,
124-
cu_seqlens_q,
125-
cu_seqlens_k,
126-
None, # cu_seqlens_k_new
127-
None, None, # seqused_q, seqused_k
128-
ctx.max_seqlen_q,
129-
ctx.max_seqlen_k,
130-
None, # block_table
131-
None, # kv_batch_idx
132-
None, # leftpad_k
133-
None, None, None, # rotary_cos, rotary_sin, seqlens_rotary
134-
None, None, None, # dq_accum, q_descale, k_descale, v_descale
135-
ctx.softmax_scale,
136-
ctx.causal,
137-
ctx.window_left, ctx.window_right,
138-
ctx.softcap,
139-
False, # deterministic
140-
True, # rotary_interleaved
141-
ctx.scheduler_metadata,
142-
)
143-
144-
# Return gradients for all forward inputs (None for non-tensor args)
145-
return grad_q, grad_k, grad_v, None, None, None, None, None, None, None, None, None, None, None
146-
147-
14845
class SiluAndMulFunction(Function):
14946
"""
15047
Autograd function for vLLM's SiluAndMul activation.
@@ -459,48 +356,6 @@ def patch_batch_invariant_with_gradients():
459356
_batch_invariant_backward_LIB.impl("aten::matmul_backward", matmul_backward_impl, "CUDA")
460357
_batch_invariant_backward_LIB.impl("aten::linear_backward", linear_backward_impl, "CUDA")
461358

462-
# Monkey-patch vLLM's flash_attn_varlen_func to use our autograd wrapper for FA3
463-
import vllm.vllm_flash_attn.flash_attn_interface as fa_interface
464-
_original_flash_attn_varlen_func = fa_interface.flash_attn_varlen_func
465-
466-
def patched_flash_attn_varlen_func(*args, **kwargs):
467-
# Only patch FA3 calls
468-
fa_version = kwargs.get('fa_version', fa_interface.DEFAULT_FA_VERSION)
469-
if fa_version == 3:
470-
# Extract the args needed for our autograd function
471-
q = args[0]
472-
k = args[1]
473-
v = args[2]
474-
max_seqlen_q = args[3]
475-
cu_seqlens_q = args[4]
476-
max_seqlen_k = args[5]
477-
cu_seqlens_k = args[6] if len(args) > 6 else kwargs.get('cu_seqlens_k')
478-
softmax_scale = kwargs.get('softmax_scale')
479-
causal = kwargs.get('causal', False)
480-
window_size = kwargs.get('window_size', (-1, -1))
481-
softcap = kwargs.get('softcap', 0.0)
482-
scheduler_metadata = kwargs.get('scheduler_metadata')
483-
num_splits = kwargs.get('num_splits', 0)
484-
485-
if window_size is None:
486-
window_size = (-1, -1)
487-
window_left, window_right = window_size
488-
489-
# Use our autograd wrapper
490-
return FlashAttn3Function.apply(
491-
q, k, v,
492-
cu_seqlens_q, cu_seqlens_k,
493-
max_seqlen_q, max_seqlen_k,
494-
softmax_scale, causal,
495-
window_left, window_right,
496-
softcap, scheduler_metadata, num_splits
497-
)
498-
else:
499-
# Fall through to original implementation for FA2
500-
return _original_flash_attn_varlen_func(*args, **kwargs)
501-
502-
fa_interface.flash_attn_varlen_func = patched_flash_attn_varlen_func
503-
504359
_batch_invariant_backward_MODE = True
505360

506361

0 commit comments

Comments
 (0)