Skip to content

Commit f91040c

Browse files
The gemma.cpp Authorscopybara-github
The gemma.cpp Authors
authored andcommitted
Linter cleanup patch.
PiperOrigin-RevId: 621331426
1 parent 7122afe commit f91040c

File tree

2 files changed

+154
-128
lines changed

2 files changed

+154
-128
lines changed

BUILD.bazel

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# gemma.cpp is a lightweight, standalone C++ inference engine for the Gemma
22
# foundation models from Google.
33

4+
load("//devtools/python/blaze:pytype.bzl", "pytype_strict_library")
45
load("@rules_license//rules:license.bzl", "license")
56

67
package(
@@ -132,3 +133,13 @@ cc_binary(
132133
"@hwy//:thread_pool",
133134
],
134135
)
136+
137+
pytype_strict_library(
138+
name = "util/convert_weights",
139+
srcs = ["util/convert_weights.py"],
140+
deps = [
141+
"//third_party/py/gemma",
142+
"//third_party/py/numpy",
143+
"//third_party/py/torch:pytorch",
144+
],
145+
)

util/convert_weights.py

Lines changed: 143 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -13,28 +13,36 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
"""Convert model weights from Python library formats to the gemma_cpp format."""
1617

17-
from collections import defaultdict
18-
import torch
18+
19+
import argparse
20+
import collections
21+
import os
22+
23+
# Requires torch 2.2 and gemma package from:
24+
# https://github.com/google/gemma_pytorch
1925
from gemma import config
2026
from gemma import model as gemma_model
2127
import numpy as np
22-
import argparse
23-
import os
28+
import torch
29+
30+
31+
def check_file_exists(path):
32+
if not os.path.exists(str(path)):
33+
raise argparse.ArgumentTypeError(
34+
f"The file {path} does not appear to exist."
35+
)
36+
return path
2437

25-
# Requires torch 2.2 and gemma package from https://github.com/google/gemma_pytorch
2638

27-
def check_file_exists(value):
28-
if not os.path.exists(str(value)):
29-
raise argparse.ArgumentTypeError("The file %s does not appear to exist." % value)
30-
return value
31-
39+
def check_model_types(path):
40+
if str(path).lower() not in ["2b", "7b"]:
41+
raise argparse.ArgumentTypeError(
42+
f"Model type path {path} is not in [2b, 7b]."
43+
)
44+
return path
3245

33-
def check_model_types(value):
34-
if str(value).lower() not in ["2b", "7b"]:
35-
raise argparse.ArgumentTypeError("Model type value %s is not in [2b, 7b]." % value)
36-
return value
37-
3846

