Skip to content

Commit 74c2ad7

Browse files
committed
Reduce unnecessary allocations and reuse code
1 parent 657c549 commit 74c2ad7

File tree

9 files changed

+204
-173
lines changed

9 files changed

+204
-173
lines changed

Project.toml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
1616
LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255"
1717
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1818
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
19+
MaybeInplace = "bb5d69b7-63fc-4a16-80bd-7e42200c7bdb"
1920
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
2021
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
2122
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
@@ -42,8 +43,8 @@ NonlinearSolveZygoteExt = "Zygote"
4243

4344
[compat]
4445
ADTypes = "0.2"
45-
ArrayInterface = "6.0.24, 7"
4646
Aqua = "0.8"
47+
ArrayInterface = "6.0.24, 7"
4748
BandedMatrices = "1"
4849
BenchmarkTools = "1"
4950
ConcreteStructs = "0.2"
@@ -70,9 +71,9 @@ Reexport = "0.2, 1"
7071
SafeTestsets = "0.1"
7172
SciMLBase = "2.9"
7273
SciMLOperators = "0.3"
73-
SimpleNonlinearSolve = "0.1.23"
74+
SimpleNonlinearSolve = "1" # FIXME: Don't update the version in this PR. Using it to test
7475
SparseArrays = "<0.0.1, 1"
75-
SparseDiffTools = "2.12"
76+
SparseDiffTools = "2.14"
7677
StaticArrays = "1"
7778
StaticArraysCore = "1.4"
7879
Symbolics = "5"

src/NonlinearSolve.jl

Lines changed: 56 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -8,25 +8,24 @@ import Reexport: @reexport
88
import PrecompileTools: @recompile_invalidations, @compile_workload, @setup_workload
99

1010
@recompile_invalidations begin
11-
using DiffEqBase,
12-
LazyArrays, LinearAlgebra, LinearSolve, Printf, SparseArrays,
11+
using DiffEqBase, LazyArrays, LinearAlgebra, LinearSolve, Printf, SparseArrays,
1312
SparseDiffTools
14-
using FastBroadcast: @..
15-
import ArrayInterface: restructure
1613

1714
import ADTypes: AbstractFiniteDifferencesMode
18-
import ArrayInterface: undefmatrix,
15+
import ArrayInterface: undefmatrix, restructure, can_setindex,
1916
matrix_colors, parameterless_type, ismutable, issingular, fast_scalar_indexing
2017
import ConcreteStructs: @concrete
2118
import EnumX: @enumx
19+
import FastBroadcast: @..
2220
import ForwardDiff
2321
import ForwardDiff: Dual
2422
import LinearSolve: ComposePreconditioner, InvPreconditioner, needs_concrete_A
23+
import MaybeInplace: @bb
2524
import RecursiveArrayTools: ArrayPartition,
2625
AbstractVectorOfArray, recursivecopy!, recursivefill!
2726
import SciMLBase: AbstractNonlinearAlgorithm, NLStats, _unwrap_val, has_jac, isinplace
2827
import SciMLOperators: FunctionOperator
29-
import StaticArraysCore: StaticArray, SVector, SArray, MArray
28+
import StaticArraysCore: StaticArray, SVector, SArray, MArray, Size, SMatrix
3029
import UnPack: @unpack
3130

