Skip to content

InternLM xcomposer2 support (at least on eye level of GPT4V and CogVLM) - help needed #5232

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 145 additions & 0 deletions convert-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,8 @@ def from_model_architecture(model_architecture):
return CodeShellModel
if model_architecture == "OrionForCausalLM":
return OrionModel
if model_architecture == "InternLM2ForCausalLM":
return InternLM2Model
return Model

def _is_model_safetensors(self) -> bool:
Expand Down Expand Up @@ -254,6 +256,8 @@ def _get_model_architecture(self) -> gguf.MODEL_ARCH:
return gguf.MODEL_ARCH.CODESHELL
if arch == "OrionForCausalLM":
return gguf.MODEL_ARCH.ORION
if arch == "InternLM2ForCausalLM":
return gguf.MODEL_ARCH.INTERNLM2

raise NotImplementedError(f'Architecture "{arch}" not supported!')

Expand Down Expand Up @@ -1344,6 +1348,147 @@ def write_tensors(self):
self.gguf_writer.add_tensor("output.weight", data)
print(name, f"=> output.weight, shape = {data.shape}, {old_dtype} --> {data.dtype}")

class InternLM2Model(Model):
def set_vocab(self):
# (TODO): Is there a better way?
# Copy from _set_vocab_sentencepiece, The only difference is that we will treat the character
# \x00 specially and convert it into an emoji character to prevent it from being mistakenly
# recognized as an empty string in C++.
from sentencepiece import SentencePieceProcessor

tokenizer_path = self.dir_model / 'tokenizer.model'

tokens: list[bytes] = []
scores: list[float] = []
toktypes: list[int] = []

if not tokenizer_path.is_file():
print(f'Error: Missing {tokenizer_path}', file=sys.stderr)
sys.exit(1)

tokenizer = SentencePieceProcessor(str(tokenizer_path))
vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size())

for token_id in range(vocab_size):
piece = tokenizer.id_to_piece(token_id)
text = piece.encode("utf-8")
score = tokenizer.get_score(token_id)
if text == b"\x00":
# (TODO): fixme
# Hack here and replace the \x00 characters.
print(f"InternLM2 convert token '{text}' to '🐉'!")
text = "🐉"

toktype = SentencePieceTokenTypes.NORMAL
if tokenizer.is_unknown(token_id):
toktype = SentencePieceTokenTypes.UNKNOWN
elif tokenizer.is_control(token_id):
toktype = SentencePieceTokenTypes.CONTROL
elif tokenizer.is_unused(token_id):
toktype = SentencePieceTokenTypes.UNUSED
elif tokenizer.is_byte(token_id):
toktype = SentencePieceTokenTypes.BYTE

tokens.append(text)
scores.append(score)
toktypes.append(toktype)

added_tokens_file = self.dir_model / 'added_tokens.json'
if added_tokens_file.is_file():
with open(added_tokens_file, "r", encoding="utf-8") as f:
added_tokens_json = json.load(f)

for key in added_tokens_json:
tokens.append(key.encode("utf-8"))
scores.append(-1000.0)
toktypes.append(SentencePieceTokenTypes.USER_DEFINED)

self.gguf_writer.add_tokenizer_model("llama")
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_scores(scores)
self.gguf_writer.add_token_types(toktypes)

special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
special_vocab.add_to_gguf(self.gguf_writer)

def set_gguf_parameters(self):
self.gguf_writer.add_name("InternLM2")
self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
self.gguf_writer.add_block_count(self.hparams["num_hidden_layers"])
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
self.gguf_writer.add_rope_freq_base(self.hparams["rope_theta"])
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"])

def post_write_tensors(self, tensor_map, name, data_torch):
old_dtype = data_torch.dtype

# convert any unsupported data types to float32
if data_torch.dtype not in (torch.float16, torch.float32):
data_torch = data_torch.to(torch.float32)

data = data_torch.squeeze().numpy()

# map tensor names
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
if new_name is None:
print(f"Can not map tensor {name!r}")
sys.exit()

n_dims = len(data.shape)
data_dtype = data.dtype

# if f32 desired, convert any float16 to float32
if self.ftype == 0 and data_dtype == np.float16:
data = data.astype(np.float32)

# TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
data = data.astype(np.float32)

# if f16 desired, convert any float32 2-dim weight tensors to float16
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
data = data.astype(np.float16)

print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
self.gguf_writer.add_tensor(new_name, data)

def write_tensors(self):
from einops import rearrange

num_heads = self.hparams.get("num_attention_heads")
num_kv_heads = self.hparams.get("num_key_value_heads")
hidden_size = self.hparams.get("hidden_size")
q_per_kv = num_heads // num_kv_heads
head_dim = hidden_size // num_heads
num_groups = num_heads // q_per_kv

block_count = self.hparams["num_hidden_layers"]
model_kv = dict(self.get_tensors())
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
qkv_pattern = r"model\.layers\.(\d+)\.attention\.wqkv"
for name, data_torch in model_kv.items():
# we don't need these
if name.endswith(".rotary_emb.inv_freq"):
continue
plora_tensor = True if "Plora" in name else False
if re.match(qkv_pattern, name) and not plora_tensor:
bid = re.findall(qkv_pattern, name)[0]
qkv = data_torch
qkv = rearrange(qkv.T, " o (g n i) ->o g n i", g=num_groups, n=q_per_kv+2, i=head_dim)
q, k, v = qkv[...,:q_per_kv,:], qkv[...,q_per_kv:q_per_kv+1,:], qkv[...,q_per_kv+1:q_per_kv+2,:]
q = rearrange(q, " o g n i -> o (g n i)").T
k = rearrange(k, " o g n i -> o (g n i)").T
v = rearrange(v, " o g n i -> o (g n i)").T
self.post_write_tensors(tensor_map, f"model.layers.{bid}.attention.wq.weight", q)
self.post_write_tensors(tensor_map, f"model.layers.{bid}.attention.wk.weight", k)
self.post_write_tensors(tensor_map, f"model.layers.{bid}.attention.wv.weight", v)
else:
self.post_write_tensors(tensor_map, name, data_torch)


###### CONVERSION LOGIC ######


Expand Down
12 changes: 8 additions & 4 deletions examples/llava/convert-image-encoder-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ def get_tensor_name(name: str) -> str:

if "mm_projector" in name:
return name.replace("model.mm_projector", "mm")

if "vision_proj" in name:
return name.replace("vision_proj", "mm")

return name.replace("text_model", "t").replace("vision_model", "v").replace("encoder.layers", "blk").replace("embeddings.", "").replace("_proj", "").replace("self_attn.", "attn_").replace("layer_norm", "ln").replace("layernorm", "ln").replace("mlp.fc1", "ffn_down").replace("mlp.fc2", "ffn_up").replace("embedding", "embd").replace("final", "post").replace("layrnorm", "ln")

Expand Down Expand Up @@ -80,12 +83,13 @@ def bytes_to_unicode():
help="Save a vision-only model. It can't be used to encode texts")
ap.add_argument("--clip_model_is_vision", action="store_true", required=False,
help="The clip model is a pure vision model (ShareGPT4V vision extract for example)")
ap.add_argument("--clip_model_is_openclip", action="store_true", required=False,
help="The clip model is from openclip (for ViT-SO400M type))")
ap.add_argument("--llava-projector", help="Path to llava.projector file. If specified, save an image encoder for LLaVA models.")
ap.add_argument("--projector-type", help="Type of projector. Possible values: mlp, ldp", choices=["mlp", "ldp"], default="mlp")
ap.add_argument("--image-mean", nargs=3, type=float, required=False, help="Override image mean values")
ap.add_argument("--image-std", nargs=3, type=float, required=False, help="Override image std values")
ap.add_argument("-o", "--output-dir", help="Directory to save GGUF files. Default is the original model directory", default=None)
# Example --image_mean 0.48145466 0.4578275 0.40821073 --image_std 0.26862954 0.26130258 0.27577711
# Example --image_mean 0.5 0.5 0.5 --image_std 0.5 0.5 0.5
default_image_mean = [0.48145466, 0.4578275, 0.40821073]
default_image_std = [0.26862954, 0.26130258, 0.27577711]
ap.add_argument('--image_mean', type=float, nargs='+', help='Mean of the images for normalization (overrides processor) ', default=None)
Expand All @@ -105,7 +109,7 @@ def bytes_to_unicode():
# output in the same directory as the model if output_dir is None
dir_model = args.model_dir

if args.clip_model_is_vision:
if args.clip_model_is_vision or not os.path.exists(dir_model + "/vocab.json") or args.clip_model_is_openclip:
vocab = None
tokens = None
else:
Expand Down Expand Up @@ -133,7 +137,7 @@ def bytes_to_unicode():
if args.use_f32:
ftype = 0

if args.clip_model_is_vision:
if args.clip_model_is_vision or args.clip_model_is_openclip:
model = CLIPVisionModel.from_pretrained(dir_model)
processor = None
else:
Expand Down
55 changes: 55 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ class MODEL_ARCH(IntEnum):
PLAMO = auto()
CODESHELL = auto()
ORION = auto()
INTERNLM2 = auto()


class MODEL_TENSOR(IntEnum):
Expand Down Expand Up @@ -132,6 +133,21 @@ class MODEL_TENSOR(IntEnum):
ATTN_Q_NORM = auto()
ATTN_K_NORM = auto()

ATTN_QKV_LORA_A = auto()
ATTN_QKV_LORA_B = auto()
ATTN_OUT_LORA_A = auto()
ATTN_OUT_LORA_B = auto()
FFN_UP_LORA_A = auto()
FFN_UP_LORA_B = auto()
FFN_GATE_LORA_A = auto()
FFN_GATE_LORA_B = auto()
FFN_DOWN_LORA_A = auto()
FFN_DOWN_LORA_B = auto()






MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.LLAMA: "llama",
Expand All @@ -153,6 +169,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.PLAMO: "plamo",
MODEL_ARCH.CODESHELL: "codeshell",
MODEL_ARCH.ORION: "orion",
MODEL_ARCH.INTERNLM2: "internlm2",
}

TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
Expand Down Expand Up @@ -182,6 +199,18 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate.{xid}",
MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down.{xid}",
MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up.{xid}",

MODEL_TENSOR.ATTN_QKV_LORA_A : "blk.{bid}.attn_qkv_lora_a",
MODEL_TENSOR.ATTN_QKV_LORA_B : "blk.{bid}.attn_qkv_lora_b",
MODEL_TENSOR.ATTN_OUT_LORA_A : "blk.{bid}.attn_out_lora_a",
MODEL_TENSOR.ATTN_OUT_LORA_B : "blk.{bid}.attn_out_lora_b",
MODEL_TENSOR.FFN_UP_LORA_A : "blk.{bid}.ffn_up_lora_a",
MODEL_TENSOR.FFN_UP_LORA_B : "blk.{bid}.ffn_up_lora_b",
MODEL_TENSOR.FFN_GATE_LORA_A : "blk.{bid}.ffn_gate_lora_a",
MODEL_TENSOR.FFN_GATE_LORA_B : "blk.{bid}.ffn_gate_lora_b",
MODEL_TENSOR.FFN_DOWN_LORA_A : "blk.{bid}.ffn_down_lora_a",
MODEL_TENSOR.FFN_DOWN_LORA_B : "blk.{bid}.ffn_down_lora_b",

}

MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
Expand Down Expand Up @@ -446,6 +475,32 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.INTERNLM2: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_ROT_EMBD,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,

MODEL_TENSOR.ATTN_QKV_LORA_A,
MODEL_TENSOR.ATTN_QKV_LORA_B,
MODEL_TENSOR.ATTN_OUT_LORA_A,
MODEL_TENSOR.ATTN_OUT_LORA_B,
MODEL_TENSOR.FFN_UP_LORA_A,
MODEL_TENSOR.FFN_UP_LORA_B,
MODEL_TENSOR.FFN_GATE_LORA_A,
MODEL_TENSOR.FFN_GATE_LORA_B,
MODEL_TENSOR.FFN_DOWN_LORA_A,
MODEL_TENSOR.FFN_DOWN_LORA_B,
],
# TODO
}

Expand Down
Loading