Skip to content

Commit 514799c

Browse files
committed
use less isnewton
1 parent f6b6c00 commit 514799c

File tree

9 files changed

+50
-49
lines changed

9 files changed

+50
-49
lines changed

lib/OrdinaryDiffEqBDF/src/OrdinaryDiffEqBDF.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ using ArrayInterface: ismutable
3636
import OrdinaryDiffEqCore
3737
using OrdinaryDiffEqDifferentiation: UJacobianWrapper
3838
using OrdinaryDiffEqNonlinearSolve: NLNewton, du_alias_or_new, build_nlsolver,
39-
nlsolve!, nlsolvefail, isnewton, markfirststage!,
39+
nlsolve!, nlsolvefail, markfirststage!,
4040
set_new_W!, DIRK, compute_step!, COEFFICIENT_MULTISTEP,
4141
NonlinearSolveAlg
4242
import ADTypes: AutoForwardDiff, AutoFiniteDiff, AbstractADType

lib/OrdinaryDiffEqBDF/src/bdf_perform_step.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -985,15 +985,15 @@ end
985985
### STEP 2
986986
nlsolver.tmp = z₁
987987
nlsolver.c = 2
988-
isnewton(nlsolver) && set_new_W!(nlsolver, false)
988+
set_new_W!(nlsolver, false)
989989
z = nlsolve!(nlsolver, integrator, cache, repeat_step)
990990
nlsolvefail(nlsolver) && return
991991
z₂ = z₁ + z
992992
### STEP 3
993993
tmp2 = 0.5uprev + z₁ - 0.5z₂
994994
nlsolver.tmp = tmp2
995995
nlsolver.c = 1
996-
isnewton(nlsolver) && set_new_W!(nlsolver, false)
996+
set_new_W!(nlsolver, false)
997997
z = nlsolve!(nlsolver, integrator, cache, repeat_step)
998998
nlsolvefail(nlsolver) && return
999999
u = tmp2 + z
@@ -1039,7 +1039,7 @@ end
10391039
### STEP 2
10401040
nlsolver.tmp = z₁
10411041
nlsolver.c = 2
1042-
isnewton(nlsolver) && set_new_W!(nlsolver, false)
1042+
set_new_W!(nlsolver, false)
10431043
z = nlsolve!(nlsolver, integrator, cache, repeat_step)
10441044
nlsolvefail(nlsolver) && return
10451045
@.. broadcast=false z₂=z₁+z
@@ -1048,7 +1048,7 @@ end
10481048
@.. broadcast=false tmp2=0.5uprev+z₁-0.5z₂
10491049
nlsolver.tmp = tmp2
10501050
nlsolver.c = 1
1051-
isnewton(nlsolver) && set_new_W!(nlsolver, false)
1051+
set_new_W!(nlsolver, false)
10521052
z = nlsolve!(nlsolver, integrator, cache, repeat_step)
10531053
nlsolvefail(nlsolver) && return
10541054
@.. broadcast=false u=tmp2+z

lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -606,7 +606,7 @@ function calc_W!(W, integrator, nlsolver::Union{Nothing, AbstractNLSolver}, cach
606606
new_jac, new_W = newJW
607607
end
608608

609-
if new_jac && isnewton(lcache)
609+
if new_jac && (isnewton(lcache))
610610
lcache.J_t = t
611611
if isdae
612612
lcache.uf.α = nlsolver.α

lib/OrdinaryDiffEqNonlinearSolve/src/newton.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,13 +126,17 @@ end
126126
@unpack tstep, invγdt, atmp, ustep = cache
127127

128128
new_jac, new_W = do_newJW(integrator, integrator.alg, nlsolver, false)
129+
cache.new_W = new_W
130+
@show new_jac, new_W
129131
if is_always_new(nlsolver) || new_jac || new_W
132+
cache.W_γdt = γ*dt
133+
cache.J_t = t
130134
recompute_jacobian = true
131135
else
132136
recompute_jacobian = false
133137
end
134138

135-
nlcache = nlsolver.cache.cache
139+
nlcache = cache.cache
136140
nlstep_data = integrator.f.nlstep_data
137141
step!(nlcache; recompute_jacobian)
138142

lib/OrdinaryDiffEqNonlinearSolve/src/nlsolve.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ function postamble!(nlsolver::NLSolver, integrator::SciMLBase.DEIntegrator)
154154
end
155155
integrator.force_stepfail = nlsolvefail(nlsolver)
156156
setfirststage!(nlsolver, false)
157-
isnewton(nlsolver) && (nlsolver.cache.firstcall = false)
157+
(isnewton(nlsolver) || isnonlinearsolve(nlsolver)) && (nlsolver.cache.firstcall = false)
158158

159159
nlsolver.z
160160
end

lib/OrdinaryDiffEqNonlinearSolve/src/type.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,8 +215,13 @@ mutable struct NonlinearSolveCache{uType, tType, rateType, tType2, P, C} <:
215215
tstep::tType
216216
k::rateType
217217
atmp::uType
218-
invγdt::tType2
219218
prob::P
220219
cache::C
221220
new_W::Bool
221+
firststage::Bool
222+
firstcall::Bool
223+
W_γdt::tType
224+
invγdt::tType2
225+
new_W_γdt_cutoff::tType
226+
J_t::tType
222227
end

lib/OrdinaryDiffEqNonlinearSolve/src/utils.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ isJcurrent(nlsolver::AbstractNLSolver, integrator) = integrator.t == nlsolver.ca
2626
isfirstcall(nlsolver::AbstractNLSolver) = nlsolver.cache.firstcall
2727
isfirststage(nlsolver::AbstractNLSolver) = nlsolver.cache.firststage
2828
setfirststage!(nlsolver::AbstractNLSolver, val::Bool) = setfirststage!(nlsolver.cache, val)
29-
function setfirststage!(nlcache::Union{NLNewtonCache, NLNewtonConstantCache}, val::Bool)
29+
function setfirststage!(nlcache::Union{NLNewtonCache, NLNewtonConstantCache, NonlinearSolveCache}, val::Bool)
3030
(nlcache.firststage = val)
3131
end
3232
setfirststage!(::Any, val::Bool) = nothing
@@ -37,9 +37,9 @@ getnfails(nlsolver::AbstractNLSolver) = nlsolver.nfails
3737

3838
set_new_W!(nlsolver::AbstractNLSolver, val::Bool)::Bool = set_new_W!(nlsolver.cache, val)
3939
set_new_W!(nlcache::Union{NLNewtonCache, NLNewtonConstantCache, NonlinearSolveCache}, val::Bool)::Bool = nlcache.new_W = val
40+
set_new_W!(nlcache::AbstractNLSolverCache, val::Bool)::Bool = nothing
4041
get_new_W!(nlsolver::AbstractNLSolver)::Bool = get_new_W!(nlsolver.cache)
41-
get_new_W!(nlcache::Union{NLNewtonCache, NLNewtonConstantCache, NonlinearSolveCache})::Bool = nlcache.new_W
42-
get_new_W!(::AbstractNLSolverCache)::Bool = true
42+
get_new_W!(::AbstractNLSolverCache)::Bool = nlcache.new_W
4343

4444
get_W(nlsolver::AbstractNLSolver) = get_W(nlsolver.cache)
4545
get_W(nlcache::Union{NLNewtonCache, NLNewtonConstantCache}) = nlcache.W
@@ -243,7 +243,8 @@ function build_nlsolver(
243243
NonlinearProblem(NonlinearFunction{true}(nlf), ztmp, nlp_params)
244244
end
245245
cache = init(prob, nlalg.alg)
246-
nlcache = NonlinearSolveCache(ustep, tstep, k, atmp, invγdt, prob, cache, true)
246+
nlcache = NonlinearSolveCache(ustep, tstep, k, atmp, prob, cache,
247+
true, true, true, tType(dt), invγdt, tType(nlalg.new_W_dt_cutoff), t)
247248
else
248249
nlcache = NLNewtonCache(ustep, tstep, k, atmp, dz, J, W, true,
249250
true, true, tType(dt), du1, uf, jac_config,
@@ -330,8 +331,8 @@ function build_nlsolver(
330331
end
331332
prob = NonlinearProblem(NonlinearFunction{false}(nlf), copy(ztmp), nlp_params)
332333
cache = init(prob, nlalg.alg)
333-
nlcache = NonlinearSolveCache(
334-
nothing, tstep, nothing, nothing, invγdt, prob, cache, true)
334+
nlcache = NonlinearSolveCache(nothing, tstep, nothing, nothing, prob, cache,
335+
true, true, true, tType(dt), invγdt, tType(nlalg.new_W_dt_cutoff), t)
335336
else
336337
nlcache = NLNewtonConstantCache(tstep, J, W, true, true, true, tType(dt), uf,
337338
invγdt, tType(nlalg.new_W_dt_cutoff), t)

lib/OrdinaryDiffEqSDIRK/src/kencarp_kvaerno_perform_step.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ end
8888
nlsolver.c = γ
8989
z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step)
9090
nlsolvefail(nlsolver) && return
91-
isnewton(nlsolver) && set_new_W!(nlsolver, false)
91+
set_new_W!(nlsolver, false)
9292

9393
################################## Solve Step 3
9494

@@ -315,7 +315,7 @@ end
315315
nlsolver.c = 2γ
316316
z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step)
317317
nlsolvefail(nlsolver) && return
318-
isnewton(nlsolver) && set_new_W!(nlsolver, false)
318+
set_new_W!(nlsolver, false)
319319

320320
################################## Solve Step 3
321321

@@ -554,7 +554,7 @@ end
554554
nlsolver.c = c2
555555
z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step)
556556
nlsolvefail(nlsolver) && return
557-
isnewton(nlsolver) && set_new_W!(nlsolver, false)
557+
set_new_W!(nlsolver, false)
558558

559559
################################## Solve Step 3
560560

@@ -719,7 +719,7 @@ end
719719
nlsolver.c = γ
720720
z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step)
721721
nlsolvefail(nlsolver) && return
722-
isnewton(nlsolver) && set_new_W!(nlsolver, false)
722+
set_new_W!(nlsolver, false)
723723

