Skip to content

Commit 6dc7c02

Browse files
committed
Merge remote-tracking branch 'origin/main' into breaking
2 parents 79150ba + ab6f38a commit 6dc7c02

File tree

13 files changed

+172
-32
lines changed

13 files changed

+172
-32
lines changed

.github/workflows/Benchmarking.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,14 +59,14 @@ jobs:
5959
echo "DPPL_COMMIT_URL=$COMMIT_URL" >> $GITHUB_ENV
6060
6161
- name: Find Existing Comment
62-
uses: peter-evans/find-comment@v3
62+
uses: peter-evans/find-comment@v4
6363
id: find_comment
6464
with:
6565
issue-number: ${{ github.event.pull_request.number }}
6666
comment-author: github-actions[bot]
6767

6868
- name: Post Benchmark Results as PR Comment
69-
uses: peter-evans/create-or-update-comment@v4
69+
uses: peter-evans/create-or-update-comment@v5
7070
with:
7171
issue-number: ${{ github.event.pull_request.number }}
7272
body: |

HISTORY.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,18 @@
44

55
Removed the method `returned(::Model, values, keys)`; please use `returned(::Model, ::AbstractDict{<:VarName})` instead.
66

7+
## 0.38.7
8+
9+
Made a small tweak to DynamicPPL's compiler output to avoid potential undefined variables when resuming model functions midway through (e.g. with Libtask in Turing's SMC/PG samplers).
10+
11+
## 0.38.6
12+
13+
Renamed keyword argument `only_ddpl` to `only_dppl` for `Experimental.is_suitable_varinfo`.
14+
15+
## 0.38.5
16+
17+
Improve performance of VarNamedVector, mostly by changing how it handles contiguification.
18+
719
## 0.38.4
820

921
Improve performance of VarNamedVector. It should now be very nearly on par with Metadata for all models we've benchmarked on.

ext/DynamicPPLJETExt.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@ using DynamicPPL: DynamicPPL
44
using JET: JET
55

