Skip to content

Commit 158080c

Browse files
Merge pull request #114 from vyudu/pretty-print
feat: prettyprint parameters
2 parents 98b2285 + 9ea685e commit 158080c

File tree

6 files changed

+75
-13
lines changed

6 files changed

+75
-13
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ version = "0.3.38"
66
[deps]
77
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
88
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
9+
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
910
RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47"
1011
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
1112

@@ -14,6 +15,7 @@ Accessors = "0.1.36"
1415
Aqua = "0.8"
1516
ArrayInterface = "7.9"
1617
Pkg = "1"
18+
PrettyTables = "2.4.0"
1719
RuntimeGeneratedFunctions = "0.5.12"
1820
SafeTestsets = "0.0.1"
1921
StaticArrays = "1.9"

docs/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
33
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
44
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
55
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
6+
RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47"
67
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
78

89
[compat]

docs/src/api.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ getp
9292
setp
9393
setp_oop
9494
ParameterIndexingProxy
95+
show_params
9596
```
9697

9798
#### Parameter timeseries

docs/src/complete_sii.md

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
This tutorial will show how to define the entire Symbolic Indexing Interface on an
55
`ExampleSystem`:
66

7-
```julia
7+
```@example implementing_sii
8+
using SymbolicIndexingInterface
89
struct ExampleSystem
910
state_index::Dict{Symbol,Int}
1011
parameter_index::Dict{Symbol,Int}
@@ -24,7 +25,7 @@ supports specific functionality. Consider the following struct, which needs to i
2425

2526
These are the simple functions which describe how to turn symbols into indices.
2627

27-
```julia
28+
```@example implementing_sii
2829
function SymbolicIndexingInterface.is_variable(sys::ExampleSystem, sym)
2930
haskey(sys.state_index, sym)
3031
end
@@ -65,7 +66,7 @@ end
6566
6667
SymbolicIndexingInterface.constant_structure(::ExampleSystem) = true
6768
68-
function SymbolicIndexingInterface.all_solvable_symbols(sys::ExampleSystem)
69+
function SymbolicIndexingInterface.all_variable_symbols(sys::ExampleSystem)
6970
return vcat(
7071
collect(keys(sys.state_index)),
7172
collect(keys(sys.observed)),
@@ -74,7 +75,7 @@ end
7475
7576
function SymbolicIndexingInterface.all_symbols(sys::ExampleSystem)
7677
return vcat(
77-
all_solvable_symbols(sys),
78+
all_variable_symbols(sys),
7879
collect(keys(sys.parameter_index)),
7980
sys.independent_variable === nothing ? Symbol[] : sys.independent_variable
8081
)
@@ -90,7 +91,7 @@ end
9091
These are for handling symbolic expressions and generating equations which are not directly
9192
in the solution vector.
9293

93-
```julia
94+
```@example implementing_sii
9495
using RuntimeGeneratedFunctions
9596
RuntimeGeneratedFunctions.init(@__MODULE__)
9697
@@ -167,7 +168,7 @@ not typically useful for solution objects, it may be useful for integrators. Typ
167168
the default implementations for `getp` and `setp` will suffice, and manually defining
168169
them is not necessary.
169170

170-
```julia
171+
```@example implementing_sii
171172
function SymbolicIndexingInterface.parameter_values(sys::ExampleSystem)
172173
sys.p
173174
end
@@ -183,7 +184,7 @@ the system's symbols. This also requires that the type implement
183184

184185
Consider the following `ExampleIntegrator`
185186

186-
```julia
187+
```@example implementing_sii
187188
mutable struct ExampleIntegrator
188189
u::Vector{Float64}
189190
p::Vector{Float64}
@@ -199,8 +200,8 @@ SymbolicIndexingInterface.current_time(sys::ExampleIntegrator) = sys.t
199200
```
200201

201202
Then the following example would work:
202-
```julia
203-
sys = ExampleSystem(Dict(:x => 1, :y => 2, :z => 3), Dict(:a => 1, :b => 2), :t, Dict())
203+
```@example implementing_sii
204+
sys = ExampleSystem(Dict(:x => 1, :y => 2, :z => 3), Dict(:a => 1, :b => 2), :t, Dict(), Dict())
204205
integrator = ExampleIntegrator([1.0, 2.0, 3.0], [4.0, 5.0], 6.0, sys)
205206
getx = getsym(sys, :x)
206207
getx(integrator) # 1.0
@@ -289,7 +290,7 @@ interface and allows using [`getp`](@ref) and [`setp`](@ref) to get and set para
289290
values. This allows for a cleaner interface for parameter indexing. Consider the
290291
following example for `ExampleIntegrator`:
291292

