@@ -40,7 +40,7 @@ def quantize_model_(
4040        if  qlinear_config  ==  "8w" :
4141            assert  (
4242                qembedding_group_size  ==  0 
43-             ), "8-bit embedding quantization only supports per-channel  at the moment, please use qembedding_group_size = 0." 
43+             ), "8-bit embedding quantization only supports per-token  at the moment, please use qembedding_group_size = 0." 
4444        if  qembedding_group_size  ==  0 :
4545            embedding_weight_granularity  =  PerAxis (0 )
4646        else :
@@ -67,6 +67,7 @@ def quantize_model_(
6767        )
6868
6969    if  qlinear_config :
70+ 
7071        def  build_linear_config (config_key : str , granularity ):
7172            if  config_key  ==  "8da4w" :
7273                return  Int8DynamicActivationIntxWeightConfig (
@@ -94,9 +95,7 @@ def build_linear_config(config_key: str, granularity):
9495        if  any (cfg  ==  ""  for  cfg  in  qlinear_configs ):
9596            raise  ValueError ("Linear quantization config entries must be non-empty." )
9697        if  len (qlinear_configs ) >  2 :
97-             raise  ValueError (
98-                 "Expected at most one fallback linear quantization config, got more than one comma." 
99-             )
98+             raise  ValueError ("Expected at most one fallback linear quantization config, got more than one comma." )
10099
101100        primary_linear_config_key  =  qlinear_configs [0 ]
102101        fallback_linear_config_key  =  qlinear_configs [1 ] if  len (qlinear_configs ) ==  2  else  None 
@@ -109,16 +108,16 @@ def build_linear_config(config_key: str, granularity):
109108                )
110109                fallback_linear_config_key  =  None 
111110        else :
112-             assert  qlinear_group_size  %  2  ==  0 , f"Linear quantization group size must be a multiple of 2, got { qlinear_group_size }  
111+             assert  (
112+                 qlinear_group_size  %  2  ==  0 
113+             ), f"Linear quantization group size must be a multiple of 2, got { qlinear_group_size }  
113114            linear_weight_granularity  =  PerGroup (qlinear_group_size )
114115
115116        logging .info ("Quantizing linear layers." )
116-         primary_linear_config  =  build_linear_config (
117-             primary_linear_config_key , linear_weight_granularity 
118-         )
117+         primary_linear_config  =  build_linear_config (primary_linear_config_key , linear_weight_granularity )
119118
120119        # First, quantize layers that are compatible with group quantization 
121-         def  quant_filter (module , fqn ):
120+         def  per_group_filter (module , fqn ):
122121            if  isinstance (module , torch .nn .Linear ):
123122                # Check if hidden dimension is divisible by group size 
124123                # For Linear layers, weight shape is [out_features, in_features] 
@@ -129,20 +128,16 @@ def quant_filter(module, fqn):
129128        quantize_ (
130129            eager_model ,
131130            primary_linear_config ,
132-             filter_fn = quant_filter ,
131+             filter_fn = per_group_filter ,
133132        )
134133
135134        # Then, quantize incompatible layers using the fallback per-axis config 
136135        if  fallback_linear_config_key  is  not None :
137-             fallback_linear_config  =  build_linear_config (
138-                 fallback_linear_config_key , PerAxis (0 )
139-             )
140-             
141-             def  per_channel_filter (module , fqn ):
136+             fallback_linear_config  =  build_linear_config (fallback_linear_config_key , PerAxis (0 ))
137+ 
138+             def  per_token_filter (module , fqn ):
142139                if  isinstance (module , torch .nn .Linear ):
143-                     # Only quantize layers that are NOT compatible with group quantization 
144-                     # and haven't been quantized yet 
145-                     return  not  quant_filter (module , fqn )
140+                     return  module .weight .shape [1 ] %  qlinear_group_size  !=  0 
146141                return  False 
147142
148143            logging .info (
@@ -152,7 +147,7 @@ def per_channel_filter(module, fqn):
152147            quantize_ (
153148                eager_model ,
154149                fallback_linear_config ,
155-                 filter_fn = per_channel_filter ,
150+                 filter_fn = per_token_filter ,
156151            )
157152
158153    unwrap_tensor_subclass (eager_model )
0 commit comments