Skip to content

Commit 184888a

Browse files
authored
Fix for tests in #516 (#517)
* fixed tests for linking of dirichlet with different dimensionality * added usage of same logp in TestUtils.setup_varinfos
1 parent 2062ed3 commit 184888a

File tree

3 files changed

+21
-49
lines changed

3 files changed

+21
-49
lines changed

src/test_utils.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,10 @@ function setup_varinfos(model::Model, example_values::NamedTuple, varnames)
5151
svi_typed = SimpleVarInfo(example_values)
5252
svi_untyped = SimpleVarInfo(OrderedDict())
5353

54+
lp = getlogp(vi_typed)
5455
return map((vi_untyped, vi_typed, svi_typed, svi_untyped)) do vi
5556
# Set them all to the same values.
56-
update_values!!(vi, example_values, varnames)
57+
DynamicPPL.setlogp!!(update_values!!(vi, example_values, varnames), lp)
5758
end
5859
end
5960

test/linking.jl

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -91,21 +91,26 @@ end
9191
end
9292
end
9393

94+
# Related: https://github.com/TuringLang/DynamicPPL.jl/issues/504
9495
@testset "dirichlet" begin
95-
@model demo_dirichlet() = x ~ Dirichlet(2, 1.0)
96-
model = demo_dirichlet()
97-
vis = DynamicPPL.TestUtils.setup_varinfos(model, rand(model), (@varname(x),))
98-
@testset "$(short_varinfo_name(vi))" for vi in vis
99-
@test length(vi[:]) == 2
100-
@test iszero(getlogp(vi))
101-
# Linked.
102-
vi_linked = DynamicPPL.link!!(deepcopy(vi), model)
103-
@test length(vi_linked[:]) == 1
104-
@test !iszero(getlogp(vi_linked)) # should now include the log-absdet-jacobian correction
105-
# Invlinked.
106-
vi_invlinked = DynamicPPL.invlink!!(deepcopy(vi_linked), model)
107-
@test length(vi_invlinked[:]) == 2
108-
@test iszero(getlogp(vi_invlinked))
96+
@model demo_dirichlet(d::Int) = x ~ Dirichlet(d, 1.0)
97+
@testset "d=$d" for d in [2, 3, 5]
98+
model = demo_dirichlet(d)
99+
vis = DynamicPPL.TestUtils.setup_varinfos(model, rand(model), (@varname(x),))
100+
@testset "$(short_varinfo_name(vi))" for vi in vis
101+
lp = logpdf(Dirichlet(d, 1.0), vi[:])
102+
@test length(vi[:]) == d
103+
@test getlogp(vi) lp
104+
# Linked.
105+
vi_linked = DynamicPPL.link!!(deepcopy(vi), model)
106+
@test length(vi_linked[:]) == d - 1
107+
# Should now include the log-absdet-jacobian correction.
108+
@test !(getlogp(vi_linked) lp)
109+
# Invlinked.
110+
vi_invlinked = DynamicPPL.invlink!!(deepcopy(vi_linked), model)
111+
@test length(vi_invlinked[:]) == d
112+
@test getlogp(vi_invlinked) lp
113+
end
109114
end
110115
end
111116
end

test/varinfo.jl

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -273,40 +273,6 @@
273273
@test vals_prev == vi.metadata.x.vals
274274
end
275275

276-
# See https://github.com/TuringLang/DynamicPPL.jl/issues/504
277-
@testset "Dimentionality checks" begin
278-
@model function demo2d()
279-
return x ~ Dirichlet(2, 1.0)
280-
end
281-
model = demo2d()
282-
vi = VarInfo(model) # make VarInfo -> sample from prior and compute logdensity
283-
getlogp(vi) 0.0 # zero because Dirichlet(1) == Uniform over Simplex
284-
spl = SampleFromPrior() # create dummy sampler for linking
285-
DynamicPPL.link!!(vi, spl, model) # transform to unconstrained space
286-
!(0.0 getlogp(last(DynamicPPL.evaluate!!(model, vi)))) # non-zero now due to log(abs(determinant(jacobian)))
287-
x = vi[spl] # extract unconstrained values
288-
newx = deepcopy(x) # simulate making a change to x
289-
vinew = deepcopy(vi)
290-
vinew[spl] = newx
291-
@test vinew[spl] == newx
292-
293-
@model function demo3d()
294-
return x ~ Dirichlet(3, 1.0) # increase K to 3
295-
end
296-
model = demo3d()
297-
vi = VarInfo(model) # make VarInfo -> sample from prior and compute logdensity
298-
getlogp(vi) 0.0 # zero because Dirichlet(1) == Uniform over Simplex
299-
spl = SampleFromPrior() # create dummy sampler for linking
300-
DynamicPPL.link!!(vi, spl, model) # transform to unconstrained space
301-
!(0.0 getlogp(last(DynamicPPL.evaluate!!(model, vi)))) # non-zero now due to log(abs(determinant(jacobian)))
302-
x = vi[spl] # extract unconstrained values
303-
newx = deepcopy(x) # simulate making a change to x
304-
vinew = deepcopy(vi)
305-
vinew[spl] = newx
306-
307-
@test vinew[spl] == newx
308-
end
309-
310276
@testset "istrans" begin
311277
@model demo_constrained() = x ~ truncated(Normal(), 0, Inf)
312278
model = demo_constrained()

0 commit comments

Comments
 (0)