Skip to content

Commit 6ed6d54

Browse files
liudangyicopybara-github
authored andcommitted
Allow gradients on QArray
This patch introduces a new approach to associate gradients with QArray. The outcome is we could define vjp rules for `quantize` and `dot_general` separately while ensuring a correct backward pass. PiperOrigin-RevId: 823246141
1 parent 212b8a8 commit 6ed6d54

File tree

6 files changed

+206
-93
lines changed

6 files changed

+206
-93
lines changed

qwix/_src/core/conv_general_qt.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ class ConvGeneralQtConfig:
4747

4848
# Misc.
4949
disable_channelwise_axes: bool = False
50-
bwd_use_original_residuals: bool = False
5150

5251

5352
# Swaps the first two dimension indices of a specification.
@@ -187,11 +186,8 @@ def _quantize_operand(
187186
operand, qtype, scale, zero_point
188187
)
189188

190-
residuals = (lhs, rhs)
191189
lhs = _quantize_operand(lhs, for_lhs=True)
192190
rhs = _quantize_operand(rhs, for_lhs=False)
193-
if not config.bwd_use_original_residuals:
194-
residuals = (lhs, rhs)
195191

196192
primal_out = conv_general.conv_general_dilated(
197193
lhs,
@@ -204,6 +200,7 @@ def _quantize_operand(
204200
feature_group_count,
205201
batch_group_count,
206202
)
203+
residuals = (lhs, rhs)
207204

208205
return primal_out, residuals
209206

qwix/_src/core/dot_general_qt.py

Lines changed: 72 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from qwix._src.core import dot_general
2525
from qwix._src.core import numerics
2626
from qwix._src.core import qarray
27+
from qwix._src.core import qarray_qt
2728

2829

2930
@dataclasses.dataclass(slots=True, frozen=True, kw_only=True)
@@ -43,19 +44,16 @@ class DotGeneralQtConfig:
4344
dlhs_grad_qtype: jax.typing.DTypeLike | None = None # incoming gradient
4445
dlhs_grad_calibration_method: str = 'absmax'
4546
dlhs_tile_size: int | float | None = None
47+
dlhs_stochastic_rounding_noise_fn: numerics.NoiseFn | None = None
4648

4749
# Backward pass (drhs).
4850
drhs_grad_qtype: jax.typing.DTypeLike | None = None # incoming gradient
4951
drhs_grad_calibration_method: str = 'absmax'
5052
drhs_tile_size: int | float | None = None
53+
drhs_stochastic_rounding_noise_fn: numerics.NoiseFn | None = None
5154

5255
# Misc.
5356
disable_channelwise_axes: bool = False
54-
bwd_use_original_residuals: bool = False # what to use as residuals
55-
56-
# Configs for stochastic rounding.
57-
dlhs_stochastic_rounding_noise_fn: numerics.NoiseFn | None = None
58-
drhs_stochastic_rounding_noise_fn: numerics.NoiseFn | None = None
5957

6058

6159
def _ranges_like(*xs):
@@ -124,84 +122,47 @@ def _apply_rhs_scale_to_lhs(lhs, rhs_scale, dnums):
124122
# disable interceptions for dot_general_qt_fwd.
125123
@interception.disable_interceptions
126124
def dot_general_qt_fwd(
127-
lhs: jax.Array,
128-
rhs: jax.Array,
125+
lhs: jax.Array | qarray_qt.QArrayWithGradient,
126+
rhs: jax.Array | qarray_qt.QArrayWithGradient,
129127
dimension_numbers: jax.lax.DotDimensionNumbers,
130128
config: DotGeneralQtConfig,
131129
):
132130
"""Forward pass for dot_general_qt custom VJP."""
133-
ndims = (lhs.ndim, rhs.ndim)
134-
135-
def _quantize_operand(operand: jax.Array, is_lhs: bool) -> qarray.MaybeQArray:
136-
"""Quantizes a single operand for the forward pass if configured to do so."""
137-
if is_lhs:
138-
qtype = config.lhs_qtype
139-
calibration_method = config.lhs_calibration_method
140-
collect_quant_stat = config.lhs_collect_quant_stat
141-
else:
142-
qtype = config.rhs_qtype
143-
calibration_method = config.rhs_calibration_method
144-
collect_quant_stat = config.rhs_collect_quant_stat
145-
146-
if not (qtype and numerics.should_quantize(operand.dtype)):
147-
return operand
148-
149-
how = dot_general.get_how_to_quantize(
150-
dimension_numbers=dimension_numbers,
151-
ndims=ndims,
152-
for_lhs=is_lhs,
153-
qtype=qtype,
154-
tile_size=config.tile_size,
155-
calibration_method=calibration_method,
156-
)
157-
if config.disable_channelwise_axes:
158-
how = dataclasses.replace(how, channelwise_axes=[])
159-
160-
calibration = qarray.calibrate(operand, how)
161-
if collect_quant_stat:
162-
calibration = collect_quant_stat(calibration)
163-
scale, zero_point = qarray.compute_scale_zero_point(calibration, qtype)
164-
return qarray.quantize_with_scale_zero_point(
165-
operand, how.qtype, scale, zero_point
166-
)
167-
168-
qlhs = _quantize_operand(lhs, is_lhs=True)
169-
qrhs = _quantize_operand(rhs, is_lhs=False)
170-
171-
primal_out = dot_general.dot_general(qlhs, qrhs, dimension_numbers)
172-
173-
if config.bwd_use_original_residuals:
174-
residuals = (lhs, rhs)
175-
else:
176-
residuals = (qlhs, qrhs)
177-
178-
return primal_out, residuals
131+
del config
132+
return dot_general.dot_general(lhs, rhs, dimension_numbers), (lhs, rhs)
179133

180134

181135
def dot_general_qt_bwd(
182136
fwd_dimension_numbers: jax.lax.DotDimensionNumbers,
183137
config: DotGeneralQtConfig,
184-
residuals: tuple[qarray.MaybeQArray, qarray.MaybeQArray],
138+
residuals: tuple[
139+
jax.Array | qarray_qt.QArrayWithGradient,
140+
jax.Array | qarray_qt.QArrayWithGradient,
141+
],
185142
g: jax.Array,
186143
):
187144
"""Backward pass for dot_general_qt custom VJP."""
188145
lhs, rhs = residuals
189146

190147
def _compute_gradient_for_operand(
191148
g: jax.Array, y: qarray.MaybeQArray, *, for_dlhs: bool
192-
):
193-
"""Compute dot_general for gradient and other_fwd_operand."""
149+
) -> jax.Array | qarray_qt.QArrayWithGradient:
150+
"""Compute dx from g and y."""
194151
bwd_dnums, transpose_axes = _update_dimension_numbers_for_backward(
195152
fwd_dimension_numbers, (lhs.ndim, rhs.ndim), for_dlhs=for_dlhs
196153
)
197154
if for_dlhs:
198155
g_qtype = config.dlhs_grad_qtype
199156
g_tile_size = config.dlhs_tile_size
200157
g_calibration_method = config.dlhs_grad_calibration_method
158+
g_noise_fn = config.dlhs_stochastic_rounding_noise_fn
159+
result_type = lhs # the result gradient must match this type.
201160
else:
202161
g_qtype = config.drhs_grad_qtype
203162
g_tile_size = config.drhs_tile_size
204163
g_calibration_method = config.drhs_grad_calibration_method
164+
g_noise_fn = config.drhs_stochastic_rounding_noise_fn
165+
result_type = rhs # the result gradient must match this type.
205166

206167
if g_qtype and numerics.should_quantize(g.dtype):
207168
if isinstance(y, qarray.QArray) and not qarray.get_tiled_axes(y):
@@ -219,23 +180,20 @@ def _compute_gradient_for_operand(
219180
tile_size=g_tile_size,
220181
calibration_method=g_calibration_method,
221182
)
183+
g_how = dataclasses.replace(g_how, noise_fn=g_noise_fn)
222184
if config.disable_channelwise_axes:
223185
g_how = dataclasses.replace(g_how, channelwise_axes=[])
224186

225-
if for_dlhs and config.dlhs_stochastic_rounding_noise_fn:
226-
g_how = dataclasses.replace(
227-
g_how,
228-
noise_fn=config.dlhs_stochastic_rounding_noise_fn,
229-
)
230-
if not for_dlhs and config.drhs_stochastic_rounding_noise_fn:
231-
g_how = dataclasses.replace(
232-
g_how,
233-
noise_fn=config.drhs_stochastic_rounding_noise_fn,
234-
)
235187
g = qarray.quantize(g, g_how)
236188

237189
grad_res = dot_general.dot_general(g, y, bwd_dnums)
238-
return jax.lax.transpose(grad_res, transpose_axes)
190+
grad_res = jax.lax.transpose(grad_res, transpose_axes)
191+
if isinstance(result_type, qarray_qt.QArrayWithGradient):
192+
return dataclasses.replace(
193+
result_type, qvalue=None, scale=None, zero_point=None, _grad=grad_res
194+
)
195+
else:
196+
return grad_res
239197

