diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index f14ee7e91f..de55162543 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -106,6 +106,9 @@ def __init__( amp: bool = False, event_names: Optional[List[Union[str, EventEnum]]] = None, event_to_attr: Optional[dict] = None, + gradient_clip_val: Optional[float] = None, + gradient_clip_algo: Optional[str] = None, + gradient_clip_norm: Optional[Union[float, int, str]] = None, ) -> None: super().__init__( device=device, @@ -130,6 +133,10 @@ def __init__( self.loss_function = loss_function self.inferer = SimpleInferer() if inferer is None else inferer + self.gradient_clip_val = gradient_clip_val + self.gradient_clip_algo = gradient_clip_algo + self.gradient_clip_norm = gradient_clip_norm + def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): """ Callback function for the Supervised Training processing logic of 1 iteration in Ignite Engine. @@ -183,6 +190,24 @@ def _compute_pred_loss(): return output + def _gradient_clipping(self): + if self.gradient_clip_algo: + if self.amp and self.scaler is not None: + self.scaler.unscale_(self.optimizer) + + if self.gradient_clip_algo == "norm": + torch.nn.utils.clip_grad_norm_( + parameters=self.network.parameters(), + max_norm=self.gradient_clip_val, + norm_type=self.gradient_clip_norm, + ) + elif self.gradient_clip_algo == "clip": + torch.nn.utils.clip_grad_value_(parameters=self.network.parameters(), clip_value=self.gradient_clip_val) + else: + raise ValueError( + f"gradient_clip_algo can be either 'norm' or 'clip' but received {self.gradient_clip_algo}" + ) + class GanTrainer(Trainer): """ @@ -309,9 +334,7 @@ def _iteration( g_output = self.g_inferer(g_input, self.g_network) # Train Discriminator - d_total_loss = torch.zeros( - 1, - ) + d_total_loss = torch.zeros(1) for _ in range(self.d_train_steps): self.d_optimizer.zero_grad() dloss = self.d_loss_function(g_output, d_input)