Skip to content

Conversation

@DhairyaLGandhi
Copy link
Member

This is in response to #666 where we can visualise train! as a wrapper around the step! function, while maintaining the same api.

To actually accumulate the loss as in #666 (comment), we could of course go with the loss function being a simple closure here, getting rid of the need to send around the mini batch.

With the changes to Zygote, it might be nice to actually have this in our interface.

cc @MikeInnes @oxinabox

throw(StopException())
end

function step!(loss, ps, minibatch, opt)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this have a docstring, and be exported?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it would be good to converge on its implementation, with that I'll add the doctrings and a basic test too.

gs = gradient(ps) do
loss(minibatch...)
end
update!(opt, ps, gs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this return incompatible with the interface proposed in #666 (comment)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It returns nothing, which should be fine for most cases. For that comment specifically, yes, but it really is just a matter of agreeing on the semantics desired out of it. returning the loss here should be fine, but seeing as it is stepping through a batch, and doing a training step, I feel the correct return should be a nothing, with a slightly more trained model as it were

@DhairyaLGandhi
Copy link
Member Author

Pinging @MikeInnes for his thoughts on the semantics that would make most sense here.

@MikeInnes
Copy link
Member

I agree that closing over the batch would make the most sense. I also agree it's best to avoid returning things like losses, since as discussed earlier we can just close over those things, and we will eventually want an out-of-place version of this that returns a model.

train! will likely need a fair bit of refactoring once we figure out new optimisers + accelerator support, but we can figure that out later.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants