1010
1111# this example uses torchtitan llama4 MoE, see
1212try :
13- from torchtitan .experiments . llama4 . model . args import TransformerModelArgs
14- from torchtitan .experiments . llama4 . model . moe import MoE
13+ from torchtitan .models . moe import MoE , MoEArgs
14+ from torchtitan .models . moe . utils import set_token_group_alignment_size_m
1515except ImportError as e :
1616 raise ImportError (
1717 "torchtitan not installed, see installation instructions at https://github.com/pytorch/torchtitan"
2020
2121# initialize model
2222device = torch .device ("cuda" )
23- model_args = TransformerModelArgs (
24- moe_enabled = True ,
25- num_experts = 8 ,
26- dim = 256 ,
27- )
28- model = MoE (model_args ).to (torch .bfloat16 ).to (device )
23+ model_args = MoEArgs (num_experts = 8 , top_k = 2 , use_grouped_mm = True )
24+ dim = 256
25+ hidden_dim = dim * 4
26+ model = MoE (model_args , dim , hidden_dim ).to (torch .bfloat16 ).to (device )
2927init_std = 0.02
3028model .init_weights (init_std , device )
3129
@@ -40,14 +38,17 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
4038 return False
4139
4240
43- # quantize the model
41+ # quantize the model, by default it is rowwise fp8
4442config = MoETrainingConfig ()
4543quantize_ (model , config = config , filter_fn = moe_module_filter_fn )
4644
45+ alignment_size = 32 if config .scaling_type == MoEScalingType .MXFP8 else 16
46+ set_token_group_alignment_size_m (alignment_size )
47+
4748# training loop
4849optimizer = torch .optim .AdamW (model .parameters (), lr = 1e-3 )
4950for step in range (10 ):
50- batch , seq , dim = 8 , 2048 , 256
51+ batch , seq , dim = 8 , 2048 , dim
5152 x = torch .randn (
5253 batch , seq , dim , dtype = torch .bfloat16 , requires_grad = True , device = device
5354 )
0 commit comments