Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,13 +181,17 @@ def backward(ctx, grad_output):
if ctx.in_feat_requires_grad and ctx.proj_weight_requires_grad:
grad_in_feat, grad_proj_weight = ctx.saved_tensors
elif not ctx.in_feat_requires_grad and ctx.proj_weight_requires_grad:
grad_in_feat = None
grad_proj_weight, = ctx.saved_tensors
elif ctx.in_feat_requires_grad and not ctx.proj_weight_requires_grad:
grad_in_feat, = ctx.saved_tensors
grad_proj_weight = None

assert grad_output.shape == tuple(), grad_output.shape
grad_in_feat *= grad_output
grad_proj_weight *= grad_output
if grad_in_feat is not None:
grad_in_feat *= grad_output
if grad_proj_weight is not None:
grad_proj_weight *= grad_output

return grad_in_feat, grad_proj_weight, None, None, None, None

Expand Down