Skip to content

Commit 5c8926e

Browse files
feat: better inbounds handling and propagtion in generated functions
1 parent 31f7a54 commit 5c8926e

File tree

9 files changed

+74
-17
lines changed

9 files changed

+74
-17
lines changed

src/systems/abstractsystem.jl

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,9 @@ function generate_custom_function(sys::AbstractSystem, exprs, dvs = unknowns(sys
172172
if wrap_code === nothing
173173
wrap_code = isscalar ? identity : (identity, identity)
174174
end
175+
if !get(kwargs, :checkbounds, false)
176+
wrap_code = wrap_code .∘ wrap_inbounds(false)
177+
end
175178
pre, sol_states = get_substitutions_and_solved_unknowns(sys, isscalar ? [exprs] : exprs)
176179
if postprocess_fbody === nothing
177180
postprocess_fbody = pre
@@ -226,6 +229,13 @@ function wrap_assignments(isscalar, assignments; let_block = false)
226229
end
227230
end
228231

232+
function wrap_inbounds(isscalar)
233+
function wrapper(expr)
234+
Func(expr.args, [], :(@inbounds begin; $(toexpr(expr.body)); end))
235+
end
236+
return isscalar ? wrapper : (wrapper, wrapper)
237+
end
238+
229239
function wrap_parameter_dependencies(sys::AbstractSystem, isscalar)
230240
wrap_assignments(isscalar, [eq.lhs eq.rhs for eq in parameter_dependencies(sys)])
231241
end
@@ -785,7 +795,7 @@ end
785795
SymbolicIndexingInterface.supports_tuple_observed(::AbstractSystem) = true
786796

787797
function SymbolicIndexingInterface.observed(
788-
sys::AbstractSystem, sym; eval_expression = false, eval_module = @__MODULE__)
798+
sys::AbstractSystem, sym; eval_expression = false, eval_module = @__MODULE__, checkbounds = true)
789799
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
790800
if sym isa Symbol
791801
_sym = get(ic.symbol_to_variable, sym, nothing)
@@ -808,7 +818,7 @@ function SymbolicIndexingInterface.observed(
808818
end
809819
end
810820
end
811-
_fn = build_explicit_observed_function(sys, sym; eval_expression, eval_module)
821+
_fn = build_explicit_observed_function(sys, sym; eval_expression, eval_module, checkbounds)
812822

813823
if is_time_dependent(sys)
814824
return _fn
@@ -1671,11 +1681,12 @@ struct ObservedFunctionCache{S}
16711681
steady_state::Bool
16721682
eval_expression::Bool
16731683
eval_module::Module
1684+
checkbounds::Bool
16741685
end
16751686

16761687
function ObservedFunctionCache(
1677-
sys; steady_state = false, eval_expression = false, eval_module = @__MODULE__)
1678-
return ObservedFunctionCache(sys, Dict(), steady_state, eval_expression, eval_module)
1688+
sys; steady_state = false, eval_expression = false, eval_module = @__MODULE__, checkbounds = true)
1689+
return ObservedFunctionCache(sys, Dict(), steady_state, eval_expression, eval_module, checkbounds)
16791690
end
16801691

16811692
# This is hit because ensemble problems do a deepcopy
@@ -1694,7 +1705,7 @@ function (ofc::ObservedFunctionCache)(obsvar, args...)
16941705
obs = get!(ofc.dict, value(obsvar)) do
16951706
SymbolicIndexingInterface.observed(
16961707
ofc.sys, obsvar; eval_expression = ofc.eval_expression,
1697-
eval_module = ofc.eval_module)
1708+
eval_module = ofc.eval_module, checkbounds = ofc.checkbounds)
16981709
end
16991710
if ofc.steady_state
17001711
obs = let fn = obs