3231
using ADTypes, LineSearches, SciMLBase, SimpleNonlinearSolve
@@ -55,13 +54,13 @@ isinplace(::AbstractNonlinearSolveCache{iip}) where {iip} = iip
5554
function Base.show(io::IO, alg::AbstractNonlinearSolveAlgorithm)
5655
str = "$(nameof(typeof(alg)))("
5756
modifiers = String[]
58-
if _getproperty(alg, Val(:ad)) !== nothing
57+
if __getproperty(alg, Val(:ad)) !== nothing
5958
push!(modifiers, "ad = $(nameof(typeof(alg.ad)))()")
6059
end
61-
if _getproperty(alg, Val(:linsolve)) !== nothing
60+
if __getproperty(alg, Val(:linsolve)) !== nothing
6261
push!(modifiers, "linsolve = $(nameof(typeof(alg.linsolve)))()")
6362
end
64-
if _getproperty(alg, Val(:linesearch)) !== nothing
63+
if __getproperty(alg, Val(:linesearch)) !== nothing
6564
ls = alg.linesearch
6665
if ls isa LineSearch
6766
ls.method !== nothing &&
@@ -70,7 +69,7 @@ function Base.show(io::IO, alg::AbstractNonlinearSolveAlgorithm)
7069
push!(modifiers, "linesearch = $(nameof(typeof(alg.linesearch)))()")
7170
end
7271
end
73-
if _getproperty(alg, Val(:radius_update_scheme)) !== nothing
72+
if __getproperty(alg, Val(:radius_update_scheme)) !== nothing
7473
push!(modifiers, "radius_update_scheme = $(alg.radius_update_scheme)")
7574
end
7675
str = str * join(modifiers, ", ")
@@ -107,7 +106,7 @@ function SciMLBase.solve!(cache::AbstractNonlinearSolveCache)
107106
end
108107
end
109108

110-
trace = _getproperty(cache, Val{:trace}())
109+
trace = __getproperty(cache, Val{:trace}())
111110
if trace !== nothing
112111
update_trace!(trace, cache.stats.nsteps, get_u(cache), get_fu(cache), nothing,
113112
nothing, nothing; last = Val(true))
@@ -134,52 +133,52 @@ include("jacobian.jl")
134133
include("ad.jl")
135134
include("default.jl")
136135

