|
91 | 91 | end |
92 | 92 | end |
93 | 93 |
|
| 94 | + @testset "LKJCholesky" begin |
| 95 | + @testset "uplo=$uplo" for uplo in ['L', 'U'] |
| 96 | + @model demo_lkj(d) = x ~ LKJCholesky(d, 1.0, uplo) |
| 97 | + @testset "d=$d" for d in [2, 3, 5] |
| 98 | + model = demo_lkj(d) |
| 99 | + dist = LKJCholesky(d, 1.0, uplo) |
| 100 | + values_original = rand(model) |
| 101 | + vis = DynamicPPL.TestUtils.setup_varinfos( |
| 102 | + model, values_original, (@varname(x),) |
| 103 | + ) |
| 104 | + @testset "$(short_varinfo_name(vi))" for vi in vis |
| 105 | + val = vi[@varname(x), dist] |
| 106 | + # Ensure that `reconstruct` works as intended. |
| 107 | + @test val isa Cholesky |
| 108 | + @test val.uplo == uplo |
| 109 | + |
| 110 | + @test length(vi[:]) == d^2 |
| 111 | + lp = logpdf(dist, val) |
| 112 | + lp_model = logjoint(model, vi) |
| 113 | + @test lp_model ≈ lp |
| 114 | + # Linked. |
| 115 | + vi_linked = DynamicPPL.link!!(deepcopy(vi), model) |
| 116 | + @test length(vi_linked[:]) == d * (d - 1) ÷ 2 |
| 117 | + # Should now include the log-absdet-jacobian correction. |
| 118 | + @test !(getlogp(vi_linked) ≈ lp) |
| 119 | + # Invlinked. |
| 120 | + vi_invlinked = DynamicPPL.invlink!!(deepcopy(vi_linked), model) |
| 121 | + @test length(vi_invlinked[:]) == d^2 |
| 122 | + @test getlogp(vi_invlinked) ≈ lp |
| 123 | + end |
| 124 | + end |
| 125 | + end |
| 126 | + end |
| 127 | + |
94 | 128 | # Related: https://github.com/TuringLang/DynamicPPL.jl/issues/504 |
95 | | - @testset "dirichlet" begin |
| 129 | + @testset "Dirichlet" begin |
96 | 130 | @model demo_dirichlet(d::Int) = x ~ Dirichlet(d, 1.0) |
97 | 131 | @testset "d=$d" for d in [2, 3, 5] |
98 | 132 | model = demo_dirichlet(d) |
99 | 133 | vis = DynamicPPL.TestUtils.setup_varinfos(model, rand(model), (@varname(x),)) |
100 | 134 | @testset "$(short_varinfo_name(vi))" for vi in vis |
101 | 135 | lp = logpdf(Dirichlet(d, 1.0), vi[:]) |
102 | 136 | @test length(vi[:]) == d |
103 | | - @test getlogp(vi) ≈ lp |
| 137 | + lp_model = logjoint(model, vi) |
| 138 | + @test lp_model ≈ lp |
104 | 139 | # Linked. |
105 | 140 | vi_linked = DynamicPPL.link!!(deepcopy(vi), model) |
106 | 141 | @test length(vi_linked[:]) == d - 1 |
|
0 commit comments