@@ -60,7 +60,8 @@ def _init(self):
60
60
self .progressive_steps = self .progressive_configs ["progressive_steps" ]
61
61
self .progressive_type = self .progressive_configs ["progressive_type" ]
62
62
self .use_global = self .progressive_configs ["use_global" ]
63
- self .progressive_logger = False
63
+ self .progressive_logger = True
64
+ self .align_masks_flag = False
64
65
self ._init_for_progressive ()
65
66
66
67
def _init_for_progressive (self ):
@@ -77,6 +78,11 @@ def _init_for_progressive(self):
77
78
self .use_progressive = False
78
79
return
79
80
81
+ if self .pruning_frequency == 1 :
82
+ logger .info ("Current progressive setting will degrading to non-progressive pruning." )
83
+ self .use_progressive = False
84
+ return
85
+
80
86
# step 3: log hyper-parameters. and check validity.
81
87
if self .use_progressive :
82
88
logger .info ("Progressive pruning is enabled!" )
@@ -225,6 +231,10 @@ def on_step_begin(self, local_step):
225
231
226
232
Implement at the start of each step.
227
233
"""
234
+ if self .global_step > self .end_step and self .align_masks_flag is False :
235
+ self .align_masks_after_pruning ()
236
+ self .align_masks_flag = True
237
+
228
238
if self .handled_global_step == self .global_step :
229
239
return
230
240
@@ -270,3 +280,24 @@ def print_progressive_sparsity(self):
270
280
"""Output the progressive sparsity."""
271
281
cur_sp = self .pattern .get_sparsity_ratio_progressive (self .progressive_masks )
272
282
logger .info ("Step: {} -> Current progressive sparsity: {}" .format (self .global_step , cur_sp ))
283
+
284
+ def obtain_weight_sparsity (self , modules ):
285
+ total_numels = 0
286
+ sparse_numels = 0
287
+ for key in modules .keys ():
288
+ total_numels += modules [key ].weight .data .numel ()
289
+ sparse_numels += torch .sum (torch .where (modules [key ].weight .data == 0 , 1 , 0 )).item ()
290
+ return sparse_numels / total_numels
291
+
292
+ def align_masks_after_pruning (self ):
293
+ if not self .use_progressive :
294
+ return
295
+ """Implement at the end of training phase."""
296
+ # If training ends while a progressive masks is applying, we have to use self.masks to align
297
+ # step 1 calculate sparsity under progressive masks
298
+ sparsity1 = self .obtain_weight_sparsity (self .modules )
299
+ # step 2 use block-wise masks to remask weights
300
+ self .mask_weights_general (self .masks )
301
+ # step 3 calculate sparsity under progressive masks
302
+ sparsity2 = self .obtain_weight_sparsity (self .modules )
303
+ logger .info (f"Replace progressive mask with complete masks: Sparsity Update: { sparsity1 } => { sparsity2 } " )
0 commit comments