Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 39 additions & 12 deletions torchao/prototype/moe_training/examples/simple_moe_layer.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,51 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

import torch
from torch import nn
from torch.nn import functional as F

# this feature requires CUDA and SM89+
assert torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)

from torchao.prototype.moe_training.conversion_utils import MoETrainingConfig
from torchao.prototype.moe_training.conversion_utils import (
MoEScalingType,
MoETrainingConfig,
)
from torchao.quantization.quant_api import quantize_

# this example uses torchtitan llama4 MoE, see
try:
from torchtitan.experiments.llama4.model.args import TransformerModelArgs
from torchtitan.experiments.llama4.model.moe import MoE
from torchtitan.models.moe import MoE, MoEArgs
from torchtitan.models.moe.utils import set_token_group_alignment_size_m
except ImportError as e:
raise ImportError(
"torchtitan not installed, see installation instructions at https://github.com/pytorch/torchtitan"
) from e


from argparse import ArgumentParser

parser = ArgumentParser()
parser.add_argument(
"--scaling_type",
type=str,
default="fp8_rowwise",
choices=["fp8_rowwise", "mxfp8"],
)
args = parser.parse_args()


# initialize model
device = torch.device("cuda")
model_args = TransformerModelArgs(
moe_enabled=True,
num_experts=8,
dim=256,
)
model = MoE(model_args).to(torch.bfloat16).to(device)
torch.manual_seed(42)
model_args = MoEArgs(num_experts=8, top_k=2, use_grouped_mm=True)
dim = 1024
hidden_dim = dim * 4
model = MoE(model_args, dim, hidden_dim).to(torch.bfloat16).to(device)
init_std = 0.02
model.init_weights(init_std, device)

Expand All @@ -40,14 +60,21 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
return False


# quantize the model
config = MoETrainingConfig()
if args.scaling_type == "fp8_rowwise":
config = MoETrainingConfig()
alignment_size = 16

elif args.scaling_type == "mxfp8":
config = MoETrainingConfig(scaling_type=MoEScalingType.MXFP8)
alignment_size = 32

quantize_(model, config=config, filter_fn=moe_module_filter_fn)
set_token_group_alignment_size_m(alignment_size)

# training loop
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
for step in range(10):
batch, seq, dim = 8, 2048, 256
batch, seq = 8, 2048
x = torch.randn(
batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device
)
Expand Down
Loading