Skip to content

Commit e196b50

Browse files
authored
Refactor attention v2 (#10707)
Pull Request resolved: #10623 Pull attention creation out of Transformer/TransformerBlock. Instead, pass the layers into Transformer. The motivation is to customize linear layers in attention for LoRA (eg. make wq into a LoraLinear instead of a regular linear). In the next diff (D73517350), we pull wq,wk,wv,wo out of the attention and pass those in as well. This allows us to customize attention parameters without passing in ModelArgs and doing the customization deep inside attention.py. I think this modularizes our attention/transformer components, though also means that users have to do some more work to construct the attention layers and pass it to transformer. It follows the torchtune structure more closely, eg. https://github.com/pytorch/torchtune/blob/main/torchtune/models/llama3_2/_component_builders.py#L221 Previously here: D73474110 ghstack-source-id: 282118266 @exported-using-ghexport Differential Revision: [D73538697](https://our.internmc.facebook.com/intern/diff/D73538697/)
1 parent 6c57bc8 commit e196b50

File tree

5 files changed

+73
-22
lines changed

5 files changed

+73
-22
lines changed

examples/models/llama/llama_transformer.py

Lines changed: 60 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import torch.nn.functional as F
1414

1515
from executorch.examples.models.llama.attention import (
16+
Attention,
1617
ATTENTION_REGISTRY,
1718
ForwardOptions,
1819
)
@@ -83,26 +84,46 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
8384

8485

8586
class TransformerBlock(nn.Module):
86-
def __init__(self, layer_id: int, args: ModelArgs, rope: Rope):
87+
def __init__(self, args: ModelArgs, attention: Attention):
88+
"""
89+
Transformer block with support for pre-norm and post-norm.
90+
Args:
91+
args (ModelArgs): model configuration parameters.
92+
attention (Attention): attention object to use in the transformer
93+
block. See `attention.py` for types of attention. Make sure
94+
the attention type is registered in the ATTENTION_REGISTRY.
95+
"""
8796
super().__init__()
8897
self.use_kv_cache = args.use_kv_cache
8998
self.n_heads = args.n_heads
9099
self.dim = args.dim
91100
self.head_dim = args.head_dim
92-
if args.attention_type not in ATTENTION_REGISTRY:
93-
raise ValueError(
94-
f"Unknown attention type: {args.attention_type}. "
95-
f"Available: {list(ATTENTION_REGISTRY.keys())}"
96-
)
97-
cls = ATTENTION_REGISTRY[args.attention_type]
98-
self.attention = cls(args, layer_id, rope)
101+
self.attention = attention
99102
if args.moe:
100103
self.block_sparse_moe = MOEFeedForward(args)
101104
else:
102105
self.feed_forward = FeedForward(args)
103106
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
104107
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
105108

109+
@classmethod
110+
def from_type(cls, layer_id, args, rope) -> "TransformerBlock":
111+
"""
112+
Create a TransformerBlock with the legacy constructor.
113+
Args:
114+
layer_id (int): the index of the layer.
115+
args (ModelArgs): model configuration parameters.
116+
rope (Rope): the rope object to use for rotary embeddings.
117+
"""
118+
if args.attention_type not in ATTENTION_REGISTRY:
119+
raise ValueError(
120+
f"Unknown attention type: {args.attention_type}. "
121+
f"Available: {list(ATTENTION_REGISTRY.keys())}"
122+
)
123+
cls = ATTENTION_REGISTRY[args.attention_type]
124+
attention = cls(args, layer_id, rope)
125+
return TransformerBlock(args, attention)
126+
106127
def forward(self, x, freqs_cos, freqs_sin, attn_options: ForwardOptions): # x: 1xN
107128
h, attn_options_update = self.attention.forward(
108129
self.attention_norm(x), freqs_cos, freqs_sin, **attn_options
@@ -117,7 +138,15 @@ def forward(self, x, freqs_cos, freqs_sin, attn_options: ForwardOptions): # x:
117138

118139

119140
class Transformer(nn.Module):
120-
def __init__(self, params: ModelArgs):
141+
def __init__(self, params: ModelArgs, layers: nn.ModuleList, rope: Rope):
142+
"""
143+
Transformer model.
144+
Args:
145+
params (ModelArgs): model configuration parameters.
146+
layers (nn.ModuleList): list of transformer blocks - see the
147+
`TransformerBlock` type above.
148+
rope (Rope): the rope object to use for rotary embeddings.
149+
"""
121150
super().__init__()
122151
self.params = params
123152
self.vocab_size = params.vocab_size
@@ -130,10 +159,8 @@ def __init__(self, params: ModelArgs):
130159
if self.apply_embedding
131160
else None
132161
)
133-
self.rope = Rope(params)
134-
self.layers = torch.nn.ModuleList()
135-
for layer_id in range(params.n_layers):
136-
self.layers.append(TransformerBlock(layer_id, params, self.rope))
162+
self.layers = layers
163+
self.rope = rope
137164
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
138165
self.output = (
139166
nn.Linear(params.dim, params.vocab_size, bias=False)
@@ -212,3 +239,23 @@ def forward(
212239
return logits, attn_options_update
213240

214241
return logits
242+
243+
244+
def construct_transformer(model_args: ModelArgs) -> Transformer:
245+
"""
246+
Construct a Transformer model from the given model arguments.
247+
"""
248+
rope = Rope(model_args)
249+
if model_args.attention_type not in ATTENTION_REGISTRY:
250+
raise ValueError(
251+
f"Unknown attention type: {model_args.attention_type}. "
252+
f"Available: {list(ATTENTION_REGISTRY.keys())}"
253+
)
254+
layers = torch.nn.ModuleList()
255+
cls = ATTENTION_REGISTRY[model_args.attention_type]
256+
for layer_id in range(model_args.n_layers):
257+
attention = cls(model_args, layer_id, rope)
258+
transformer_block = TransformerBlock(model_args, attention)
259+
layers.append(transformer_block)
260+
261+
return Transformer(model_args, layers, rope)

examples/models/llama/model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,10 @@
1515
get_checkpoint_dtype,
1616
get_default_model_resource_dir,
1717
)
18-
from executorch.examples.models.llama.llama_transformer import Transformer
1918

19+
from executorch.examples.models.llama.llama_transformer import construct_transformer
2020
from executorch.examples.models.llama.model_args import ModelArgs
21+
from executorch.examples.models.llama.rope import Rope
2122
from torchao.utils import TorchAOBaseTensor
2223

2324
try:
@@ -174,7 +175,7 @@ def __init__(self, **kwargs):
174175
# They possess all other metadata a tensor carries such as size, stride, requires_grad.
175176
with torch.device("meta"):
176177
# Model itself is loaded in default dtype, fp32.
177-
self.model_ = Transformer(model_args)
178+
self.model_ = construct_transformer(model_args)
178179
# Get checkpoint dtype.
179180
if checkpoint:
180181
self.model_.checkpoint_dtype = get_checkpoint_dtype(checkpoint)

examples/models/llama/tests/test_pre_quantization_transforms.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@
77
import unittest
88

99
import torch
10-
from executorch.examples.models.llama.llama_transformer import Transformer
10+
from executorch.examples.models.llama.llama_transformer import (
11+
construct_transformer,
12+
Transformer,
13+
)
1114
from executorch.examples.models.llama.model_args import ModelArgs
1215
from executorch.examples.models.llama.source_transformation.pre_quantization import (
1316
sanitize_checkpoint_from_pre_quantization,
@@ -39,7 +42,7 @@ def _prepare_dummy_model(self) -> Transformer:
3942
vocab_size=32000,
4043
)
4144

42-
model = Transformer(model_args)
45+
model = construct_transformer(model_args)
4346

4447
return model
4548

examples/models/llama/tests/test_static_attention.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import torch
44
from executorch.examples.models.llama.attention import AttentionMHA, ForwardOptions
5-
from executorch.examples.models.llama.llama_transformer import Transformer
5+
from executorch.examples.models.llama.llama_transformer import construct_transformer
66
from executorch.examples.models.llama.model_args import ModelArgs
77
from executorch.examples.models.llama.rope import Rope
88
from executorch.examples.models.llama.static_attention import (
@@ -160,10 +160,10 @@ def test_within_transformer(self):
160160
n_layers=4,
161161
vocab_size=128,
162162
)
163-
mha_transformer = Transformer(config).eval()
163+
mha_transformer = construct_transformer(config).eval()
164164

165165
config.attention_type = "static"
166-
static_transformer = Transformer(config).eval()
166+
static_transformer = construct_transformer(config).eval()
167167
static_transformer.load_state_dict(mha_transformer.state_dict(), strict=False)
168168
for mha_layer, static_layer in zip(
169169
mha_transformer.layers, static_transformer.layers

examples/models/llava/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
import requests
1414
import torch
15-
from executorch.examples.models.llama.llama_transformer import Transformer
15+
from executorch.examples.models.llama.llama_transformer import construct_transformer
1616
from executorch.examples.models.llama.model_args import ModelArgs
1717

1818
from executorch.examples.models.llama.source_transformation.custom_kv_cache import (
@@ -66,7 +66,7 @@ def __init__(
6666
use_hf_rope=True,
6767
max_seq_len=max_seq_len,
6868
)
69-
self.text_model = Transformer(self.text_model_args)
69+
self.text_model = construct_transformer(self.text_model_args)
7070
# use custom op for SDPA.
7171
if use_sdpa_with_kv_cache_op:
7272
self.text_model = replace_kv_cache_with_custom_kv_cache(self.text_model)

0 commit comments

Comments
 (0)