diff --git a/.gitignore b/.gitignore index 5a4b8af..679106e 100644 --- a/.gitignore +++ b/.gitignore @@ -23,4 +23,5 @@ docs/site/ # environment. Manifest.toml -.vscode \ No newline at end of file +.vscode +script/ \ No newline at end of file diff --git a/src/deim.jl b/src/deim.jl index 95a19a7..7ecf101 100644 --- a/src/deim.jl +++ b/src/deim.jl @@ -26,6 +26,45 @@ function deim_interpolation_indices(basis::AbstractMatrix)::Vector{Int} return indices end +""" +$(TYPEDSIGNATURES) + +Compute the QDEIM interpolation indices for the given projection basis. +""" +function qdeim_interpolation_indices(basis::AbstractMatrix)::Vector{Int} + dim = size(basis, 2) + return qr(basis', ColumnNorm()).p[1:dim] +end + +""" +$(TYPEDSIGNATURES) + +Compute the ODEIM interpolation indices for the given projection basis. +""" +function odeim_interpolation_indices(basis::AbstractMatrix, m::Int)::Vector{Int} + dim = size(basis, 2) + @assert m >= dim && m <= size(basis,1) "Invalid sampling dimension" + + # Compute the first dim points with QDEIM + p = qdeim_interpolation_indices(basis) + + # select points n+1, ..., m + for _ in (length(p) + 1):m + _, S, W = svd(basis[p, :]) + gap = S[end - 1]^2 - S[end]^2 # eigengap + proj_basis = W' * basis' + r = gap .+ sum(proj_basis.^2, dims=1) + r -= sqrt.((gap .+ sum(proj_basis.^2, dims=1)).^2 - 4 * gap * (proj_basis[end, :].^2)') + indices = sortperm(vec(r), rev=true) + e = 1 + while any(indices[e] .== p) + e += 1 + end + push!(p, indices[e]) + end + return p +end + """ $(SIGNATURES) @@ -51,6 +90,10 @@ where ``P=[\\mathbf e_{\\rho_1},\\dots,\\mathbf e_{\\rho_m}]\\in\\mathbb R^{n\\t algorithm, and ``\\mathbf e_{\\rho_i}=[0,\\ldots,0,1,0,\\ldots,0]^T\\in\\mathbb R^n`` is the ``\\rho_i``-th column of the identity matrix ``I_n\\in\\mathbb R^{n\\times n}``. +Besides the standard DEIM algorithm for interpolation, this method also supports the QDEIM +and the ODEIM algorithms. The ODEIM algorithm requires an additional parameter `odeim_dim` +to specify the number of the oversampled interpolation points. + # Arguments - `full_vars::AbstractVector`: the dependent variables ``\\underset{n\\times 1}{\\mathbf y}`` in FOM. - `linear_coeffs::AbstractMatrix`: the coefficient matrix ``\\underset{n\\times n}A`` of linear terms in FOM. @@ -59,15 +102,22 @@ the ``\\rho_i``-th column of the identity matrix ``I_n\\in\\mathbb R^{n\\times n - `reduced_vars::AbstractVector`: the dependent variables ``\\underset{k\\times 1}{\\hat{\\mathbf y}}`` in the reduced-order model. - `linear_projection_matrix::AbstractMatrix`: the projection matrix ``\\underset{n\\times k}V`` for the dependent variables ``\\mathbf y``. - `nonlinear_projection_matrix::AbstractMatrix`: the projection matrix ``\\underset{n\\times m}U`` for the nonlinear functions ``\\mathbf F``. +- `interpolation_algo::Symbol`: the interpolation algorithm, which can be `:deim`, `:qdeim`, or `:odeim`. # Return - `reduced_rhss`: the right-hand side of ROM. - `linear_projection_eqs`: the linear projection mapping ``\\mathbf y=V\\hat{\\mathbf y}``. + +# References +- [DEIM](https://epubs.siam.org/doi/abs/10.1137/110822724): Chaturantabut and Sorensen, 2012. +- [QDEIM](http://epubs.siam.org/doi/10.1137/15M1019271): Drmac and Gugercin, 2016. +- [ODEIM](https://epubs.siam.org/doi/10.1137/19M1307391): Peherstorfer, Drmac, and Gugercin, 2020. """ function deim(full_vars::AbstractVector, linear_coeffs::AbstractMatrix, constant_part::AbstractVector, nonlinear_part::AbstractVector, reduced_vars::AbstractVector, linear_projection_matrix::AbstractMatrix, - nonlinear_projection_matrix::AbstractMatrix; kwargs...) + nonlinear_projection_matrix::AbstractMatrix, + interpolation_algo::Symbol, odeim_dim::Integer; kwargs...) # rename variables for convenience y = full_vars A = linear_coeffs @@ -81,7 +131,13 @@ function deim(full_vars::AbstractVector, linear_coeffs::AbstractMatrix, linear_projection_eqs = Symbolics.scalarize(y .~ V * ŷ) linear_projection_dict = Dict(eq.lhs => eq.rhs for eq in linear_projection_eqs) - indices = deim_interpolation_indices(U) # DEIM interpolation indices + if interpolation_algo == :deim + indices = deim_interpolation_indices(U) # DEIM interpolation indices + elseif interpolation_algo == :qdeim + indices = qdeim_interpolation_indices(U) # QDEIM interpolation indices + elseif interpolation_algo == :odeim + indices = odeim_interpolation_indices(U, odeim_dim) # ODEIM interpolation indices + end # the DEIM projector (not DEIM basis) satisfies # F(original_vars) ≈ projector * F(pod_basis * reduced_vars)[indices] projector = ((@view U[indices, :])' \ (U' * V))' @@ -91,8 +147,9 @@ function deim(full_vars::AbstractVector, linear_coeffs::AbstractMatrix, Â = V' * A * V ĝ = V' * g reduced_rhss = Â * ŷ + ĝ + F̂ - reduced_rhss, linear_projection_eqs + return reduced_rhss, linear_projection_eqs end + """ $(FUNCTIONNAME)( sys::ModelingToolkit.ODESystem, @@ -100,6 +157,7 @@ end pod_dim::Integer; deim_dim::Integer = pod_dim, name::Symbol = Symbol(nameof(sys), :_deim), + interpolation_algo::Symbol = :deim, kwargs... ) -> ModelingToolkit.ODESystem @@ -116,11 +174,16 @@ The LHS of equations in `sys` are all assumed to be 1st order derivatives. Use The POD basis used for DEIM interpolation is obtained from the snapshot matrix of the nonlinear terms, which is computed by executing the runtime-generated function for -nonlinear expressions. +nonlinear expressions. + +Additional to the DEIM algorithm, this function also supports the QDEIM and ODEIM. For ODEIM, +the `odeim_dim` parameter specifies the number of oversampled interpolation points. """ function deim(sys::ODESystem, snapshot::AbstractMatrix, pod_dim::Integer; - deim_dim::Integer = pod_dim, name::Symbol = Symbol(nameof(sys), :_deim), - kwargs...)::ODESystem + deim_dim::Integer = pod_dim, odeim_dim::Integer = 2*pod_dim, + name::Symbol = Symbol(nameof(sys), :_deim), + interpolation_algo::Symbol = :deim, kwargs...)::ODESystem + @assert interpolation_algo ∈ (:deim, :qdeim, :odeim) "Invalid interpolation algorithm" sys = deepcopy(sys) @set! sys.name = name @@ -158,7 +221,7 @@ function deim(sys::ODESystem, snapshot::AbstractMatrix, pod_dim::Integer; reduce!(deim_reducer, TSVD()) U = deim_reducer.rbasis # DEIM projection basis - reduced_rhss, linear_projection_eqs = deim(dvs, A, g, F, ŷ, V, U; kwargs...) + reduced_rhss, linear_projection_eqs = deim(dvs, A, g, F, ŷ, V, U, interpolation_algo, odeim_dim; kwargs...) reduced_deqs = D.(ŷ) ~ reduced_rhss @set! sys.eqs = [Symbolics.scalarize(reduced_deqs); eqs] diff --git a/test/deim.jl b/test/deim.jl index 05d61d7..afb1f3d 100644 --- a/test/deim.jl +++ b/test/deim.jl @@ -37,18 +37,40 @@ sol = solve(ode_prob, Tsit5(), saveat = 1.0) snapshot_simpsys = Array(sol.original_sol) pod_dim = 3 + +# test DEIM deim_sys = @test_nowarn deim(simp_sys, snapshot_simpsys, pod_dim) +# test QDEIM +qdeim_sys = @test_nowarn deim(simp_sys, snapshot_simpsys, pod_dim; interpolation_algo=:qdeim) +# test ODEIM +odeim_sys = @test_nowarn deim(simp_sys, snapshot_simpsys, pod_dim; interpolation_algo=:odeim) -# check the number of dependent variables in the new system +# DEIM: check the number of dependent variables in the new system @test length(ModelingToolkit.get_states(deim_sys)) == pod_dim - deim_prob = ODEProblem(deim_sys, nothing, tspan) - deim_sol = solve(deim_prob, Tsit5(), saveat = 1.0) nₓ = length(sol[x]) nₜ = length(sol[t]) -# test solution retrieva +# test solution retrieval +@test size(deim_sol[v(x, t)]) == (nₓ, nₜ) +@test size(deim_sol[w(x, t)]) == (nₓ, nₜ) + +# QDEIM: check the number of dependent variables in the new system +@test length(ModelingToolkit.get_states(qdeim_sys)) == pod_dim +deim_prob = ODEProblem(qdeim_sys, nothing, tspan) +deim_sol = solve(deim_prob, Tsit5(), saveat = 1.0) + +# test solution retrieval @test size(deim_sol[v(x, t)]) == (nₓ, nₜ) @test size(deim_sol[w(x, t)]) == (nₓ, nₜ) + +# ODEIM: check the number of dependent variables in the new system +@test length(ModelingToolkit.get_states(odeim_sys)) == pod_dim +deim_prob = ODEProblem(odeim_sys, nothing, tspan) +deim_sol = solve(deim_prob, Tsit5(), saveat = 1.0) + +# test solution retrieval +@test size(deim_sol[v(x, t)]) == (nₓ, nₜ) +@test size(deim_sol[w(x, t)]) == (nₓ, nₜ) \ No newline at end of file