@@ -40,7 +40,7 @@ def quantize_model_(
4040        if  qlinear_config  ==  "8w" :
4141            assert  (
4242                qembedding_group_size  ==  0 
43-             ), "8-bit embedding quantization only supports per-channel  at the moment, please use qembedding_group_size = 0." 
43+             ), "8-bit embedding quantization only supports per-token  at the moment, please use qembedding_group_size = 0." 
4444        if  qembedding_group_size  ==  0 :
4545            embedding_weight_granularity  =  PerAxis (0 )
4646        else :
@@ -94,9 +94,7 @@ def build_linear_config(config_key: str, granularity):
9494        if  any (cfg  ==  ""  for  cfg  in  qlinear_configs ):
9595            raise  ValueError ("Linear quantization config entries must be non-empty." )
9696        if  len (qlinear_configs ) >  2 :
97-             raise  ValueError (
98-                 "Expected at most one fallback linear quantization config, got more than one comma." 
99-             )
97+             raise  ValueError ("Expected at most one fallback linear quantization config, got more than one comma." )
10098
10199        primary_linear_config_key  =  qlinear_configs [0 ]
102100        fallback_linear_config_key  =  qlinear_configs [1 ] if  len (qlinear_configs ) ==  2  else  None 
@@ -109,16 +107,16 @@ def build_linear_config(config_key: str, granularity):
109107                )
110108                fallback_linear_config_key  =  None 
111109        else :
112-             assert  qlinear_group_size  %  2  ==  0 , f"Linear quantization group size must be a multiple of 2, got { qlinear_group_size }  
110+             assert  (
111+                 qlinear_group_size  %  2  ==  0 
112+             ), f"Linear quantization group size must be a multiple of 2, got { qlinear_group_size }  
113113            linear_weight_granularity  =  PerGroup (qlinear_group_size )
114114
115115        logging .info ("Quantizing linear layers." )
116-         primary_linear_config  =  build_linear_config (
117-             primary_linear_config_key , linear_weight_granularity 
118-         )
116+         primary_linear_config  =  build_linear_config (primary_linear_config_key , linear_weight_granularity )
119117
120118        # First, quantize layers that are compatible with group quantization 
121-         def  quant_filter (module , fqn ):
119+         def  per_group_filter (module , fqn ):
122120            if  isinstance (module , torch .nn .Linear ):
123121                # Check if hidden dimension is divisible by group size 
124122                # For Linear layers, weight shape is [out_features, in_features] 
@@ -129,20 +127,16 @@ def quant_filter(module, fqn):
129127        quantize_ (
130128            eager_model ,
131129            primary_linear_config ,
132-             filter_fn = quant_filter ,
130+             filter_fn = per_group_filter ,
133131        )
134132
135133        # Then, quantize incompatible layers using the fallback per-axis config 
136134        if  fallback_linear_config_key  is  not None :
137-             fallback_linear_config  =  build_linear_config (
138-                 fallback_linear_config_key , PerAxis (0 )
139-             )
140-             
141-             def  per_channel_filter (module , fqn ):
135+             fallback_linear_config  =  build_linear_config (fallback_linear_config_key , PerAxis (0 ))
136+ 
137+             def  per_token_filter (module , fqn ):
142138                if  isinstance (module , torch .nn .Linear ):
143-                     # Only quantize layers that are NOT compatible with group quantization 
144-                     # and haven't been quantized yet 
145-                     return  not  quant_filter (module , fqn )
139+                     return  module .weight .shape [1 ] %  qlinear_group_size  !=  0 
146140                return  False 
147141
148142            logging .info (
@@ -152,7 +146,7 @@ def per_channel_filter(module, fqn):
152146            quantize_ (
153147                eager_model ,
154148                fallback_linear_config ,
155-                 filter_fn = per_channel_filter ,
149+                 filter_fn = per_token_filter ,
156150            )
157151
158152    unwrap_tensor_subclass (eager_model )
0 commit comments