diff --git a/came_pytorch/CAME.py b/came_pytorch/CAME.py index 280edfd..4962285 100644 --- a/came_pytorch/CAME.py +++ b/came_pytorch/CAME.py @@ -60,7 +60,7 @@ def _rms(self, tensor): def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col): r_factor = ( - (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)) + (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True).add_(self.param_groups[0]["eps"][1])) .rsqrt_() .unsqueeze(-1) ) @@ -102,8 +102,8 @@ def step(self, closure=None): grad_shape[:-2] + grad_shape[-1:] ).type_as(grad) - state["exp_avg_res_row"] = torch.zeros(grad_shape[:-1]).type_as(grad) - state["exp_avg_res_col"] = torch.zeros( + state["exp_avg_res_row"] = torch.ones(grad_shape[:-1]).type_as(grad) + state["exp_avg_res_col"] = torch.ones( grad_shape[:-2] + grad_shape[-1:] ).type_as(grad) else: @@ -171,4 +171,4 @@ def step(self, closure=None): update.mul_(group["lr"]) p.data.add_(-update) - return loss \ No newline at end of file + return loss