@@ -124,7 +124,16 @@ def unpack_weights(packed: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
124
124
125
125
126
126
class BitLinear (nn .Module ):
127
- def __init__ (self , in_features : int , out_features : int , bias : bool , device = None , dtype = None ):
127
+ def __init__ (
128
+ self ,
129
+ in_features : int ,
130
+ out_features : int ,
131
+ bias : bool ,
132
+ device = None ,
133
+ dtype = None ,
134
+ use_rms_norm : bool = False ,
135
+ rms_norm_eps : float = 1e-6 ,
136
+ ):
128
137
super ().__init__ ()
129
138
self .dtype = dtype
130
139
self .in_features = in_features
@@ -150,6 +159,13 @@ def __init__(self, in_features: int, out_features: int, bias: bool, device=None,
150
159
else :
151
160
self .bias = None
152
161
162
+ # Optional RMSNorm (applied on the activations before quantization).
163
+ self .rms_norm = None
164
+ if use_rms_norm :
165
+ from ..models .llama .modeling_llama import LlamaRMSNorm
166
+
167
+ self .rms_norm = LlamaRMSNorm (in_features , eps = rms_norm_eps )
168
+
153
169
@torch .compile
154
170
def activation_quant (self , input , num_bits = 8 ):
155
171
"""
@@ -180,6 +196,10 @@ def post_quant_process(self, input, input_scale, weight_scale):
180
196
return out
181
197
182
198
def forward (self , input ):
199
+ # Apply RMSNorm on the input if requested.
200
+ if self .rms_norm is not None :
201
+ input = self .rms_norm (input )
202
+
183
203
w = self .weight
184
204
w_quant = unpack_weights (w , dtype = self .dtype )
185
205
input_quant , input_scale = self .activation_quant (input )
@@ -245,9 +265,17 @@ def __init__(
245
265
device = None ,
246
266
dtype = None ,
247
267
online_quant : bool = False ,
268
+ use_rms_norm : bool = False ,
269
+ rms_norm_eps : float = 1e-6 ,
248
270
):
249
271
super ().__init__ (in_features , out_features , bias )
250
272
self .online_quant = online_quant
273
+ # Optional RMSNorm
274
+ self .rms_norm = None
275
+ if use_rms_norm :
276
+ from ..models .llama .modeling_llama import LlamaRMSNorm
277
+
278
+ self .rms_norm = LlamaRMSNorm (in_features , eps = rms_norm_eps )
251
279
if not online_quant :
252
280
self .register_buffer (
253
281
"weight_scale" ,
@@ -271,6 +299,10 @@ def load_hook(
271
299
return state_dict
272
300
273
301
def forward (self , input ):
302
+ # Optional RMSNorm on activations prior to quantization.
303
+ if self .rms_norm is not None :
304
+ input = self .rms_norm (input )
305
+
274
306
if self .online_quant :
275
307
weight = WeightQuant .apply (self .weight )
276
308
else :
@@ -318,6 +350,8 @@ def _replace_with_bitnet_linear(
318
350
device = module .weight .device ,
319
351
dtype = module .weight .dtype ,
320
352
online_quant = (quantization_config .quantization_mode == "online" ),
353
+ use_rms_norm = quantization_config .use_rms_norm ,
354
+ rms_norm_eps = quantization_config .rms_norm_eps ,
321
355
)
322
356
if quantization_config .quantization_mode == "offline" :
323
357
model ._modules [name ].requires_grad_ (False )
@@ -328,6 +362,8 @@ def _replace_with_bitnet_linear(
328
362
bias = module .bias is not None ,
329
363
device = module .weight .device ,
330
364
dtype = module .weight .dtype ,
365
+ use_rms_norm = quantization_config .use_rms_norm ,
366
+ rms_norm_eps = quantization_config .rms_norm_eps ,
331
367
)
332
368
model ._modules [name ].requires_grad_ (False )
333
369
has_been_replaced = True
@@ -363,7 +399,7 @@ def replace_with_bitnet_linear(
363
399
model (`torch.nn.Module`):
364
400
Input model or `torch.nn.Module` as the function is run recursively.
365
401
modules_to_not_convert (`List[`str`]`, *optional*, defaults to `["lm_head"]`):
366
- Names of the modules to not convert in `EetqLinear `. In practice we keep the `lm_head` in full precision
402
+ Names of the modules to not convert in `BitLinear `. In practice we keep the `lm_head` in full precision
367
403
for numerical stability reasons.
368
404
current_key_name (`List[`str`]`, *optional*):
369
405
An array to track the current key of the recursion. This is used to check whether the current key (part of
0 commit comments