Skip to content

Commit 801399b

Browse files
committed
Address review comments
- Add BSD-style license headers to all new files: * batch_invariant_backward.py * simple_rl.py * tests/test_batch_invariant_backward.py * tests/test_exact_determinism.py * weights_vllm_compat.py * weights/converter.py * weights/__init__.py - Add note about single-device limitation in README.md Currently supports single-device training only; future work will extend to distributed training with parallelism - Remove unused imports in simple_rl.py: * Remove 'import torchtitan.experiments.compat' (unused) * Remove duplicate imports of torchtitan_to_vllm_compat - Fix all imports to use absolute paths for python -m compatibility: * Update model_vllm_compat.py to import from torchtitan.experiments.deterministic_vllm_rl.batch_invariant_backward * Update simple_rl.py to import from torchtitan.experiments.deterministic_vllm_rl modules * Removes sys.path manipulation - now works cleanly with python -m - Remove duplicate RMSNormFunction from model_vllm_compat.py: * Import rms_norm_with_gradients from batch_invariant_backward.py * Remove duplicate RMSNormFunction class and function definition * Keeps gradient-enabled operations centralized in utilities module
1 parent 2823b41 commit 801399b

File tree

9 files changed

+69
-114
lines changed

9 files changed

+69
-114
lines changed

torchtitan/experiments/deterministic_vllm_rl/README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ This experiment solves both problems by:
2121
- **Gradient Support**: Full backward pass support for training
2222
- **Model Compatibility**: Drop-in replacement for standard Qwen3 models in TorchTitan
2323

24+
**Note**: This experiment currently supports single-device training only. We plan to extend support for distributed training with tensor parallelism and pipeline parallelism in the future.
25+
2426
## Architecture
2527

2628
### Components
@@ -110,8 +112,7 @@ loss.backward()
110112
Run the complete RL training loop:
111113

112114
```bash
113-
cd torchtitan/experiments/deterministic_vllm_rl
114-
python simple_rl.py
115+
VLLM_BATCH_INVARIANT=1 VLLM_FLASH_ATTN_VERSION=3 with-proxy python -m torchtitan.experiments.deterministic_vllm_rl.simple_rl
115116
```
116117

117118
This will:

torchtitan/experiments/deterministic_vllm_rl/batch_invariant_backward.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
17
"""
28
Batch-invariant operations with backward pass support.
39

torchtitan/experiments/deterministic_vllm_rl/models/qwen3/model_vllm_compat.py

Lines changed: 17 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -8,28 +8,23 @@
88
# Uses merged gate_up projections and vLLM Flash Attention
99

1010
import torch
11-
import torch.nn.functional as F
1211
from torch import nn
13-
from torch.nn.attention.flex_attention import and_masks, BlockMask
1412

1513
from torchtitan.components.tokenizer import BaseTokenizer
16-
from torchtitan.models.attention import (
17-
create_attention_mask,
18-
get_causal_mask_mod,
19-
get_document_mask_mod,
20-
)
21-
from torchtitan.protocols.model import AttentionMasksType
22-
from torchtitan.protocols.train_spec import ModelProtocol
23-
24-
# Import from local experiment's models
25-
from ..attention import VLLMCompatibleFlashAttention
2614

2715
# Import from main torchtitan
2816
from torchtitan.models.qwen3.model.args import Qwen3ModelArgs
17+
from torchtitan.protocols.model import AttentionMasksType
18+
from torchtitan.protocols.train_spec import ModelProtocol
2919

3020
# Import vLLM's exact operations for bitwise determinism
3121
from vllm.model_executor.layers.activation import SiluAndMul as VLLMSiluAndMul
32-
from vllm.model_executor.layers.batch_invariant import rms_norm as vllm_rms_norm
22+
23+
# Import gradient-enabled operations from experiment utilities
24+
from torchtitan.experiments.deterministic_vllm_rl.batch_invariant_backward import rms_norm_with_gradients
25+
26+
# Import from local experiment's models
27+
from ..attention import VLLMCompatibleFlashAttention
3328

3429

3530
# RoPE functions (same as original)
@@ -90,84 +85,6 @@ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
9085
SiluAndMul = VLLMSiluAndMul
9186

9287

93-
class RMSNormFunction(torch.autograd.Function):
94-
"""
95-
Autograd function for RMS normalization using vLLM's Triton kernel in forward
96-
and batch-invariant operations in backward.
97-
"""
98-
99-
@staticmethod
100-
def forward(ctx, input, weight, eps):
101-
"""
102-
Forward pass using vLLM's rms_norm Triton kernel.
103-
104-
Args:
105-
input: Input tensor [*, hidden_size]
106-
weight: Weight tensor [hidden_size]
107-
eps: Epsilon for numerical stability
108-
109-
Returns:
110-
output: Normalized and scaled tensor [*, hidden_size]
111-
"""
112-
# Use vLLM's Triton kernel for forward (deterministic)
113-
output = vllm_rms_norm(input, weight, eps)
114-
115-
# Save for backward
116-
ctx.save_for_backward(input, weight)
117-
ctx.eps = eps
118-
119-
return output
120-
121-
@staticmethod
122-
def backward(ctx, grad_output):
123-
"""
124-
Backward pass using batch-invariant PyTorch operations.
125-
126-
Returns:
127-
(grad_input, grad_weight, None)
128-
"""
129-
input, weight = ctx.saved_tensors
130-
eps = ctx.eps
131-
132-
# Compute forward pass values needed for backward
133-
# variance = mean(x^2) along last dim
134-
variance = (input * input).mean(dim=-1, keepdim=True)
135-
rms = torch.sqrt(variance + eps)
136-
x_norm = input / rms
137-
138-
# Gradient w.r.t. weight
139-
# grad_weight = sum(grad_output * x_norm) over all dims except last
140-
grad_weight = (grad_output * x_norm).sum(dim=tuple(range(grad_output.ndim - 1)))
141-
142-
# Gradient w.r.t. input
143-
# grad_x_norm = grad_output * weight
144-
grad_x_norm = grad_output * weight
145-
146-
# grad_x = (grad_x_norm - mean(grad_x_norm * x_norm) * x_norm) / rms
147-
mean_term = (grad_x_norm * x_norm).mean(dim=-1, keepdim=True)
148-
grad_input = (grad_x_norm - mean_term * x_norm) / rms
149-
150-
return grad_input, grad_weight, None
151-
152-
153-
def rms_norm_with_gradients(input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
154-
"""
155-
RMS normalization with gradient support.
156-
157-
Uses vLLM's Triton kernel for forward pass (deterministic) and
158-
batch-invariant PyTorch operations for backward pass.
159-
160-
Args:
161-
input: Input tensor [*, hidden_size]
162-
weight: Weight tensor [hidden_size]
163-
eps: Epsilon for numerical stability
164-
165-
Returns:
166-
output: Normalized and scaled tensor [*, hidden_size]
167-
"""
168-
return RMSNormFunction.apply(input, weight, eps)
169-
170-
17188
class VLLMRMSNorm(nn.Module):
17289
"""
17390
RMSNorm using vLLM's exact Triton kernel for bitwise determinism.
@@ -253,10 +170,14 @@ def __init__(self, model_args: Qwen3ModelArgs):
253170
self.k_norm = None
254171

255172
# QKV projections
256-
self.wq = nn.Linear(model_args.dim, model_args.n_heads * self.head_dim, bias=False)
173+
self.wq = nn.Linear(
174+
model_args.dim, model_args.n_heads * self.head_dim, bias=False
175+
)
257176
self.wk = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False)
258177
self.wv = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False)
259-
self.wo = nn.Linear(model_args.n_heads * self.head_dim, model_args.dim, bias=False)
178+
self.wo = nn.Linear(
179+
model_args.n_heads * self.head_dim, model_args.dim, bias=False
180+
)
260181