137-
@setup_workload begin
138-
nlfuncs = ((NonlinearFunction{false}((u, p) -> u .* u .- p), 0.1),
139-
(NonlinearFunction{false}((u, p) -> u .* u .- p), [0.1]),
140-
(NonlinearFunction{true}((du, u, p) -> du .= u .* u .- p), [0.1]))
141-
probs_nls = NonlinearProblem[]
142-
for T in (Float32, Float64), (fn, u0) in nlfuncs
143-
push!(probs_nls, NonlinearProblem(fn, T.(u0), T(2)))
144-
end
145-
146-
nls_algs = (NewtonRaphson(), TrustRegion(), LevenbergMarquardt(), PseudoTransient(),
147-
GeneralBroyden(), GeneralKlement(), DFSane(), nothing)
148-
149-
probs_nlls = NonlinearLeastSquaresProblem[]
150-
nlfuncs = ((NonlinearFunction{false}((u, p) -> (u .^ 2 .- p)[1:1]), [0.1, 0.0]),
151-
(NonlinearFunction{false}((u, p) -> vcat(u .* u .- p, u .* u .- p)), [0.1, 0.1]),
152-
(NonlinearFunction{true}((du, u, p) -> du[1] = u[1] * u[1] - p,
153-
resid_prototype = zeros(1)), [0.1, 0.0]),
154-
(NonlinearFunction{true}((du, u, p) -> du .= vcat(u .* u .- p, u .* u .- p),
155-
resid_prototype = zeros(4)), [0.1, 0.1]))
156-
for (fn, u0) in nlfuncs
157-
push!(probs_nlls, NonlinearLeastSquaresProblem(fn, u0, 2.0))
158-
end
159-
nlfuncs = ((NonlinearFunction{false}((u, p) -> (u .^ 2 .- p)[1:1]), Float32[0.1, 0.0]),
160-
(NonlinearFunction{false}((u, p) -> vcat(u .* u .- p, u .* u .- p)),
161-
Float32[0.1, 0.1]),
162-
(NonlinearFunction{true}((du, u, p) -> du[1] = u[1] * u[1] - p,
163-
resid_prototype = zeros(Float32, 1)), Float32[0.1, 0.0]),
164-
(NonlinearFunction{true}((du, u, p) -> du .= vcat(u .* u .- p, u .* u .- p),
165-
resid_prototype = zeros(Float32, 4)), Float32[0.1, 0.1]))
166-
for (fn, u0) in nlfuncs
167-
push!(probs_nlls, NonlinearLeastSquaresProblem(fn, u0, 2.0f0))
168-
end
169-
170-
nlls_algs = (LevenbergMarquardt(), GaussNewton(),
171-
LevenbergMarquardt(; linsolve = LUFactorization()),
172-
GaussNewton(; linsolve = LUFactorization()))
173-
174-
@compile_workload begin
175-
for prob in probs_nls, alg in nls_algs
176-
solve(prob, alg, abstol = 1e-2)
177-
end
178-
for prob in probs_nlls, alg in nlls_algs
179-
solve(prob, alg, abstol = 1e-2)
180-
end
181-
end
182-
end
136+
# @setup_workload begin
137+
# nlfuncs = ((NonlinearFunction{false}((u, p) -> u .* u .- p), 0.1),
138+
# (NonlinearFunction{false}((u, p) -> u .* u .- p), [0.1]),
139+
# (NonlinearFunction{true}((du, u, p) -> du .= u .* u .- p), [0.1]))
140+
# probs_nls = NonlinearProblem[]
141+
# for T in (Float32, Float64), (fn, u0) in nlfuncs
142+
# push!(probs_nls, NonlinearProblem(fn, T.(u0), T(2)))
143+
# end
144+
145+
# nls_algs = (NewtonRaphson(), TrustRegion(), LevenbergMarquardt(), PseudoTransient(),
146+
# GeneralBroyden(), GeneralKlement(), DFSane(), nothing)
147+
148+
# probs_nlls = NonlinearLeastSquaresProblem[]
149+
# nlfuncs = ((NonlinearFunction{false}((u, p) -> (u .^ 2 .- p)[1:1]), [0.1, 0.0]),
150+
# (NonlinearFunction{false}((u, p) -> vcat(u .* u .- p, u .* u .- p)), [0.1, 0.1]),
151+
# (NonlinearFunction{true}((du, u, p) -> du[1] = u[1] * u[1] - p,
152+
# resid_prototype = zeros(1)), [0.1, 0.0]),
153+
# (NonlinearFunction{true}((du, u, p) -> du .= vcat(u .* u .- p, u .* u .- p),
154+
# resid_prototype = zeros(4)), [0.1, 0.1]))
155+
# for (fn, u0) in nlfuncs
156+
# push!(probs_nlls, NonlinearLeastSquaresProblem(fn, u0, 2.0))
157+
# end
158+
# nlfuncs = ((NonlinearFunction{false}((u, p) -> (u .^ 2 .- p)[1:1]), Float32[0.1, 0.0]),
159+
# (NonlinearFunction{false}((u, p) -> vcat(u .* u .- p, u .* u .- p)),
160+
# Float32[0.1, 0.1]),
161+
# (NonlinearFunction{true}((du, u, p) -> du[1] = u[1] * u[1] - p,
162+
# resid_prototype = zeros(Float32, 1)), Float32[0.1, 0.0]),
163+
# (NonlinearFunction{true}((du, u, p) -> du .= vcat(u .* u .- p, u .* u .- p),
164+
# resid_prototype = zeros(Float32, 4)), Float32[0.1, 0.1]))
165+
# for (fn, u0) in nlfuncs
166+
# push!(probs_nlls, NonlinearLeastSquaresProblem(fn, u0, 2.0f0))
167+
# end
168+
169+
# nlls_algs = (LevenbergMarquardt(), GaussNewton(),
170+
# LevenbergMarquardt(; linsolve = LUFactorization()),
171+
# GaussNewton(; linsolve = LUFactorization()))
172+
173+
# @compile_workload begin
174+
# for prob in probs_nls, alg in nls_algs
175+
# solve(prob, alg, abstol = 1e-2)
176+
# end
177+
# for prob in probs_nlls, alg in nlls_algs
178+
# solve(prob, alg, abstol = 1e-2)
179+
# end
180+
# end
181+
# end
183182