3947
parser = argparse.ArgumentParser()
4048
parser.add_argument(
@@ -73,126 +81,133 @@ def check_model_types(value):
7381

7482

7583
TRANSFORMATIONS = {
76-
"2b":defaultdict(
77-
lambda: lambda x: x,
78-
{
79-
"embedder.weight": lambda x: x,
80-
"self_attn.qkv_proj.weight": lambda x: x.reshape((10, 256, 2048)),
81-
"self_attn.o_proj.weight": lambda x: x.reshape((2048, 8, 256)).transpose([1,0,2]),
82-
"mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :],
83-
"mlp.up_proj.weight": lambda x: x[np.newaxis, :, :],
84-
"mlp.down_proj.weight": lambda x: x,
85-
}
86-
),
87-
"7b":defaultdict(
88-
lambda: lambda x: x,
89-
{
90-
"embedder.weight": lambda x: x,
91-
"self_attn.qkv_proj.weight": lambda x: x.reshape((3, 16, 256, 3072)).transpose([1,0,2,3]),
92-
"self_attn.o_proj.weight": lambda x: x.reshape((3072, 16, 256)).transpose([1,0,2]),
93-
"mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :],
94-
"mlp.up_proj.weight": lambda x: x[np.newaxis, :, :],
95-
"mlp.down_proj.weight": lambda x: x,
96-
}
97-
),
84+
"2b": collections.defaultdict(
85+
lambda: lambda x: x,
86+
{
87+
"embedder.weight": lambda x: x,
88+
"self_attn.qkv_proj.weight": lambda x: x.reshape((10, 256, 2048)),
89+
"self_attn.o_proj.weight": lambda x: x.reshape(
90+
(2048, 8, 256)
91+
).transpose([1, 0, 2]),
92+
"mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :],
93+
"mlp.up_proj.weight": lambda x: x[np.newaxis, :, :],
94+
"mlp.down_proj.weight": lambda x: x,
95+
},
96+
),
97+
"7b": collections.defaultdict(
98+
lambda: lambda x: x,
99+
{
100+
"embedder.weight": lambda x: x,
101+
"self_attn.qkv_proj.weight": lambda x: x.reshape(
102+
(3, 16, 256, 3072)
103+
).transpose([1, 0, 2, 3]),
104+
"self_attn.o_proj.weight": lambda x: x.reshape(
105+
(3072, 16, 256)
106+
).transpose([1, 0, 2]),
107+
"mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :],
108+
"mlp.up_proj.weight": lambda x: x[np.newaxis, :, :],
109+
"mlp.down_proj.weight": lambda x: x,
110+
},
111+
),
98112
}
99113

100114
VALIDATIONS = {
101-
"2b": {
102-
"embedder.weight": lambda x: x.shape == (256000, 2048),
103-
"model.norm.weight": lambda x: x.shape == (2048,),
104-
"self_attn.qkv_proj.weight": lambda x: x.shape == (10, 256, 2048),
105-
"self_attn.o_proj.weight": lambda x: x.shape == (8, 2048, 256),
106-
"mlp.gate_proj.weight": lambda x: x.shape == (1, 16384, 2048),
107-
"mlp.up_proj.weight": lambda x: x.shape == (1, 16384, 2048),
108-
"mlp.down_proj.weight": lambda x: x.shape == (2048, 16384),
109-
"input_layernorm.weight": lambda x: x.shape == (2048,),
110-
"post_attention_layernorm.weight": lambda x: x.shape == (2048,),
111-
},
112-
"7b": {
113-
"embedder.weight": lambda x: x.shape == (256000, 3072),
114-
"model.norm.weight": lambda x: x.shape == (3072,),
115-
"self_attn.qkv_proj.weight": lambda x: x.shape == (16, 3, 256, 3072),
116-
"self_attn.o_proj.weight": lambda x: x.shape == (16, 3072, 256),
117-
"mlp.gate_proj.weight": lambda x: x.shape == (1, 24576, 3072),
118-
"mlp.up_proj.weight": lambda x: x.shape == (1, 24576, 3072),
119-
"mlp.down_proj.weight": lambda x: x.shape == (3072, 24576),
120-
"input_layernorm.weight": lambda x: x.shape == (3072,),
121-
"post_attention_layernorm.weight": lambda x: x.shape == (3072,),
122-
},
115+
"2b": {
116+
"embedder.weight": lambda x: x.shape == (256000, 2048),
117+
"model.norm.weight": lambda x: x.shape == (2048,),
118+
"self_attn.qkv_proj.weight": lambda x: x.shape == (10, 256, 2048),
119+
"self_attn.o_proj.weight": lambda x: x.shape == (8, 2048, 256),
120+
"mlp.gate_proj.weight": lambda x: x.shape == (1, 16384, 2048),
121+
"mlp.up_proj.weight": lambda x: x.shape == (1, 16384, 2048),
122+
"mlp.down_proj.weight": lambda x: x.shape == (2048, 16384),
123+
"input_layernorm.weight": lambda x: x.shape == (2048,),
124+
"post_attention_layernorm.weight": lambda x: x.shape == (2048,),
125+
},
126+
"7b": {
127+
"embedder.weight": lambda x: x.shape == (256000, 3072),
128+
"model.norm.weight": lambda x: x.shape == (3072,),
129+
"self_attn.qkv_proj.weight": lambda x: x.shape == (16, 3, 256, 3072),
130+
"self_attn.o_proj.weight": lambda x: x.shape == (16, 3072, 256),
131+
"mlp.gate_proj.weight": lambda x: x.shape == (1, 24576, 3072),
132+
"mlp.up_proj.weight": lambda x: x.shape == (1, 24576, 3072),
133+
"mlp.down_proj.weight": lambda x: x.shape == (3072, 24576),
134+
"input_layernorm.weight": lambda x: x.shape == (3072,),
135+
"post_attention_layernorm.weight": lambda x: x.shape == (3072,),
136+
},
123137
}
124138

125139

126-
def param_names(num_hidden_layers: int):
127-
"""Return parameter names in the order they are expected for deserialization."""
128-
129-
# note *weight_scaler params are ignored in the forward computation unless
130-
# quantization is being used.
131-
#
132-
# since we are working with the full precision weights as input, don't
133-
# include these in the parameters being iterated over.
134-
135-
# fmt: off
136-
names = [
137-
("embedder.weight", ) * 2, # embedder_input_embedding
138-
("model.norm.weight", ) * 2 # final_norm_scale
139-
]
140-
layer_params = [
141-
"self_attn.o_proj.weight", # attn_vec_einsum_w
142-
"self_attn.qkv_proj.weight", # qkv_einsum_w
143-
"mlp.gate_proj.weight", # gating_einsum_w
144-
"mlp.up_proj.weight",
145-
"mlp.down_proj.weight", # linear_w
146-
"input_layernorm.weight", # pre_attention_norm_scale
147-
"post_attention_layernorm.weight", # pre_ffw_norm_scale
148-
]
149-
# fmt: on
150-
for layer in range(num_hidden_layers):
151-
for layer_param in layer_params:
152-
names = names + [(f"model.layers.{layer}.{layer_param}", layer_param)]
153-
return names
154-
155-
156-
def convert_weights():
157-
model_type = args.model_type
158-
output_file = args.output_file
159-
160-
model_config = config.get_model_config(model_type)
161-
model_config.dtype = "float32"
162-
model_config.tokenizer = args.tokenizer
163-
device = torch.device("cpu")
164-
torch.set_default_dtype(torch.float)
165-
model = gemma_model.GemmaForCausalLM(model_config)
166-
167-
model.load_weights(args.weights)
168-
model.to(device).eval()
169-
170-
model_dict = dict(model.named_parameters())
171-
param_order = param_names(model_config.num_hidden_layers)
172-
173-
all_ok = True
174-
print("Checking transformations ...")
140+
def param_names(num_hidden_layers: int) -> list[str]:
141+
"""Return parameter names in the order they are expected for deserialization."""
142+
143+
# note *weight_scaler params are ignored in the forward computation unless
144+
# quantization is being used.
145+
#
146+
# since we are working with the full precision weights as input, don't
147+
# include these in the parameters being iterated over.
148+
149+
names = [
150+
("embedder.weight",) * 2, # embedder_input_embedding
151+
("model.norm.weight",) * 2, # final_norm_scale
152+
]
153+
layer_params = [
154+
"self_attn.o_proj.weight", # attn_vec_einsum_w
155+
"self_attn.qkv_proj.weight", # qkv_einsum_w
156+
"mlp.gate_proj.weight", # gating_einsum_w
157+
"mlp.up_proj.weight",
158+
"mlp.down_proj.weight", # linear_w
159+
"input_layernorm.weight", # pre_attention_norm_scale
160+
"post_attention_layernorm.weight", # pre_ffw_norm_scale
161+
]
162+
163+
for layer in range(num_hidden_layers):
164+
for layer_param in layer_params:
165+
names = names + [(f"model.layers.{layer}.{layer_param}", layer_param)]
166+
return names
167+
168+
169+
def convert_weights() -> None:
170+
"""Convert model weights from Python library to gemma_cpp format."""
171+
model_type = args.model_type
172+
output_file = args.output_file
173+
174+
model_config = config.get_model_config(model_type)
175+
model_config.dtype = "float32"
176+
model_config.tokenizer = args.tokenizer
177+
device = torch.device("cpu")
178+
torch.set_default_dtype(torch.float)
179+
model = gemma_model.GemmaForCausalLM(model_config)
180+
181+
model.load_weights(args.weights)
182+
model.to(device).eval()
183+
184+
model_dict = dict(model.named_parameters())
185+
param_order = param_names(model_config.num_hidden_layers)
186+
187+
any_errors = False
188+
print("Checking transformations ...")
189+
for name, layer_name in param_order:
190+
arr = model_dict[name].detach().numpy()
191+
arr = TRANSFORMATIONS[model_type][layer_name](arr)
192+
check = "OK" if VALIDATIONS[model_type][layer_name](arr) else "FAILED"
193+
194+
if check == "FAILED":
195+
any_errors = True
196+
print(f" {name : <60}{str(arr.shape) : <20}{check}")
197+
198+
if any_errors:
199+
return None
200+
201+
print("Writing parameters ...")
202+
with open(output_file, "wb") as bin_handle:
175203
for name, layer_name in param_order:
176-
arr = model_dict[name].detach().numpy()
177-
arr = TRANSFORMATIONS[model_type][layer_name](arr)
178-
check = "OK" if VALIDATIONS[model_type][layer_name](arr) else "FAILED"
179-
180-
if check == "FAILED":
181-
all_ok = False
182-
print(f" {name : <60}{str(arr.shape) : <20}{check}")
183-
184-
if all_ok:
185-
print("Writing parameters ...")
186-
gate = None
187-
with open(output_file, "wb") as bin_handle:
188-
for name, layer_name in param_order:
189-
arr = model_dict[name].detach().numpy()
190-
arr = TRANSFORMATIONS[model_type][layer_name](arr)
191-
check = "OK" if VALIDATIONS[model_type][layer_name](arr) else "FAILED"
192-
print(f" {name : <60}{str(arr.shape) : <20}{check}")
193-
arr.flatten().astype(np.float32).tofile(bin_handle)
204+
arr = model_dict[name].detach().numpy()
205+
arr = TRANSFORMATIONS[model_type][layer_name](arr)
206+
check = "OK" if VALIDATIONS[model_type][layer_name](arr) else "FAILED"
207+
print(f" {name : <60}{str(arr.shape) : <20}{check}")
208+
arr.flatten().astype(np.float32).tofile(bin_handle)
194209

195210

196211
if __name__ == "__main__":
197-
convert_weights()
198-
print("Done")
212+
convert_weights()
213+
print("Done")

0 commit comments

Comments
 (0)