Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions torchtitan/experiments/deterministic_vllm_rl/simple_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,6 @@
from huggingface_hub import snapshot_download
from safetensors.torch import load_file, save_file
from torch.utils.tensorboard import SummaryWriter
from transformers import AutoConfig, AutoTokenizer

from vllm import LLM, SamplingParams
from vllm.model_executor.layers.batch_invariant import init_batch_invariance

from torchtitan.experiments.deterministic_vllm_rl.weights.converter import (
torchtitan_to_vllm,
Expand All @@ -39,6 +35,10 @@
)

from torchtitan.models.qwen3.model.args import Qwen3ModelArgs
from transformers import AutoConfig, AutoTokenizer

from vllm import LLM, SamplingParams
from vllm.model_executor.layers.batch_invariant import init_batch_invariance

init_batch_invariance()

Expand Down Expand Up @@ -169,7 +169,6 @@ def update_weights(self, vllm_compat_state: dict) -> None:
dtype="bfloat16",
gpu_memory_utilization=0.3, # Reduced from 0.5
seed=42, # Fixed seed for determinism
enforce_eager=True,
)
print("✓ Created new vLLM engine")
else:
Expand Down Expand Up @@ -1086,6 +1085,7 @@ def main():
)
model = model.to(device)
model.train()
model = torch.compile(model)

# Save initial weights for delta computation (on CPU to save GPU memory)
print("Saving initial weights for tracking...")
Expand Down
35 changes: 17 additions & 18 deletions torchtitan/experiments/deterministic_vllm_rl/weights/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,24 @@

# Weight name mapping from HuggingFace/vLLM to TorchTitan
VLLM_TO_TITAN_MAP = {
"model.embed_tokens.weight": "tok_embeddings.weight",
"model.embed_tokens.weight": "_orig_mod.tok_embeddings.weight",
# Attention weights
"model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight",
"model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight",
"model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight",
"model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight",
"model.layers.{}.self_attn.q_norm.weight": "layers.{}.attention.q_norm.weight",
"model.layers.{}.self_attn.k_norm.weight": "layers.{}.attention.k_norm.weight",
"model.layers.{}.self_attn.q_proj.weight": "_orig_mod.layers.{}.attention.wq.weight",
"model.layers.{}.self_attn.k_proj.weight": "_orig_mod.layers.{}.attention.wk.weight",
"model.layers.{}.self_attn.v_proj.weight": "_orig_mod.layers.{}.attention.wv.weight",
"model.layers.{}.self_attn.o_proj.weight": "_orig_mod.layers.{}.attention.wo.weight",
"model.layers.{}.self_attn.q_norm.weight": "_orig_mod.layers.{}.attention.q_norm.weight",
"model.layers.{}.self_attn.k_norm.weight": "_orig_mod.layers.{}.attention.k_norm.weight",
# MLP weights
"model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight",
"model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight",
"model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight",
"model.layers.{}.mlp.gate_proj.weight": "_orig_mod.layers.{}.feed_forward.w1.weight",
"model.layers.{}.mlp.up_proj.weight": "_orig_mod.layers.{}.feed_forward.w3.weight",
"model.layers.{}.mlp.down_proj.weight": "_orig_mod.layers.{}.feed_forward.w2.weight",
# Layer norms
"model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight",
"model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight",
"model.layers.{}.input_layernorm.weight": "_orig_mod.layers.{}.attention_norm.weight",
"model.layers.{}.post_attention_layernorm.weight": "_orig_mod.layers.{}.ffn_norm.weight",
# Final norm and output
"model.norm.weight": "norm.weight",
"lm_head.weight": "output.weight",
"model.norm.weight": "_orig_mod.norm.weight",
"lm_head.weight": "_orig_mod.output.weight",
}


Expand Down Expand Up @@ -142,7 +142,7 @@ def torchtitan_to_vllm(titan_state: dict[str, torch.Tensor]) -> dict[str, torch.

# Extract layer number
parts = titan_key.split(".")
layer_idx = parts[1]
layer_idx = parts[2]

# Create vLLM keys
gate_key = f"model.layers.{layer_idx}.mlp.gate_proj.weight"
Expand All @@ -155,7 +155,7 @@ def torchtitan_to_vllm(titan_state: dict[str, torch.Tensor]) -> dict[str, torch.
# Handle down_proj (vLLM-compat format)
if ".feed_forward.down_proj.weight" in titan_key:
parts = titan_key.split(".")
layer_idx = parts[1]
layer_idx = parts[2]
vllm_key = f"model.layers.{layer_idx}.mlp.down_proj.weight"
# CLONE to avoid aliasing
vllm_state[vllm_key] = tensor.clone()
Expand All @@ -165,7 +165,7 @@ def torchtitan_to_vllm(titan_state: dict[str, torch.Tensor]) -> dict[str, torch.
if "layers." in titan_key:
# Extract layer number
parts = titan_key.split(".")
layer_idx = parts[1]
layer_idx = parts[2]

# Create abstract key with placeholder
abstract_titan_key = titan_key.replace(f".{layer_idx}.", ".{}.")
Expand Down Expand Up @@ -210,7 +210,6 @@ def torchtitan_to_vllm(titan_state: dict[str, torch.Tensor]) -> dict[str, torch.
mode = sys.argv[1]
input_path = sys.argv[2]
output_path = sys.argv[3]

if mode == "vllm_to_titan":
# Convert vLLM to TorchTitan
titan_state = vllm_to_torchtitan(input_path)
Expand Down
Loading