184183
export RadiusUpdateSchemes
185184

src/jacobian.jl

Lines changed: 52 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,11 @@
33
Jᵀ
44
end
55

6-
SciMLBase.isinplace(JᵀJ::KrylovJᵀJ) = isinplace(JᵀJ.Jᵀ)
6+
__maybe_symmetric(x::KrylovJᵀJ) = x.JᵀJ
7+
8+
isinplace(JᵀJ::KrylovJᵀJ) = isinplace(JᵀJ.Jᵀ)
79

10+
# Select if we are going to use sparse differentiation or not
811
sparsity_detection_alg(_, _) = NoSparsityDetection()
912
function sparsity_detection_alg(f, ad::AbstractSparseADType)
1013
if f.sparsity === nothing
@@ -33,13 +36,21 @@ function jacobian!!(J::Union{AbstractMatrix{<:Number}, Nothing}, cache)
3336
@unpack f, uf, u, p, jac_cache, alg, fu2 = cache
3437
iip = isinplace(cache)
3538
if iip
36-
has_jac(f) ? f.jac(J, u, p) :
37-
sparse_jacobian!(J, alg.ad, jac_cache, uf, fu2, _maybe_mutable(u, alg.ad))
39+
if has_jac(f)
40+
f.jac(J, u, p)
41+
else
42+
sparse_jacobian!(J, alg.ad, jac_cache, uf, fu2, u)
43+
end
44+
return J
3845
else
39-
return has_jac(f) ? f.jac(u, p) :
40-
sparse_jacobian!(J, alg.ad, jac_cache, uf, _maybe_mutable(u, alg.ad))
46+
if has_jac(f)
47+
return f.jac(u, p)
48+
elseif can_setindex(typeof(J))
49+
return sparse_jacobian!(J, alg.ad, jac_cache, uf, u)
50+
else
51+
return sparse_jacobian(alg.ad, jac_cache, uf, u)
52+
end
4153
end
42-
return J
4354
end
4455
# Scalar case
4556
jacobian!!(::Number, cache) = last(value_derivative(cache.uf, cache.u))
@@ -59,13 +70,13 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u, p, ::Val
5970
alg_wants_jac = (concrete_jac(alg) !== nothing && concrete_jac(alg))
6071

