2424from qwix ._src .core import dot_general
2525from qwix ._src .core import numerics
2626from qwix ._src .core import qarray
27+ from qwix ._src .core import qarray_qt
2728
2829
2930@dataclasses .dataclass (slots = True , frozen = True , kw_only = True )
@@ -43,19 +44,16 @@ class DotGeneralQtConfig:
4344 dlhs_grad_qtype : jax .typing .DTypeLike | None = None # incoming gradient
4445 dlhs_grad_calibration_method : str = 'absmax'
4546 dlhs_tile_size : int | float | None = None
47+ dlhs_stochastic_rounding_noise_fn : numerics .NoiseFn | None = None
4648
4749 # Backward pass (drhs).
4850 drhs_grad_qtype : jax .typing .DTypeLike | None = None # incoming gradient
4951 drhs_grad_calibration_method : str = 'absmax'
5052 drhs_tile_size : int | float | None = None
53+ drhs_stochastic_rounding_noise_fn : numerics .NoiseFn | None = None
5154
5255 # Misc.
5356 disable_channelwise_axes : bool = False
54- bwd_use_original_residuals : bool = False # what to use as residuals
55-
56- # Configs for stochastic rounding.
57- dlhs_stochastic_rounding_noise_fn : numerics .NoiseFn | None = None
58- drhs_stochastic_rounding_noise_fn : numerics .NoiseFn | None = None
5957
6058
6159def _ranges_like (* xs ):
@@ -124,84 +122,47 @@ def _apply_rhs_scale_to_lhs(lhs, rhs_scale, dnums):
124122# disable interceptions for dot_general_qt_fwd.
125123@interception .disable_interceptions
126124def dot_general_qt_fwd (
127- lhs : jax .Array ,
128- rhs : jax .Array ,
125+ lhs : jax .Array | qarray_qt . QArrayWithGradient ,
126+ rhs : jax .Array | qarray_qt . QArrayWithGradient ,
129127 dimension_numbers : jax .lax .DotDimensionNumbers ,
130128 config : DotGeneralQtConfig ,
131129):
132130 """Forward pass for dot_general_qt custom VJP."""
133- ndims = (lhs .ndim , rhs .ndim )
134-
135- def _quantize_operand (operand : jax .Array , is_lhs : bool ) -> qarray .MaybeQArray :
136- """Quantizes a single operand for the forward pass if configured to do so."""
137- if is_lhs :
138- qtype = config .lhs_qtype
139- calibration_method = config .lhs_calibration_method
140- collect_quant_stat = config .lhs_collect_quant_stat
141- else :
142- qtype = config .rhs_qtype
143- calibration_method = config .rhs_calibration_method
144- collect_quant_stat = config .rhs_collect_quant_stat
145-
146- if not (qtype and numerics .should_quantize (operand .dtype )):
147- return operand
148-
149- how = dot_general .get_how_to_quantize (
150- dimension_numbers = dimension_numbers ,
151- ndims = ndims ,
152- for_lhs = is_lhs ,
153- qtype = qtype ,
154- tile_size = config .tile_size ,
155- calibration_method = calibration_method ,
156- )
157- if config .disable_channelwise_axes :
158- how = dataclasses .replace (how , channelwise_axes = [])
159-
160- calibration = qarray .calibrate (operand , how )
161- if collect_quant_stat :
162- calibration = collect_quant_stat (calibration )
163- scale , zero_point = qarray .compute_scale_zero_point (calibration , qtype )
164- return qarray .quantize_with_scale_zero_point (
165- operand , how .qtype , scale , zero_point
166- )
167-
168- qlhs = _quantize_operand (lhs , is_lhs = True )
169- qrhs = _quantize_operand (rhs , is_lhs = False )
170-
171- primal_out = dot_general .dot_general (qlhs , qrhs , dimension_numbers )
172-
173- if config .bwd_use_original_residuals :
174- residuals = (lhs , rhs )
175- else :
176- residuals = (qlhs , qrhs )
177-
178- return primal_out , residuals
131+ del config
132+ return dot_general .dot_general (lhs , rhs , dimension_numbers ), (lhs , rhs )
179133
180134
181135def dot_general_qt_bwd (
182136 fwd_dimension_numbers : jax .lax .DotDimensionNumbers ,
183137 config : DotGeneralQtConfig ,
184- residuals : tuple [qarray .MaybeQArray , qarray .MaybeQArray ],
138+ residuals : tuple [
139+ jax .Array | qarray_qt .QArrayWithGradient ,
140+ jax .Array | qarray_qt .QArrayWithGradient ,
141+ ],
185142 g : jax .Array ,
186143):
187144 """Backward pass for dot_general_qt custom VJP."""
188145 lhs , rhs = residuals
189146
190147 def _compute_gradient_for_operand (
191148 g : jax .Array , y : qarray .MaybeQArray , * , for_dlhs : bool
192- ):
193- """Compute dot_general for gradient and other_fwd_operand ."""
149+ ) -> jax . Array | qarray_qt . QArrayWithGradient :
150+ """Compute dx from g and y ."""
194151 bwd_dnums , transpose_axes = _update_dimension_numbers_for_backward (
195152 fwd_dimension_numbers , (lhs .ndim , rhs .ndim ), for_dlhs = for_dlhs
196153 )
197154 if for_dlhs :
198155 g_qtype = config .dlhs_grad_qtype
199156 g_tile_size = config .dlhs_tile_size
200157 g_calibration_method = config .dlhs_grad_calibration_method
158+ g_noise_fn = config .dlhs_stochastic_rounding_noise_fn
159+ result_type = lhs # the result gradient must match this type.
201160 else :
202161 g_qtype = config .drhs_grad_qtype
203162 g_tile_size = config .drhs_tile_size
204163 g_calibration_method = config .drhs_grad_calibration_method
164+ g_noise_fn = config .drhs_stochastic_rounding_noise_fn
165+ result_type = rhs # the result gradient must match this type.
205166
206167 if g_qtype and numerics .should_quantize (g .dtype ):
207168 if isinstance (y , qarray .QArray ) and not qarray .get_tiled_axes (y ):
@@ -219,23 +180,20 @@ def _compute_gradient_for_operand(
219180 tile_size = g_tile_size ,
220181 calibration_method = g_calibration_method ,
221182 )
183+ g_how = dataclasses .replace (g_how , noise_fn = g_noise_fn )
222184 if config .disable_channelwise_axes :
223185 g_how = dataclasses .replace (g_how , channelwise_axes = [])
224186
225- if for_dlhs and config .dlhs_stochastic_rounding_noise_fn :
226- g_how = dataclasses .replace (
227- g_how ,
228- noise_fn = config .dlhs_stochastic_rounding_noise_fn ,
229- )
230- if not for_dlhs and config .drhs_stochastic_rounding_noise_fn :
231- g_how = dataclasses .replace (
232- g_how ,
233- noise_fn = config .drhs_stochastic_rounding_noise_fn ,
234- )
235187 g = qarray .quantize (g , g_how )
236188
237189 grad_res = dot_general .dot_general (g , y , bwd_dnums )
238- return jax .lax .transpose (grad_res , transpose_axes )
190+ grad_res = jax .lax .transpose (grad_res , transpose_axes )
191+ if isinstance (result_type , qarray_qt .QArrayWithGradient ):
192+ return dataclasses .replace (
193+ result_type , qvalue = None , scale = None , zero_point = None , _grad = grad_res
194+ )
195+ else :
196+ return grad_res
239197
240198 dlhs = _compute_gradient_for_operand (g , rhs , for_dlhs = True )
241199 drhs = _compute_gradient_for_operand (g , lhs , for_dlhs = False )
@@ -244,15 +202,59 @@ def _compute_gradient_for_operand(
244202
245203
246204@functools .partial (jax .custom_vjp , nondiff_argnums = (2 , 3 ))
205+ def dot_general_fwd_bwd (
206+ lhs : jax .Array | qarray_qt .QArrayWithGradient ,
207+ rhs : jax .Array | qarray_qt .QArrayWithGradient ,
208+ dimension_numbers : jax .lax .DotDimensionNumbers ,
209+ config : DotGeneralQtConfig ,
210+ ) -> jax .Array :
211+ """Quantized dot_general with backpropagation support."""
212+ del config
213+ return dot_general .dot_general (lhs , rhs , dimension_numbers )
214+
215+
216+ dot_general_fwd_bwd .defvjp (dot_general_qt_fwd , dot_general_qt_bwd )
217+
218+
247219def dot_general_qt (
248220 lhs : jax .Array ,
249221 rhs : jax .Array ,
250222 dimension_numbers : jax .lax .DotDimensionNumbers ,
251223 config : DotGeneralQtConfig ,
252224) -> jax .Array :
253225 """Quantized dot_general with backpropagation support."""
254- result , _ = dot_general_qt_fwd (lhs , rhs , dimension_numbers , config )
255- return result
226+ if config .lhs_qtype and numerics .should_quantize (lhs .dtype ):
227+ how = dot_general .get_how_to_quantize (
228+ dimension_numbers = dimension_numbers ,
229+ ndims = (lhs .ndim , rhs .ndim ),
230+ for_lhs = True ,
231+ qtype = config .lhs_qtype ,
232+ tile_size = config .tile_size ,
233+ calibration_method = config .lhs_calibration_method ,
234+ )
235+ if config .disable_channelwise_axes :
236+ how = dataclasses .replace (how , channelwise_axes = [])
237+
238+ calibration = qarray .calibrate (lhs , how )
239+ if config .lhs_collect_quant_stat :
240+ calibration = config .lhs_collect_quant_stat (calibration )
241+ lhs = qarray_qt .quantize_with_calibration (lhs , how .qtype , calibration )
242+
243+ if config .rhs_qtype and numerics .should_quantize (rhs .dtype ):
244+ how = dot_general .get_how_to_quantize (
245+ dimension_numbers = dimension_numbers ,
246+ ndims = (lhs .ndim , rhs .ndim ),
247+ for_lhs = False ,
248+ qtype = config .rhs_qtype ,
249+ tile_size = config .tile_size ,
250+ calibration_method = config .rhs_calibration_method ,
251+ )
252+ if config .disable_channelwise_axes :
253+ how = dataclasses .replace (how , channelwise_axes = [])
256254
255+ calibration = qarray .calibrate (rhs , how )
256+ if config .rhs_collect_quant_stat :
257+ calibration = config .rhs_collect_quant_stat (calibration )
258+ rhs = qarray_qt .quantize_with_calibration (rhs , how .qtype , calibration )
257259
258- dot_general_qt . defvjp ( dot_general_qt_fwd , dot_general_qt_bwd )
260+ return dot_general_fwd_bwd ( lhs , rhs , dimension_numbers , config )
0 commit comments