1+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2+ # All rights reserved.
3+ #
4+ # This source code is licensed under the BSD 3-Clause license found in the
5+ # LICENSE file in the root directory of this source tree.
6+
17import torch
28from torch import nn
39from torch .nn import functional as F
410
511# this feature requires CUDA and SM89+
612assert torch .cuda .is_available () and torch .cuda .get_device_capability () >= (8 , 9 )
713
8- from torchao .prototype .moe_training .conversion_utils import MoETrainingConfig
14+ from torchao .prototype .moe_training .conversion_utils import (
15+ MoEScalingType ,
16+ MoETrainingConfig ,
17+ )
918from torchao .quantization .quant_api import quantize_
1019
1120# this example uses torchtitan llama4 MoE, see
1221try :
13- from torchtitan .experiments . llama4 . model . args import TransformerModelArgs
14- from torchtitan .experiments . llama4 . model . moe import MoE
22+ from torchtitan .models . moe import MoE , MoEArgs
23+ from torchtitan .models . moe . utils import set_token_group_alignment_size_m
1524except ImportError as e :
1625 raise ImportError (
1726 "torchtitan not installed, see installation instructions at https://github.com/pytorch/torchtitan"
2029
2130# initialize model
2231device = torch .device ("cuda" )
23- model_args = TransformerModelArgs (
24- moe_enabled = True ,
25- num_experts = 8 ,
26- dim = 256 ,
27- )
28- model = MoE (model_args ).to (torch .bfloat16 ).to (device )
32+ model_args = MoEArgs (num_experts = 8 , top_k = 2 , use_grouped_mm = True )
33+ dim = 256
34+ hidden_dim = dim * 4
35+ model = MoE (model_args , dim , hidden_dim ).to (torch .bfloat16 ).to (device )
2936init_std = 0.02
3037model .init_weights (init_std , device )
3138
@@ -40,14 +47,17 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
4047 return False
4148
4249
43- # quantize the model
50+ # quantize the model, by default it is rowwise fp8
4451config = MoETrainingConfig ()
4552quantize_ (model , config = config , filter_fn = moe_module_filter_fn )
4653
54+ alignment_size = 32 if config .scaling_type == MoEScalingType .MXFP8 else 16
55+ set_token_group_alignment_size_m (alignment_size )
56+
4757# training loop
4858optimizer = torch .optim .AdamW (model .parameters (), lr = 1e-3 )
4959for step in range (10 ):
50- batch , seq , dim = 8 , 2048 , 256
60+ batch , seq , dim = 8 , 2048 , dim
5161 x = torch .randn (
5262 batch , seq , dim , dtype = torch .bfloat16 , requires_grad = True , device = device
5363 )
0 commit comments