-
Notifications
You must be signed in to change notification settings - Fork 25.8k
Closed
Labels
featureA request for a proper, new feature.A request for a proper, new feature.module: optimizerRelated to torch.optimRelated to torch.optimmodule: sparseRelated to torch.sparseRelated to torch.sparsetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
Now that sparse tensors are mostly working (#1147), it would be awesome if all the optimizers worked with sparse tensors. This requires some cleverness to do amortized updates to the parameters. For example, for weight decay, you would do something like (pseudocode):
# apply weight decay
# dp is sparse
n_t++
for i in dp.indices():
p[i] *= (1 - lr * wd)^(n_t - n[i])
n[i] = n_t
Note this isn't exactly equivalent (to be equivalent you'd need to apply the weight decay before the forward pass, not after backwards), but it's a good approximation. You can do the same thing for momentum.
I'm guessing the same thing works for Adam/Adamax as well but I haven't worked through the equations. https://arxiv.org/pdf/1412.6980.pdf
@ezyang expressed interest in working on this.
martinraison, voxmenthe, nijianmo, artemru, Kaixhin and 1 more
Metadata
Metadata
Assignees
Labels
featureA request for a proper, new feature.A request for a proper, new feature.module: optimizerRelated to torch.optimRelated to torch.optimmodule: sparseRelated to torch.sparseRelated to torch.sparsetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module