|
91 | 91 | end |
92 | 92 | end |
93 | 93 |
|
| 94 | + # Related: https://github.com/TuringLang/DynamicPPL.jl/issues/504 |
94 | 95 | @testset "dirichlet" begin |
95 | | - @model demo_dirichlet() = x ~ Dirichlet(2, 1.0) |
96 | | - model = demo_dirichlet() |
97 | | - vis = DynamicPPL.TestUtils.setup_varinfos(model, rand(model), (@varname(x),)) |
98 | | - @testset "$(short_varinfo_name(vi))" for vi in vis |
99 | | - @test length(vi[:]) == 2 |
100 | | - @test iszero(getlogp(vi)) |
101 | | - # Linked. |
102 | | - vi_linked = DynamicPPL.link!!(deepcopy(vi), model) |
103 | | - @test length(vi_linked[:]) == 1 |
104 | | - @test !iszero(getlogp(vi_linked)) # should now include the log-absdet-jacobian correction |
105 | | - # Invlinked. |
106 | | - vi_invlinked = DynamicPPL.invlink!!(deepcopy(vi_linked), model) |
107 | | - @test length(vi_invlinked[:]) == 2 |
108 | | - @test iszero(getlogp(vi_invlinked)) |
| 96 | + @model demo_dirichlet(d::Int) = x ~ Dirichlet(d, 1.0) |
| 97 | + @testset "d=$d" for d in [2, 3, 5] |
| 98 | + model = demo_dirichlet(d) |
| 99 | + vis = DynamicPPL.TestUtils.setup_varinfos(model, rand(model), (@varname(x),)) |
| 100 | + @testset "$(short_varinfo_name(vi))" for vi in vis |
| 101 | + lp = logpdf(Dirichlet(d, 1.0), vi[:]) |
| 102 | + @test length(vi[:]) == d |
| 103 | + @test getlogp(vi) ≈ lp |
| 104 | + # Linked. |
| 105 | + vi_linked = DynamicPPL.link!!(deepcopy(vi), model) |
| 106 | + @test length(vi_linked[:]) == d - 1 |
| 107 | + # Should now include the log-absdet-jacobian correction. |
| 108 | + @test !(getlogp(vi_linked) ≈ lp) |
| 109 | + # Invlinked. |
| 110 | + vi_invlinked = DynamicPPL.invlink!!(deepcopy(vi_linked), model) |
| 111 | + @test length(vi_invlinked[:]) == d |
| 112 | + @test getlogp(vi_invlinked) ≈ lp |
| 113 | + end |
109 | 114 | end |
110 | 115 | end |
111 | 116 | end |
0 commit comments