Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ AbstractTrees = "0.3, 0.4"
ArrayInterface = "6, 7"
BifurcationKit = "0.4"
BlockArrays = "1.1"
BoundaryValueDiffEqAscher = "1.1.0"
BoundaryValueDiffEqMIRK = "1.4.0"
BoundaryValueDiffEqAscher = "1.6.0"
BoundaryValueDiffEqMIRK = "1.7.0"
CasADi = "1.0.6"
ChainRulesCore = "1"
Combinatorics = "1"
Expand Down
66 changes: 36 additions & 30 deletions ext/MTKCasADiDynamicOptExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,14 @@ struct CasADiDynamicOptProblem{uType, tType, isinplace, P, F, K} <:
end
end

function (M::MXLinearInterpolation)(τ)
function (M::MXLinearInterpolation)(τ)
nt = (τ - M.t[1]) / M.dt
i = 1 + floor(Int, nt)
Δ = nt - i + 1

(i > length(M.t) || i < 1) && error("Cannot extrapolate past the tspan.")
if i < length(M.t)
M.u[:, i] + Δ*(M.u[:, i + 1] - M.u[:, i])
M.u[:, i] + Δ * (M.u[:, i + 1] - M.u[:, i])
else
M.u[:, i]
end
Expand All @@ -74,7 +74,7 @@ The constraints are:
function MTK.CasADiDynamicOptProblem(sys::ODESystem, u0map, tspan, pmap;
dt = nothing,
steps = nothing,
guesses = Dict(), kwargs...)
guesses = Dict(), kwargs...)
MTK.warn_overdetermined(sys, u0map)
_u0map = has_alg_eqs(sys) ? u0map : merge(Dict(u0map), Dict(guesses))
f, u0, p = MTK.process_SciMLProblem(ODEInputFunction, sys, _u0map, pmap;
Expand Down Expand Up @@ -104,21 +104,21 @@ function init_model(sys, tspan, steps, u0map, pmap, u0; is_free_t = false)
subject_to!(opti, tₛ >= lo)
subject_to!(opti, tₛ >= hi)
end
pmap[te_sym] = tₛ
pmap[te_sym] = tₛ
tsteps = LinRange(0, 1, steps)
else
tₛ = MX(1)
tsteps = LinRange(tspan[1], tspan[2], steps)
end

U = CasADi.variable!(opti, length(states), steps)
V = CasADi.variable!(opti, length(ctrls), steps)
set_initial!(opti, U, DM(repeat(u0, 1, steps)))
c0 = MTK.value.([pmap[c] for c in ctrls])
!isempty(c0) && set_initial!(opti, V, DM(repeat(c0, 1, steps)))

U_interp = MXLinearInterpolation(U, tsteps, tsteps[2]-tsteps[1])
V_interp = MXLinearInterpolation(V, tsteps, tsteps[2]-tsteps[1])
U_interp = MXLinearInterpolation(U, tsteps, tsteps[2] - tsteps[1])
V_interp = MXLinearInterpolation(V, tsteps, tsteps[2] - tsteps[1])
for (i, ct) in enumerate(ctrls)
pmap[ct] = V[i, :]
end
Expand Down Expand Up @@ -185,8 +185,8 @@ function add_user_constraints!(model::CasADiModel, sys, tspan, pmap; is_free_t)
x = MTK.operation(st)
t = only(MTK.arguments(st))
MTK.symbolic_type(t) === MTK.NotSymbolic() || continue
if haskey(stidxmap, x(iv))
idx = stidxmap[x(iv)]
if haskey(stidxmap, x(iv))
idx = stidxmap[x(iv)]
cv = U
else
idx = ctidxmap[x(iv)]
Expand All @@ -196,11 +196,11 @@ function add_user_constraints!(model::CasADiModel, sys, tspan, pmap; is_free_t)
end

if cons isa Equation
subject_to!(opti, cons.lhs - cons.rhs==0)
subject_to!(opti, cons.lhs - cons.rhs == 0)
elseif cons.relational_op === Symbolics.geq
subject_to!(opti, cons.lhs - cons.rhs0)
subject_to!(opti, cons.lhs - cons.rhs0)
else
subject_to!(opti, cons.lhs - cons.rhs0)
subject_to!(opti, cons.lhs - cons.rhs0)
end
end
end
Expand All @@ -227,8 +227,8 @@ function add_cost_function!(model::CasADiModel, sys, tspan, pmap; is_free_t)
x = operation(st)
t = only(arguments(st))
MTK.symbolic_type(t) === MTK.NotSymbolic() || continue
if haskey(stidxmap, x(iv))
idx = stidxmap[x(iv)]
if haskey(stidxmap, x(iv))
idx = stidxmap[x(iv)]
cv = U
else
idx = ctidxmap[x(iv)]
Expand All @@ -244,7 +244,8 @@ function add_cost_function!(model::CasADiModel, sys, tspan, pmap; is_free_t)
op = MTK.operation(int)
arg = only(arguments(MTK.value(int)))
lo, hi = (op.domain.domain.left, op.domain.domain.right)
!isequal((lo, hi), tspan) && error("Non-whole interval bounds for integrals are not currently supported for CasADiDynamicOptProblem.")
!isequal((lo, hi), tspan) &&
error("Non-whole interval bounds for integrals are not currently supported for CasADiDynamicOptProblem.")
# Approximate integral as sum.
intmap[int] = dt * tₛ * sum(arg)
end
Expand All @@ -253,7 +254,8 @@ function add_cost_function!(model::CasADiModel, sys, tspan, pmap; is_free_t)
minimize!(opti, MX(MTK.value(consolidate(jcosts))))
end

function substitute_casadi_vars(model::CasADiModel, sys, pmap, exprs; auxmap::Dict = Dict(), is_free_t)
function substitute_casadi_vars(
model::CasADiModel, sys, pmap, exprs; auxmap::Dict = Dict(), is_free_t)
@unpack opti, U, V, tₛ = model
iv = MTK.get_iv(sys)
sts = unknowns(sys)
Expand Down Expand Up @@ -281,44 +283,44 @@ end

function add_solve_constraints(prob, tableau)
@unpack A, α, c = tableau
@unpack model, f, p = prob
@unpack model, f, p = prob
@unpack opti, U, V, tₛ = model
solver_opti = copy(opti)

tsteps = U.t
tsteps = U.t
dt = tsteps[2] - tsteps[1]

nᵤ = size(U.u, 1)
nᵥ = size(V.u, 1)

if MTK.is_explicit(tableau)
K = MX[]
for k in 1:length(tsteps)-1
for k in 1:(length(tsteps) - 1)
τ = tsteps[k]
for (i, h) in enumerate(c)
ΔU = sum([A[i, j] * K[j] for j in 1:(i - 1)], init = MX(zeros(nᵤ)))
Uₙ = U.u[:, k] + ΔU*dt
Uₙ = U.u[:, k] + ΔU * dt
Vₙ = V.u[:, k]
Kₙ = tₛ * f(Uₙ, Vₙ, p, τ + h * dt) # scale the time
push!(K, Kₙ)
end
ΔU = dt * sum([α[i] * K[i] for i in 1:length(α)])
subject_to!(solver_opti, U.u[:, k] + ΔU == U.u[:, k+1])
subject_to!(solver_opti, U.u[:, k] + ΔU == U.u[:, k + 1])
empty!(K)
end
else
for k in 1:length(tsteps)-1
for k in 1:(length(tsteps) - 1)
τ = tsteps[k]
Kᵢ = variable!(solver_opti, nᵤ, length(α))
ΔUs = A * Kᵢ' # the stepsize at each stage of the implicit method
for (i, h) in enumerate(c)
ΔU = ΔUs[i,:]'
Uₙ = U.u[:,k] + ΔU*dt
Vₙ = V.u[:,k]
subject_to!(solver_opti, Kᵢ[:,i] == tₛ * f(Uₙ, Vₙ, p, τ + h*dt))
ΔU = ΔUs[i, :]'
Uₙ = U.u[:, k] + ΔU * dt
Vₙ = V.u[:, k]
subject_to!(solver_opti, Kᵢ[:, i] == tₛ * f(Uₙ, Vₙ, p, τ + h * dt))
end
ΔU_tot = dt*(Kᵢ*α)
subject_to!(solver_opti, U.u[:, k] + ΔU_tot == U.u[:,k+1])
ΔU_tot = dt * (Kᵢ * α)
subject_to!(solver_opti, U.u[:, k] + ΔU_tot == U.u[:, k + 1])
end
end
solver_opti
Expand All @@ -331,7 +333,10 @@ end

NOTE: the solver should be passed in as a string to CasADi. "ipopt"
"""
function DiffEqBase.solve(prob::CasADiDynamicOptProblem, solver::Union{String, Symbol} = "ipopt", tableau_getter = MTK.constructDefault; plugin_options::Dict = Dict(), solver_options::Dict = Dict(), silent = false)
function DiffEqBase.solve(
prob::CasADiDynamicOptProblem, solver::Union{String, Symbol} = "ipopt",
tableau_getter = MTK.constructDefault; plugin_options::Dict = Dict(),
solver_options::Dict = Dict(), silent = false)
@unpack model, u0, p, tspan, f = prob
tableau = tableau_getter()
@unpack opti, U, V, tₛ = model
Expand Down Expand Up @@ -366,7 +371,8 @@ function DiffEqBase.solve(prob::CasADiDynamicOptProblem, solver::Union{String, S
end

if failed
ode_sol = SciMLBase.solution_new_retcode(ode_sol, SciMLBase.ReturnCode.ConvergenceFailure)
ode_sol = SciMLBase.solution_new_retcode(
ode_sol, SciMLBase.ReturnCode.ConvergenceFailure)
!isnothing(input_sol) && (input_sol = SciMLBase.solution_new_retcode(
input_sol, SciMLBase.ReturnCode.ConvergenceFailure))
end
Expand Down
3 changes: 2 additions & 1 deletion src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,8 @@ export AnalysisPoint, get_sensitivity_function, get_comp_sensitivity_function,
function FMIComponent end

include("systems/optimal_control_interface.jl")
export AbstractDynamicOptProblem, JuMPDynamicOptProblem, InfiniteOptDynamicOptProblem, CasADiDynamicOptProblem
export AbstractDynamicOptProblem, JuMPDynamicOptProblem, InfiniteOptDynamicOptProblem,
CasADiDynamicOptProblem
export DynamicOptSolution

end # module
2 changes: 1 addition & 1 deletion src/systems/optimal_control_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ function constructDefault(T::Type = Float64)
A = map(T, A)
α = map(T, α)
c = map(T, c)

DiffEqBase.ImplicitRKTableau(A, c, α, 5)
end

Expand Down
8 changes: 4 additions & 4 deletions test/bvproblem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ daesolvers = [Ascher2, Ascher4, Ascher6]

for solver in solvers
sol = solve(bvp, solver(), dt = 0.01)
@test_broken isapprox(sol.u[end], osol.u[end]; atol = 0.01)
@test_broken sol.u[1] == [1.0, 2.0]
@test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
@test sol.u[1] == [1.0, 2.0]
end

# Test out of place
Expand All @@ -39,8 +39,8 @@ daesolvers = [Ascher2, Ascher4, Ascher6]

for solver in solvers
sol = solve(bvp2, solver(), dt = 0.01)
@test_broken isapprox(sol.u[end], osol.u[end]; atol = 0.01)
@test_broken sol.u[1] == [1.0, 2.0]
@test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
@test sol.u[1] == [1.0, 2.0]
end
end

Expand Down
8 changes: 5 additions & 3 deletions test/extensions/dynamic_optimization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ const M = ModelingToolkit
@test jsol.sol(0.6)[1] ≈ 3.5
@test jsol.sol(0.3)[1] ≈ 7.0

cprob = CasADiDynamicOptProblem(lksys, u0map, tspan, parammap; guesses = guess, dt = 0.01)
cprob = CasADiDynamicOptProblem(
lksys, u0map, tspan, parammap; guesses = guess, dt = 0.01)
csol = solve(cprob, "ipopt", constructTsitouras5, silent = true)
@test csol.sol(0.6)[1] ≈ 3.5
@test csol.sol(0.3)[1] ≈ 7.0
Expand All @@ -87,7 +88,8 @@ const M = ModelingToolkit
jsol = solve(jprob, Ipopt.Optimizer, constructRadauIA3, silent = true) # 12.190 s, 9.68 GiB
@test all(u -> u > [1, 1], jsol.sol.u)

cprob = CasADiDynamicOptProblem(lksys, u0map, tspan, parammap; guesses = guess, dt = 0.01)
cprob = CasADiDynamicOptProblem(
lksys, u0map, tspan, parammap; guesses = guess, dt = 0.01)
csol = solve(cprob, "ipopt", constructRadauIA3, silent = true)
@test all(u -> u > [1, 1], csol.sol.u)
end
Expand Down Expand Up @@ -220,7 +222,7 @@ end
jprob = JuMPDynamicOptProblem(rocket, u0map, (ts, te), pmap; dt = 0.001, cse = false)
jsol = solve(jprob, Ipopt.Optimizer, constructRadauIIA5, silent = true)
@test jsol.sol.u[end][1] > 1.012

cprob = CasADiDynamicOptProblem(rocket, u0map, (ts, te), pmap; dt = 0.001, cse = false)
csol = solve(cprob, "ipopt"; silent = true)
@test csol.sol.u[end][1] > 1.012
Expand Down
Loading