Skip to content

Commit 6e3f2f8

Browse files
Cyrilvallezvasqu
andauthored
[TP plans] Fix some incorrects TP plans (#42448)
* gemma3 * qwen3 and modulars * fix tp plans --------- Co-authored-by: vasqu <[email protected]>
1 parent 83fe012 commit 6e3f2f8

File tree

5 files changed

+12
-22
lines changed

5 files changed

+12
-22
lines changed

src/transformers/models/apertus/configuration_apertus.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -101,10 +101,10 @@ class ApertusConfig(PreTrainedConfig):
101101
keys_to_ignore_at_inference = ["past_key_values"]
102102
default_theta = 12000000.0
103103
base_model_tp_plan = {
104-
"layers.*.self_attn.q_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
105-
"layers.*.self_attn.k_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
106-
"layers.*.self_attn.v_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
107-
"layers.*.self_attn.o_proj": "rowwise_rep", # we need to replicate here due to the added norm on q and k
104+
"layers.*.self_attn.q_proj": "colwise",
105+
"layers.*.self_attn.k_proj": "colwise",
106+
"layers.*.self_attn.v_proj": "colwise",
107+
"layers.*.self_attn.o_proj": "rowwise",
108108
"layers.*.mlp.up_proj": "colwise",
109109
"layers.*.mlp.down_proj": "rowwise",
110110
}

src/transformers/models/apertus/modular_apertus.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -119,10 +119,10 @@ class ApertusConfig(PreTrainedConfig):
119119
keys_to_ignore_at_inference = ["past_key_values"]
120120
default_theta = 12000000.0
121121
base_model_tp_plan = {
122-
"layers.*.self_attn.q_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
123-
"layers.*.self_attn.k_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
124-
"layers.*.self_attn.v_proj": "colwise_rep", # we need to replicate here due to the added norm on q and k
125-
"layers.*.self_attn.o_proj": "rowwise_rep", # we need to replicate here due to the added norm on q and k
122+
"layers.*.self_attn.q_proj": "colwise",
123+
"layers.*.self_attn.k_proj": "colwise",
124+
"layers.*.self_attn.v_proj": "colwise",
125+
"layers.*.self_attn.o_proj": "rowwise",
126126
"layers.*.mlp.up_proj": "colwise",
127127
"layers.*.mlp.down_proj": "rowwise",
128128
}

src/transformers/models/doge/configuration_doge.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -117,11 +117,6 @@ class DogeConfig(PreTrainedConfig):
117117
"layers.*.self_attn.v_proj": "colwise",
118118
"layers.*.self_attn.dt_proj": "rowwise",
119119
"layers.*.self_attn.o_proj": "rowwise",
120-
"layers.*.input_layernorm.weight": "sequence_parallel",
121-
"layers.*.input_residual": "sequence_parallel",
122-
"layers.*.post_attention_layernorm.weight": "sequence_parallel",
123-
"layers.*.post_attention_residual": "sequence_parallel",
124-
"norm.weight": "sequence_parallel",
125120
"layers.*.mlp.gate_proj": "colwise",
126121
"layers.*.mlp.up_proj": "colwise",
127122
"layers.*.mlp.down_proj": "rowwise",

src/transformers/models/doge/modular_doge.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -146,11 +146,6 @@ class DogeConfig(PreTrainedConfig):
146146
"layers.*.self_attn.v_proj": "colwise",
147147
"layers.*.self_attn.dt_proj": "rowwise",
148148
"layers.*.self_attn.o_proj": "rowwise",
149-
"layers.*.input_layernorm.weight": "sequence_parallel",
150-
"layers.*.input_residual": "sequence_parallel",
151-
"layers.*.post_attention_layernorm.weight": "sequence_parallel",
152-
"layers.*.post_attention_residual": "sequence_parallel",
153-
"norm.weight": "sequence_parallel",
154149
"layers.*.mlp.gate_proj": "colwise",
155150
"layers.*.mlp.up_proj": "colwise",
156151
"layers.*.mlp.down_proj": "rowwise",

src/transformers/models/nanochat/configuration_nanochat.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,10 +94,10 @@ class NanoChatConfig(PretrainedConfig):
9494
keys_to_ignore_at_inference = ["past_key_values"]
9595

9696
base_model_tp_plan = {
97-
"layers.*.self_attn.q_proj": "colwise_rep",
98-
"layers.*.self_attn.k_proj": "colwise_rep",
99-
"layers.*.self_attn.v_proj": "colwise_rep",
100-
"layers.*.self_attn.o_proj": "rowwise_rep",
97+
"layers.*.self_attn.q_proj": "colwise",
98+
"layers.*.self_attn.k_proj": "colwise",
99+
"layers.*.self_attn.v_proj": "colwise",
100+
"layers.*.self_attn.o_proj": "rowwise",
101101
"layers.*.mlp.fc1": "colwise",
102102
"layers.*.mlp.fc2": "rowwise",
103103
}

0 commit comments

Comments
 (0)