Skip to content

Commit 5e3a679

Browse files
committed
[SW-191415] update fp8 maxAbs observer using torch.copy_
Change-Id: I3923c832f9a8a2b14e392f3f4719d233a457702f
1 parent 7f62871 commit 5e3a679

File tree

1 file changed

+4
-18
lines changed
  • neural_compressor/torch/algorithms/fp8_quant/_core

1 file changed

+4
-18
lines changed

neural_compressor/torch/algorithms/fp8_quant/_core/measure.py

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ def register_patched_measure_modules(model, mod_list, observer_class, d_shapes=N
9595
)
9696
patched_types.add(type(mod))
9797

98+
set_hqt_config(mod, top_level_config)
9899
mod_extra_config = init_measure_object(
99100
mod,
100101
name,
@@ -104,7 +105,6 @@ def register_patched_measure_modules(model, mod_list, observer_class, d_shapes=N
104105
(d_shapes[name] if ((d_shapes is not None) and (name in d_shapes)) else None),
105106
params,
106107
)
107-
set_hqt_config(mod, top_level_config) # set config in the module, as it consumed by the patched module
108108
pmod = patch_module_measure(mod, mod_extra_config, mod_default_dict)
109109
for param_name in pmod._mod_extra_config.params:
110110
param = getattr(pmod, param_name)
@@ -247,27 +247,13 @@ def __init__(self, name, mod, d_shape=None, params=None):
247247
self.mod = mod
248248
self.first = True
249249
self.used = False
250-
self.state = self.init_state_from_shape(d_shape)
251-
252-
def init_state(self, x):
253-
device = x.device
254-
state = torch.zeros((1, 1), device=device, dtype=torch.float32)
255-
self.shape = list(x.shape)
256-
return state
257-
258-
def init_state_from_shape(self, x_shape, device="hpu"):
259-
state = torch.zeros((1, 1), device=device, dtype=torch.float32)
260-
self.first = False
261-
return state
250+
config = get_hqt_config(mod).cfg
251+
self.state = torch.zeros((1, 1), device="hpu", dtype=config["hp_dtype"])
262252

263253
def update_state(self, x):
264-
# TODO: [SW-189690] Find better way to update self.state in MaxAbsObserver class in HQT
265-
self.state = torch.maximum(torch.max(torch.abs(x)), self.state)
254+
self.state.copy_(torch.maximum(torch.max(torch.abs(x)), self.state))
266255

267256
def measure(self, x):
268-
if self.first:
269-
self.state = self.init_state(x)
270-
self.first = False
271257
self.update_state(x)
272258
self.used = True
273259

0 commit comments

Comments
 (0)