File tree Expand file tree Collapse file tree 1 file changed +27
-0
lines changed Expand file tree Collapse file tree 1 file changed +27
-0
lines changed Original file line number Diff line number Diff line change @@ -110,3 +110,30 @@ cb = function ()
110
110
accuracy () > 0.9 && Flux. stop ()
111
111
end
112
112
```
113
+
114
+ ## Custom Training loops
115
+
116
+ The ` Flux.train! ` function can be very convenient, especially for simple problems.
117
+ Its also very flexible with the use of callbacks.
118
+ But for some problems its much cleaner to write your own custom training loop.
119
+ An example follows that works similar to the default ` Flux.train ` but with no callbacks.
120
+ You don't need callbacks if you just code the calls to your functions directly into the loop.
121
+ E.g. in the places marked with comments.
122
+
123
+ ```
124
+ function my_custom_train!(loss, ps, data, opt)
125
+ ps = Params(ps)
126
+ for d in data
127
+ gs = gradient(ps) do
128
+ training_loss = loss(d...)
129
+ # Insert what ever code you want here that needs Training loss, e.g. logging
130
+ return training_loss
131
+ end
132
+ # insert what ever code you want here that needs gradient
133
+ # E.g. logging with TensorBoardLogger.jl as histogram so you can see if it is becoming huge
134
+ update!(opt, ps, gs)
135
+ # Here you might like to check validation set accuracy, and break out to do early stopping
136
+ end
137
+ end
138
+ ```
139
+ You could simplify this further, for example by hard-coding in the loss function.
You can’t perform that action at this time.
0 commit comments