diff --git a/src/systems/diffeqs/basic_transformations.jl b/src/systems/diffeqs/basic_transformations.jl index d41e948d64..1a09de1fc0 100644 --- a/src/systems/diffeqs/basic_transformations.jl +++ b/src/systems/diffeqs/basic_transformations.jl @@ -155,7 +155,7 @@ function change_independent_variable( # Create a utility that performs the chain rule on an expression, followed by insertion of the new independent variable: # e.g. (d/dt)(f(t)) -> (d/dt)(f(u(t))) -> df(u(t))/du(t) * du(t)/dt -> df(u)/du * uˍt(u) - function transform(ex) + function transform(ex::T) where {T} # 1) Replace the argument of every function; e.g. f(t) -> f(u(t)) for var in vars(ex; op = Nothing) # loop over all variables in expression (op = Nothing prevents interpreting "D(f(t))" as one big variable) is_function_of_iv1 = iscall(var) && isequal(only(arguments(var)), iv1) # of the form f(t)? @@ -175,7 +175,7 @@ function change_independent_variable( # 3) Set new independent variable ex = substitute(ex, iv2_of_iv1 => iv2; fold) # set e.g. u(t) -> u everywhere ex = substitute(ex, iv1 => iv1_of_iv2; fold) # set e.g. t -> t(u) everywhere - return ex + return ex::T end # Use the utility function to transform everything in the system! diff --git a/test/basic_transformations.jl b/test/basic_transformations.jl index b593deb345..173c4ad062 100644 --- a/test/basic_transformations.jl +++ b/test/basic_transformations.jl @@ -231,3 +231,31 @@ end # compare to analytical solution (x(t) = v*t, y(t) = v*t - g*t^2/2) @test all(isapprox.(sol[Mx.y], sol[Mx.x - g * (Mx.t_units)^2 / 2]; atol = 1e-10)) end + +@testset "Change independent variable, no equations" begin + # make this "look" like the standard library RealInput + @mtkmodel Input begin + @variables begin + u(t) + end + end + @named input_sys = Input() + input_sys = complete(input_sys) + # test no failures + @test change_independent_variable(input_sys, input_sys.u) isa ODESystem + + @mtkmodel NestedInput begin + @components begin + in = Input() + end + @variables begin + x(t) + end + @equations begin + D(x) ~ in.u + end + end + @named nested_input_sys = NestedInput() + nested_input_sys = complete(nested_input_sys; flatten = false) + @test change_independent_variable(nested_input_sys, nested_input_sys.x) isa ODESystem +end