Skip to content

Commit a7e623b

Browse files
committed
fix the outdated end2end training examples of moe+torchtitan
1 parent 53b5efd commit a7e623b

File tree

1 file changed

+21
-11
lines changed

1 file changed

+21
-11
lines changed

torchao/prototype/moe_training/examples/simple_moe_layer.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,26 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
17
import torch
28
from torch import nn
39
from torch.nn import functional as F
410

511
# this feature requires CUDA and SM89+
612
assert torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
713

8-
from torchao.prototype.moe_training.conversion_utils import MoETrainingConfig
14+
from torchao.prototype.moe_training.conversion_utils import (
15+
MoEScalingType,
16+
MoETrainingConfig,
17+
)
918
from torchao.quantization.quant_api import quantize_
1019

1120
# this example uses torchtitan llama4 MoE, see
1221
try:
13-
from torchtitan.experiments.llama4.model.args import TransformerModelArgs
14-
from torchtitan.experiments.llama4.model.moe import MoE
22+
from torchtitan.models.moe import MoE, MoEArgs
23+
from torchtitan.models.moe.utils import set_token_group_alignment_size_m
1524
except ImportError as e:
1625
raise ImportError(
1726
"torchtitan not installed, see installation instructions at https://github.com/pytorch/torchtitan"
@@ -20,12 +29,10 @@
2029

2130
# initialize model
2231
device = torch.device("cuda")
23-
model_args = TransformerModelArgs(
24-
moe_enabled=True,
25-
num_experts=8,
26-
dim=256,
27-
)
28-
model = MoE(model_args).to(torch.bfloat16).to(device)
32+
model_args = MoEArgs(num_experts=8, top_k=2, use_grouped_mm=True)
33+
dim = 256
34+
hidden_dim = dim * 4
35+
model = MoE(model_args, dim, hidden_dim).to(torch.bfloat16).to(device)
2936
init_std = 0.02
3037
model.init_weights(init_std, device)
3138

@@ -40,14 +47,17 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
4047
return False
4148

4249

43-
# quantize the model
50+
# quantize the model, by default it is rowwise fp8
4451
config = MoETrainingConfig()
4552
quantize_(model, config=config, filter_fn=moe_module_filter_fn)
4653

54+
alignment_size = 32 if config.scaling_type == MoEScalingType.MXFP8 else 16
55+
set_token_group_alignment_size_m(alignment_size)
56+
4757
# training loop
4858
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
4959
for step in range(10):
50-
batch, seq, dim = 8, 2048, 256
60+
batch, seq, dim = 8, 2048, dim
5161
x = torch.randn(
5262
batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device
5363
)

0 commit comments

Comments
 (0)