Skip to content

Commit 1e921a3

Browse files
Codys12steinmetzccodys12MekkCyber
authored
Add optional RMSNorm support to BitNet quantization (config + layers) (#38087)
* enable optional RMS in BitLinear * Fix naming * Import RMS from Llama using config.* * make fix-copies * ran CI loop * remove default BitNetQuantConfig values * Fix BitNetQuantConfig to be Optional * Fix config docstrings to match Optoinal * Edit docstrings to match standards --------- Co-authored-by: steinmetzc <[email protected]> Co-authored-by: codys12 <[email protected]> Co-authored-by: Mohamed Mekkouri <[email protected]>
1 parent 57a79f5 commit 1e921a3

File tree

3 files changed

+50
-3
lines changed

3 files changed

+50
-3
lines changed

src/transformers/convert_slow_tokenizer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1584,7 +1584,9 @@ def __init__(
15841584
self.pattern = pattern
15851585
self.add_prefix_space = add_prefix_space
15861586
self.additional_special_tokens = (
1587-
additional_special_tokens.keys() if type(additional_special_tokens) is dict else additional_special_tokens
1587+
additional_special_tokens.keys()
1588+
if isinstance(additional_special_tokens, dict)
1589+
else additional_special_tokens
15881590
)
15891591

15901592
def extract_vocab_merges_from_model(self, tiktoken_url: str):

src/transformers/integrations/bitnet.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,16 @@ def unpack_weights(packed: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
124124

125125

126126
class BitLinear(nn.Module):
127-
def __init__(self, in_features: int, out_features: int, bias: bool, device=None, dtype=None):
127+
def __init__(
128+
self,
129+
in_features: int,
130+
out_features: int,
131+
bias: bool,
132+
device=None,
133+
dtype=None,
134+
use_rms_norm: bool = False,
135+
rms_norm_eps: float = 1e-6,
136+
):
128137
super().__init__()
129138
self.dtype = dtype
130139
self.in_features = in_features
@@ -150,6 +159,13 @@ def __init__(self, in_features: int, out_features: int, bias: bool, device=None,
150159
else:
151160
self.bias = None
152161

162+
# Optional RMSNorm (applied on the activations before quantization).
163+
self.rms_norm = None
164+
if use_rms_norm:
165+
from ..models.llama.modeling_llama import LlamaRMSNorm
166+
167+
self.rms_norm = LlamaRMSNorm(in_features, eps=rms_norm_eps)
168+
153169
@torch.compile
154170
def activation_quant(self, input, num_bits=8):
155171
"""
@@ -180,6 +196,10 @@ def post_quant_process(self, input, input_scale, weight_scale):
180196
return out
181197

182198
def forward(self, input):
199+
# Apply RMSNorm on the input if requested.
200+
if self.rms_norm is not None:
201+
input = self.rms_norm(input)
202+
183203
w = self.weight
184204
w_quant = unpack_weights(w, dtype=self.dtype)
185205
input_quant, input_scale = self.activation_quant(input)
@@ -245,9 +265,17 @@ def __init__(
245265
device=None,
246266
dtype=None,
247267
online_quant: bool = False,
268+
use_rms_norm: bool = False,
269+
rms_norm_eps: float = 1e-6,
248270
):
249271
super().__init__(in_features, out_features, bias)
250272
self.online_quant = online_quant
273+
# Optional RMSNorm
274+
self.rms_norm = None
275+
if use_rms_norm:
276+
from ..models.llama.modeling_llama import LlamaRMSNorm
277+
278+
self.rms_norm = LlamaRMSNorm(in_features, eps=rms_norm_eps)
251279
if not online_quant:
252280
self.register_buffer(
253281
"weight_scale",
@@ -271,6 +299,10 @@ def load_hook(
271299
return state_dict
272300

273301
def forward(self, input):
302+
# Optional RMSNorm on activations prior to quantization.
303+
if self.rms_norm is not None:
304+
input = self.rms_norm(input)
305+
274306
if self.online_quant:
275307
weight = WeightQuant.apply(self.weight)
276308
else:
@@ -318,6 +350,8 @@ def _replace_with_bitnet_linear(
318350
device=module.weight.device,
319351
dtype=module.weight.dtype,
320352
online_quant=(quantization_config.quantization_mode == "online"),
353+
use_rms_norm=quantization_config.use_rms_norm,
354+
rms_norm_eps=quantization_config.rms_norm_eps,
321355
)
322356
if quantization_config.quantization_mode == "offline":
323357
model._modules[name].requires_grad_(False)
@@ -328,6 +362,8 @@ def _replace_with_bitnet_linear(
328362
bias=module.bias is not None,
329363
device=module.weight.device,
330364
dtype=module.weight.dtype,
365+
use_rms_norm=quantization_config.use_rms_norm,
366+
rms_norm_eps=quantization_config.rms_norm_eps,
331367
)
332368
model._modules[name].requires_grad_(False)
333369
has_been_replaced = True
@@ -363,7 +399,7 @@ def replace_with_bitnet_linear(
363399
model (`torch.nn.Module`):
364400
Input model or `torch.nn.Module` as the function is run recursively.
365401
modules_to_not_convert (`List[`str`]`, *optional*, defaults to `["lm_head"]`):
366-
Names of the modules to not convert in `EetqLinear`. In practice we keep the `lm_head` in full precision
402+
Names of the modules to not convert in `BitLinear`. In practice we keep the `lm_head` in full precision
367403
for numerical stability reasons.
368404
current_key_name (`List[`str`]`, *optional*):
369405
An array to track the current key of the recursion. This is used to check whether the current key (part of

src/transformers/utils/quantization_config.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1791,6 +1791,11 @@ class BitNetQuantConfig(QuantizationConfigMixin):
17911791
In `offline` mode, quantization parameters are pre-calculated *before* inference.
17921792
These parameters are then fixed and loaded into the quantized model. This
17931793
generally results in lower runtime overhead compared to online quantization.
1794+
use_rms_norm (`bool`, *optional*, defaults to `False`):
1795+
Whether to apply RMSNorm on the activations before quantization. This matches the original BitNet paper's approach
1796+
of normalizing activations before quantization/packing.
1797+
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
1798+
The epsilon value used in the RMSNorm layer for numerical stability.
17941799
kwargs (`Dict[str, Any]`, *optional*):
17951800
Additional keyword arguments that may be used by specific quantization
17961801
backends or future versions.
@@ -1801,6 +1806,8 @@ def __init__(
18011806
modules_to_not_convert: Optional[List] = None,
18021807
linear_class: Optional[str] = "bitlinear",
18031808
quantization_mode: Optional[str] = "offline",
1809+
use_rms_norm: Optional[bool] = False,
1810+
rms_norm_eps: Optional[float] = 1e-6,
18041811
**kwargs,
18051812
):
18061813
if linear_class not in ["bitlinear", "autobitlinear"]:
@@ -1811,6 +1818,8 @@ def __init__(
18111818
self.modules_to_not_convert = modules_to_not_convert
18121819
self.linear_class = linear_class
18131820
self.quantization_mode = quantization_mode
1821+
self.use_rms_norm = use_rms_norm
1822+
self.rms_norm_eps = rms_norm_eps
18141823
self.post_init()
18151824

18161825
def post_init(self):

0 commit comments

Comments
 (0)