Skip to content

Commit fcdc29a

Browse files
Align sparsity with block-wise masks in progressive pruning. (#1250)
Signed-off-by: YIYANGCAI <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent ec06411 commit fcdc29a

File tree

1 file changed

+32
-1
lines changed

1 file changed

+32
-1
lines changed

neural_compressor/compression/pruner/pruners/progressive.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,8 @@ def _init(self):
6060
self.progressive_steps = self.progressive_configs["progressive_steps"]
6161
self.progressive_type = self.progressive_configs["progressive_type"]
6262
self.use_global = self.progressive_configs["use_global"]
63-
self.progressive_logger = False
63+
self.progressive_logger = True
64+
self.align_masks_flag = False
6465
self._init_for_progressive()
6566

6667
def _init_for_progressive(self):
@@ -77,6 +78,11 @@ def _init_for_progressive(self):
7778
self.use_progressive = False
7879
return
7980

81+
if self.pruning_frequency == 1:
82+
logger.info("Current progressive setting will degrading to non-progressive pruning.")
83+
self.use_progressive = False
84+
return
85+
8086
# step 3: log hyper-parameters. and check validity.
8187
if self.use_progressive:
8288
logger.info("Progressive pruning is enabled!")
@@ -225,6 +231,10 @@ def on_step_begin(self, local_step):
225231
226232
Implement at the start of each step.
227233
"""
234+
if self.global_step > self.end_step and self.align_masks_flag is False:
235+
self.align_masks_after_pruning()
236+
self.align_masks_flag = True
237+
228238
if self.handled_global_step == self.global_step:
229239
return
230240

@@ -270,3 +280,24 @@ def print_progressive_sparsity(self):
270280
"""Output the progressive sparsity."""
271281
cur_sp = self.pattern.get_sparsity_ratio_progressive(self.progressive_masks)
272282
logger.info("Step: {} -> Current progressive sparsity: {}".format(self.global_step, cur_sp))
283+
284+
def obtain_weight_sparsity(self, modules):
285+
total_numels = 0
286+
sparse_numels = 0
287+
for key in modules.keys():
288+
total_numels += modules[key].weight.data.numel()
289+
sparse_numels += torch.sum(torch.where(modules[key].weight.data == 0, 1, 0)).item()
290+
return sparse_numels / total_numels
291+
292+
def align_masks_after_pruning(self):
293+
if not self.use_progressive:
294+
return
295+
"""Implement at the end of training phase."""
296+
# If training ends while a progressive masks is applying, we have to use self.masks to align
297+
# step 1 calculate sparsity under progressive masks
298+
sparsity1 = self.obtain_weight_sparsity(self.modules)
299+
# step 2 use block-wise masks to remask weights
300+
self.mask_weights_general(self.masks)
301+
# step 3 calculate sparsity under progressive masks
302+
sparsity2 = self.obtain_weight_sparsity(self.modules)
303+
logger.info(f"Replace progressive mask with complete masks: Sparsity Update: {sparsity1} => {sparsity2}")

0 commit comments

Comments
 (0)