|
13 | 13 | # See the License for the specific language governing permissions and
|
14 | 14 | # limitations under the License.
|
15 | 15 |
|
| 16 | +"""Convert model weights from Python library formats to the gemma_cpp format.""" |
16 | 17 |
|
17 |
| -from collections import defaultdict |
18 |
| -import torch |
| 18 | + |
| 19 | +import argparse |
| 20 | +import collections |
| 21 | +import os |
| 22 | + |
| 23 | +# Requires torch 2.2 and gemma package from: |
| 24 | +# https://github.com/google/gemma_pytorch |
19 | 25 | from gemma import config
|
20 | 26 | from gemma import model as gemma_model
|
21 | 27 | import numpy as np
|
22 |
| -import argparse |
23 |
| -import os |
| 28 | +import torch |
| 29 | + |
| 30 | + |
| 31 | +def check_file_exists(path): |
| 32 | + if not os.path.exists(str(path)): |
| 33 | + raise argparse.ArgumentTypeError( |
| 34 | + f"The file {path} does not appear to exist." |
| 35 | + ) |
| 36 | + return path |
24 | 37 |
|
25 |
| -# Requires torch 2.2 and gemma package from https://github.com/google/gemma_pytorch |
26 | 38 |
|
27 |
| -def check_file_exists(value): |
28 |
| - if not os.path.exists(str(value)): |
29 |
| - raise argparse.ArgumentTypeError("The file %s does not appear to exist." % value) |
30 |
| - return value |
31 |
| - |
| 39 | +def check_model_types(path): |
| 40 | + if str(path).lower() not in ["2b", "7b"]: |
| 41 | + raise argparse.ArgumentTypeError( |
| 42 | + f"Model type path {path} is not in [2b, 7b]." |
| 43 | + ) |
| 44 | + return path |
32 | 45 |
|
33 |
| -def check_model_types(value): |
34 |
| - if str(value).lower() not in ["2b", "7b"]: |
35 |
| - raise argparse.ArgumentTypeError("Model type value %s is not in [2b, 7b]." % value) |
36 |
| - return value |
37 |
| - |
38 | 46 |
|
39 | 47 | parser = argparse.ArgumentParser()
|
40 | 48 | parser.add_argument(
|
@@ -73,126 +81,133 @@ def check_model_types(value):
|
73 | 81 |
|
74 | 82 |
|
75 | 83 | TRANSFORMATIONS = {
|
76 |
| - "2b":defaultdict( |
77 |
| - lambda: lambda x: x, |
78 |
| - { |
79 |
| - "embedder.weight": lambda x: x, |
80 |
| - "self_attn.qkv_proj.weight": lambda x: x.reshape((10, 256, 2048)), |
81 |
| - "self_attn.o_proj.weight": lambda x: x.reshape((2048, 8, 256)).transpose([1,0,2]), |
82 |
| - "mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :], |
83 |
| - "mlp.up_proj.weight": lambda x: x[np.newaxis, :, :], |
84 |
| - "mlp.down_proj.weight": lambda x: x, |
85 |
| - } |
86 |
| - ), |
87 |
| - "7b":defaultdict( |
88 |
| - lambda: lambda x: x, |
89 |
| - { |
90 |
| - "embedder.weight": lambda x: x, |
91 |
| - "self_attn.qkv_proj.weight": lambda x: x.reshape((3, 16, 256, 3072)).transpose([1,0,2,3]), |
92 |
| - "self_attn.o_proj.weight": lambda x: x.reshape((3072, 16, 256)).transpose([1,0,2]), |
93 |
| - "mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :], |
94 |
| - "mlp.up_proj.weight": lambda x: x[np.newaxis, :, :], |
95 |
| - "mlp.down_proj.weight": lambda x: x, |
96 |
| - } |
97 |
| - ), |
| 84 | + "2b": collections.defaultdict( |
| 85 | + lambda: lambda x: x, |
| 86 | + { |
| 87 | + "embedder.weight": lambda x: x, |
| 88 | + "self_attn.qkv_proj.weight": lambda x: x.reshape((10, 256, 2048)), |
| 89 | + "self_attn.o_proj.weight": lambda x: x.reshape( |
| 90 | + (2048, 8, 256) |
| 91 | + ).transpose([1, 0, 2]), |
| 92 | + "mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :], |
| 93 | + "mlp.up_proj.weight": lambda x: x[np.newaxis, :, :], |
| 94 | + "mlp.down_proj.weight": lambda x: x, |
| 95 | + }, |
| 96 | + ), |
| 97 | + "7b": collections.defaultdict( |
| 98 | + lambda: lambda x: x, |
| 99 | + { |
| 100 | + "embedder.weight": lambda x: x, |
| 101 | + "self_attn.qkv_proj.weight": lambda x: x.reshape( |
| 102 | + (3, 16, 256, 3072) |
| 103 | + ).transpose([1, 0, 2, 3]), |
| 104 | + "self_attn.o_proj.weight": lambda x: x.reshape( |
| 105 | + (3072, 16, 256) |
| 106 | + ).transpose([1, 0, 2]), |
| 107 | + "mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :], |
| 108 | + "mlp.up_proj.weight": lambda x: x[np.newaxis, :, :], |
| 109 | + "mlp.down_proj.weight": lambda x: x, |
| 110 | + }, |
| 111 | + ), |
98 | 112 | }
|
99 | 113 |
|
100 | 114 | VALIDATIONS = {
|
101 |
| - "2b": { |
102 |
| - "embedder.weight": lambda x: x.shape == (256000, 2048), |
103 |
| - "model.norm.weight": lambda x: x.shape == (2048,), |
104 |
| - "self_attn.qkv_proj.weight": lambda x: x.shape == (10, 256, 2048), |
105 |
| - "self_attn.o_proj.weight": lambda x: x.shape == (8, 2048, 256), |
106 |
| - "mlp.gate_proj.weight": lambda x: x.shape == (1, 16384, 2048), |
107 |
| - "mlp.up_proj.weight": lambda x: x.shape == (1, 16384, 2048), |
108 |
| - "mlp.down_proj.weight": lambda x: x.shape == (2048, 16384), |
109 |
| - "input_layernorm.weight": lambda x: x.shape == (2048,), |
110 |
| - "post_attention_layernorm.weight": lambda x: x.shape == (2048,), |
111 |
| - }, |
112 |
| - "7b": { |
113 |
| - "embedder.weight": lambda x: x.shape == (256000, 3072), |
114 |
| - "model.norm.weight": lambda x: x.shape == (3072,), |
115 |
| - "self_attn.qkv_proj.weight": lambda x: x.shape == (16, 3, 256, 3072), |
116 |
| - "self_attn.o_proj.weight": lambda x: x.shape == (16, 3072, 256), |
117 |
| - "mlp.gate_proj.weight": lambda x: x.shape == (1, 24576, 3072), |
118 |
| - "mlp.up_proj.weight": lambda x: x.shape == (1, 24576, 3072), |
119 |
| - "mlp.down_proj.weight": lambda x: x.shape == (3072, 24576), |
120 |
| - "input_layernorm.weight": lambda x: x.shape == (3072,), |
121 |
| - "post_attention_layernorm.weight": lambda x: x.shape == (3072,), |
122 |
| - }, |
| 115 | + "2b": { |
| 116 | + "embedder.weight": lambda x: x.shape == (256000, 2048), |
| 117 | + "model.norm.weight": lambda x: x.shape == (2048,), |
| 118 | + "self_attn.qkv_proj.weight": lambda x: x.shape == (10, 256, 2048), |
| 119 | + "self_attn.o_proj.weight": lambda x: x.shape == (8, 2048, 256), |
| 120 | + "mlp.gate_proj.weight": lambda x: x.shape == (1, 16384, 2048), |
| 121 | + "mlp.up_proj.weight": lambda x: x.shape == (1, 16384, 2048), |
| 122 | + "mlp.down_proj.weight": lambda x: x.shape == (2048, 16384), |
| 123 | + "input_layernorm.weight": lambda x: x.shape == (2048,), |
| 124 | + "post_attention_layernorm.weight": lambda x: x.shape == (2048,), |
| 125 | + }, |
| 126 | + "7b": { |
| 127 | + "embedder.weight": lambda x: x.shape == (256000, 3072), |
| 128 | + "model.norm.weight": lambda x: x.shape == (3072,), |
| 129 | + "self_attn.qkv_proj.weight": lambda x: x.shape == (16, 3, 256, 3072), |
| 130 | + "self_attn.o_proj.weight": lambda x: x.shape == (16, 3072, 256), |
| 131 | + "mlp.gate_proj.weight": lambda x: x.shape == (1, 24576, 3072), |
| 132 | + "mlp.up_proj.weight": lambda x: x.shape == (1, 24576, 3072), |
| 133 | + "mlp.down_proj.weight": lambda x: x.shape == (3072, 24576), |
| 134 | + "input_layernorm.weight": lambda x: x.shape == (3072,), |
| 135 | + "post_attention_layernorm.weight": lambda x: x.shape == (3072,), |
| 136 | + }, |
123 | 137 | }
|
124 | 138 |
|
125 | 139 |
|
126 |
| -def param_names(num_hidden_layers: int): |
127 |
| - """Return parameter names in the order they are expected for deserialization.""" |
128 |
| - |
129 |
| - # note *weight_scaler params are ignored in the forward computation unless |
130 |
| - # quantization is being used. |
131 |
| - # |
132 |
| - # since we are working with the full precision weights as input, don't |
133 |
| - # include these in the parameters being iterated over. |
134 |
| - |
135 |
| - # fmt: off |
136 |
| - names = [ |
137 |
| - ("embedder.weight", ) * 2, # embedder_input_embedding |
138 |
| - ("model.norm.weight", ) * 2 # final_norm_scale |
139 |
| - ] |
140 |
| - layer_params = [ |
141 |
| - "self_attn.o_proj.weight", # attn_vec_einsum_w |
142 |
| - "self_attn.qkv_proj.weight", # qkv_einsum_w |
143 |
| - "mlp.gate_proj.weight", # gating_einsum_w |
144 |
| - "mlp.up_proj.weight", |
145 |
| - "mlp.down_proj.weight", # linear_w |
146 |
| - "input_layernorm.weight", # pre_attention_norm_scale |
147 |
| - "post_attention_layernorm.weight", # pre_ffw_norm_scale |
148 |
| - ] |
149 |
| - # fmt: on |
150 |
| - for layer in range(num_hidden_layers): |
151 |
| - for layer_param in layer_params: |
152 |
| - names = names + [(f"model.layers.{layer}.{layer_param}", layer_param)] |
153 |
| - return names |
154 |
| - |
155 |
| - |
156 |
| -def convert_weights(): |
157 |
| - model_type = args.model_type |
158 |
| - output_file = args.output_file |
159 |
| - |
160 |
| - model_config = config.get_model_config(model_type) |
161 |
| - model_config.dtype = "float32" |
162 |
| - model_config.tokenizer = args.tokenizer |
163 |
| - device = torch.device("cpu") |
164 |
| - torch.set_default_dtype(torch.float) |
165 |
| - model = gemma_model.GemmaForCausalLM(model_config) |
166 |
| - |
167 |
| - model.load_weights(args.weights) |
168 |
| - model.to(device).eval() |
169 |
| - |
170 |
| - model_dict = dict(model.named_parameters()) |
171 |
| - param_order = param_names(model_config.num_hidden_layers) |
172 |
| - |
173 |
| - all_ok = True |
174 |
| - print("Checking transformations ...") |
| 140 | +def param_names(num_hidden_layers: int) -> list[str]: |
| 141 | + """Return parameter names in the order they are expected for deserialization.""" |
| 142 | + |
| 143 | + # note *weight_scaler params are ignored in the forward computation unless |
| 144 | + # quantization is being used. |
| 145 | + # |
| 146 | + # since we are working with the full precision weights as input, don't |
| 147 | + # include these in the parameters being iterated over. |
| 148 | + |
| 149 | + names = [ |
| 150 | + ("embedder.weight",) * 2, # embedder_input_embedding |
| 151 | + ("model.norm.weight",) * 2, # final_norm_scale |
| 152 | + ] |
| 153 | + layer_params = [ |
| 154 | + "self_attn.o_proj.weight", # attn_vec_einsum_w |
| 155 | + "self_attn.qkv_proj.weight", # qkv_einsum_w |
| 156 | + "mlp.gate_proj.weight", # gating_einsum_w |
| 157 | + "mlp.up_proj.weight", |
| 158 | + "mlp.down_proj.weight", # linear_w |
| 159 | + "input_layernorm.weight", # pre_attention_norm_scale |
| 160 | + "post_attention_layernorm.weight", # pre_ffw_norm_scale |
| 161 | + ] |
| 162 | + |
| 163 | + for layer in range(num_hidden_layers): |
| 164 | + for layer_param in layer_params: |
| 165 | + names = names + [(f"model.layers.{layer}.{layer_param}", layer_param)] |
| 166 | + return names |
| 167 | + |
| 168 | + |
| 169 | +def convert_weights() -> None: |
| 170 | + """Convert model weights from Python library to gemma_cpp format.""" |
| 171 | + model_type = args.model_type |
| 172 | + output_file = args.output_file |
| 173 | + |
| 174 | + model_config = config.get_model_config(model_type) |
| 175 | + model_config.dtype = "float32" |
| 176 | + model_config.tokenizer = args.tokenizer |
| 177 | + device = torch.device("cpu") |
| 178 | + torch.set_default_dtype(torch.float) |
| 179 | + model = gemma_model.GemmaForCausalLM(model_config) |
| 180 | + |
| 181 | + model.load_weights(args.weights) |
| 182 | + model.to(device).eval() |
| 183 | + |
| 184 | + model_dict = dict(model.named_parameters()) |
| 185 | + param_order = param_names(model_config.num_hidden_layers) |
| 186 | + |
| 187 | + any_errors = False |
| 188 | + print("Checking transformations ...") |
| 189 | + for name, layer_name in param_order: |
| 190 | + arr = model_dict[name].detach().numpy() |
| 191 | + arr = TRANSFORMATIONS[model_type][layer_name](arr) |
| 192 | + check = "OK" if VALIDATIONS[model_type][layer_name](arr) else "FAILED" |
| 193 | + |
| 194 | + if check == "FAILED": |
| 195 | + any_errors = True |
| 196 | + print(f" {name : <60}{str(arr.shape) : <20}{check}") |
| 197 | + |
| 198 | + if any_errors: |
| 199 | + return None |
| 200 | + |
| 201 | + print("Writing parameters ...") |
| 202 | + with open(output_file, "wb") as bin_handle: |
175 | 203 | for name, layer_name in param_order:
|
176 |
| - arr = model_dict[name].detach().numpy() |
177 |
| - arr = TRANSFORMATIONS[model_type][layer_name](arr) |
178 |
| - check = "OK" if VALIDATIONS[model_type][layer_name](arr) else "FAILED" |
179 |
| - |
180 |
| - if check == "FAILED": |
181 |
| - all_ok = False |
182 |
| - print(f" {name : <60}{str(arr.shape) : <20}{check}") |
183 |
| - |
184 |
| - if all_ok: |
185 |
| - print("Writing parameters ...") |
186 |
| - gate = None |
187 |
| - with open(output_file, "wb") as bin_handle: |
188 |
| - for name, layer_name in param_order: |
189 |
| - arr = model_dict[name].detach().numpy() |
190 |
| - arr = TRANSFORMATIONS[model_type][layer_name](arr) |
191 |
| - check = "OK" if VALIDATIONS[model_type][layer_name](arr) else "FAILED" |
192 |
| - print(f" {name : <60}{str(arr.shape) : <20}{check}") |
193 |
| - arr.flatten().astype(np.float32).tofile(bin_handle) |
| 204 | + arr = model_dict[name].detach().numpy() |
| 205 | + arr = TRANSFORMATIONS[model_type][layer_name](arr) |
| 206 | + check = "OK" if VALIDATIONS[model_type][layer_name](arr) else "FAILED" |
| 207 | + print(f" {name : <60}{str(arr.shape) : <20}{check}") |
| 208 | + arr.flatten().astype(np.float32).tofile(bin_handle) |
194 | 209 |
|
195 | 210 |
|
196 | 211 | if __name__ == "__main__":
|
197 |
| - convert_weights() |
198 |
| - print("Done") |
| 212 | + convert_weights() |
| 213 | + print("Done") |
0 commit comments