Skip to content

Commit 5c076fb

Browse files
ArthurZuckerzRzRzRzRzRzRzR
authored andcommitted
Add glm4 (#37388)
* add changed * Revert "add changed" This reverts commit 0a0166a. * update with NEW MODEL class called GLM4 * update * Update glm4.md * Name * style * fix copies * fixup test --------- Co-authored-by: Yuxuan Zhang <[email protected]>
1 parent 28c9541 commit 5c076fb

File tree

15 files changed

+1911
-0
lines changed

15 files changed

+1911
-0
lines changed

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,8 @@
461461
title: Gemma2
462462
- local: model_doc/glm
463463
title: GLM
464+
- local: model_doc/glm4
465+
title: glm4
464466
- local: model_doc/openai-gpt
465467
title: GPT
466468
- local: model_doc/gpt_neo

docs/source/en/model_doc/glm4.md

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
<!--Copyright 2025 The GLM & ZhipuAI team and The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
12+
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
13+
rendered properly in your Markdown viewer.
14+
15+
-->
16+
17+
# Glm4
18+
19+
## Overview
20+
21+
To be released with the official model launch.
22+
23+
## Glm4Config
24+
25+
[[autodoc]] Glm4Config
26+
27+
## Glm4Model
28+
29+
[[autodoc]] Glm4Model
30+
- forward
31+
32+
## Glm4ForCausalLM
33+
34+
[[autodoc]] Glm4ForCausalLM
35+
- forward
36+
37+
## Glm4ForSequenceClassification
38+
39+
[[autodoc]] Glm4ForSequenceClassification
40+
- forward
41+
42+
## Glm4ForTokenClassification
43+
44+
[[autodoc]] Glm4ForTokenClassification
45+
- forward

src/transformers/__init__.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,7 @@
482482
"GitVisionConfig",
483483
],
484484
"models.glm": ["GlmConfig"],
485+
"models.glm4": ["Glm4Config"],
485486
"models.glpn": ["GLPNConfig"],
486487
"models.got_ocr2": [
487488
"GotOcr2Config",
@@ -2526,6 +2527,15 @@
25262527
"Llama4PreTrainedModel",
25272528
]
25282529
)
2530+
_import_structure["models.glm4"].extend(
2531+
[
2532+
"Glm4ForCausalLM",
2533+
"Glm4ForSequenceClassification",
2534+
"Glm4ForTokenClassification",
2535+
"Glm4Model",
2536+
"Glm4PreTrainedModel",
2537+
]
2538+
)
25292539
_import_structure["models.glpn"].extend(
25302540
[
25312541
"GLPNForDepthEstimation",
@@ -5742,6 +5752,7 @@
57425752
GitVisionConfig,
57435753
)
57445754
from .models.glm import GlmConfig
5755+
from .models.glm4 import Glm4Config
57455756
from .models.glpn import GLPNConfig
57465757
from .models.got_ocr2 import GotOcr2Config, GotOcr2Processor, GotOcr2VisionConfig
57475758
from .models.gpt2 import (
@@ -7624,6 +7635,13 @@
76247635
GlmModel,
76257636
GlmPreTrainedModel,
76267637
)
7638+
from .models.glm4 import (
7639+
Glm4ForCausalLM,
7640+
Glm4ForSequenceClassification,
7641+
Glm4ForTokenClassification,
7642+
Glm4Model,
7643+
Glm4PreTrainedModel,
7644+
)
76277645
from .models.glpn import (
76287646
GLPNForDepthEstimation,
76297647
GLPNModel,

src/transformers/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@
110110
gemma3,
111111
git,
112112
glm,
113+
glm4,
113114
glpn,
114115
got_ocr2,
115116
gpt2,

src/transformers/models/auto/configuration_auto.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@
129129
("gemma3_text", "Gemma3TextConfig"),
130130
("git", "GitConfig"),
131131
("glm", "GlmConfig"),
132+
("glm4", "Glm4Config"),
132133
("glpn", "GLPNConfig"),
133134
("got_ocr2", "GotOcr2Config"),
134135
("gpt-sw3", "GPT2Config"),
@@ -476,6 +477,7 @@
476477
("gemma3_text", "Gemma3ForCausalLM"),
477478
("git", "GIT"),
478479
("glm", "GLM"),
480+
("glm4", "glm4"),
479481
("glpn", "GLPN"),
480482
("got_ocr2", "GOT-OCR2"),
481483
("gpt-sw3", "GPT-Sw3"),

src/transformers/models/auto/modeling_auto.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@
122122
("gemma3_text", "Gemma3TextModel"),
123123
("git", "GitModel"),
124124
("glm", "GlmModel"),
125+
("glm4", "Glm4Model"),
125126
("glpn", "GLPNModel"),
126127
("got_ocr2", "GotOcr2ForConditionalGeneration"),
127128
("gpt-sw3", "GPT2Model"),
@@ -532,6 +533,7 @@
532533
("gemma3_text", "Gemma3ForCausalLM"),
533534
("git", "GitForCausalLM"),
534535
("glm", "GlmForCausalLM"),
536+
("glm4", "Glm4ForCausalLM"),
535537
("got_ocr2", "GotOcr2ForConditionalGeneration"),
536538
("gpt-sw3", "GPT2LMHeadModel"),
537539
("gpt2", "GPT2LMHeadModel"),
@@ -1035,6 +1037,7 @@
10351037
("gemma", "GemmaForSequenceClassification"),
10361038
("gemma2", "Gemma2ForSequenceClassification"),
10371039
("glm", "GlmForSequenceClassification"),
1040+
("glm4", "Glm4ForSequenceClassification"),
10381041
("gpt-sw3", "GPT2ForSequenceClassification"),
10391042
("gpt2", "GPT2ForSequenceClassification"),
10401043
("gpt_bigcode", "GPTBigCodeForSequenceClassification"),
@@ -1236,6 +1239,7 @@
12361239
("gemma", "GemmaForTokenClassification"),
12371240
("gemma2", "Gemma2ForTokenClassification"),
12381241
("glm", "GlmForTokenClassification"),
1242+
("glm4", "Glm4ForTokenClassification"),
12391243
("gpt-sw3", "GPT2ForTokenClassification"),
12401244
("gpt2", "GPT2ForTokenClassification"),
12411245
("gpt_bigcode", "GPTBigCodeForTokenClassification"),

src/transformers/models/auto/tokenization_auto.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,7 @@
238238
),
239239
("git", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
240240
("glm", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
241+
("glm4", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
241242
("gpt-sw3", ("GPTSw3Tokenizer" if is_sentencepiece_available() else None, None)),
242243
("gpt2", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
243244
("gpt_bigcode", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Copyright 2025 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from typing import TYPE_CHECKING
15+
16+
from ...utils import _LazyModule
17+
from ...utils.import_utils import define_import_structure
18+
19+
20+
if TYPE_CHECKING:
21+
from .configuration_glm4 import *
22+
from .modeling_glm4 import *
23+
else:
24+
import sys
25+
26+
_file = globals()["__file__"]
27+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
# coding=utf-8
2+
# Copyright 2025 The GLM4 & ZhipuAI team and HuggingFace Inc. team. All rights reserved.
3+
#
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
from ...configuration_utils import PretrainedConfig
18+
19+
20+
class Glm4Config(PretrainedConfig):
21+
r"""
22+
This is the configuration class to store the configuration of a [`Glm4Model`]. It is used to instantiate an Glm4
23+
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
24+
defaults will yield a similar configuration to that of the Glm4-4-9b-chat.
25+
e.g. [THUDM/glm-4-0414-9b-chat-chat](https://huggingface.co/THUDM/glm-4-0414-9b-chat-chat)
26+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
27+
documentation from [`PretrainedConfig`] for more information.
28+
Args:
29+
vocab_size (`int`, *optional*, defaults to 151552):
30+
Vocabulary size of the Glm4 model. Defines the number of different tokens that can be represented by the
31+
`inputs_ids` passed when calling [`Glm4Model`]
32+
hidden_size (`int`, *optional*, defaults to 4096):
33+
Dimension of the hidden representations.
34+
intermediate_size (`int`, *optional*, defaults to 13696):
35+
Dimension of the MLP representations.
36+
num_hidden_layers (`int`, *optional*, defaults to 40):
37+
Number of hidden layers in the Transformer decoder.
38+
num_attention_heads (`int`, *optional*, defaults to 32):
39+
Number of attention heads for each attention layer in the Transformer decoder.
40+
num_key_value_heads (`int`, *optional*, defaults to 2):
41+
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
42+
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
43+
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
44+
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
45+
by meanpooling all the original heads within that group. For more details checkout [this
46+
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
47+
`num_attention_heads`.
48+
partial_rotary_factor (`float`, *optional*, defaults to 0.5): The factor of the partial rotary position.
49+
head_dim (`int`, *optional*, defaults to 128):
50+
The attention head dimension.
51+
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
52+
The legacy activation function. It is overwritten by the `hidden_activation`.
53+
attention_dropout (`float`, *optional*, defaults to 0.0):
54+
The dropout ratio for the attention probabilities.
55+
max_position_embeddings (`int`, *optional*, defaults to 131072):
56+
The maximum sequence length that this model might ever be used with.
57+
initializer_range (`float`, *optional*, defaults to 0.02):
58+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
59+
rms_norm_eps (`float`, *optional*, defaults to 1.5625e-07):
60+
The epsilon used by the rms normalization layers.
61+
use_cache (`bool`, *optional*, defaults to `True`):
62+
Whether or not the model should return the last key/values attentions (not used by all models). Only
63+
relevant if `config.is_decoder=True`.
64+
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
65+
Whether to tie weight embeddings
66+
rope_theta (`float`, *optional*, defaults to 10000.0):
67+
The base period of the RoPE embeddings.
68+
pad_token_id (`int`, *optional*, defaults to 151329):
69+
Padding token id.
70+
eos_token_id (`int` | `list`, *optional*, defaults to `[151329, 151336, 151338]`):
71+
End of stream token id.
72+
bos_token_id (`int`, *optional*):
73+
Beginning of stream token id.
74+
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `True`):
75+
Whether to use a bias in the query, key, value and output projection layers during self-attention.
76+
```python
77+
>>> from transformers import Glm4Model, Glm4Config
78+
>>> # Initializing a Glm4 glm4-4-9b-chat style configuration
79+
>>> configuration = Glm4Config()
80+
>>> # Initializing a model from the glm4-4-9b-chat style configuration
81+
>>> model = Glm4Model(configuration)
82+
>>> # Accessing the model configuration
83+
>>> configuration = model.config
84+
```"""
85+
86+
model_type = "glm4"
87+
keys_to_ignore_at_inference = ["past_key_values"]
88+
base_model_tp_plan = {
89+
"layers.*.self_attn.q_proj": "colwise",
90+
"layers.*.self_attn.k_proj": "colwise",
91+
"layers.*.self_attn.v_proj": "colwise",
92+
"layers.*.self_attn.o_proj": "rowwise",
93+
"layers.*.mlp.gate_up_proj": "colwise_rep", # we need to replicate here due to the `chunk` operation
94+
"layers.*.mlp.down_proj": "rowwise_rep", # we need to replicate here due to the `chunk` operation
95+
}
96+
base_model_pp_plan = {
97+
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
98+
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
99+
"norm": (["hidden_states"], ["hidden_states"]),
100+
}
101+
102+
def __init__(
103+
self,
104+
vocab_size=151552,
105+
hidden_size=4096,
106+
intermediate_size=13696,
107+
num_hidden_layers=40,
108+
num_attention_heads=32,
109+
num_key_value_heads=2,
110+
partial_rotary_factor=0.5,
111+
head_dim=128,
112+
hidden_act="silu",
113+
attention_dropout=0.0,
114+
max_position_embeddings=131072,
115+
initializer_range=0.02,
116+
rms_norm_eps=0.00000015625,
117+
use_cache=True,
118+
tie_word_embeddings=False,
119+
rope_theta=10000.0,
120+
pad_token_id=151329,
121+
eos_token_id=[151329, 151336, 151338],
122+
bos_token_id=None,
123+
attention_bias=True,
124+
**kwargs,
125+
):
126+
self.vocab_size = vocab_size
127+
self.max_position_embeddings = max_position_embeddings
128+
self.hidden_size = hidden_size
129+
self.intermediate_size = intermediate_size
130+
self.num_hidden_layers = num_hidden_layers
131+
self.num_attention_heads = num_attention_heads
132+
self.partial_rotary_factor = partial_rotary_factor
133+
self.head_dim = head_dim
134+
self.num_key_value_heads = num_key_value_heads
135+
self.hidden_act = hidden_act
136+
self.initializer_range = initializer_range
137+
self.rms_norm_eps = rms_norm_eps
138+
self.use_cache = use_cache
139+
self.rope_theta = rope_theta
140+
self.attention_bias = attention_bias
141+
self.attention_dropout = attention_dropout
142+
143+
super().__init__(
144+
pad_token_id=pad_token_id,
145+
bos_token_id=bos_token_id,
146+
eos_token_id=eos_token_id,
147+
tie_word_embeddings=tie_word_embeddings,
148+
**kwargs,
149+
)
150+
151+
152+
__all__ = ["Glm4Config"]

0 commit comments

Comments
 (0)