Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/systems/diffeqs/basic_transformations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ function change_independent_variable(

# Create a utility that performs the chain rule on an expression, followed by insertion of the new independent variable:
# e.g. (d/dt)(f(t)) -> (d/dt)(f(u(t))) -> df(u(t))/du(t) * du(t)/dt -> df(u)/du * uˍt(u)
function transform(ex)
function transform(ex::T) where {T}
# 1) Replace the argument of every function; e.g. f(t) -> f(u(t))
for var in vars(ex; op = Nothing) # loop over all variables in expression (op = Nothing prevents interpreting "D(f(t))" as one big variable)
is_function_of_iv1 = iscall(var) && isequal(only(arguments(var)), iv1) # of the form f(t)?
Expand All @@ -175,7 +175,7 @@ function change_independent_variable(
# 3) Set new independent variable
ex = substitute(ex, iv2_of_iv1 => iv2; fold) # set e.g. u(t) -> u everywhere
ex = substitute(ex, iv1 => iv1_of_iv2; fold) # set e.g. t -> t(u) everywhere
return ex
return ex::T
end

# Use the utility function to transform everything in the system!
Expand Down
28 changes: 28 additions & 0 deletions test/basic_transformations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -231,3 +231,31 @@ end
# compare to analytical solution (x(t) = v*t, y(t) = v*t - g*t^2/2)
@test all(isapprox.(sol[Mx.y], sol[Mx.x - g * (Mx.t_units)^2 / 2]; atol = 1e-10))
end

@testset "Change independent variable, no equations" begin
# make this "look" like the standard library RealInput
@mtkmodel Input begin
@variables begin
u(t)
end
end
@named input_sys = Input()
input_sys = complete(input_sys)
# test no failures
@test change_independent_variable(input_sys, input_sys.u) isa ODESystem

@mtkmodel NestedInput begin
@components begin
in = Input()
end
@variables begin
x(t)
end
@equations begin
D(x) ~ in.u
end
end
@named nested_input_sys = NestedInput()
nested_input_sys = complete(nested_input_sys; flatten = false)
@test change_independent_variable(nested_input_sys, nested_input_sys.x) isa ODESystem
end
Loading