261182
# Always use vLLM compatible flash attention
262183
self.inner_attention = VLLMCompatibleFlashAttention()
@@ -303,7 +224,9 @@ def forward(
303224
xv = values.transpose(1, 2)
304225

305226
# Apply flash attention (vLLM compatible, no flex attention)
306-
assert attention_masks is None, "vLLM compat mode doesn't use flex attention masks"
227+
assert (
228+
attention_masks is None
229+
), "vLLM compat mode doesn't use flex attention masks"
307230
output = self.inner_attention(xq, xk, xv, scale=self.scaling)
308231

309232
# Transpose back

torchtitan/experiments/deterministic_vllm_rl/simple_rl.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
17
"""
28
Simple RL training loop with GRPO-style advantage estimation.
39
@@ -11,19 +17,16 @@
1117
"""
1218

1319
import os
14-
import tempfile
1520
import torch
1621
import torch.nn.functional as F
1722
from transformers import AutoTokenizer, AutoConfig
1823
from safetensors.torch import load_file, save_file
1924
from huggingface_hub import snapshot_download
20-
import numpy as np
2125
from torch.utils.tensorboard import SummaryWriter
2226

23-
import torchtitan.experiments.compat
2427
from torchtitan.models.qwen3.model.args import Qwen3ModelArgs
25-
from weights_vllm_compat import torchtitan_to_vllm_compat, vllm_compat_to_torchtitan
26-
from weights.converter import torchtitan_to_vllm, vllm_to_torchtitan
28+
from torchtitan.experiments.deterministic_vllm_rl.weights_vllm_compat import torchtitan_to_vllm_compat, vllm_compat_to_torchtitan
29+
from torchtitan.experiments.deterministic_vllm_rl.weights.converter import torchtitan_to_vllm, vllm_to_torchtitan
2730

2831
from vllm import LLM, SamplingParams
2932
from vllm.model_executor.layers.batch_invariant import init_batch_invariance
@@ -618,16 +621,13 @@ def rl_update_step(
618621
metrics: Dict of training metrics
619622
"""
620623
# Update vLLM weights from current policy
621-
from weights_vllm_compat import torchtitan_to_vllm_compat
622624
titan_state = model.state_dict()
623625
vllm_compat_state = torchtitan_to_vllm_compat(titan_state)
624626
vllm_engine.update_weights(vllm_compat_state)
625627

626628
# Round-trip: load weights back from disk to maintain consistency with vLLM
627629
import glob
628630
from safetensors.torch import load_file as sf_load
629-
from weights.converter import vllm_to_torchtitan
630-
from weights_vllm_compat import torchtitan_to_vllm_compat as titan_to_vllm_compat
631631

632632
shard_files = sorted(glob.glob(os.path.join(vllm_engine.temp_model_dir, "model-*.safetensors")))
633633
if shard_files:
@@ -642,7 +642,7 @@ def rl_update_step(
642642

643643
if use_vllm_compat:
644644
# Convert to vLLM-compat format for vLLM-compatible model
645-
weights_for_model = titan_to_vllm_compat(titan_from_disk)
645+
weights_for_model = torchtitan_to_vllm_compat(titan_from_disk)
646646
else:
647647
# Use standard TorchTitan format for standard model
648648
weights_for_model = titan_from_disk
@@ -776,7 +776,7 @@ def main():
776776
print("Batch invariance detected - using vLLM-compatible model")
777777
# Add backward pass support to vLLM's batch_invariant mode
778778
print("Adding gradient support to vLLM's batch_invariant mode...")
779-
from batch_invariant_backward import patch_batch_invariant_with_gradients
779+
from torchtitan.experiments.deterministic_vllm_rl.batch_invariant_backward import patch_batch_invariant_with_gradients
780780
patch_batch_invariant_with_gradients()
781781
else:
782782
print("Batch invariance NOT detected - using standard model")

torchtitan/experiments/deterministic_vllm_rl/tests/test_batch_invariant_backward.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
17
"""
28
Test batch_invariant_backward module to ensure it works correctly.
39
"""
410

511
import torch
6-
import sys
7-
from pathlib import Path
8-
9-
# Add current directory to path
10-
sys.path.insert(0, str(Path(__file__).parent))
1112

12-
from batch_invariant_backward import (
13+
from torchtitan.experiments.deterministic_vllm_rl.batch_invariant_backward import (
1314
enable_batch_invariant_backward_mode,
1415
disable_batch_invariant_backward_mode,
1516
mm_batch_invariant_backward,

torchtitan/experiments/deterministic_vllm_rl/tests/test_exact_determinism.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,18 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
17
"""
28
Test if batch_invariant operations are EXACTLY deterministic.
39
410
This runs the same operation multiple times and checks if results are bit-for-bit identical.
511
"""
612

713
import torch
8-
from batch_invariant_backward import enable_batch_invariant_backward_mode
9-
from vllm.model_executor.layers.batch_invariant import disable_batch_invariant_mode, matmul_persistent
14+
from torchtitan.experiments.deterministic_vllm_rl.batch_invariant_backward import enable_batch_invariant_backward_mode
15+
from vllm.model_executor.layers.batch_invariant import disable_batch_invariant_mode
1016

1117
print("Enabling batch_invariant_backward mode...")
1218
disable_batch_invariant_mode()

torchtitan/experiments/deterministic_vllm_rl/weights/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
17
"""Weight conversion utilities for vLLM and TorchTitan."""
28

39
from .converter import vllm_to_torchtitan, torchtitan_to_vllm

torchtitan/experiments/deterministic_vllm_rl/weights/converter.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
17
"""
28
Minimal weight converter between vLLM and TorchTitan formats for Qwen3-1.7B.
39

torchtitan/experiments/deterministic_vllm_rl/weights_vllm_compat.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
17
"""
28
Weight conversion utilities for Qwen3VLLMCompatModel.
39

0 commit comments

Comments
 (0)