Skip to content

Commit 91c0b23

Browse files
authored
models : add conversion scripts from HuggingFace models to CoreML (#1304)
1 parent 2f668c3 commit 91c0b23

File tree

2 files changed

+134
-6
lines changed

2 files changed

+134
-6
lines changed

models/convert-h5-to-coreml.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
import argparse
2+
import importlib.util
3+
4+
spec = importlib.util.spec_from_file_location('whisper_to_coreml', 'models/convert-whisper-to-coreml.py')
5+
whisper_to_coreml = importlib.util.module_from_spec(spec)
6+
spec.loader.exec_module(whisper_to_coreml)
7+
8+
from whisper import load_model
9+
10+
from copy import deepcopy
11+
import torch
12+
from transformers import WhisperForConditionalGeneration
13+
from huggingface_hub import metadata_update
14+
15+
# https://github.com/bayartsogt-ya/whisper-multiple-hf-datasets/blob/main/src/multiple_datasets/hub_default_utils.py
16+
WHISPER_MAPPING = {
17+
"layers": "blocks",
18+
"fc1": "mlp.0",
19+
"fc2": "mlp.2",
20+
"final_layer_norm": "mlp_ln",
21+
"layers": "blocks",
22+
".self_attn.q_proj": ".attn.query",
23+
".self_attn.k_proj": ".attn.key",
24+
".self_attn.v_proj": ".attn.value",
25+
".self_attn_layer_norm": ".attn_ln",
26+
".self_attn.out_proj": ".attn.out",
27+
".encoder_attn.q_proj": ".cross_attn.query",
28+
".encoder_attn.k_proj": ".cross_attn.key",
29+
".encoder_attn.v_proj": ".cross_attn.value",
30+
".encoder_attn_layer_norm": ".cross_attn_ln",
31+
".encoder_attn.out_proj": ".cross_attn.out",
32+
"decoder.layer_norm.": "decoder.ln.",
33+
"encoder.layer_norm.": "encoder.ln_post.",
34+
"embed_tokens": "token_embedding",
35+
"encoder.embed_positions.weight": "encoder.positional_embedding",
36+
"decoder.embed_positions.weight": "decoder.positional_embedding",
37+
"layer_norm": "ln_post",
38+
}
39+
40+
# https://github.com/bayartsogt-ya/whisper-multiple-hf-datasets/blob/main/src/multiple_datasets/hub_default_utils.py
41+
def rename_keys(s_dict):
42+
keys = list(s_dict.keys())
43+
for key in keys:
44+
new_key = key
45+
for k, v in WHISPER_MAPPING.items():
46+
if k in key:
47+
new_key = new_key.replace(k, v)
48+
49+
print(f"{key} -> {new_key}")
50+
51+
s_dict[new_key] = s_dict.pop(key)
52+
return s_dict
53+
54+
# https://github.com/bayartsogt-ya/whisper-multiple-hf-datasets/blob/main/src/multiple_datasets/hub_default_utils.py
55+
def convert_hf_whisper(hf_model_name_or_path: str, whisper_state_path: str):
56+
transformer_model = WhisperForConditionalGeneration.from_pretrained(hf_model_name_or_path)
57+
config = transformer_model.config
58+
59+
# first build dims
60+
dims = {
61+
'n_mels': config.num_mel_bins,
62+
'n_vocab': config.vocab_size,
63+
'n_audio_ctx': config.max_source_positions,
64+
'n_audio_state': config.d_model,
65+
'n_audio_head': config.encoder_attention_heads,
66+
'n_audio_layer': config.encoder_layers,
67+
'n_text_ctx': config.max_target_positions,
68+
'n_text_state': config.d_model,
69+
'n_text_head': config.decoder_attention_heads,
70+
'n_text_layer': config.decoder_layers
71+
}
72+
73+
state_dict = deepcopy(transformer_model.model.state_dict())
74+
state_dict = rename_keys(state_dict)
75+
76+
torch.save({"dims": dims, "model_state_dict": state_dict}, whisper_state_path)
77+
78+
# Ported from models/convert-whisper-to-coreml.py
79+
if __name__ == "__main__":
80+
parser = argparse.ArgumentParser()
81+
parser.add_argument("--model-name", type=str, help="name of model to convert (e.g. tiny, tiny.en, base, base.en, small, small.en, medium, medium.en, large, large-v1)", required=True)
82+
parser.add_argument("--model-path", type=str, help="path to the model (e.g. if published on HuggingFace: Oblivion208/whisper-tiny-cantonese)", required=True)
83+
parser.add_argument("--encoder-only", type=bool, help="only convert encoder", default=False)
84+
parser.add_argument("--quantize", type=bool, help="quantize weights to F16", default=False)
85+
parser.add_argument("--optimize-ane", type=bool, help="optimize for ANE execution (currently broken)", default=False)
86+
args = parser.parse_args()
87+
88+
if args.model_name not in ["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large", "large-v1"]:
89+
raise ValueError("Invalid model name")
90+
91+
pt_target_path = f"models/hf-{args.model_name}.pt"
92+
convert_hf_whisper(args.model_path, pt_target_path)
93+
94+
whisper = load_model(pt_target_path).cpu()
95+
hparams = whisper.dims
96+
print(hparams)
97+
98+
if args.optimize_ane:
99+
whisperANE = whisper_to_coreml.WhisperANE(hparams).eval()
100+
whisperANE.load_state_dict(whisper.state_dict())
101+
102+
encoder = whisperANE.encoder
103+
decoder = whisperANE.decoder
104+
else:
105+
encoder = whisper.encoder
106+
decoder = whisper.decoder
107+
108+
# Convert encoder
109+
encoder = whisper_to_coreml.convert_encoder(hparams, encoder, quantize=args.quantize)
110+
encoder.save(f"models/coreml-encoder-{args.model_name}.mlpackage")
111+
112+
if args.encoder_only is False:
113+
# Convert decoder
114+
decoder = whisper_to_coreml.convert_decoder(hparams, decoder, quantize=args.quantize)
115+
decoder.save(f"models/coreml-decoder-{args.model_name}.mlpackage")
116+
117+
print("done converting")

models/generate-coreml-model.sh

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,30 @@
11
#!/bin/bash
22

33
# Usage: ./generate-coreml-model.sh <model-name>
4-
if [ $# -eq 0 ]
5-
then
6-
echo "No model name supplied"
7-
echo "Usage: ./generate-coreml-model.sh <model-name>"
8-
exit 1
4+
if [ $# -eq 0 ]; then
5+
echo "No model name supplied"
6+
echo "Usage for Whisper models: ./generate-coreml-model.sh <model-name>"
7+
echo "Usage for HuggingFace models: ./generate-coreml-model.sh -h5 <model-name> <model-path>"
8+
exit 1
9+
elif [[ "$1" == "-h5" && $# != 3 ]]; then
10+
echo "No model name and model path supplied for a HuggingFace model"
11+
echo "Usage for HuggingFace models: ./generate-coreml-model.sh -h5 <model-name> <model-path>"
12+
exit 1
913
fi
1014

1115
mname="$1"
1216

1317
wd=$(dirname "$0")
1418
cd "$wd/../"
1519

16-
python3 models/convert-whisper-to-coreml.py --model $mname --encoder-only True
20+
if [[ $mname == "-h5" ]]; then
21+
mname="$2"
22+
mpath="$3"
23+
echo $mpath
24+
python3 models/convert-h5-to-coreml.py --model-name $mname --model-path $mpath --encoder-only True
25+
else
26+
python3 models/convert-whisper-to-coreml.py --model $mname --encoder-only True
27+
fi
1728

1829
xcrun coremlc compile models/coreml-encoder-${mname}.mlpackage models/
1930
rm -rf models/ggml-${mname}-encoder.mlmodelc

0 commit comments

Comments
 (0)