Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions src/optimise/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,13 @@ function stop()
throw(StopException())
end

function step!(loss, ps, minibatch, opt)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this have a docstring, and be exported?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it would be good to converge on its implementation, with that I'll add the doctrings and a basic test too.

gs = gradient(ps) do
loss(minibatch...)
end
update!(opt, ps, gs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this return incompatible with the interface proposed in #666 (comment)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It returns nothing, which should be fine for most cases. For that comment specifically, yes, but it really is just a matter of agreeing on the semantics desired out of it. returning the loss here should be fine, but seeing as it is stepping through a batch, and doing a training step, I feel the correct return should be a nothing, with a slightly more trained model as it were

end

"""
train!(loss, params, data, opt; cb)

Expand All @@ -65,10 +72,11 @@ function train!(loss, ps, data, opt; cb = () -> ())
cb = runall(cb)
@progress for d in data
try
gs = gradient(ps) do
loss(d...)
end
update!(opt, ps, gs)
# gs = gradient(ps) do
# loss(d...)
# end
# update!(opt, ps, gs)
step!(loss, ps, d, opt)
cb()
catch ex
if ex isa StopException
Expand Down