240198
dlhs = _compute_gradient_for_operand(g, rhs, for_dlhs=True)
241199
drhs = _compute_gradient_for_operand(g, lhs, for_dlhs=False)
@@ -244,15 +202,59 @@ def _compute_gradient_for_operand(
244202

245203

246204
@functools.partial(jax.custom_vjp, nondiff_argnums=(2, 3))
205+
def dot_general_fwd_bwd(
206+
lhs: jax.Array | qarray_qt.QArrayWithGradient,
207+
rhs: jax.Array | qarray_qt.QArrayWithGradient,
208+
dimension_numbers: jax.lax.DotDimensionNumbers,
209+
config: DotGeneralQtConfig,
210+
) -> jax.Array:
211+
"""Quantized dot_general with backpropagation support."""
212+
del config
213+
return dot_general.dot_general(lhs, rhs, dimension_numbers)
214+
215+
216+
dot_general_fwd_bwd.defvjp(dot_general_qt_fwd, dot_general_qt_bwd)
217+
218+
247219
def dot_general_qt(
248220
lhs: jax.Array,
249221
rhs: jax.Array,
250222
dimension_numbers: jax.lax.DotDimensionNumbers,
251223
config: DotGeneralQtConfig,
252224
) -> jax.Array:
253225
"""Quantized dot_general with backpropagation support."""
254-
result, _ = dot_general_qt_fwd(lhs, rhs, dimension_numbers, config)
255-
return result
226+
if config.lhs_qtype and numerics.should_quantize(lhs.dtype):
227+
how = dot_general.get_how_to_quantize(
228+
dimension_numbers=dimension_numbers,
229+
ndims=(lhs.ndim, rhs.ndim),
230+
for_lhs=True,
231+
qtype=config.lhs_qtype,
232+
tile_size=config.tile_size,
233+
calibration_method=config.lhs_calibration_method,
234+
)
235+
if config.disable_channelwise_axes:
236+
how = dataclasses.replace(how, channelwise_axes=[])
237+
238+
calibration = qarray.calibrate(lhs, how)
239+
if config.lhs_collect_quant_stat:
240+
calibration = config.lhs_collect_quant_stat(calibration)
241+
lhs = qarray_qt.quantize_with_calibration(lhs, how.qtype, calibration)
242+
243+
if config.rhs_qtype and numerics.should_quantize(rhs.dtype):
244+
how = dot_general.get_how_to_quantize(
245+
dimension_numbers=dimension_numbers,
246+
ndims=(lhs.ndim, rhs.ndim),
247+
for_lhs=False,
248+
qtype=config.rhs_qtype,
249+
tile_size=config.tile_size,
250+
calibration_method=config.rhs_calibration_method,
251+
)
252+
if config.disable_channelwise_axes:
253+
how = dataclasses.replace(how, channelwise_axes=[])
256254

255+
calibration = qarray.calibrate(rhs, how)
256+
if config.rhs_collect_quant_stat:
257+
calibration = config.rhs_collect_quant_stat(calibration)
258+
rhs = qarray_qt.quantize_with_calibration(rhs, how.qtype, calibration)
257259

258-
dot_general_qt.defvjp(dot_general_qt_fwd, dot_general_qt_bwd)
260+
return dot_general_fwd_bwd(lhs, rhs, dimension_numbers, config)

qwix/_src/core/qarray_qt.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""QArray with gradient for custom VJP."""
15+
16+
import dataclasses
17+
from typing import Mapping
18+
import flax.struct
19+
import jax
20+
from qwix._src.core import qarray
21+
22+
23+
@flax.struct.dataclass
24+
class QArrayWithGradient(qarray.QArray):
25+
"""QArray with gradient.
26+
27+
This dataclass allows us to associate a gradient with the QArray. It's
28+
achieved by defining an extra attribute `_grad` on the QArray, which has the
29+
same dtype and the same shape as the unquantized array. In forward pass, the
30+
`_grad` does nothing and should never be consumed. In backward pass, the
31+
`_grad` carries the gradient of the whole QArray.
32+
33+
This approach overcomes the Jax limitation on the gradients, i.e., the
34+
gradient of a qvalue of int8[128,128] has to be float0[128,128], while the
35+
gradient of a scale of float32[1,1] has to be float32[1,1]. An alternative
36+
is to define the QArray as a new Hijax type, which is more complex.
37+
"""
38+
39+
_grad: jax.Array = flax.struct.field(kw_only=True)
40+
41+
42+
def quantize_with_calibration(
43+
array: jax.Array,
44+
qtype: jax.typing.DTypeLike,
45+
calibration: Mapping[str, jax.Array],
46+
clip_gradient: bool = False,
47+
) -> QArrayWithGradient:
48+
"""Quantizes an array with calibration with backpropagation support.
49+
50+
Args:
51+
array: The array to quantize.
52+
qtype: The quantized type.
53+
calibration: The calibration of the array.
54+
clip_gradient: Whether to clip the straight-through estimator to the
55+
calibration range, i.e., the gradient outside the calibration range is 0.
56+
57+
Returns:
58+
The quantized array with backpropagation support.
59+
"""
60+
scale, zero_point = qarray.compute_scale_zero_point(calibration, qtype)
61+
res = qarray.quantize_with_scale_zero_point(array, qtype, scale, zero_point)
62+
if clip_gradient:
63+
array = qarray.clip_to_calibration(
64+
array, calibration, qarray.get_tiled_axes(res)
65+
)
66+
# Do not allow gradients on the quantized array to flow back to the input.
67+
res = jax.lax.stop_gradient(res)
68+
return QArrayWithGradient(**dataclasses.asdict(res), _grad=array)
69+
70+
71+
@jax.custom_jvp
72+
def dequantize(array: QArrayWithGradient) -> jax.Array:
73+
"""Dequantizes an array."""
74+
return qarray.dequantize(array)
75+
76+
77+
@dequantize.defjvp
78+
def _dequantize_jvp(primals, tangents):
79+
return dequantize(*primals), tangents[0]._grad # pylint: disable=protected-access

qwix/_src/providers/qt.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class QtRule(qconfig.QuantizationRule):
3434

3535
# In backward pass, quantize the gradients to the given type. This doesn't
3636
# affect the residuals as the residuals will reuse the quantization in the
37-
# forward pass, unless bwd_use_original_residuals is set.
37+
# forward pass.
3838
bwd_qtype: jax.typing.DTypeLike | None = None
3939

4040
# In backward pass, calibrate the gradients using the given method.
@@ -48,11 +48,6 @@ class QtRule(qconfig.QuantizationRule):
4848
# If True, disable channelwise axes for both forward and backward passes.
4949
disable_channelwise_axes: bool = False
5050

51-
# If True, use the original values instead of the quantized values as the
52-
# residuals for backward pass. Enabling this prevents using low-precision
53-
# matmuls during bwd pass and has a negative impact on performance.
54-
bwd_use_original_residuals: bool = False
55-
5651
# Use stochastic rounding for the gradients. (Only 'uniform' is supported.)
5752
bwd_stochastic_rounding: str | None = None
5853

@@ -293,7 +288,6 @@ def _create_conv_general_qt_config(
293288
drhs_grad_calibration_method=rule.bwd_calibration_method,
294289
# misc.
295290
disable_channelwise_axes=rule.disable_channelwise_axes,
296-
bwd_use_original_residuals=rule.bwd_use_original_residuals,
297291
)
298292

299293
def _create_dot_general_qt_config(
@@ -392,7 +386,6 @@ def _create_dot_general_qt_config(
392386
drhs_grad_calibration_method=rule.bwd_calibration_method,
393387
# misc.
394388
disable_channelwise_axes=rule.disable_channelwise_axes,
395-
bwd_use_original_residuals=rule.bwd_use_original_residuals,
396389
dlhs_stochastic_rounding_noise_fn=dlhs_stochastic_rounding_noise_fn,
397390
drhs_stochastic_rounding_noise_fn=drhs_stochastic_rounding_noise_fn,
398391
)

0 commit comments

Comments
 (0)