1919import logging
2020from copy import deepcopy
2121from functools import partial
22- from typing import Any , Callable , Dict , List , Optional , Tuple , Union
22+ from typing import Any , Callable , Dict , List , Optional , Sequence , Tuple , Union
2323
2424import torch
2525import torch .nn .utils .prune as pytorch_prune
2626from torch import nn
27+ from typing_extensions import TypedDict
2728
29+ import pytorch_lightning as pl
2830from pytorch_lightning .callbacks .base import Callback
2931from pytorch_lightning .core .lightning import LightningModule
3032from pytorch_lightning .utilities .distributed import rank_zero_debug , rank_zero_only
4749}
4850
4951_PARAM_TUPLE = Tuple [nn .Module , str ]
50- _PARAM_LIST = Union [ List [ _PARAM_TUPLE ], Tuple [ _PARAM_TUPLE ] ]
52+ _PARAM_LIST = Sequence [ _PARAM_TUPLE ]
5153_MODULE_CONTAINERS = (LightningModule , nn .Sequential , nn .ModuleList , nn .ModuleDict )
54+ _LayerRef = TypedDict ('_LayerRef' , {'data' : nn .Module , 'names' : List [Tuple [int , str ]]})
5255
5356
5457class ModelPruning (Callback ):
@@ -57,7 +60,7 @@ class ModelPruning(Callback):
5760 def __init__ (
5861 self ,
5962 pruning_fn : Union [Callable , str ],
60- parameters_to_prune : Optional [ _PARAM_LIST ] = None ,
63+ parameters_to_prune : _PARAM_LIST = () ,
6164 parameter_names : Optional [List [str ]] = None ,
6265 use_global_unstructured : bool = True ,
6366 amount : Union [int , float , Callable [[int ], Union [int , float ]]] = 0.5 ,
@@ -153,9 +156,9 @@ def __init__(
153156 self ._use_lottery_ticket_hypothesis = use_lottery_ticket_hypothesis
154157 self ._resample_parameters = resample_parameters
155158 self ._parameter_names = parameter_names or self .PARAMETER_NAMES
156- self ._global_kwargs = {}
157- self ._original_layers = None
158- self ._pruning_fn_name = None
159+ self ._global_kwargs : Dict [ str , Any ] = {}
160+ self ._original_layers : Optional [ Dict [ int , _LayerRef ]] = None
161+ self ._pruning_fn_name : Optional [ str ] = None
159162
160163 for name in self ._parameter_names :
161164 if name not in self .PARAMETER_NAMES :
@@ -196,17 +199,18 @@ def __init__(
196199 " HINT: if passing a `BasePruningMethod`, pass the the class, not an instance"
197200 )
198201
199- if use_global_unstructured and pruning_fn .PRUNING_TYPE != "unstructured" :
202+ # need to ignore typing here since pytorch base class does not define the PRUNING_TYPE attribute
203+ if use_global_unstructured and pruning_fn .PRUNING_TYPE != "unstructured" : # type: ignore
200204 raise MisconfigurationException (
201- 'Only the "unstructured" PRUNING_TYPE is supported with `use_global_unstructured=True`.'
205+ 'Only the "unstructured" PRUNING_TYPE is supported with `use_global_unstructured=True`.' # type: ignore
202206 f" Found method { pruning_fn } of type { pruning_fn .PRUNING_TYPE } . "
203207 )
204208
205209 self .pruning_fn = pruning_fn
206210 self ._apply_pruning = apply_pruning
207211 self ._make_pruning_permanent = make_pruning_permanent
208212
209- if not isinstance (amount , (int , float , Callable )):
213+ if not ( isinstance (amount , (int , float )) or callable ( amount )):
210214 raise MisconfigurationException (
211215 "`amount` should be provided and be either an int, a float or Callable function."
212216 )
@@ -218,25 +222,27 @@ def __init__(
218222
219223 self ._verbose = verbose
220224
221- def filter_parameters_to_prune (self , parameters_to_prune : Optional [ _PARAM_LIST ] = None ) -> Optional [ _PARAM_LIST ] :
225+ def filter_parameters_to_prune (self , parameters_to_prune : _PARAM_LIST = ()) -> _PARAM_LIST :
222226 """
223227 This function can be overridden to control which module to prune.
224228 """
225229 return parameters_to_prune
226230
227- def _create_pruning_fn (self , pruning_fn : str , ** kwargs ) -> Union [Callable , pytorch_prune .BasePruningMethod ]:
231+ def _create_pruning_fn (self , pruning_fn : str , ** kwargs : Any ) -> Union [Callable , pytorch_prune .BasePruningMethod ]:
228232 """
229233 This function takes `pruning_fn`, a function name.
230234
231235 IF use_global_unstructured, pruning_fn will be resolved into its associated ``PyTorch BasePruningMethod``
232236 ELSE, pruning_fn will be resolved into its function counterpart from `torch.nn.utils.prune`.
233237
234238 """
239+ pruning_fn = (
240+ _PYTORCH_PRUNING_METHOD [pruning_fn ]
241+ if self ._use_global_unstructured else _PYTORCH_PRUNING_FUNCTIONS [pruning_fn ]
242+ )
243+ assert callable (pruning_fn )
235244 if self ._use_global_unstructured :
236- pruning_fn = _PYTORCH_PRUNING_METHOD [pruning_fn ]
237245 self ._global_kwargs = kwargs
238- else :
239- pruning_fn = _PYTORCH_PRUNING_FUNCTIONS [pruning_fn ]
240246 # save the function __name__ now because partial does not include it
241247 # and there are issues setting the attribute manually in ddp.
242248 self ._pruning_fn_name = pruning_fn .__name__
@@ -245,10 +251,10 @@ def _create_pruning_fn(self, pruning_fn: str, **kwargs) -> Union[Callable, pytor
245251 return ModelPruning ._wrap_pruning_fn (pruning_fn , ** kwargs )
246252
247253 @staticmethod
248- def _wrap_pruning_fn (pruning_fn , ** kwargs ) :
254+ def _wrap_pruning_fn (pruning_fn : Callable , ** kwargs : Any ) -> Callable :
249255 return partial (pruning_fn , ** kwargs )
250256
251- def make_pruning_permanent (self , pl_module : LightningModule ):
257+ def make_pruning_permanent (self , pl_module : LightningModule ) -> None :
252258 """
253259 Removes pruning buffers from any pruned modules
254260
@@ -261,14 +267,14 @@ def make_pruning_permanent(self, pl_module: LightningModule):
261267 hook .remove (module )
262268 del module ._forward_pre_hooks [k ]
263269
264- def _restore_original_weights (self , module : nn .Module , orig_module : nn .Module , tensor_name : str ):
270+ def _restore_original_weights (self , module : nn .Module , orig_module : nn .Module , tensor_name : str ) -> None :
265271 trained = getattr (module , tensor_name )
266272 orig = getattr (orig_module , tensor_name )
267273 if trained is None or orig is None :
268274 return
269275 trained .data = orig .data .to (trained .device )
270276
271- def apply_lottery_ticket_hypothesis (self ):
277+ def apply_lottery_ticket_hypothesis (self ) -> None :
272278 r"""
273279 Lottery ticket hypothesis algorithm (see page 2 of the paper):
274280
@@ -282,33 +288,35 @@ def apply_lottery_ticket_hypothesis(self):
282288 The ``resample_parameters`` argument can be used to reset the parameters with a new :math:`\theta_z \sim \mathcal{D}_\theta`
283289 """ # noqa: E501
284290
285- def copy_param (new , old , name : str ) -> None :
291+ def copy_param (new : nn . Module , old : nn . Module , name : str ) -> None :
286292 dst = getattr (new , name )
287293 src = getattr (old , name )
288294 if dst is None or src is None or not isinstance (dst , torch .Tensor ) or not isinstance (src , torch .Tensor ):
289295 return
290296 dst .data = src .data .to (dst .device )
291297
298+ assert self ._original_layers is not None
292299 for d in self ._original_layers .values ():
293- copy , names = d ["data" ], d ["names" ]
294- if self ._resample_parameters and hasattr (copy , "reset_parameters" ):
300+ copy = d ["data" ]
301+ names = d ["names" ]
302+ if self ._resample_parameters and hasattr (copy , "reset_parameters" ) and callable (copy .reset_parameters ):
295303 copy = deepcopy (copy ) # keep the original parameters
296304 copy .reset_parameters ()
297305 for i , name in names :
298306 new , new_name = self ._parameters_to_prune [i ]
299307 copy_param (new , copy , name )
300308
301- def _apply_local_pruning (self , amount : float ):
309+ def _apply_local_pruning (self , amount : float ) -> None :
302310 for module , name in self ._parameters_to_prune :
303311 self .pruning_fn (module , name = name , amount = amount )
304312
305- def _resolve_global_kwargs (self , amount : float ):
313+ def _resolve_global_kwargs (self , amount : float ) -> Dict [ str , Any ] :
306314 self ._global_kwargs ["amount" ] = amount
307315 params = set (inspect .signature (self .pruning_fn ).parameters )
308316 params .discard ("self" )
309317 return {k : v for k , v in self ._global_kwargs .items () if k in params }
310318
311- def _apply_global_pruning (self , amount : float ):
319+ def _apply_global_pruning (self , amount : float ) -> None :
312320 pytorch_prune .global_unstructured (
313321 self ._parameters_to_prune , pruning_method = self .pruning_fn , ** self ._resolve_global_kwargs (amount )
314322 )
@@ -321,7 +329,7 @@ def _get_pruned_stats(module: nn.Module, name: str) -> Tuple[int, int]:
321329 mask = getattr (module , attr )
322330 return (mask == 0 ).sum ().item (), mask .numel ()
323331
324- def apply_pruning (self , amount : Union [int , float ]):
332+ def apply_pruning (self , amount : Union [int , float ]) -> None :
325333 """ Applies pruning to ``parameters_to_prune``. """
326334 if self ._verbose :
327335 prev_stats = [self ._get_pruned_stats (m , n ) for m , n in self ._parameters_to_prune ]
@@ -338,7 +346,7 @@ def apply_pruning(self, amount: Union[int, float]):
338346 @rank_zero_only
339347 def _log_sparsity_stats (
340348 self , prev : List [Tuple [int , int ]], curr : List [Tuple [int , int ]], amount : Union [int , float ] = 0
341- ):
349+ ) -> None :
342350 total_params = sum (p .numel () for layer , _ in self ._parameters_to_prune for p in layer .parameters ())
343351 prev_total_zeros = sum (zeros for zeros , _ in prev )
344352 curr_total_zeros = sum (zeros for zeros , _ in curr )
@@ -357,7 +365,7 @@ def _log_sparsity_stats(
357365 f" { curr_mask_zeros } ({ curr_mask_zeros / curr_mask_size :.2%} )"
358366 )
359367
360- def on_before_accelerator_backend_setup (self , trainer , pl_module : LightningModule ):
368+ def on_before_accelerator_backend_setup (self , trainer : 'pl.Trainer' , pl_module : LightningModule ) -> None :
361369 parameters_to_prune = self .sanitize_parameters_to_prune (
362370 pl_module , self ._parameters_to_prune , parameter_names = self ._parameter_names
363371 )
@@ -370,29 +378,34 @@ def on_before_accelerator_backend_setup(self, trainer, pl_module: LightningModul
370378 self ._original_layers = {}
371379 for i , (module , name ) in enumerate (self ._parameters_to_prune ):
372380 id_ = id (module )
373- self ._original_layers .setdefault (id_ , { " data" : deepcopy (module ), " names" : []} )
381+ self ._original_layers .setdefault (id_ , _LayerRef ( data = deepcopy (module ), names = []) )
374382 self ._original_layers [id_ ]["names" ].append ((i , name ))
375383
376- def on_train_epoch_end (self , trainer , pl_module : LightningModule ):
377- current_epoch = trainer .current_epoch
378- prune = self ._apply_pruning (current_epoch ) if isinstance (self ._apply_pruning , Callable ) else self ._apply_pruning
379- amount = self .amount (current_epoch ) if isinstance (self .amount , Callable ) else self .amount
384+ def on_train_epoch_end (self , trainer : 'pl.Trainer' , pl_module : LightningModule ) -> None : # type: ignore
385+ current_epoch = pl_module .current_epoch
386+ prune = self ._apply_pruning (current_epoch ) if callable (self ._apply_pruning ) else self ._apply_pruning
387+ amount = self .amount (current_epoch ) if callable (self .amount ) else self .amount
380388 if not prune or not amount :
381389 return
382390 self .apply_pruning (amount )
383391
384392 if (
385393 self ._use_lottery_ticket_hypothesis (current_epoch )
386- if isinstance (self ._use_lottery_ticket_hypothesis , Callable ) else self ._use_lottery_ticket_hypothesis
394+ if callable (self ._use_lottery_ticket_hypothesis ) else self ._use_lottery_ticket_hypothesis
387395 ):
388396 self .apply_lottery_ticket_hypothesis ()
389397
390- def on_train_end (self , trainer , pl_module : LightningModule ):
398+ def on_train_end (self , trainer : 'pl.Trainer' , pl_module : LightningModule ) -> None :
391399 if self ._make_pruning_permanent :
392400 rank_zero_debug ("`ModelPruning.on_train_end`. Pruning is made permanent for this checkpoint." )
393401 self .make_pruning_permanent (pl_module )
394402
395- def on_save_checkpoint (self , trainer , pl_module : LightningModule , checkpoint : Dict [str , Any ]):
403+ def on_save_checkpoint (
404+ self ,
405+ trainer : 'pl.Trainer' ,
406+ pl_module : LightningModule ,
407+ checkpoint : Dict [str , Any ],
408+ ) -> Dict [str , Any ]:
396409 if self ._make_pruning_permanent :
397410 rank_zero_debug ("`ModelPruning.on_save_checkpoint`. Pruning is made permanent for this checkpoint." )
398411 prev_device = pl_module .device
@@ -402,11 +415,13 @@ def on_save_checkpoint(self, trainer, pl_module: LightningModule, checkpoint: Di
402415 checkpoint ["state_dict" ] = copy .state_dict ()
403416 pl_module .to (prev_device )
404417
418+ return checkpoint
419+
405420 @staticmethod
406421 def sanitize_parameters_to_prune (
407422 pl_module : LightningModule ,
408- parameters_to_prune : Optional [ _PARAM_LIST ] = None ,
409- parameter_names : Optional [ List [ str ]] = None ,
423+ parameters_to_prune : _PARAM_LIST = () ,
424+ parameter_names : Sequence [ str ] = () ,
410425 ) -> _PARAM_LIST :
411426 """
412427 This function is responsible of sanitizing ``parameters_to_prune`` and ``parameter_names``.
@@ -415,13 +430,13 @@ def sanitize_parameters_to_prune(
415430 Raises:
416431 MisconfigurationException:
417432 If ``parameters_to_prune`` doesn't exist in the model, or
418- if ``parameters_to_prune`` is neither a list of tuple nor ``None`` .
433+ if ``parameters_to_prune`` is neither a list nor a tuple .
419434 """
420435 parameters = parameter_names or ModelPruning .PARAMETER_NAMES
421436
422437 current_modules = [m for m in pl_module .modules () if not isinstance (m , _MODULE_CONTAINERS )]
423438
424- if parameters_to_prune is None :
439+ if not parameters_to_prune :
425440 parameters_to_prune = [(m , p ) for p in parameters for m in current_modules
426441 if getattr (m , p , None ) is not None ]
427442 elif (
0 commit comments