3
3
#
4
4
# This source code is licensed under the BSD 3-Clause license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
+
7
+ from collections import defaultdict
6
8
from collections .abc import Callable
7
9
from functools import partial
8
10
from typing import Any , Optional
11
13
from torch import Tensor
12
14
from torch .optim import Optimizer
13
15
14
- from ..quant import LSBQuantizer , Quantizer
16
+ from ..quant import Quantizer
17
+ from ..utils import HAS_DTENSOR , is_dtensor
15
18
from .proxmap import ProxMap
16
19
17
- try :
18
- from torch .distributed .tensor import DTensor
19
-
20
- HAS_DTENSOR = True
21
- except ImportError :
22
- HAS_DTENSOR = False
20
+ if HAS_DTENSOR :
21
+ from torch .distributed .tensor import distribute_tensor
22
+ from torch .distributed .tensor .experimental import local_map
23
+ from torch .distributed .tensor .placement_types import Shard
23
24
24
25
25
26
class QuantOptimizer (Optimizer ):
@@ -31,7 +32,7 @@ class QuantOptimizer(Optimizer):
31
32
a proximal mapping (e.g, HardQuant/STE, PARQ, BinaryRelax)
32
33
- update model parameters based on the above two updates
33
34
Other parameters:
34
- - warmup_steps: int > 0
35
+ - warmup_steps: int >= 0
35
36
- quant_period: int > 0
36
37
- quant_per_channel: True or False
37
38
- quant_shrink: True or False
@@ -86,23 +87,23 @@ def __repr__(self) -> str:
86
87
extra_repr = "\n " .join (("(" , base_optimizer , f"{ quantizer = } " , f"{ prox_map = } " ))
87
88
return f"{ self .__class__ .__name__ } { extra_repr } \n )"
88
89
90
+ @property
91
+ def state (self ) -> defaultdict [Tensor , Any ]: # pyre-ignore[3]
92
+ return self ._state if hasattr (self , "_state" ) else self .base_optimizer .state
93
+
89
94
@staticmethod
90
95
def quantize_ (
91
96
p : Tensor ,
92
97
quants : Tensor ,
93
98
quantizer : Quantizer ,
94
99
b : int ,
95
- quant_update : bool ,
96
100
dim : Optional [int ] = None ,
97
101
) -> Optional [Tensor ]:
98
102
"""Optionally update the quantization targets `quants` in place.
99
103
Return the quantized `p` as a by-product if `quant_update=True`.
100
104
"""
101
- if quant_update : # update Q for each channel
102
- q , Q = quantizer .quantize (p , b , dim = dim ) # pyre-ignore[28]
103
- quants .copy_ (Q )
104
- else :
105
- q = None
105
+ q , Q = quantizer .quantize (p , b , dim = dim ) # pyre-ignore[28]
106
+ quants .copy_ (Q )
106
107
return q
107
108
108
109
def regularized_param_groups (self ): # pyre-ignore[3]
@@ -122,12 +123,13 @@ def state_dict(self) -> dict[str, Any]:
122
123
def load_state_dict (
123
124
self , state_dict : dict [str , Any ], start_step : Optional [int ] = None
124
125
) -> None :
125
- qat_state = state_dict .pop ("qat_state" )
126
+ qat_state = state_dict .get ("qat_state" )
126
127
# resume from check points usually not corresponds to saved num_steps
127
128
# so allow explicit start_step computed from epochs * steps_per_epoc
128
129
if start_step is not None :
129
130
self .num_steps = start_step
130
- else : # hope discrepancy in num_steps does not cause major problem!
131
+ elif qat_state is not None :
132
+ # hope discrepancy in num_steps does not cause major problem!
131
133
self .num_steps = qat_state ["num_steps" ]
132
134
self .base_optimizer .load_state_dict (state_dict )
133
135
@@ -144,16 +146,23 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]
144
146
self .num_steps += 1
145
147
return loss
146
148
147
- # call base optimizer step() method to update latent parameters
148
- loss = self .base_optimizer .step (closure = closure ) # pyre-ignore[6]
149
-
150
149
if self .num_steps == self .warmup_steps :
151
150
# first step of qat, save latent params, instead of restore
152
151
self .save_latent_params ()
153
152
else :
154
153
# qat: restore latent params for update by the base optimizer
155
154
self .restore_latent_params ()
156
155
156
+ # call base optimizer step() method to update latent parameters
157
+ loss = self .base_optimizer .step (closure = closure ) # pyre-ignore[6]
158
+
159
+ if hasattr (self , "_state" ):
160
+ assert self .warmup_steps == 0
161
+ # restore the temporary state to the base optimizer's state
162
+ for p in self ._state .keys ():
163
+ self .base_optimizer .state [p ]["latent" ] = self ._state [p ]["latent" ]
164
+ del self ._state
165
+
157
166
# check if it is time to update set of quantization values Q
158
167
if (self .num_steps - self .warmup_steps ) % self .quant_period == 0 :
159
168
quant_update = True
@@ -165,6 +174,7 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]
165
174
group ["cumu_lr" ] += group ["lr" ]
166
175
gamma = max (1.0 , group ["cumu_lr" ])
167
176
b = group ["quant_bits" ]
177
+ block_size = group .get ("quant_block_size" )
168
178
inv_slope = 0.0
169
179
for p in group ["params" ]:
170
180
if not p .requires_grad :
@@ -177,44 +187,66 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]
177
187
if self .quant_shrink :
178
188
p .div_ (gamma )
179
189
190
+ # reshape p according to block size if specified
191
+ if block_size is not None :
192
+ assert (
193
+ p .size (- 1 ) % block_size == 0
194
+ ), f"{ p .size (- 1 )= } is not divisible by { block_size = } "
195
+ assert p .dim () <= 2 , f"Invalid { p .dim ()= } for { block_size = } "
196
+ if p .dim () == 1 :
197
+ p = p .unsqueeze (0 )
198
+
199
+ # row-major ordering ensures this is correct
200
+ p = p .view (- 1 , block_size )
201
+
180
202
# quantization by channel or by layer
181
203
# update quantization targets periodically
182
204
per_channel = self .quant_per_channel and p .dim () > 1
183
205
if quant_update :
184
- quants_size = 3 if b == 0 else 2 ** b
185
- if per_channel :
186
- quants_size = (p .size (0 ), quants_size )
187
- state ["quants" ] = torch .empty (
188
- quants_size , device = p .device
189
- ) # pyre-ignore[6]
206
+ quant_size = self .quantizer .get_quant_size (b )
190
207
191
- # avoid type mismatch between sharded and full tensors
192
- if HAS_DTENSOR and isinstance (p , DTensor ):
193
- p = p .full_tensor ()
208
+ if per_channel :
209
+ quant_size = (p .size (0 ), quant_size )
210
+ state ["quants" ] = torch .empty (quant_size , device = p .device )
211
+ if is_dtensor (p ):
212
+ state ["quants" ] = distribute_tensor (
213
+ state ["quants" ],
214
+ device_mesh = p .device_mesh ,
215
+ placements = p .placements ,
216
+ )
194
217
195
218
dim = - 1 if per_channel else None
196
219
if per_channel and p .dim () > 2 :
197
220
p = p .flatten (start_dim = 1 )
198
221
199
- # NOTE: for LSBQ and optimal=False, use faster per-channel
200
- # implementation instead of vmap
201
- if isinstance (self .quantizer , LSBQuantizer ) and self .quantizer .optimal :
222
+ q = None
223
+ if quant_update :
202
224
qfunc = partial (
203
- self .quantize_ ,
204
- quantizer = self .quantizer ,
205
- b = b ,
206
- quant_update = quant_update ,
207
- )
208
- q = torch .vmap (qfunc , in_dims = 0 , out_dims = 0 )(p , state ["quants" ])
209
- else :
210
- q = self .quantize_ (
211
- p , state ["quants" ], self .quantizer , b , quant_update , dim = dim
225
+ self .quantize_ , quantizer = self .quantizer , b = b , dim = dim
212
226
)
227
+ if is_dtensor (p ):
228
+ qfunc = local_map (
229
+ qfunc ,
230
+ out_placements = [* p .placements ],
231
+ in_placements = ([Shard (0 )], [Shard (0 )]),
232
+ )
233
+ q = qfunc (p , state ["quants" ])
213
234
214
235
# apply (step-dependent) proximal mapping in place
215
- inv_slope = self . prox_map . apply_ ( # pyre-ignore[28]
216
- p , q , state [ "quants" ], self .num_steps , dim = dim
236
+ pfunc = partial (
237
+ self . prox_map . apply_ , step_count = self .num_steps , dim = dim
217
238
)
239
+ if is_dtensor (p ):
240
+ pfunc = local_map (
241
+ pfunc ,
242
+ out_placements = None ,
243
+ in_placements = (
244
+ [Shard (0 )],
245
+ None if q is None else [Shard (0 )],
246
+ [Shard (0 )],
247
+ ),
248
+ )
249
+ inv_slope = pfunc (p , q , state ["quants" ])
218
250
219
251
# quantized parameters share the same PARQ inverse slope
220
252
if inv_slope :
@@ -239,6 +271,12 @@ def restore_latent_params(self) -> None:
239
271
@torch ._disable_dynamo
240
272
def save_latent_params (self ) -> None :
241
273
"""Save updated latent parameters before applying prox-map"""
274
+ if self .warmup_steps == 0 :
275
+ assert len (self .state ) == 0 , "Expected empty state at first step()"
276
+ # Maintain the invariant that `len(self.state) == 0` before first
277
+ # self.base_optimizer.step() call by using a temporary state buffer
278
+ self ._state = defaultdict (dict )
279
+
242
280
for group in self .regularized_param_groups ():
243
281
for p in group ["params" ]:
244
282
if p .requires_grad :
0 commit comments