6172
# NOTE: The deepcopy is needed here since we are using the resid_prototype elsewhere
62-
fu = f.resid_prototype === nothing ? (iip ? _mutable_zero(u) : _mutable(f(u, p))) :
73+
fu = f.resid_prototype === nothing ? (iip ? zero(u) : f(u, p)) :
6374
(iip ? deepcopy(f.resid_prototype) : f.resid_prototype)
6475
if !has_analytic_jac && (linsolve_needs_jac || alg_wants_jac)
6576
sd = sparsity_detection_alg(f, alg.ad)
6677
ad = alg.ad
67-
jac_cache = iip ? sparse_jacobian_cache(ad, sd, uf, fu, _maybe_mutable(u, ad)) :
68-
sparse_jacobian_cache(ad, sd, uf, _maybe_mutable(u, ad); fx = fu)
78+
jac_cache = iip ? sparse_jacobian_cache(ad, sd, uf, fu, u) :
79+
sparse_jacobian_cache(ad, sd, uf, __maybe_mutable(u, ad); fx = fu)
6980
else
7081
jac_cache = nothing
7182
end
@@ -76,11 +87,11 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u, p, ::Val
7687
JacVec(uf, u; fu, autodiff = __get_nonsparse_ad(alg.ad))
7788
else
7889
if iip
79-
jvp = (_, u, v) -> (du = similar(fu); f.jvp(du, v, u, p); du)
80-
jvp! = (du, _, u, v) -> f.jvp(du, v, u, p)
90+
jvp = (_, u, v) -> (du_ = similar(fu); f.jvp(du_, v, u, p); du_)
91+
jvp! = (du_, _, u, v) -> f.jvp(du_, v, u, p)
8192
else
8293
jvp = (_, u, v) -> f.jvp(v, u, p)
83-
jvp! = (du, _, u, v) -> (du .= f.jvp(v, u, p))
94+
jvp! = (du_, _, u, v) -> (du_ .= f.jvp(v, u, p))
8495
end
8596
op = SparseDiffTools.FwdModeAutoDiffVecProd(f, u, (), jvp, jvp!)
8697
FunctionOperator(op, u, fu; isinplace = Val(true), outofplace = Val(false),
@@ -89,24 +100,27 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u, p, ::Val
89100
else
90101
if has_analytic_jac
91102
f.jac_prototype === nothing ? undefmatrix(u) : f.jac_prototype
103+
elseif f.jac_prototype === nothing
104+
init_jacobian(jac_cache; preserve_immutable = Val(true))
92105
else
93-
f.jac_prototype === nothing ? init_jacobian(jac_cache) : f.jac_prototype
106+
f.jac_prototype
94107
end
95108
end
96109

97-
du = _mutable_zero(u)
110+
du = copy(u)
98111

99112
if needsJᵀJ
100113
JᵀJ, Jᵀfu = __init_JᵀJ(J, _vec(fu), uf, u; f,
101-
vjp_autodiff = __get_nonsparse_ad(_getproperty(alg, Val(:vjp_autodiff))),
114+
vjp_autodiff = __get_nonsparse_ad(__getproperty(alg, Val(:vjp_autodiff))),
102115
jvp_autodiff = __get_nonsparse_ad(alg.ad))
103116
end
104117

105118
if linsolve_init
106119
linprob_A = alg isa PseudoTransient ?
107120
(J - (1 / (convert(eltype(u), alg.alpha_initial))) * I) :
108121
(needsJᵀJ ? __maybe_symmetric(JᵀJ) : J)
109-
linsolve = __setup_linsolve(linprob_A, needsJᵀJ ? Jᵀfu : fu, du, p, alg)
122+
linsolve = linsolve_caches(linprob_A, needsJᵀJ ? Jᵀfu : fu, du, p, alg;
123+
linsolve_kwargs)
110124
else
111125
linsolve = nothing
112126
end
@@ -115,22 +129,33 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u, p, ::Val
115129
return uf, linsolve, J, fu, jac_cache, du
116130
end
117131

118-
function __setup_linsolve(A, b, u, p, alg)
119-
linprob = LinearProblem(A, _vec(b); u0 = _vec(u))
132+
## Special Handling for Scalars
133+
function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u::Number, p,
134+
::Val{false}; linsolve_with_JᵀJ::Val{needsJᵀJ} = Val(false),
135+
kwargs...) where {needsJᵀJ, F}
136+
# NOTE: Scalar `u` assumes scalar output from `f`
137+
uf = SciMLBase.JacobianWrapper{false}(f, p)
138+
needsJᵀJ && return uf, nothing, u, nothing, nothing, u, u, u
139+
return uf, FakeLinearSolveJLCache(u, u), u, nothing, nothing, u
140+
end
120141

121-
weight = similar(u)
122-
recursivefill!(weight, true)
142+
# Linear Solve Cache
143+
function linsolve_caches(A, b, u, p, alg; linsolve_kwargs = (;))
144+
if alg.linsolve === nothing && A isa SMatrix && linsolve_kwargs === (;)
145+
# Default handling for SArrays in LinearSolve is not great. Some parts are patched
146+
# but there are quite a few unnecessary allocations
147+
return FakeLinearSolveJLCache(A, b)
148+
end
149+
150+
linprob = LinearProblem(A, _vec(b); u0 = _vec(u), linsolve_kwargs...)
151+
152+
weight = __init_ones(u)
123153

