-
-
Notifications
You must be signed in to change notification settings - Fork 617
Description
There are a few cases where I find myself wondering if we should make it more explicit how we can extend the train loop design to be more friendly for callbacks not having to cheat to get things like the loss and so on. Further, things like FluxTraining.jl also show that we have a certain lack of preexisting callbacks, which don't need to be rewritten.
So keeping this stuff in mind, I think using pullback instead of gradient would be a step towards that, as well as not optimising before a prehook to check for callback conditions etc. This should also fall in nicely how we want to set up schedulers. I would also want to figure out where distributed and multi gpu falls in this, so we know how to proceed.
We don't necessarily want to return the losses etc, but perhaps a slightly more trained model? This would fall in line with how Optimisers.jl is looking as well.
cc @lorenzoh