diff --git a/modules.py b/modules.py index 36d2116..1f36ecb 100644 --- a/modules.py +++ b/modules.py @@ -181,13 +181,17 @@ def backward(ctx, grad_output): if ctx.in_feat_requires_grad and ctx.proj_weight_requires_grad: grad_in_feat, grad_proj_weight = ctx.saved_tensors elif not ctx.in_feat_requires_grad and ctx.proj_weight_requires_grad: + grad_in_feat = None grad_proj_weight, = ctx.saved_tensors elif ctx.in_feat_requires_grad and not ctx.proj_weight_requires_grad: grad_in_feat, = ctx.saved_tensors + grad_proj_weight = None assert grad_output.shape == tuple(), grad_output.shape - grad_in_feat *= grad_output - grad_proj_weight *= grad_output + if grad_in_feat is not None: + grad_in_feat *= grad_output + if grad_proj_weight is not None: + grad_proj_weight *= grad_output return grad_in_feat, grad_proj_weight, None, None, None, None