diff --git a/torchtitan/experiments/deterministic_vllm_rl/simple_rl.py b/torchtitan/experiments/deterministic_vllm_rl/simple_rl.py index ffc7d52eb0..318ec6351d 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/simple_rl.py +++ b/torchtitan/experiments/deterministic_vllm_rl/simple_rl.py @@ -25,10 +25,6 @@ from huggingface_hub import snapshot_download from safetensors.torch import load_file, save_file from torch.utils.tensorboard import SummaryWriter -from transformers import AutoConfig, AutoTokenizer - -from vllm import LLM, SamplingParams -from vllm.model_executor.layers.batch_invariant import init_batch_invariance from torchtitan.experiments.deterministic_vllm_rl.weights.converter import ( torchtitan_to_vllm, @@ -39,6 +35,10 @@ ) from torchtitan.models.qwen3.model.args import Qwen3ModelArgs +from transformers import AutoConfig, AutoTokenizer + +from vllm import LLM, SamplingParams +from vllm.model_executor.layers.batch_invariant import init_batch_invariance init_batch_invariance() @@ -169,7 +169,6 @@ def update_weights(self, vllm_compat_state: dict) -> None: dtype="bfloat16", gpu_memory_utilization=0.3, # Reduced from 0.5 seed=42, # Fixed seed for determinism - enforce_eager=True, ) print("✓ Created new vLLM engine") else: @@ -1086,6 +1085,7 @@ def main(): ) model = model.to(device) model.train() + model = torch.compile(model) # Save initial weights for delta computation (on CPU to save GPU memory) print("Saving initial weights for tracking...") diff --git a/torchtitan/experiments/deterministic_vllm_rl/weights/converter.py b/torchtitan/experiments/deterministic_vllm_rl/weights/converter.py index 092af9c37d..d351236473 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/weights/converter.py +++ b/torchtitan/experiments/deterministic_vllm_rl/weights/converter.py @@ -20,24 +20,24 @@ # Weight name mapping from HuggingFace/vLLM to TorchTitan VLLM_TO_TITAN_MAP = { - "model.embed_tokens.weight": "tok_embeddings.weight", + "model.embed_tokens.weight": "_orig_mod.tok_embeddings.weight", # Attention weights - "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight", - "model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight", - "model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight", - "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight", - "model.layers.{}.self_attn.q_norm.weight": "layers.{}.attention.q_norm.weight", - "model.layers.{}.self_attn.k_norm.weight": "layers.{}.attention.k_norm.weight", + "model.layers.{}.self_attn.q_proj.weight": "_orig_mod.layers.{}.attention.wq.weight", + "model.layers.{}.self_attn.k_proj.weight": "_orig_mod.layers.{}.attention.wk.weight", + "model.layers.{}.self_attn.v_proj.weight": "_orig_mod.layers.{}.attention.wv.weight", + "model.layers.{}.self_attn.o_proj.weight": "_orig_mod.layers.{}.attention.wo.weight", + "model.layers.{}.self_attn.q_norm.weight": "_orig_mod.layers.{}.attention.q_norm.weight", + "model.layers.{}.self_attn.k_norm.weight": "_orig_mod.layers.{}.attention.k_norm.weight", # MLP weights - "model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight", - "model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight", - "model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight", + "model.layers.{}.mlp.gate_proj.weight": "_orig_mod.layers.{}.feed_forward.w1.weight", + "model.layers.{}.mlp.up_proj.weight": "_orig_mod.layers.{}.feed_forward.w3.weight", + "model.layers.{}.mlp.down_proj.weight": "_orig_mod.layers.{}.feed_forward.w2.weight", # Layer norms - "model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight", - "model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight", + "model.layers.{}.input_layernorm.weight": "_orig_mod.layers.{}.attention_norm.weight", + "model.layers.{}.post_attention_layernorm.weight": "_orig_mod.layers.{}.ffn_norm.weight", # Final norm and output - "model.norm.weight": "norm.weight", - "lm_head.weight": "output.weight", + "model.norm.weight": "_orig_mod.norm.weight", + "lm_head.weight": "_orig_mod.output.weight", } @@ -142,7 +142,7 @@ def torchtitan_to_vllm(titan_state: dict[str, torch.Tensor]) -> dict[str, torch. # Extract layer number parts = titan_key.split(".") - layer_idx = parts[1] + layer_idx = parts[2] # Create vLLM keys gate_key = f"model.layers.{layer_idx}.mlp.gate_proj.weight" @@ -155,7 +155,7 @@ def torchtitan_to_vllm(titan_state: dict[str, torch.Tensor]) -> dict[str, torch. # Handle down_proj (vLLM-compat format) if ".feed_forward.down_proj.weight" in titan_key: parts = titan_key.split(".") - layer_idx = parts[1] + layer_idx = parts[2] vllm_key = f"model.layers.{layer_idx}.mlp.down_proj.weight" # CLONE to avoid aliasing vllm_state[vllm_key] = tensor.clone() @@ -165,7 +165,7 @@ def torchtitan_to_vllm(titan_state: dict[str, torch.Tensor]) -> dict[str, torch. if "layers." in titan_key: # Extract layer number parts = titan_key.split(".") - layer_idx = parts[1] + layer_idx = parts[2] # Create abstract key with placeholder abstract_titan_key = titan_key.replace(f".{layer_idx}.", ".{}.") @@ -210,7 +210,6 @@ def torchtitan_to_vllm(titan_state: dict[str, torch.Tensor]) -> dict[str, torch. mode = sys.argv[1] input_path = sys.argv[2] output_path = sys.argv[3] - if mode == "vllm_to_titan": # Convert vLLM to TorchTitan titan_state = vllm_to_torchtitan(input_path)