Skip to content

Commit ddc2c20

Browse files
Merge pull request #994 from FluxML/ox/doccustomtraining
Add custom training loops to docs
2 parents 620cffc + 7797e31 commit ddc2c20

File tree

1 file changed

+27
-0
lines changed

1 file changed

+27
-0
lines changed

docs/src/training/training.md

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,3 +110,30 @@ cb = function ()
110110
accuracy() > 0.9 && Flux.stop()
111111
end
112112
```
113+
114+
## Custom Training loops
115+
116+
The `Flux.train!` function can be very convenient, especially for simple problems.
117+
Its also very flexible with the use of callbacks.
118+
But for some problems its much cleaner to write your own custom training loop.
119+
An example follows that works similar to the default `Flux.train` but with no callbacks.
120+
You don't need callbacks if you just code the calls to your functions directly into the loop.
121+
E.g. in the places marked with comments.
122+
123+
```
124+
function my_custom_train!(loss, ps, data, opt)
125+
ps = Params(ps)
126+
for d in data
127+
gs = gradient(ps) do
128+
training_loss = loss(d...)
129+
# Insert what ever code you want here that needs Training loss, e.g. logging
130+
return training_loss
131+
end
132+
# insert what ever code you want here that needs gradient
133+
# E.g. logging with TensorBoardLogger.jl as histogram so you can see if it is becoming huge
134+
update!(opt, ps, gs)
135+
# Here you might like to check validation set accuracy, and break out to do early stopping
136+
end
137+
end
138+
```
139+
You could simplify this further, for example by hard-coding in the loss function.

0 commit comments

Comments
 (0)