src/systems/callbacks.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -583,9 +583,14 @@ function compile_condition(cb::SymbolicDiscreteCallback, sys, dvs, ps;
583583
cmap = map(x -> x => getdefault(x), cs)
584584
condit = substitute(condit, cmap)
585585
end
586+
if !get(kwargs, :checkbounds, false)
587+
inbounds_wrapper = wrap_inbounds(!(condit isa AbstractArray))
588+
else
589+
inbounds_wrapper = condit isa AbstractArray ? (identity, identity) : identity
590+
end
586591
expr = build_function(
587592
condit, u, t, p...; expression = Val{true},
588-
wrap_code = condition_header(sys) .∘
593+
wrap_code = condition_header(sys) .∘ inbounds_wrapper .∘
589594
wrap_array_vars(sys, condit; dvs, ps, inputs = true) .∘
590595
wrap_parameter_dependencies(sys, !(condit isa AbstractArray)),
591596
kwargs...)
@@ -671,6 +676,7 @@ function compile_affect(eqs::Vector{Equation}, cb, sys, dvs, ps; outputidxs = no
671676
t = get_iv(sys)
672677
integ = gensym(:MTKIntegrator)
673678
pre = get_preprocess_constants(rhss)
679+
inbounds_wrapper = get(kwargs, :checkbounds, false) ? (identity, identity) : wrap_inbounds(false)
674680
rf_oop, rf_ip = build_function(rhss, u, p..., t; expression = Val{true},
675681
wrap_code = callback_save_header(sys, cb) .∘
676682
add_integrator_header(sys, integ, outvar) .∘

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,9 @@ function generate_tgrad(
116116
else
117117
(ps,)
118118
end
119+
if !get(kwargs, :checkbounds, false)
120+
wrap_code = wrap_code .∘ wrap_inbounds(false)
121+
end
119122
wrap_code = wrap_code .∘ wrap_array_vars(sys, tgrad; dvs, ps) .∘
120123
wrap_parameter_dependencies(sys, !(tgrad isa AbstractArray))
121124
return build_function(tgrad,
@@ -137,6 +140,9 @@ function generate_jacobian(sys::AbstractODESystem, dvs = unknowns(sys),
137140
else
138141
(ps,)
139142
end
143+
if !get(kwargs, :checkbounds, false)
144+
wrap_code = wrap_code .∘ wrap_inbounds(false)
145+
end
140146
wrap_code = wrap_code .∘ wrap_array_vars(sys, jac; dvs, ps) .∘
141147
wrap_parameter_dependencies(sys, false)
142148
return build_function(jac,
@@ -208,6 +214,10 @@ function generate_function(sys::AbstractODESystem, dvs = unknowns(sys),
208214
p = map.(x -> time_varying_as_func(value(x), sys), reorder_parameters(sys, ps))
209215
t = get_iv(sys)
210216

217+
if !get(kwargs, :checkbounds, false)
218+
wrap_code = wrap_code .∘ wrap_inbounds(false)
219+
end
220+
211221
if isdde
212222
build_function(rhss, u, DDE_HISTORY_FUN, p..., t; kwargs...,
213223
wrap_code = wrap_code .∘ wrap_mtkparameters(sys, false, 3) .∘
@@ -439,7 +449,7 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
439449
ArrayInterface.restructure(u0 .* u0', M)
440450
end
441451

442-
observedfun = ObservedFunctionCache(sys; steady_state, eval_expression, eval_module)
452+
observedfun = ObservedFunctionCache(sys; steady_state, eval_expression, eval_module, checkbounds)
443453

444454
jac_prototype = if sparse
445455
uElType = u0 === nothing ? Float64 : eltype(u0)
@@ -531,7 +541,7 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
531541
_jac = nothing
532542
end
533543

534-
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
544+
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false))
535545

536546
jac_prototype = if sparse
537547
uElType = u0 === nothing ? Float64 : eltype(u0)

src/systems/diffeqs/odesystem.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -629,6 +629,12 @@ function build_explicit_observed_function(sys, ts;
629629
oop_mtkp_wrapper = mtkparams_wrapper
630630
end
631631

632+
if !checkbounds
633+
inbounds_wrapper = wrap_inbounds(false)
634+
else
635+
inbounds_wrapper = (identity, identity)
636+
end
637+
632638
# Need to keep old method of building the function since it uses `output_type`,
633639
# which can't be provided to `build_function`
634640
return_value = if isscalar
@@ -641,14 +647,14 @@ function build_explicit_observed_function(sys, ts;
641647
oop_fn = Func(args, [],
642648
pre(Let(obsexprs,
643649
return_value,
644-
false))) |> array_wrapper[1] |> oop_mtkp_wrapper |> toexpr
650+
false))) |> array_wrapper[1] |> oop_mtkp_wrapper |> inbounds_wrapper[1] |> toexpr
645651
oop_fn = expression ? oop_fn : eval_or_rgf(oop_fn; eval_expression, eval_module)
646652

647653
if !isscalar
648654
iip_fn = build_function(ts,
649655
args...;
650656
postprocess_fbody = pre,
651-
wrap_code = mtkparams_wrapper .∘ array_wrapper .∘
657+
wrap_code = inbounds_wrapper .∘ mtkparams_wrapper .∘ array_wrapper .∘
652658
wrap_assignments(isscalar, obsexprs),
653659
expression = Val{true})[2]
654660
if !expression

src/systems/diffeqs/sdesystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -589,7 +589,7 @@ function DiffEqBase.SDEFunction{iip, specialize}(sys::SDESystem, dvs = unknowns(
589589
M = calculate_massmatrix(sys)
590590
_M = (u0 === nothing || M == I) ? M : ArrayInterface.restructure(u0 .* u0', M)
591591

592-
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
592+
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false))
593593

594594
SDEFunction{iip, specialize}(f, g,
595595
sys = sys,

src/systems/discrete_system/discrete_system.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,9 @@ end
234234
function generate_function(
235235
sys::DiscreteSystem, dvs = unknowns(sys), ps = parameters(sys); wrap_code = identity, kwargs...)
236236
exprs = [eq.rhs for eq in equations(sys)]
237+
if !get(kwargs, :checkbounds, false)
238+
wrap_code = wrap_code .∘ wrap_inbounds(false)
239+
end
237240
wrap_code = wrap_code .∘ wrap_array_vars(sys, exprs) .∘
238241
wrap_parameter_dependencies(sys, false)
239242
generate_custom_function(sys, exprs, dvs, ps; wrap_code, kwargs...)
@@ -327,7 +330,7 @@ function SciMLBase.DiscreteFunction{iip, specialize}(
327330
f = SciMLBase.wrapfun_iip(f, (u0, u0, p, t))
328331
end
329332

330-
observedfun = ObservedFunctionCache(sys)
333+
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false))
331334

332335
DiscreteFunction{iip, specialize}(f;
333336
sys = sys,

src/systems/jumps/jumpsystem.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,7 @@ function DiffEqBase.DiscreteProblem(sys::JumpSystem, u0map, tspan::Union{Tuple,
406406
t = tspan === nothing ? nothing : tspan[1], use_union, tofloat = false, check_length = false)
407407
f = DiffEqBase.DISCRETE_INPLACE_DEFAULT
408408

409-
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
409+
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false))
410410

411411
df = DiscreteFunction{true, true}(f; sys = sys, observed = observedfun)
412412
DiscreteProblem(df, u0, tspan, p; kwargs...)
@@ -504,7 +504,7 @@ function DiffEqBase.ODEProblem(sys::JumpSystem, u0map, tspan::Union{Tuple, Nothi
504504
t = tspan === nothing ? nothing : tspan[1], use_union, tofloat = false,
505505
check_length = false)
506506
f = (du, u, p, t) -> (du .= 0; nothing)
507-
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
507+
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false))
508508
df = ODEFunction(f; sys, observed = observedfun)
509509
return ODEProblem(df, u0, tspan, p; kwargs...)
510510
end

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,9 @@ function generate_jacobian(
227227
jac = calculate_jacobian(sys, sparse = sparse, simplify = simplify)
228228
pre, sol_states = get_substitutions_and_solved_unknowns(sys)
229229
p = reorder_parameters(sys, ps)
230+
if !get(kwargs, :checkbounds, false)
231+
wrap_code = wrap_code .∘ wrap_inbounds(false)
232+
end
230233
wrap_code = wrap_code .∘ wrap_array_vars(sys, jac; dvs = vs, ps) .∘
231234
wrap_parameter_dependencies(sys, false)
232235
return build_function(
@@ -251,6 +254,9 @@ function generate_hessian(
251254
hess = calculate_hessian(sys, sparse = sparse, simplify = simplify)
252255
pre = get_preprocess_constants(hess)
253256
p = reorder_parameters(sys, ps)
257+
if !get(kwargs, :checkbounds, false)
258+
wrap_code = wrap_code .∘ wrap_inbounds(false)
259+
end
254260
wrap_code = wrap_code .∘ wrap_array_vars(sys, hess; dvs = vs, ps) .∘
255261
wrap_parameter_dependencies(sys, false)
256262
return build_function(hess, vs, p...; postprocess_fbody = pre, wrap_code, kwargs...)
@@ -266,6 +272,9 @@ function generate_function(
266272
dvs′ = only(dvs)
267273
end
268274
pre, sol_states = get_substitutions_and_solved_unknowns(sys)
275+
if !get(kwargs, :checkbounds, false)
276+
wrap_code = wrap_code .∘ wrap_inbounds(false)
277+
end
269278
wrap_code = wrap_code .∘ wrap_array_vars(sys, rhss; dvs, ps) .∘
270279
wrap_parameter_dependencies(sys, scalar)
271280
p = reorder_parameters(sys, value.(ps))
@@ -342,7 +351,7 @@ function SciMLBase.NonlinearFunction{iip}(sys::NonlinearSystem, dvs = unknowns(s
342351
_jac = nothing
343352
end
344353

345-
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
354+
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false))
346355

347356
if length(dvs) == length(equations(sys))
348357
resid_prototype = nothing
@@ -383,7 +392,7 @@ function SciMLBase.IntervalNonlinearFunction(
383392
f(u, p) = f_oop(u, p)
384393
f(u, p::MTKParameters) = f_oop(u, p...)
385394

386-
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
395+
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false))
387396

388397
IntervalNonlinearFunction{false}(f; observed = observedfun, sys = sys)
389398
end
@@ -579,6 +588,9 @@ function SCCNonlinearFunction{iip}(
579588
cmap, cs = get_cmap(sys)
580589
cmap_assignments = [eq.lhs eq.rhs for eq in cmap]
581590
rhss = [eq.rhs - eq.lhs for eq in _eqs]
591+
if !get(kwargs, :checkbounds, false)
592+
wrap_code = wrap_code .∘ wrap_inbounds(false)
593+
end
582594
wrap_code = wrap_assignments(false, cmap_assignments) .∘
583595
(wrap_array_vars(sys, rhss; dvs = _dvs, cachesyms)) .∘
584596
wrap_parameter_dependencies(sys, false) .∘

src/systems/optimization/optimizationsystem.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,9 @@ function generate_gradient(sys::OptimizationSystem, vs = unknowns(sys),
199199
grad = calculate_gradient(sys)
200200
pre = get_preprocess_constants(grad)
201201
p = reorder_parameters(sys, ps)
202+
if !get(kwargs, :checkbounds, false)
203+
wrap_code = wrap_code .∘ wrap_inbounds(false)
204+
end
202205
wrap_code = wrap_code .∘ wrap_array_vars(sys, grad; dvs = vs, ps) .∘
203206
wrap_parameter_dependencies(sys, !(grad isa AbstractArray))
204207
return build_function(grad, vs, p...; postprocess_fbody = pre, wrap_code,
@@ -219,6 +222,9 @@ function generate_hessian(
219222
end
220223
pre = get_preprocess_constants(hess)
221224
p = reorder_parameters(sys, ps)
225+
if !get(kwargs, :checkbounds, false)
226+
wrap_code = wrap_code .∘ wrap_inbounds(false)
227+
end
222228
wrap_code = wrap_code .∘ wrap_array_vars(sys, hess; dvs = vs, ps) .∘
223229
wrap_parameter_dependencies(sys, false)
224230
return build_function(hess, vs, p...; postprocess_fbody = pre, wrap_code,
@@ -235,6 +241,9 @@ function generate_function(sys::OptimizationSystem, vs = unknowns(sys),
235241
else
236242
(ps,)
237243
end
244+
if !get(kwargs, :checkbounds, false)
245+
wrap_code = wrap_code .∘ wrap_inbounds(false)
246+
end
238247
wrap_code = wrap_code .∘ wrap_array_vars(sys, eqs; dvs = vs, ps) .∘
239248
wrap_parameter_dependencies(sys, !(eqs isa AbstractArray))
240249
return build_function(eqs, vs, p...; wrap_code,
@@ -419,7 +428,7 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0map,
419428
hess_prototype = nothing
420429
end
421430

422-
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
431+
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module, checkbounds)
423432

424433
if length(cstr) > 0
425434
@named cons_sys = ConstraintsSystem(cstr, dvs, ps; checks)

0 commit comments

Comments
 (0)