Skip to content

Commit 471aab6

Browse files
foldlJudd
and
Judd
authored
convert : add support of baichuan-7b (#2055)
Co-authored-by: Judd <[email protected]>
1 parent 463f2f4 commit 471aab6

File tree

2 files changed

+37
-5
lines changed

2 files changed

+37
-5
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ as the main playground for developing new features for the [ggml](https://github
8585
- [X] [OpenBuddy 🐶 (Multilingual)](https://github.com/OpenBuddy/OpenBuddy)
8686
- [X] [Pygmalion 7B / Metharme 7B](#using-pygmalion-7b--metharme-7b)
8787
- [X] [WizardLM](https://github.com/nlpxucan/WizardLM)
88+
- [X] [Baichuan-7B](https://huggingface.co/baichuan-inc/baichuan-7B)
8889

8990
**Bindings:**
9091

convert.py

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def find_n_mult(n_ff: int, n_embd: int) -> int:
136136
calc_ff = (((8*n_embd) // 3 + n_mult - 1) // n_mult)*n_mult
137137
if calc_ff == n_ff:
138138
return n_mult
139-
return 1
139+
raise Exception(f"failed to find n_mult for (n_ff={n_ff}, n_embd={n_embd}).")
140140

141141
@dataclass
142142
class Params:
@@ -321,6 +321,10 @@ def astype(self, data_type: DataType) -> 'Tensor': ...
321321
@abstractmethod
322322
def permute(self, n_head: int) -> 'Tensor': ...
323323
@abstractmethod
324+
def permute_part(self, n_part: int, n_head: int) -> 'UnquantizedTensor': ...
325+
@abstractmethod
326+
def part(self, n_part: int) -> 'UnquantizedTensor': ...
327+
@abstractmethod
324328
def to_ggml(self) -> 'GGMLCompatibleTensor': ...
325329

326330

@@ -345,6 +349,14 @@ def astype(self, data_type: DataType) -> Tensor:
345349
def to_ggml(self) -> 'UnquantizedTensor':
346350
return self
347351

352+
def permute_part(self, n_part: int, n_head: int) -> 'UnquantizedTensor':
353+
r = self.ndarray.shape[0] // 3
354+
return UnquantizedTensor(permute(self.ndarray[r * n_part : r * n_part + r, ...], n_head))
355+
356+
def part(self, n_part: int) -> 'UnquantizedTensor':
357+
r = self.ndarray.shape[0] // 3
358+
return UnquantizedTensor(self.ndarray[r * n_part : r * n_part + r, ...])
359+
348360
def permute(self, n_head: int) -> 'UnquantizedTensor':
349361
return UnquantizedTensor(permute(self.ndarray, n_head))
350362

@@ -642,6 +654,19 @@ def load() -> Tensor:
642654
return lazy_tensor.load().permute(n_head)
643655
return LazyTensor(load, lazy_tensor.shape, lazy_tensor.data_type, f'permute({n_head}) ' + lazy_tensor.description)
644656

657+
def permute_part_lazy(lazy_tensor: LazyTensor, n_part: int, n_head: int) -> LazyTensor:
658+
def load() -> Tensor:
659+
return lazy_tensor.load().permute_part(n_part, n_head)
660+
s = lazy_tensor.shape.copy()
661+
s[0] = s[0] // 3
662+
return LazyTensor(load, s, lazy_tensor.data_type, f'permute({n_head}) ' + lazy_tensor.description)
663+
664+
def part_lazy(lazy_tensor: LazyTensor, n_part: int) -> LazyTensor:
665+
def load() -> Tensor:
666+
return lazy_tensor.load().part(n_part)
667+
s = lazy_tensor.shape.copy()
668+
s[0] = s[0] // 3
669+
return LazyTensor(load, s, lazy_tensor.data_type, 'part ' + lazy_tensor.description)
645670

646671
def convert_transformers_to_orig(model: LazyModel, params: Params) -> LazyModel:
647672
out: LazyModel = {}
@@ -650,11 +675,17 @@ def convert_transformers_to_orig(model: LazyModel, params: Params) -> LazyModel:
650675
out["output.weight"] = model["lm_head.weight"]
651676

652677
for i in itertools.count():
653-
if f"model.layers.{i}.self_attn.q_proj.weight" not in model:
678+
if f"model.layers.{i}.self_attn.q_proj.weight" in model:
679+
out[f"layers.{i}.attention.wq.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.q_proj.weight"], params.n_head)
680+
out[f"layers.{i}.attention.wk.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.k_proj.weight"], params.n_head)
681+
out[f"layers.{i}.attention.wv.weight"] = model[f"model.layers.{i}.self_attn.v_proj.weight"]
682+
elif f"model.layers.{i}.self_attn.W_pack.weight" in model:
683+
out[f"layers.{i}.attention.wq.weight"] = permute_part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 0, params.n_head)
684+
out[f"layers.{i}.attention.wk.weight"] = permute_part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 1, params.n_head)
685+
out[f"layers.{i}.attention.wv.weight"] = part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 2)
686+
else:
654687
break
655-
out[f"layers.{i}.attention.wq.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.q_proj.weight"], params.n_head)
656-
out[f"layers.{i}.attention.wk.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.k_proj.weight"], params.n_head)
657-
out[f"layers.{i}.attention.wv.weight"] = model[f"model.layers.{i}.self_attn.v_proj.weight"]
688+
658689
out[f"layers.{i}.attention.wo.weight"] = model[f"model.layers.{i}.self_attn.o_proj.weight"]
659690

660691
out[f"layers.{i}.feed_forward.w1.weight"] = model[f"model.layers.{i}.mlp.gate_proj.weight"]

0 commit comments

Comments
 (0)