diff --git a/src/systems/if_lifting.jl b/src/systems/if_lifting.jl index da069cc76e..aeb8afc0a8 100644 --- a/src/systems/if_lifting.jl +++ b/src/systems/if_lifting.jl @@ -111,8 +111,8 @@ function (cw::CondRewriter)(expr, dep) # and ELSE branch is true # similarly for expression being false return (ifelse(rw_cond, rw_conda, rw_condb), - implies(ctrue, truea) | implies(cfalse, trueb), - implies(ctrue, falsea) | implies(cfalse, falseb)) + ctrue & truea | cfalse & trueb, + ctrue & falsea | cfalse & falseb) elseif operation(expr) == Base.:(!) # NOT of expression (a,) = arguments(expr) (rw, ctrue, cfalse) = cw(a, dep) diff --git a/test/if_lifting.jl b/test/if_lifting.jl index 9c58e676d0..085dc600c8 100644 --- a/test/if_lifting.jl +++ b/test/if_lifting.jl @@ -124,3 +124,43 @@ end end @test_nowarn @mtkbuild sys=SimpleAbs() additional_passes=[IfLifting] end + +@testset "Nested conditions are handled properly" begin + @mtkmodel RampModel begin + @variables begin + x(t) + y(t) + end + @parameters begin + start_time = 1.0 + duration = 1.0 + height = 1.0 + end + @equations begin + y ~ ifelse(start_time < t, + ifelse(t < start_time + duration, + (t - start_time) * height / duration, height), + 0.0) + D(x) ~ y + end + end + @mtkbuild sys = RampModel() + @mtkbuild sys2=RampModel() additional_passes=[IfLifting] + prob = ODEProblem(sys, [sys.x => 1.0], (0.0, 3.0)) + prob2 = ODEProblem(sys2, [sys.x => 1.0], (0.0, 3.0)) + sol = solve(prob) + sol2 = solve(prob2) + @test sol(0.99)[1] > 1.0 + @test sol2(0.99)[1] == 1.0 + # During ramp + # D(x) ~ t - 1 + # x ~ t^2 / 2 - t + C, and `x(1) ~ 1` => `C = 3/2` + # x(1.01) ~ 1.01^2 / 2 - 1.01 + 3/2 ~ 1.00005 + @test sol2(1.01)[1] ≈ 1.00005 + @test sol2(2)[1] ≈ 1.5 + # After ramp + # D(x) ~ 1 + # x ~ t + C and `x(2) ~ 3/2` => `C = -1/2` + # x(3) ~ 3 - 1/2 + @test sol2(3)[1] ≈ 5 / 2 +end