Skip to content

Commit 3db4013

Browse files
author
Andrew David Werner Rosemberg
committed
add tolorance kwarg
1 parent 2dc709c commit 3db4013

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

src/FullyConnected.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,10 @@ end
6868
# Define a container to hold any optimiser specific parameters (if any):
6969
struct ConvexRule <: Flux.Optimise.AbstractOptimiser
7070
rule::Flux.Optimise.AbstractOptimiser
71+
tol::Real
72+
end
73+
function ConvexRule(rule::Flux.Optimise.AbstractOptimiser; tol=1e-6)
74+
return ConvexRule(rule, tol)
7175
end
7276

7377
"""
@@ -113,7 +117,7 @@ function MLJFlux.train!(
113117
return batch_loss
114118
end
115119
Flux.update!(optimiser.rule, parameters, gs)
116-
make_convex!(chain)
120+
make_convex!(chain; tol=optimiser.tol)
117121
end
118122
return training_loss / n_batches
119123
end

0 commit comments

Comments
 (0)