66
function DynamicPPL.Experimental.is_suitable_varinfo(
7-
model::DynamicPPL.Model, varinfo::DynamicPPL.AbstractVarInfo; only_ddpl::Bool=true
7+
model::DynamicPPL.Model, varinfo::DynamicPPL.AbstractVarInfo; only_dppl::Bool=true
88
)
99
f, argtypes = DynamicPPL.DebugUtils.gen_evaluator_call_with_types(model, varinfo)
1010
# If specified, we only check errors originating somewhere in the DynamicPPL.jl.
1111
# This way we don't just fall back to untyped if the user's code is the issue.
12-
result = if only_ddpl
12+
result = if only_dppl
1313
JET.report_call(f, argtypes; target_modules=(JET.AnyFrameModule(DynamicPPL),))
1414
else
1515
JET.report_call(f, argtypes)
@@ -18,15 +18,15 @@ function DynamicPPL.Experimental.is_suitable_varinfo(
1818
end
1919

2020
function DynamicPPL.Experimental._determine_varinfo_jet(
21-
model::DynamicPPL.Model; only_ddpl::Bool=true
21+
model::DynamicPPL.Model; only_dppl::Bool=true
2222
)
2323
# Generate a typed varinfo to test model type stability with
2424
varinfo = DynamicPPL.typed_varinfo(model)
2525

2626
# Check type stability of evaluation (i.e. DefaultContext)
2727
model = DynamicPPL.setleafcontext(model, DynamicPPL.DefaultContext())
2828
eval_issuccess, eval_result = DynamicPPL.Experimental.is_suitable_varinfo(
29-
model, varinfo; only_ddpl
29+
model, varinfo; only_dppl
3030
)
3131
if !eval_issuccess
3232
@debug "Evaluation with typed varinfo failed with the following issues:"
@@ -36,7 +36,7 @@ function DynamicPPL.Experimental._determine_varinfo_jet(
3636
# Check type stability of initialisation (i.e. InitContext)
3737
model = DynamicPPL.setleafcontext(model, DynamicPPL.InitContext())
3838
init_issuccess, init_result = DynamicPPL.Experimental.is_suitable_varinfo(
39-
model, varinfo; only_ddpl
39+
model, varinfo; only_dppl
4040
)
4141
if !init_issuccess
4242
@debug "Initialisation with typed varinfo failed with the following issues:"

src/compiler.jl

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -461,9 +461,16 @@ function generate_tilde(left, right)
461461
elseif $isassumption
462462
$(generate_tilde_assume(left, dist, vn))
463463
else
464-
# If `vn` is not in `argnames`, we need to make sure that the variable is defined.
465-
if !$(DynamicPPL.inargnames)($vn, __model__)
466-
$left = $(DynamicPPL.getconditioned_nested)(
464+
# If `vn` is not in `argnames`, then it's definitely been conditioned on (if
465+
# it's not in `argnames` and wasn't conditioned on, then `isassumption` would
466+
# be true).
467+
$left = if $(DynamicPPL.inargnames)($vn, __model__)
468+
# This is a no-op and looks redundant, but defining the compiler output this
469+
# way ensures that the variable `$left` is always defined. See
470+
# https://github.com/TuringLang/DynamicPPL.jl/pull/1110.
471+
$left
472+
else
473+
$(DynamicPPL.getconditioned_nested)(
467474
__model__.context, $(DynamicPPL.prefix)(__model__.context, $vn)
468475
)
469476
end

src/experimental.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ Check if the `model` supports evaluation using the provided `varinfo`.
1616
- `varinfo`: The varinfo to verify the support for.
1717
1818
# Keyword Arguments
19-
- `only_ddpl`: If `true`, only consider error reports occuring in the tilde pipeline. Default: `true`.
19+
- `only_dppl`: If `true`, only consider error reports occuring in the tilde pipeline. Default: `true`.
2020
2121
# Returns
2222
- `issuccess`: `true` if the model supports the varinfo, otherwise `false`.
@@ -28,7 +28,7 @@ function is_suitable_varinfo end
2828
function _determine_varinfo_jet end
2929

3030
"""
31-
determine_suitable_varinfo(model; only_ddpl::Bool=true)
31+
determine_suitable_varinfo(model; only_dppl::Bool=true)
3232
3333
Return a suitable varinfo for the given `model`.
3434
@@ -42,7 +42,7 @@ See also: [`DynamicPPL.Experimental.is_suitable_varinfo`](@ref).
4242
- `model`: The model for which to determine the varinfo.
4343
4444
# Keyword Arguments
45-
- `only_ddpl`: If `true`, only consider error reports within DynamicPPL.jl.
45+
- `only_dppl`: If `true`, only consider error reports within DynamicPPL.jl.
4646
4747
# Examples
4848
@@ -83,10 +83,10 @@ julia> vi isa typeof(DynamicPPL.typed_varinfo(model_with_static_support()))
8383
true
8484
```
8585
"""
86-
function determine_suitable_varinfo(model::DynamicPPL.Model; only_ddpl::Bool=true)
86+
function determine_suitable_varinfo(model::DynamicPPL.Model; only_dppl::Bool=true)
8787
# If JET.jl has been loaded, and thus `determine_varinfo` has been defined, we use that.
8888
return if Base.get_extension(DynamicPPL, :DynamicPPLJETExt) !== nothing
89-
_determine_varinfo_jet(model; only_ddpl)
89+
_determine_varinfo_jet(model; only_dppl)
9090
else
9191
# Warn the user.
9292
@warn "JET.jl is not loaded. Assumes the model is compatible with typed varinfo."

src/varinfo.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1297,6 +1297,10 @@ function _link_metadata!!(
12971297
metadata = setindex_internal!!(metadata, val_new, vn, transform_from_linked)
12981298
set_transformed!(metadata, true, vn)
12991299
end
1300+
# Linking can often change the sizes of variables, causing inactive elements. We don't
1301+
# want to keep them around, since typically linking is done once and then the VarInfo
1302+
# is evaluated multiple times. Hence we contiguify here.
1303+
metadata = contiguify!(metadata)
13001304
return metadata, cumulative_logjac
13011305
end
13021306

@@ -1465,6 +1469,10 @@ function _invlink_metadata!!(
14651469
metadata = setindex_internal!!(metadata, tovec(new_val), vn, new_transform)
14661470
set_transformed!(metadata, false, vn)
14671471
end
1472+
# Linking can often change the sizes of variables, causing inactive elements. We don't
1473+
# want to keep them around, since typically linking is done once and then the VarInfo
1474+
# is evaluated multiple times. Hence we contiguify here.
1475+
metadata = contiguify!(metadata)
14681476
return metadata, cumulative_inv_logjac
14691477
end
14701478

src/varnamedvector.jl

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -341,10 +341,13 @@ function ==(vnv_left::VarNamedVector, vnv_right::VarNamedVector)
341341
vnv_left.num_inactive == vnv_right.num_inactive
342342
end
343343

344-
function is_concretely_typed(vnv::VarNamedVector)
345-
return isconcretetype(eltype(vnv.varnames)) &&
346-
isconcretetype(eltype(vnv.vals)) &&
347-
isconcretetype(eltype(vnv.transforms))
344+
function is_tightly_typed(vnv::VarNamedVector)
345+
k = eltype(vnv.varnames)
346+
v = eltype(vnv.vals)
347+
t = eltype(vnv.transforms)
348+
return (isconcretetype(k) || k === Union{}) &&
349+
(isconcretetype(v) || v === Union{}) &&
350+
(isconcretetype(t) || t === Union{})
348351
end
349352

350353
getidx(vnv::VarNamedVector, vn::VarName) = vnv.varname_to_index[vn]
@@ -880,7 +883,16 @@ function loosen_types!!(
880883
return if vn_type == K && val_type == V && transform_type == T
881884
vnv
882885
elseif isempty(vnv)
883-
VarNamedVector(vn_type[], val_type[], transform_type[])
886+
VarNamedVector(
887+
Dict{vn_type,Int}(),
888+
Vector{vn_type}(),
889+
UnitRange{Int}[],
890+
Vector{val_type}(),
891+
Vector{transform_type}(),
892+
BitVector(),
893+
Dict{Int,Int}();
894+
check_consistency=false,
895+
)
884896
else
885897
# TODO(mhauru) We allow a `vnv` to have any AbstractVector type as its vals, but
886898
# then here always revert to Vector.
@@ -944,7 +956,7 @@ julia> vnv_tight.transforms
944956
```
945957
"""
946958
function tighten_types!!(vnv::VarNamedVector)
947-
return if is_concretely_typed(vnv)
959+
return if is_tightly_typed(vnv)
948960
# There can not be anything to tighten, so short-circuit.
949961
vnv
950962
elseif isempty(vnv)
@@ -1020,6 +1032,7 @@ function insert_internal!!(
10201032
end
10211033
vnv = loosen_types!!(vnv, typeof(vn), eltype(val), typeof(transform))
10221034
insert_internal!(vnv, val, vn, transform)
1035+
vnv = tighten_types!!(vnv)
10231036
return vnv
10241037
end
10251038

@@ -1029,6 +1042,7 @@ function update_internal!!(
10291042
transform_resolved = transform === nothing ? gettransform(vnv, vn) : transform
10301043
vnv = loosen_types!!(vnv, typeof(vn), eltype(val), typeof(transform_resolved))
10311044
update_internal!(vnv, val, vn, transform)
1045+
vnv = tighten_types!!(vnv)
10321046
return vnv
10331047
end
10341048

@@ -1104,6 +1118,9 @@ care about them.
11041118
11051119
This is in a sense the reverse operation of `vnv[:]`.
11061120
1121+
The return value may share memory with the input `vnv`, and thus one can not be mutated
1122+
safely without affecting the other.
1123+
11071124
Unflatten recontiguifies the internal storage, getting rid of any inactive entries.
11081125
11091126
# Examples
@@ -1125,15 +1142,20 @@ function unflatten(vnv::VarNamedVector, vals::AbstractVector)
11251142
),
11261143
)
11271144
end
1128-
new_ranges = deepcopy(vnv.ranges)
1129-
recontiguify_ranges!(new_ranges)
1145+
new_ranges = vnv.ranges
1146+
num_inactive = vnv.num_inactive
1147+
if has_inactive(vnv)
1148+
new_ranges = recontiguify_ranges!(new_ranges)
1149+
num_inactive = Dict{Int,Int}()
1150+
end
11301151
return VarNamedVector(
11311152
vnv.varname_to_index,
11321153
vnv.varnames,
11331154
new_ranges,
11341155
vals,
11351156
vnv.transforms,
1136-
vnv.is_unconstrained;
1157+
vnv.is_unconstrained,
1158+
num_inactive;
11371159
check_consistency=false,
11381160
)
11391161
end
@@ -1428,6 +1450,9 @@ julia> vnv[@varname(x)] # All the values are still there.
14281450
```
14291451
"""
14301452
function contiguify!(vnv::VarNamedVector)
1453+
if !has_inactive(vnv)
1454+
return vnv
1455+
end
14311456
# Extract the re-contiguified values.
14321457
# NOTE: We need to do this before we update the ranges.
14331458
old_vals = copy(vnv.vals)

test/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
44
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
55
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
66
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
7+
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
78
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
89
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
910
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
@@ -34,6 +35,7 @@ AbstractMCMC = "5"
3435
AbstractPPL = "0.13"
3536
Accessors = "0.1"
3637
Aqua = "0.8"
38+
BangBang = "0.4"
3739
Bijectors = "0.15.1"
3840
Combinatorics = "1"
3941
DifferentiationInterface = "0.6.41, 0.7"

test/accumulators.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ using DynamicPPL:
117117
@test at_all64[:LogLikelihood] == ll_f64
118118

119119
@test haskey(AccumulatorTuple(lp_f64), Val(:LogPrior))
120-
@test ~haskey(AccumulatorTuple(lp_f64), Val(:LogLikelihood))
120+
@test !haskey(AccumulatorTuple(lp_f64), Val(:LogLikelihood))
121121
@test length(AccumulatorTuple(lp_f64, ll_f64)) == 2
122122
@test keys(at_all64) == (:LogPrior, :LogLikelihood)
123123
@test collect(at_all64) == [lp_f64, ll_f64]

test/ext/DynamicPPLJETExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
DynamicPPL.NTVarInfo
5757
# Should fail if we're including errors in the model body.
5858
@test DynamicPPL.Experimental.determine_suitable_varinfo(
59-
demo5(); only_ddpl=false
59+
demo5(); only_dppl=false
6060
) isa DynamicPPL.UntypedVarInfo
6161
end
6262

0 commit comments

Comments
 (0)