292-
```julia
293+
```@example implementing_sii
293294
function Base.getproperty(obj::ExampleIntegrator, sym::Symbol)
294295
if sym === :ps
295296
return ParameterIndexingProxy(obj)
@@ -301,8 +302,8 @@ end
301302

302303
This enables the following API:
303304

304-
```julia
305-
integrator = ExampleIntegrator([1.0, 2.0, 3.0], [4.0, 5.0], 6.0, Dict(:x => 1, :y => 2, :z => 3), Dict(:a => 1, :b => 2), :t)
305+
```@example implementing_sii
306+
integrator = ExampleIntegrator([1.0, 2.0, 3.0], [4.0, 5.0], 6.0, sys)
306307
307308
integrator.ps[:a] # 4.0
308309
getp(integrator, :a)(integrator) # functionally the same as above
@@ -311,6 +312,11 @@ integrator.ps[:b] = 3.0
311312
setp(integrator, :b)(integrator, 3.0) # functionally the same as above
312313
```
313314

315+
The parameters will display as a table:
316+
```@example implementing_sii
317+
integrator.ps
318+
```
319+
314320
## Parameter Timeseries
315321

316322
If a solution object includes modified parameter values (such as through callbacks) during the

src/SymbolicIndexingInterface.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ using RuntimeGeneratedFunctions
44
import StaticArraysCore: MArray, similar_type
55
import ArrayInterface
66
using Accessors: @reset
7+
using PrettyTables # for pretty printing
78

89
RuntimeGeneratedFunctions.init(@__MODULE__)
910

@@ -44,7 +45,7 @@ include("batched_interface.jl")
4445
export ProblemState
4546
include("problem_state.jl")
4647

47-
export ParameterIndexingProxy
48+
export ParameterIndexingProxy, show_params
4849
include("parameter_indexing_proxy.jl")
4950

5051
export remake_buffer

src/parameter_indexing_proxy.jl

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,54 @@ end
1717
function Base.setindex!(p::ParameterIndexingProxy, val, idx)
1818
return setp(p.wrapped, idx)(p.wrapped, val)
1919
end
20+
21+
function Base.show(io::IO, ::MIME"text/plain", pip::ParameterIndexingProxy)
22+
show_params(io, pip; num_rows = 20, show_all = false, scalarize = true)
23+
end
24+
25+
"""
26+
show_params(io::IO, pip::ParameterIndexingProxy; num_rows = 20, show_all = false, scalarize = true, kwargs...)
27+
28+
Method for customizing the table output. Keyword args:
29+
- num_rows
30+
- show_all: whether to show all parameters. Overrides `num_rows`.
31+
- scalarize: whether to scalarize array symbolics in the table output.
32+
- kwargs... are passed to the pretty_table call.
33+
"""
34+
function show_params(io::IO, pip::ParameterIndexingProxy; num_rows = 20,
35+
show_all = false, scalarize = true, kwargs...)
36+
params = Any[]
37+
vals = Any[]
38+
for p in parameter_symbols(pip.wrapped)
39+
if symbolic_type(p) === ArraySymbolic() && scalarize
40+
val = getp(pip.wrapped, p)(pip.wrapped)
41+
for (_p, _v) in zip(collect(p), val)
42+
push!(params, _p)
43+
push!(vals, _v)
44+
end
45+
else
46+
push!(params, p)
47+
val = getp(pip.wrapped, p)(pip.wrapped)
48+
push!(vals, val)
49+
end
50+
end
51+
52+
num_shown = if show_all
53+
length(params)
54+
else
55+
if num_rows > length(params)
56+
length(params)
57+
else
58+
num_rows
59+
end
60+
end
61+
62+
pretty_table(io, [params[1:num_shown] vals[1:num_shown]];
63+
header = ["Parameter", "Value"],
64+
kwargs...)
65+
66+
if num_shown < length(params)
67+
println(io,
68+
"$num_shown of $(length(params)) params shown. To show all the parameters, call `show_params(io, ps, show_all = true)`. Adjust the number of rows with the num_rows kwarg. Consult `show_params` docstring for more options.")
69+
end
70+
end

0 commit comments

Comments
 (0)