Skip to content

Commit e4eb126

Browse files
committed
fix the outdated end2end training examples of moe+torchtitan
1 parent 53b5efd commit e4eb126

File tree

1 file changed

+11
-10
lines changed

1 file changed

+11
-10
lines changed

torchao/prototype/moe_training/examples/simple_moe_layer.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010

1111
# this example uses torchtitan llama4 MoE, see
1212
try:
13-
from torchtitan.experiments.llama4.model.args import TransformerModelArgs
14-
from torchtitan.experiments.llama4.model.moe import MoE
13+
from torchtitan.models.moe import MoE, MoEArgs
14+
from torchtitan.models.moe.utils import set_token_group_alignment_size_m
1515
except ImportError as e:
1616
raise ImportError(
1717
"torchtitan not installed, see installation instructions at https://github.com/pytorch/torchtitan"
@@ -20,12 +20,10 @@
2020

2121
# initialize model
2222
device = torch.device("cuda")
23-
model_args = TransformerModelArgs(
24-
moe_enabled=True,
25-
num_experts=8,
26-
dim=256,
27-
)
28-
model = MoE(model_args).to(torch.bfloat16).to(device)
23+
model_args = MoEArgs(num_experts=8, top_k=2, use_grouped_mm=True)
24+
dim = 256
25+
hidden_dim = dim * 4
26+
model = MoE(model_args, dim, hidden_dim).to(torch.bfloat16).to(device)
2927
init_std = 0.02
3028
model.init_weights(init_std, device)
3129

@@ -40,14 +38,17 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
4038
return False
4139

4240

43-
# quantize the model
41+
# quantize the model, by default it is rowwise fp8
4442
config = MoETrainingConfig()
4543
quantize_(model, config=config, filter_fn=moe_module_filter_fn)
4644

45+
alignment_size = 32 if config.scaling_type == MoEScalingType.MXFP8 else 16
46+
set_token_group_alignment_size_m(alignment_size)
47+
4748
# training loop
4849
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
4950
for step in range(10):
50-
batch, seq, dim = 8, 2048, 256
51+
batch, seq, dim = 8, 2048, dim
5152
x = torch.randn(
5253
batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device
5354
)

0 commit comments

Comments
 (0)