@@ -95,6 +95,7 @@ def register_patched_measure_modules(model, mod_list, observer_class, d_shapes=N
95
95
)
96
96
patched_types .add (type (mod ))
97
97
98
+ set_hqt_config (mod , top_level_config )
98
99
mod_extra_config = init_measure_object (
99
100
mod ,
100
101
name ,
@@ -104,7 +105,6 @@ def register_patched_measure_modules(model, mod_list, observer_class, d_shapes=N
104
105
(d_shapes [name ] if ((d_shapes is not None ) and (name in d_shapes )) else None ),
105
106
params ,
106
107
)
107
- set_hqt_config (mod , top_level_config ) # set config in the module, as it consumed by the patched module
108
108
pmod = patch_module_measure (mod , mod_extra_config , mod_default_dict )
109
109
for param_name in pmod ._mod_extra_config .params :
110
110
param = getattr (pmod , param_name )
@@ -247,27 +247,13 @@ def __init__(self, name, mod, d_shape=None, params=None):
247
247
self .mod = mod
248
248
self .first = True
249
249
self .used = False
250
- self .state = self .init_state_from_shape (d_shape )
251
-
252
- def init_state (self , x ):
253
- device = x .device
254
- state = torch .zeros ((1 , 1 ), device = device , dtype = torch .float32 )
255
- self .shape = list (x .shape )
256
- return state
257
-
258
- def init_state_from_shape (self , x_shape , device = "hpu" ):
259
- state = torch .zeros ((1 , 1 ), device = device , dtype = torch .float32 )
260
- self .first = False
261
- return state
250
+ config = get_hqt_config (mod ).cfg
251
+ self .state = torch .zeros ((1 , 1 ), device = "hpu" , dtype = config ["hp_dtype" ])
262
252
263
253
def update_state (self , x ):
264
- # TODO: [SW-189690] Find better way to update self.state in MaxAbsObserver class in HQT
265
- self .state = torch .maximum (torch .max (torch .abs (x )), self .state )
254
+ self .state .copy_ (torch .maximum (torch .max (torch .abs (x )), self .state ))
266
255
267
256
def measure (self , x ):
268
- if self .first :
269
- self .state = self .init_state (x )
270
- self .first = False
271
257
self .update_state (x )
272
258
self .used = True
273
259
0 commit comments