Skip to content

ForwardDiff on ODE with callback crashes #3914

@hersle

Description

@hersle

Differentiate an ODE that uses a callback to terminate when x = 2:

using ModelingToolkit, OrdinaryDiffEq
using ModelingToolkit: t_nounits as t, D_nounits as D

@variables x(t)
@mtkbuild sys = System([D(x) ~ x], t)
callback = ContinuousCallback((u, t, integrator) -> u[1] - 2.0, terminate!)
prob = ODEProblem(sys, [x => NaN], (0.0, 10.0); callback)

function t_when_x_is_2(x0)
    newprob = remake(prob; u0 = [sys.x => x0])
    sol = solve(newprob)
    return sol[t][end]
end

ForwardDiff.derivative(t_when_x_is_2, 1.0)

This errors with

ERROR: MethodError: no method matching Float64(::ForwardDiff.Dual{ForwardDiff.Tag{typeof(t_when_x_is_2), Float64}, Float64, 1})
The type `Float64` exists, but no method is defined for this combination of argument types when trying to construct it.

Closest candidates are:
  (::Type{T})(::Real, ::RoundingMode) where T<:AbstractFloat
   @ Base rounding.jl:265
  (::Type{T})(::T) where T<:Number
   @ Core boot.jl:900
  Float64(::IrrationalConstants.Log4π)
   @ IrrationalConstants ~/.julia/packages/IrrationalConstants/lWTip/src/macro.jl:131
  ...

Stacktrace:
  [1] copysign(x::Float64, y::ForwardDiff.Dual{ForwardDiff.Tag{typeof(t_when_x_is_2), Float64}, Float64, 1})
    @ Base ./floatfuncs.jl:8
  [2] solve(::IntervalNonlinearProblem{…}, ::DiffEqBase.InternalITP; maxiters::Int64, kwargs::@Kwargs{…})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/Rfwzp/src/internal_itp.jl:61
  [3] solve
    @ ~/.julia/packages/DiffEqBase/Rfwzp/src/internal_itp.jl:29 [inlined]
  [4] #bisection#20
    @ ~/.julia/packages/DiffEqBase/Rfwzp/src/callbacks.jl:366 [inlined]
  [5] bisection
    @ ~/.julia/packages/DiffEqBase/Rfwzp/src/callbacks.jl:362 [inlined]
  [6] find_callback_time(integrator::OrdinaryDiffEqCore.ODEIntegrator{…}, callback::ContinuousCallback{…}, counter::Int64)
    @ DiffEqBase ~/.julia/packages/DiffEqBase/Rfwzp/src/callbacks.jl:434
  [7] macro expansion
    @ ~/.julia/packages/DiffEqBase/Rfwzp/src/callbacks.jl:130 [inlined]
  [8] find_first_continuous_callback
    @ ~/.julia/packages/DiffEqBase/Rfwzp/src/callbacks.jl:125 [inlined]
  [9] find_first_continuous_callback
    @ ~/.julia/packages/DiffEqBase/Rfwzp/src/callbacks.jl:123 [inlined]
 [10] handle_callbacks!
    @ ~/.julia/packages/OrdinaryDiffEqCore/TUQ8i/src/integrators/integrator_utils.jl:379 [inlined]
 [11] _loopfooter!(integrator::OrdinaryDiffEqCore.ODEIntegrator{…})
    @ OrdinaryDiffEqCore ~/.julia/packages/OrdinaryDiffEqCore/TUQ8i/src/integrators/integrator_utils.jl:284
 [12] loopfooter!
    @ ~/.julia/packages/OrdinaryDiffEqCore/TUQ8i/src/integrators/integrator_utils.jl:248 [inlined]
 [13] solve!(integrator::OrdinaryDiffEqCore.ODEIntegrator{…})
    @ OrdinaryDiffEqCore ~/.julia/packages/OrdinaryDiffEqCore/TUQ8i/src/solve.jl:611
 [14] #__solve#62
    @ ~/.julia/packages/OrdinaryDiffEqCore/TUQ8i/src/solve.jl:7 [inlined]
 [15] __solve
    @ ~/.julia/packages/OrdinaryDiffEqCore/TUQ8i/src/solve.jl:1 [inlined]
 [16] #__solve#3
    @ ~/.julia/packages/OrdinaryDiffEqDefault/ydhv7/src/default_alg.jl:48 [inlined]
 [17] __solve
    @ ~/.julia/packages/OrdinaryDiffEqDefault/ydhv7/src/default_alg.jl:47 [inlined]
 [18] #__solve#54
    @ ~/.julia/packages/DiffEqBase/Rfwzp/src/solve.jl:896 [inlined]
 [19] __solve
    @ ~/.julia/packages/DiffEqBase/Rfwzp/src/solve.jl:887 [inlined]
 [20] solve_call(::ODEProblem{…}; merge_callbacks::Bool, kwargshandle::Nothing, kwargs::@Kwargs{})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/Rfwzp/src/solve.jl:127
 [21] solve_call
    @ ~/.julia/packages/DiffEqBase/Rfwzp/src/solve.jl:84 [inlined]
 [22] #solve_up#40
    @ ~/.julia/packages/DiffEqBase/Rfwzp/src/solve.jl:665 [inlined]
 [23] solve_up
    @ ~/.julia/packages/DiffEqBase/Rfwzp/src/solve.jl:658 [inlined]
 [24] #solve#38
    @ ~/.julia/packages/DiffEqBase/Rfwzp/src/solve.jl:553 [inlined]
 [25] solve(::ODEProblem{Vector{…}, Tuple{…}, true, MTKParameters{…}, ODEFunction{…}, @Kwargs{…}, SciMLBase.StandardODEProblem})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/Rfwzp/src/solve.jl:543
 [26] t_when_x_is_2(x0::ForwardDiff.Dual{ForwardDiff.Tag{typeof(t_when_x_is_2), Float64}, Float64, 1})
    @ Main ./REPL[65]:11
 [27] derivative(f::typeof(t_when_x_is_2), x::Float64)
    @ ForwardDiff ~/.julia/packages/ForwardDiff/Or6Qh/src/derivative.jl:14
 [28] top-level scope
    @ REPL[65]:15
Some type information was truncated. Use `show(err)` to see complete types.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions