@@ -40,7 +40,7 @@ def quantize_model_(
4040 if qlinear_config == "8w" :
4141 assert (
4242 qembedding_group_size == 0
43- ), "8-bit embedding quantization only supports per-channel at the moment, please use qembedding_group_size = 0."
43+ ), "8-bit embedding quantization only supports per-token at the moment, please use qembedding_group_size = 0."
4444 if qembedding_group_size == 0 :
4545 embedding_weight_granularity = PerAxis (0 )
4646 else :
@@ -67,6 +67,7 @@ def quantize_model_(
6767 )
6868
6969 if qlinear_config :
70+
7071 def build_linear_config (config_key : str , granularity ):
7172 if config_key == "8da4w" :
7273 return Int8DynamicActivationIntxWeightConfig (
@@ -94,9 +95,7 @@ def build_linear_config(config_key: str, granularity):
9495 if any (cfg == "" for cfg in qlinear_configs ):
9596 raise ValueError ("Linear quantization config entries must be non-empty." )
9697 if len (qlinear_configs ) > 2 :
97- raise ValueError (
98- "Expected at most one fallback linear quantization config, got more than one comma."
99- )
98+ raise ValueError ("Expected at most one fallback linear quantization config, got more than one comma." )
10099
101100 primary_linear_config_key = qlinear_configs [0 ]
102101 fallback_linear_config_key = qlinear_configs [1 ] if len (qlinear_configs ) == 2 else None
@@ -109,16 +108,16 @@ def build_linear_config(config_key: str, granularity):
109108 )
110109 fallback_linear_config_key = None
111110 else :
112- assert qlinear_group_size % 2 == 0 , f"Linear quantization group size must be a multiple of 2, got { qlinear_group_size } ."
111+ assert (
112+ qlinear_group_size % 2 == 0
113+ ), f"Linear quantization group size must be a multiple of 2, got { qlinear_group_size } ."
113114 linear_weight_granularity = PerGroup (qlinear_group_size )
114115
115116 logging .info ("Quantizing linear layers." )
116- primary_linear_config = build_linear_config (
117- primary_linear_config_key , linear_weight_granularity
118- )
117+ primary_linear_config = build_linear_config (primary_linear_config_key , linear_weight_granularity )
119118
120119 # First, quantize layers that are compatible with group quantization
121- def quant_filter (module , fqn ):
120+ def per_group_filter (module , fqn ):
122121 if isinstance (module , torch .nn .Linear ):
123122 # Check if hidden dimension is divisible by group size
124123 # For Linear layers, weight shape is [out_features, in_features]
@@ -129,20 +128,16 @@ def quant_filter(module, fqn):
129128 quantize_ (
130129 eager_model ,
131130 primary_linear_config ,
132- filter_fn = quant_filter ,
131+ filter_fn = per_group_filter ,
133132 )
134133
135134 # Then, quantize incompatible layers using the fallback per-axis config
136135 if fallback_linear_config_key is not None :
137- fallback_linear_config = build_linear_config (
138- fallback_linear_config_key , PerAxis (0 )
139- )
140-
141- def per_channel_filter (module , fqn ):
136+ fallback_linear_config = build_linear_config (fallback_linear_config_key , PerAxis (0 ))
137+
138+ def per_token_filter (module , fqn ):
142139 if isinstance (module , torch .nn .Linear ):
143- # Only quantize layers that are NOT compatible with group quantization
144- # and haven't been quantized yet
145- return not quant_filter (module , fqn )
140+ return module .weight .shape [1 ] % qlinear_group_size != 0
146141 return False
147142
148143 logging .info (
@@ -152,7 +147,7 @@ def per_channel_filter(module, fqn):
152147 quantize_ (
153148 eager_model ,
154149 fallback_linear_config ,
155- filter_fn = per_channel_filter ,
150+ filter_fn = per_token_filter ,
156151 )
157152
158153 unwrap_tensor_subclass (eager_model )
0 commit comments