724724
################################## Solve Step 3
725725

@@ -1008,7 +1008,7 @@ end
10081008
markfirststage!(nlsolver)
10091009
z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step)
10101010
nlsolvefail(nlsolver) && return
1011-
isnewton(nlsolver) && set_new_W!(nlsolver, false)
1011+
set_new_W!(nlsolver, false)
10121012

10131013
################################## Solve Step 3
10141014

@@ -1262,7 +1262,7 @@ end
12621262
nlsolver.c = γ
12631263
z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step)
12641264
nlsolvefail(nlsolver) && return
1265-
isnewton(nlsolver) && set_new_W!(nlsolver, false)
1265+
set_new_W!(nlsolver, false)
12661266

12671267
################################## Solve Step 3
12681268

@@ -1612,7 +1612,7 @@ end
16121612
nlsolver.c = 2γ
16131613
z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step)
16141614
nlsolvefail(nlsolver) && return
1615-
isnewton(nlsolver) && set_new_W!(nlsolver, false)
1615+
set_new_W!(nlsolver, false)
16161616

16171617
################################## Solve Step 3
16181618

@@ -2028,7 +2028,7 @@ end
20282028
nlsolver.c = 2γ
20292029
z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step)
20302030
nlsolvefail(nlsolver) && return
2031-
isnewton(nlsolver) && set_new_W!(nlsolver, false)
2031+
set_new_W!(nlsolver, false)
20322032

20332033
################################## Solve Step 3
20342034

@@ -2449,7 +2449,7 @@ end
24492449
nlsolver.c = 2γ
24502450
z₂ = nlsolve!(nlsolver, integrator, cache, repeat_step)
24512451
nlsolvefail(nlsolver) && return
2452-
isnewton(nlsolver) && set_new_W!(nlsolver, false)
2452+
set_new_W!(nlsolver, false)
24532453

24542454
################################## Solve Step 3
24552455

0 commit comments

Comments
 (0)