-
-
Notifications
You must be signed in to change notification settings - Fork 232
Description
Describe the bug 🐞
I'm seeing a bug in the order of the expressions in initializeprobmap
with MTKNN. Currently MTKNN forces the defaults for the nn inputs to 0s due to some initialization warning:
┌ Warning: Internal error: Variable (nn₊input₊u(t))[2] was marked as being in 0 ~ (LuxCore.stateless_apply(Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = Dense(2 => 5, tanh), layer_2 = Dense(5 => 5, tanh), layer_3 = Dense(5 => 2)), nothing), nn₊input₊u(t), convert(nn₊T, nn₊p)))[2] - (nn₊output₊u(t))[2], but was actually zero
└ @ ModelingToolkit.StructuralTransformations ~/.julia/dev/ModelingToolkit/src/structural_transformation/utils.jl:237
If I don't provide the defaults, then I hit the OverrideInit
dispatch for _initialize_dae!
and the generated code for the initializeprobmap
getu
is
julia> prob.f.initializeprobmap.obsfn.var"#515#_fn"
RuntimeGeneratedFunction(#=in ModelingToolkit=#, #=using ModelingToolkit=#, :((var"##arg#5964805160111424296", ___mtkparameters___)->begin
#= /home/sebastian/.julia/packages/SymbolicUtils/ij6YM/src/code.jl:385 =#
#= /home/sebastian/.julia/packages/SymbolicUtils/ij6YM/src/code.jl:386 =#
#= /home/sebastian/.julia/packages/SymbolicUtils/ij6YM/src/code.jl:387 =#
begin
var"##arg#1300299927463060512" = ___mtkparameters___[1]
var"##arg#5212838221877633925" = ___mtkparameters___[2]
var"##arg#14786571317958501831" = ___mtkparameters___[3]
begin
var"nn₊p[1]" = var"##arg#1300299927463060512"[1]
var"nn₊p[2]" = var"##arg#1300299927463060512"[2]
var"nn₊p[3]" = var"##arg#1300299927463060512"[3]
var"nn₊p[4]" = var"##arg#1300299927463060512"[4]
var"nn₊p[5]" = var"##arg#1300299927463060512"[5]
var"nn₊p[6]" = var"##arg#1300299927463060512"[6]
var"nn₊p[7]" = var"##arg#1300299927463060512"[7]
var"nn₊p[8]" = var"##arg#1300299927463060512"[8]
var"nn₊p[9]" = var"##arg#1300299927463060512"[9]
var"nn₊p[10]" = var"##arg#1300299927463060512"[10]
var"nn₊p[11]" = var"##arg#1300299927463060512"[11]
var"nn₊p[12]" = var"##arg#1300299927463060512"[12]
var"nn₊p[13]" = var"##arg#1300299927463060512"[13]
var"nn₊p[14]" = var"##arg#1300299927463060512"[14]
var"nn₊p[15]" = var"##arg#1300299927463060512"[15]
var"nn₊p[16]" = var"##arg#1300299927463060512"[16]
var"nn₊p[17]" = var"##arg#1300299927463060512"[17]
var"nn₊p[18]" = var"##arg#1300299927463060512"[18]
var"nn₊p[19]" = var"##arg#1300299927463060512"[19]
var"nn₊p[20]" = var"##arg#1300299927463060512"[20]
var"nn₊p[21]" = var"##arg#1300299927463060512"[21]
var"nn₊p[22]" = var"##arg#1300299927463060512"[22]
var"nn₊p[23]" = var"##arg#1300299927463060512"[23]
var"nn₊p[24]" = var"##arg#1300299927463060512"[24]
var"nn₊p[25]" = var"##arg#1300299927463060512"[25]
var"nn₊p[26]" = var"##arg#1300299927463060512"[26]
var"nn₊p[27]" = var"##arg#1300299927463060512"[27]
var"nn₊p[28]" = var"##arg#1300299927463060512"[28]
var"nn₊p[29]" = var"##arg#1300299927463060512"[29]
var"nn₊p[30]" = var"##arg#1300299927463060512"[30]
var"nn₊p[31]" = var"##arg#1300299927463060512"[31]
var"nn₊p[32]" = var"##arg#1300299927463060512"[32]
var"nn₊p[33]" = var"##arg#1300299927463060512"[33]
var"nn₊p[34]" = var"##arg#1300299927463060512"[34]
var"nn₊p[35]" = var"##arg#1300299927463060512"[35]
var"nn₊p[36]" = var"##arg#1300299927463060512"[36]
var"nn₊p[37]" = var"##arg#1300299927463060512"[37]
var"nn₊p[38]" = var"##arg#1300299927463060512"[38]
var"nn₊p[39]" = var"##arg#1300299927463060512"[39]
var"nn₊p[40]" = var"##arg#1300299927463060512"[40]
var"nn₊p[41]" = var"##arg#1300299927463060512"[41]
var"nn₊p[42]" = var"##arg#1300299927463060512"[42]
var"nn₊p[43]" = var"##arg#1300299927463060512"[43]
var"nn₊p[44]" = var"##arg#1300299927463060512"[44]
var"nn₊p[45]" = var"##arg#1300299927463060512"[45]
var"nn₊p[46]" = var"##arg#1300299927463060512"[46]
var"nn₊p[47]" = var"##arg#1300299927463060512"[47]
var"nn₊p[48]" = var"##arg#1300299927463060512"[48]
var"nn₊p[49]" = var"##arg#1300299927463060512"[49]
var"nn₊p[50]" = var"##arg#1300299927463060512"[50]
var"nn₊p[51]" = var"##arg#1300299927463060512"[51]
var"nn₊p[52]" = var"##arg#1300299927463060512"[52]
var"nn₊p[53]" = var"##arg#1300299927463060512"[53]
var"nn₊p[54]" = var"##arg#1300299927463060512"[54]
var"nn₊p[55]" = var"##arg#1300299927463060512"[55]
var"nn₊p[56]" = var"##arg#1300299927463060512"[56]
var"nn₊p[57]" = var"##arg#1300299927463060512"[57]
t = var"##arg#1300299927463060512"[58]
lotka₊δ = var"##arg#5212838221877633925"[1]
lotka₊α = var"##arg#5212838221877633925"[2]
nn₊T = var"##arg#14786571317958501831"[1]
begin
nn₊p = reshape(view(var"##arg#1300299927463060512", 1:57), (57,))
begin
begin
var"lotka₊x(t)" = 3.1
var"lotka₊y(t)" = 1.5
var"(nn₊output₊u(t))[1]" = (getindex)((LuxCore.stateless_apply)(Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = Dense(2 => 5, tanh), layer_2 = Dense(5 => 5, tanh), layer_3 = Dense(5 => 2)), nothing), begin
#= /home/sebastian/.julia/packages/SymbolicUtils/ij6YM/src/code.jl:480 =#
(SymbolicUtils.Code.create_array)(OffsetArrays.OffsetVector{SymbolicUtils.BasicSymbolic{Real}, Vector{SymbolicUtils.BasicSymbolic{Real}}}, nothing, Val{1}(), Val{(2,)}(), var"(nn₊input₊u(t))[1]", var"(nn₊input₊u(t))[2]")
end, (convert)(nn₊T, nn₊p)), 1)
var"(nn₊output₊u(t))[2]" = (getindex)((LuxCore.stateless_apply)(Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = Dense(2 => 5, tanh), layer_2 = Dense(5 => 5, tanh), layer_3 = Dense(5 => 2)), nothing), begin
#= /home/sebastian/.julia/packages/SymbolicUtils/ij6YM/src/code.jl:480 =#
(SymbolicUtils.Code.create_array)(OffsetArrays.OffsetVector{SymbolicUtils.BasicSymbolic{Real}, Vector{SymbolicUtils.BasicSymbolic{Real}}}, nothing, Val{1}(), Val{(2,)}(), var"(nn₊input₊u(t))[1]", var"(nn₊input₊u(t))[2]")
end, (convert)(nn₊T, nn₊p)), 2)
var"(nn₊input₊u(t))[1]" = var"lotka₊x(t)"
var"(nn₊input₊u(t))[2]" = var"lotka₊y(t)"
begin
#= /home/sebastian/.julia/packages/SymbolicUtils/ij6YM/src/code.jl:480 =#
(SymbolicUtils.Code.create_array)(Array, nothing, Val{1}(), Val{(4,)}(), var"lotka₊x(t)", var"lotka₊y(t)", var"(nn₊input₊u(t))[1]", var"(nn₊input₊u(t))[2]")
end
end
end
end
end
end
end))
Note how var"(nn₊input₊u(t))[1]", var"(nn₊input₊u(t))[2]")
are used before they are declared.
Expected behavior
solve
working
Minimal Reproducible Example 👇
Without MRE, we would only be able to help you to a limited extent, and attention to the issue would be limited. to know more about MRE refer to wikipedia and stackoverflow.
using Test
using ModelingToolkitNeuralNets
using ModelingToolkit
using ModelingToolkitStandardLibrary.Blocks
using OrdinaryDiffEqRosenbrock, OrdinaryDiffEqNonlinearSolve
using SymbolicIndexingInterface
using StableRNGs
function lotka_ude()
@variables t x(t) = 3.1 y(t) = 1.5
@parameters α = 1.3 [tunable = false] δ = 1.8 [tunable = false]
Dt = ModelingToolkit.D_nounits
@named nn_in = RealInputArray(nin=2)
@named nn_out = RealOutputArray(nout=2)
eqs = [
Dt(x) ~ α * x + nn_in.u[1],
Dt(y) ~ -δ * y + nn_in.u[2],
nn_out.u[1] ~ x,
nn_out.u[2] ~ y
]
return ODESystem(
eqs, ModelingToolkit.t_nounits, name=:lotka, systems=[nn_in, nn_out])
end
function lotka_true()
@variables t x(t) = 3.1 y(t) = 1.5
@parameters α = 1.3 β = 0.9 γ = 0.8 δ = 1.8
Dt = ModelingToolkit.D_nounits
eqs = [
Dt(x) ~ α * x - β * x * y,
Dt(y) ~ -δ * y + δ * x * y
]
return ODESystem(eqs, ModelingToolkit.t_nounits, name=:lotka_true)
end
model = lotka_ude()
chain = multi_layer_feed_forward(2, 2)
@named nn = NeuralNetworkBlock(2, 2; chain, rng=StableRNG(42))
eqs = [connect(model.nn_in, nn.output)
connect(model.nn_out, nn.input)]
ude_sys = complete(ODESystem(
eqs, ModelingToolkit.t_nounits, systems=[model, nn],
name=:ude_sys,
# defaults=[nn.input.u => [0.0, 0.0]]
))
sys = structural_simplify(ude_sys)
prob = ODEProblem{true,SciMLBase.FullSpecialize}(sys, [], (0, 1.0), [])
iprob = ModelingToolkit.InitializationProblem(sys, 0.0)
solve(iprob)
solve(prob, Rodas5P())
Error & Stacktrace
ERROR: UndefVarError: `(nn₊input₊u(t))[1]` not defined in local scope
Suggestion: check for an assignment to a local variable that shadows a global of the same name.
Stacktrace:
[1] macro expansion
@ ~/.julia/packages/SymbolicUtils/ij6YM/src/code.jl:480 [inlined]
[2] macro expansion
@ ~/.julia/packages/RuntimeGeneratedFunctions/M9ZX8/src/RuntimeGeneratedFunctions.jl:163 [inlined]
[3] macro expansion
@ ./none:0 [inlined]
[4] generated_callfunc
@ ./none:0 [inlined]
[5] (::RuntimeGeneratedFunctions.RuntimeGeneratedFunction{…})(::Vector{…}, ::MTKParameters{…})
@ RuntimeGeneratedFunctions ~/.julia/packages/RuntimeGeneratedFunctions/M9ZX8/src/RuntimeGeneratedFunctions.jl:150
[6] (::ModelingToolkit.var"#fn2#272"{…})(u::Vector{…}, p::MTKParameters{…})
@ ModelingToolkit ~/.julia/dev/ModelingToolkit/src/systems/abstractsystem.jl:840
[7] (::SymbolicIndexingInterface.TimeIndependentObservedFunction{…})(::NotTimeseries, prob::SciMLBase.NonlinearSolution{…})
@ SymbolicIndexingInterface ~/.julia/packages/SymbolicIndexingInterface/cwAFH/src/state_indexing.jl:142
[8] (::SymbolicIndexingInterface.TimeIndependentObservedFunction{…})(prob::SciMLBase.NonlinearSolution{…})
@ SymbolicIndexingInterface ~/.julia/packages/SymbolicIndexingInterface/cwAFH/src/value_provider_interface.jl:166
[9] _initialize_dae!(integrator::OrdinaryDiffEqCore.ODEIntegrator{…}, prob::ODEProblem{…}, alg::OrdinaryDiffEqCore.OverrideInit{…}, isinplace::Val{…})
@ OrdinaryDiffEqCore ~/.julia/packages/OrdinaryDiffEqCore/HwWWN/src/initialize_dae.jl:174
[10] _initialize_dae!(integrator::OrdinaryDiffEqCore.ODEIntegrator{…}, prob::ODEProblem{…}, alg::OrdinaryDiffEqCore.DefaultInit, x::Val{…})
@ OrdinaryDiffEqCore ~/.julia/packages/OrdinaryDiffEqCore/HwWWN/src/initialize_dae.jl:60
[11] initialize_dae!(integrator::OrdinaryDiffEqCore.ODEIntegrator{…}, initializealg::OrdinaryDiffEqCore.DefaultInit)
@ OrdinaryDiffEqCore ~/.julia/packages/OrdinaryDiffEqCore/HwWWN/src/initialize_dae.jl:50
[12] __init(prob::ODEProblem{…}, alg::Rodas5P{…}, timeseries_init::Tuple{}, ts_init::Tuple{}, ks_init::Tuple{}, recompile::Type{…}; saveat::Tuple{}, tstops::Tuple{}, d_discontinuities::Tuple{}, save_idxs::Nothing, save_everystep::Bool, save_on::Bool, save_start::Bool, save_end::Nothing, callback::Nothing, dense::Bool, calck::Bool, dt::Float64, dtmin::Float64, dtmax::Float64, force_dtmin::Bool, adaptive::Bool, gamma::Rational{…}, abstol::Nothing, reltol::Nothing, qmin::Rational{…}, qmax::Int64, qsteady_min::Int64, qsteady_max::Rational{…}, beta1::Nothing, beta2::Nothing, qoldinit::Rational{…}, controller::Nothing, fullnormalize::Bool, failfactor::Int64, maxiters::Int64, internalnorm::typeof(DiffEqBase.ODE_DEFAULT_NORM), internalopnorm::typeof(LinearAlgebra.opnorm), isoutofdomain::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), unstable_check::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), verbose::Bool, timeseries_errors::Bool, dense_errors::Bool, advance_to_tstop::Bool, stop_at_next_tstop::Bool, initialize_save::Bool, progress::Bool, progress_steps::Int64, progress_name::String, progress_message::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), progress_id::Symbol, userdata::Nothing, allow_extrapolation::Bool, initialize_integrator::Bool, alias_u0::Bool, alias_du0::Bool, initializealg::OrdinaryDiffEqCore.DefaultInit, kwargs::@Kwargs{})
@ OrdinaryDiffEqCore ~/.julia/packages/OrdinaryDiffEqCore/HwWWN/src/solve.jl:503
[13] __init (repeats 5 times)
@ ~/.julia/packages/OrdinaryDiffEqCore/HwWWN/src/solve.jl:11 [inlined]
[14] __solve(::ODEProblem{…}, ::Rodas5P{…}; kwargs::@Kwargs{})
@ OrdinaryDiffEqCore ~/.julia/packages/OrdinaryDiffEqCore/HwWWN/src/solve.jl:6
[15] __solve
@ ~/.julia/packages/OrdinaryDiffEqCore/HwWWN/src/solve.jl:1 [inlined]
[16] #solve_call#44
@ ~/.julia/packages/DiffEqBase/uqSeD/src/solve.jl:612 [inlined]
[17] solve_call(_prob::ODEProblem{…}, args::Rodas5P{…})
@ DiffEqBase ~/.julia/packages/DiffEqBase/uqSeD/src/solve.jl:569
[18] solve_up(prob::ODEProblem{…}, sensealg::Nothing, u0::Vector{…}, p::MTKParameters{…}, args::Rodas5P{…}; kwargs::@Kwargs{})
@ DiffEqBase ~/.julia/packages/DiffEqBase/uqSeD/src/solve.jl:1092
[19] solve_up
@ ~/.julia/packages/DiffEqBase/uqSeD/src/solve.jl:1078 [inlined]
[20] solve(prob::ODEProblem{…}, args::Rodas5P{…}; sensealg::Nothing, u0::Nothing, p::Nothing, wrap::Val{…}, kwargs::@Kwargs{})
@ DiffEqBase ~/.julia/packages/DiffEqBase/uqSeD/src/solve.jl:1015
[21] solve(prob::ODEProblem{…}, args::Rodas5P{…})
@ DiffEqBase ~/.julia/packages/DiffEqBase/uqSeD/src/solve.jl:1005
[22] top-level scope
Environment (please complete the following information):
- Output of
using Pkg; Pkg.status()
- Output of
using Pkg; Pkg.status(; mode = PKGMODE_MANIFEST)
- Output of
versioninfo()
Additional context
Add any other context about the problem here.