Skip to content

Commit c89ed6f

Browse files
Merge pull request #60 from CasBex/59_sumtree
2 parents 85de617 + 3f3e99f commit c89ed6f

File tree

4 files changed

+95
-5
lines changed

4 files changed

+95
-5
lines changed

Project.toml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,18 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1616
[compat]
1717
Adapt = "3"
1818
CircularArrayBuffers = "0.1"
19+
DataStructures = "0.18"
1920
ElasticArrays = "1"
2021
MacroTools = "0.5"
2122
OnlineStats = "1"
2223
StackViews = "0.1"
23-
julia = "1.9"
24-
DataStructures = "0.18"
2524
StatsBase = "0.34"
25+
julia = "1.9"
2626

2727
[extras]
2828
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
29+
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
2930
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3031

3132
[targets]
32-
test = ["Test", "CUDA"]
33+
test = ["Test", "CUDA", "StableRNGs"]

src/common/sum_tree.jl

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,28 @@ function Base.empty!(t::SumTree)
131131
t
132132
end
133133

134+
"""
135+
correct_sample(t::SumTree, leaf_ind)
136+
Check whether the sampled leaf is valid and if not return another valid leaf close to it. Used to correct samples with zero priority which may occur due to numerical errors with floats.
137+
"""
138+
function correct_sample(t::SumTree, leaf_ind)
139+
p = t.tree[leaf_ind]
140+
# walk backwards until p != 0 or until leftmost leaf reached
141+
tmp_ind = leaf_ind
142+
while iszero(p) && (tmp_ind-1)*2 > length(t.tree)
143+
tmp_ind -= 1
144+
p = t.tree[tmp_ind]
145+
end
146+
# walk forwards until p != 0 or until rightmost leaf reached
147+
iszero(p) && (tmp_ind = leaf_ind)
148+
while iszero(p) && (tmp_ind - t.nparents) <= t.length
149+
tmp_ind += 1
150+
p = t.tree[tmp_ind]
151+
end
152+
return p, tmp_ind
153+
end
154+
155+
134156
function Base.get(t::SumTree, v)
135157
parent_ind = 1
136158
leaf_ind = parent_ind
@@ -152,7 +174,7 @@ function Base.get(t::SumTree, v)
152174
if leaf_ind <= t.nparents
153175
leaf_ind += t.capacity
154176
end
155-
p = t.tree[leaf_ind]
177+
p, leaf_ind = correct_sample(t, leaf_ind)
156178
ind = leaf_ind - t.nparents
157179
real_ind = ind >= t.first ? ind - t.first + 1 : ind + t.capacity - t.first + 1
158180
real_ind, p
@@ -172,4 +194,4 @@ function Random.rand(rng::AbstractRNG, t::SumTree{T}, n::Int) where {T}
172194
inds, priorities
173195
end
174196

175-
Random.rand(t::SumTree, n::Int) = rand(Random.GLOBAL_RNG, t, n)
197+
Random.rand(t::SumTree, n::Int) = rand(Random.GLOBAL_RNG, t, n)

test/runtests.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
using ReinforcementLearningTrajectories
22
using CircularArrayBuffers, DataStructures
3+
using StableRNGs
34
using Test
45
using CUDA
56
using Adapt
7+
using Random
68
import ReinforcementLearningTrajectories.StatsBase.sample
9+
import StatsBase.countmap
710

811
struct TestAdaptor end
912

@@ -13,6 +16,7 @@ Adapt.adapt_storage(to::TestAdaptor, x) = CUDA.functional() ? CUDA.cu(x) : x
1316

1417
@testset "ReinforcementLearningTrajectories.jl" begin
1518
include("traces.jl")
19+
include("sum_tree.jl")
1620
include("common.jl")
1721
include("samplers.jl")
1822
include("controllers.jl")

test/sum_tree.jl

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
function gen_rand_sumtree(n, seed, type::DataType=Float32)
2+
rng = StableRNG(seed)
3+
a = SumTree(type, n)
4+
append!(a, rand(rng, type, n))
5+
return a
6+
end
7+
8+
function gen_sumtree_with_zeros(n, seed, type::DataType=Float32)
9+
a = gen_rand_sumtree(n, seed, type)
10+
b = rand(StableRNG(seed), Bool, n)
11+
return copy_multiply(a, b)
12+
end
13+
14+
function copy_multiply(stree, m)
15+
new_tree = deepcopy(stree)
16+
new_tree .*= m
17+
return new_tree
18+
end
19+
20+
function sumtree_nozero(t::SumTree, rng::AbstractRNG, iters=1)
21+
for _ in iters
22+
(_, p) = rand(rng, t)
23+
p == 0 && return false
24+
end
25+
return true
26+
end
27+
sumtree_nozero(n::Integer, seed::Integer, iters=1) = sumtree_nozero(gen_sumtree_with_zeros(n, seed), StableRNG(seed), iters)
28+
sumtree_nozero(n, seeds::AbstractVector, iters=1) = all(sumtree_nozero(n, seed, iters) for seed in seeds)
29+
30+
31+
function sumtree_distribution!(indices, priorities, t::SumTree, rng::AbstractRNG, iters=1000*t.length)
32+
for i = 1:iters
33+
indices[i], priorities[i] = rand(rng, t)
34+
end
35+
imap = countmap(indices)
36+
est_pdf = Dict(k=>v/length(indices) for (k, v) in imap)
37+
ex_pdf = Dict(k=>v/t.tree[1] for (k, v) in Dict(1:length(t) .=> t))
38+
abserrs = [est_pdf[k] - ex_pdf[k] for k in keys(est_pdf)]
39+
return abserrs
40+
end
41+
sumtree_distribution!(indices, priorities, n, seed, iters=1000*n) = sumtree_distribution!(indices, priorities, gen_rand_sumtree(n, seed), StableRNG(seed), iters)
42+
function sumtree_distribution(n, seeds::AbstractVector, iters=1000*n)
43+
p = [zeros(Float32, iters) for _ = 1:Threads.nthreads()]
44+
i = [zeros(Float32, iters) for _ = 1:Threads.nthreads()]
45+
results = Vector{Vector{Float64}}(undef, length(seeds))
46+
Threads.@threads for ix = 1:length(seeds)
47+
results[ix] = sumtree_distribution!(i[Threads.threadid()], p[Threads.threadid()], gen_rand_sumtree(n, seeds[ix]), StableRNG(seeds[ix]), iters)
48+
end
49+
return results
50+
end
51+
52+
@testset "SumTree" begin
53+
n = 1024
54+
seeds = 1:100
55+
nozero_iters=1024
56+
distr_iters=1024*10_000
57+
abstol = 0.05
58+
maxerr=0.01
59+
60+
@test sumtree_nozero(n, seeds, nozero_iters)
61+
@test all(x->all(x .< maxerr) && sum(abs2, x) < abstol,
62+
sumtree_distribution(n, seeds, distr_iters))
63+
end

0 commit comments

Comments
 (0)