124154
Pl, Pr = wrapprecs(alg.precs(A, nothing, u, p, nothing, nothing, nothing, nothing,
125155
nothing)..., weight)
126156
return init(linprob, alg.linsolve; alias_A = true, alias_b = true, Pl, Pr)
127157
end
128-
__setup_linsolve(A::KrylovJᵀJ, b, u, p, alg) = __setup_linsolve(A.JᵀJ, b, u, p, alg)
129-
130-
__get_nonsparse_ad(::AutoSparseForwardDiff) = AutoForwardDiff()
131-
__get_nonsparse_ad(::AutoSparseFiniteDiff) = AutoFiniteDiff()
132-
__get_nonsparse_ad(::AutoSparseZygote) = AutoZygote()
133-
__get_nonsparse_ad(ad) = ad
158+
linsolve_caches(A::KrylovJᵀJ, b, u, p, alg) = linsolve_caches(A.JᵀJ, b, u, p, alg)
134159

135160
__init_JᵀJ(J::Number, args...; kwargs...) = zero(J), zero(J)
136161
function __init_JᵀJ(J::AbstractArray, fu, args...; kwargs...)
@@ -180,24 +205,7 @@ function __concrete_vjp_autodiff(vjp_autodiff, jvp_autodiff, uf)
180205
end
181206
end
182207

183-
__maybe_symmetric(x) = Symmetric(x)
184-
__maybe_symmetric(x::Number) = x
185-
# LinearSolve with `nothing` doesn't dispatch correctly here
186-
__maybe_symmetric(x::StaticArray) = x
187-
__maybe_symmetric(x::SparseArrays.AbstractSparseMatrix) = x
188-
__maybe_symmetric(x::SciMLOperators.AbstractSciMLOperator) = x
189-
__maybe_symmetric(x::KrylovJᵀJ) = x.JᵀJ
190-
191-
## Special Handling for Scalars
192-
function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u::Number, p,
193-
::Val{false}; linsolve_with_JᵀJ::Val{needsJᵀJ} = Val(false),
194-
kwargs...) where {needsJᵀJ, F}
195-
# NOTE: Scalar `u` assumes scalar output from `f`
196-
uf = SciMLBase.JacobianWrapper{false}(f, p)
197-
needsJᵀJ && return uf, nothing, u, nothing, nothing, u, u, u
198-
return uf, nothing, u, nothing, nothing, u
199-
end
200-
208+
# Generic Handling of Krylov Methods for Normal Form Linear Solves
201209
function __update_JᵀJ!(iip::Val, cache, sym::Symbol, J)
202210
return __update_JᵀJ!(iip, cache, sym, getproperty(cache, sym), J)
203211
end

src/klement.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::GeneralKleme
8787
linsolve_alg = alg_.linsolve === nothing && u isa Array ? LUFactorization() :
8888
nothing
8989
alg = set_linsolve(alg_, linsolve_alg)
90-
linsolve = __setup_linsolve(J, _vec(fu), _vec(du), p, alg)
90+
linsolve = linsolve_caches(J, _vec(fu), _vec(du), p, alg)
9191
end
9292

9393
abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, fu, u,

src/levenberg.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip},
232232
fill!(mat_tmp, zero(eltype(u)))
233233
rhs_tmp = vcat(_vec(fu1), _vec(u))
234234
fill!(rhs_tmp, zero(eltype(u)))
235-
linsolve = __setup_linsolve(mat_tmp, rhs_tmp, u, p, alg)
235+
linsolve = linsolve_caches(mat_tmp, rhs_tmp, u, p, alg)
236236
end
237237

238238
return LevenbergMarquardtCache{iip, !_unwrap_val(linsolve_with_JᵀJ)}(f, alg, u, copy(u),

0 commit comments

Comments
 (0)