-
-
Notifications
You must be signed in to change notification settings - Fork 617
Refactor train! #1017
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Refactor train! #1017
Conversation
| throw(StopException()) | ||
| end | ||
|
|
||
| function step!(loss, ps, minibatch, opt) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this have a docstring, and be exported?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, it would be good to converge on its implementation, with that I'll add the doctrings and a basic test too.
| gs = gradient(ps) do | ||
| loss(minibatch...) | ||
| end | ||
| update!(opt, ps, gs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't this return incompatible with the interface proposed in #666 (comment)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It returns nothing, which should be fine for most cases. For that comment specifically, yes, but it really is just a matter of agreeing on the semantics desired out of it. returning the loss here should be fine, but seeing as it is stepping through a batch, and doing a training step, I feel the correct return should be a nothing, with a slightly more trained model as it were
|
Pinging @MikeInnes for his thoughts on the semantics that would make most sense here. |
|
I agree that closing over the batch would make the most sense. I also agree it's best to avoid returning things like losses, since as discussed earlier we can just close over those things, and we will eventually want an out-of-place version of this that returns a model.
|
This is in response to #666 where we can visualise
train!as a wrapper around thestep!function, while maintaining the same api.To actually accumulate the loss as in #666 (comment), we could of course go with the loss function being a simple closure here, getting rid of the need to send around the mini batch.
With the changes to Zygote, it might be nice to actually have this in our interface.
cc @MikeInnes @oxinabox