@@ -54,9 +54,16 @@ def split_param_groups(model) -> tuple[list, list, list]:
5454 params_quant , params_embed , params_no_quant = [], [], []
5555
5656 def get_param_groups (model ):
57+ seen_data_ptrs = set () # avoid duplicates in case of tied weights
5758 for module in model .children ():
5859 is_linear = _is_linear (module )
5960 for n , p in module .named_parameters ():
61+ if n == "weight" :
62+ data_ptr = p .data_ptr ()
63+ if data_ptr in seen_data_ptrs :
64+ continue
65+ seen_data_ptrs .add (data_ptr )
66+
6067 if is_linear and n == "weight" :
6168 params_quant .append (p )
6269 elif isinstance (module , nn .Embedding ) and n == "weight" :
@@ -152,7 +159,12 @@ def compare_parq_convert(
152159def check_torchao_tensor_subclass (
153160 test_case : common_utils .TestCase , model : nn .Module , weight_only : bool = False
154161):
155- for module in model .modules ():
162+ for name , module in model .named_modules ():
163+ if not hasattr (module , "weight" ) or f"{ name } .weight" in getattr (
164+ model , "_tied_weights_keys" , []
165+ ):
166+ continue
167+
156168 if not weight_only and _is_linear (module ):
157169 test_case .assertTrue (isinstance (module .weight , IntxUnpackedToInt8Tensor ))
158170 test_case .assertTrue (
@@ -163,34 +175,58 @@ def check_torchao_tensor_subclass(
163175 test_case .assertTrue (module .weight .activation_quantization is None )
164176
165177
178+ def apply_activation_quantization (
179+ model : nn .Module , optimizer : torch .optim .Optimizer , model_dtype : torch .dtype
180+ ):
181+ # apply torchao quantized activations on top
182+ activation_config = IntxFakeQuantizeConfig (
183+ torch .int8 , "per_token" , is_symmetric = False , scale_precision = model_dtype
184+ )
185+ qat_config = QATConfig (activation_config = activation_config , step = "prepare" )
186+ for filter_fn in optimizer .get_filter_fns (model ):
187+ try :
188+ quantize_ (model , qat_config , filter_fn = filter_fn )
189+ except ValueError as e :
190+ if str (e ) == "Activation fake quantization is not supported for embedding" :
191+ pass
192+
193+
166194class M (nn .Module ):
167- def __init__ (self , m = 256 , n = 128 , k = 16 , bias = False , embedding = True ):
195+ _tied_weights_keys : list [str ] = []
196+
197+ def __init__ (
198+ self , m = 256 , n = 128 , k = 16 , bias = False , embedding = True , tied_weights = False
199+ ):
168200 super ().__init__ ()
169- self .embedding = nn .Embedding (10 , m ) if embedding else nn .Identity ()
201+ self .embedding = nn .Embedding (k , m ) if embedding else nn .Identity ()
170202 self .linear1 = nn .Linear (m , n , bias = bias )
171203 self .linear2 = nn .Linear (n , k , bias = bias )
172204 self .relu = nn .ReLU ()
173205 self .sigmoid = nn .Sigmoid ()
174206
207+ if embedding and tied_weights :
208+ assert self .embedding .weight .shape == self .linear2 .weight .shape
209+ self .linear2 .weight = self .embedding .weight
210+ self ._tied_weights_keys .append ("linear2.weight" )
211+
175212 def reset_parameters (self ):
176213 for module in (self .linear1 , self .linear2 ):
177214 nn .init .xavier_uniform_ (module .weight )
178215 if module .bias is not None :
179216 nn .init .zeros_ (module .bias )
180217
181218 def example_inputs (self , device = None ):
182- return (
183- torch .randint (1 , 10 , (1 , self .linear1 .in_features ), device = device )
184- if isinstance (self .embedding , nn .Embedding )
185- else torch .randn (1 , self .linear1 .in_features , device = device )
186- )
219+ if isinstance (self .embedding , nn .Identity ):
220+ inputs = torch .randn (1 , self .linear1 .in_features , device = device )
221+ else :
222+ k = self .embedding .num_embeddings
223+ inputs = torch .randint (1 , k , (1 , self .linear1 .in_features ), device = device )
224+ return inputs
187225
188226 def forward (self , x ):
189227 x = self .embedding (x )
190- x = self .linear1 (x )
191- x = self .relu (x )
192- x = self .linear2 (x )
193- x = self .sigmoid (x )
228+ x = self .relu (self .linear1 (x ))
229+ x = self .sigmoid (self .linear2 (x ))
194230 return x
195231
196232
@@ -297,7 +333,7 @@ def test_int4_weight_only_e2e(self, group_size: int = 32):
297333 ProxHardQuant (),
298334 quant_per_channel = True ,
299335 )
300- compare_parq_convert (model , m_ref , optimizer )
336+ compare_parq_convert (model , m_ref , optimizer , weight_only = True )
301337
302338 @unittest .skipIf (_DEVICE == "cpu" , "Need GPU available" )
303339 @common_utils .parametrize ("b" , [2 , 3 , 4 , 8 ])
@@ -399,6 +435,30 @@ def test_intx_weight_only_e2e(self, b: int = 2, group_size: int = 32):
399435 compare_parq_convert (model , m_ref , optimizer , weight_only = True )
400436 check_torchao_tensor_subclass (self , model , weight_only = True )
401437
438+ @common_utils .parametrize ("b" , [2 , 3 ])
439+ @common_utils .parametrize (
440+ "model_dtype" , [torch .float16 , torch .float32 , torch .bfloat16 ]
441+ )
442+ def test_intx_weight_only_tied_embed_linear (
443+ self , b : int = 2 , model_dtype : torch .dtype = torch .float32
444+ ):
445+ model = M (m = 256 , n = 256 , tied_weights = True ).to (_DEVICE )
446+
447+ quantizer = StretchedUnifTorchaoQuantizer (b )
448+ base_optimizer = torch .optim .SGD (build_param_groups (model , b ))
449+ optimizer = QuantOptimizer (
450+ base_optimizer , quantizer , ProxHardQuant (), quant_per_channel = True
451+ )
452+ optimizer .zero_grad ()
453+ optimizer .step ()
454+
455+ apply_activation_quantization (model , optimizer , model_dtype )
456+ optimizer .torchao_convert (model )
457+ check_torchao_tensor_subclass (self , model )
458+ self .assertTrue (
459+ torch .equal (model .embedding .weight .qdata , model .linear2 .weight .qdata )
460+ )
461+
402462
403463class TestInt8DynamicActivationTorchaoQuantizer (common_utils .TestCase ):
404464 def setUp (self ):
@@ -435,16 +495,12 @@ def test_int8_dynamic_activation_intx_e2e(
435495 optimizer = QuantOptimizer (
436496 base_optimizer , quantizer , ProxHardQuant (), quant_per_channel = True
437497 )
498+
438499 optimizer .zero_grad ()
439500 optimizer .step ()
440501
441- # apply torchao quantized activations on top
442- activation_config = IntxFakeQuantizeConfig (
443- torch .int8 , "per_token" , is_symmetric = False , scale_precision = model_dtype
444- )
445- qat_config = QATConfig (activation_config = activation_config , step = "prepare" )
446- for filter_fn in optimizer .get_filter_fns (model ):
447- quantize_ (model , qat_config , filter_fn = filter_fn )
502+ apply_activation_quantization (model , optimizer , model_dtype )
503+
448504 out = model (x )
449505 torch .testing .assert_close (out , ref_out , atol = 0 , rtol = 0 )
450506
@@ -462,7 +518,10 @@ def test_int8_dynamic_activation_intx_e2e(
462518 check_torchao_tensor_subclass (self , model )
463519
464520 if attach_hf_config :
465- reg_param_names = {n for n , m in model .named_modules () if _is_linear (m )}
521+ reg_param_names = {
522+ n for n , m in model .named_modules () if isinstance (m , nn .Embedding )
523+ }
524+ reg_param_names .add ("_default" )
466525 module_fqn_to_config = (
467526 model .config .quantization_config .quant_type .module_fqn_to_config
468527 )
0 commit comments