Skip to content

Commit 7974510

Browse files
committed
fix tests
1 parent 17ff1ed commit 7974510

File tree

6 files changed

+33
-44
lines changed

6 files changed

+33
-44
lines changed

src/samplers.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@ Uniformly sample a batch of examples for each trace specified in `names`. By def
1717
1818
See also [`sample`](@ref).
1919
"""
20+
BatchSampler(batch_size; kw...) = BatchSampler(; batch_size=batch_size, kw...)
2021
BatchSampler(; kw...) = BatchSampler{nothing}(; kw...)
22+
BatchSampler{names}(batch_size; kw...) where {names} = BatchSampler{names}(; batch_size=batch_size, kw...)
2123
BatchSampler{names}(; batch_size, rng=Random.GLOBAL_RNG, transformer=identity) where {names} = BatchSampler{names}(batch_size, rng, transformer)
2224

2325
sample(s::BatchSampler{nothing}, t::AbstractTraces) = sample(s, t, keys(t))

src/traces.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ end
298298

299299
Base.size(t::Traces) = (mapreduce(length, min, t.traces),)
300300

301-
for f in (:push!, :pushfirst!, :append!, :prepend!)
301+
for f in (:push!, :pushfirst!)
302302
@eval function Base.$f(ts::Traces, xs::NamedTuple)
303303
for (k, v) in pairs(xs)
304304
t = ts.traces[ts.inds[k]]
@@ -311,6 +311,15 @@ for f in (:push!, :pushfirst!, :append!, :prepend!)
311311
end
312312
end
313313

314+
for f in (:append!, :prepend!)
315+
@eval function Base.$f(ts::Traces, xs::Traces)
316+
for k in keys(xs)
317+
t = ts.traces[ts.inds[k]]
318+
$f(t, xs[k])
319+
end
320+
end
321+
end
322+
314323
for f in (:pop!, :popfirst!, :empty!)
315324
@eval function Base.$f(ts::Traces)
316325
for t in ts.traces

src/trajectory.jl

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,14 @@ Base.@kwdef struct Trajectory{C,S,T}
3030
function Trajectory(container::C, sampler::S, controller::T) where {C,S,T<:AsyncInsertSampleRatioController}
3131
t = Threads.@spawn while true
3232
for msg in controller.ch_in
33-
if msg.f === Base.push! || msg.f === Base.append!
34-
n_pre = length(container)
35-
msg.f(container, msg.args...; msg.kw...)
36-
n_post = length(container)
37-
controller.n_inserted += n_post - n_pre
33+
if msg.f === Base.push!
34+
x, = msg.args
35+
msg.f(container, x)
36+
controller.n_inserted += 1
37+
elseif msg.f === Base.append!
38+
x, = msg.args
39+
msg.f(container, x)
40+
controller.n_inserted += length(x)
3841
else
3942
msg.f(container, msg.args...; msg.kw...)
4043
end
@@ -78,10 +81,10 @@ struct CallMsg
7881
kw::Any
7982
end
8083

81-
Base.push!(t::Trajectory{<:Any,<:Any,<:AsyncInsertSampleRatioController}, args...; kw...) = put!(t.controller.ch_in, CallMsg(Base.push!, args, kw))
82-
Base.append!(t::Trajectory{<:Any,<:Any,<:AsyncInsertSampleRatioController}, args...; kw...) = put!(t.controller.ch_in, CallMsg(Base.append!, args, kw))
84+
Base.push!(t::Trajectory{<:Any,<:Any,<:AsyncInsertSampleRatioController}, x) = put!(t.controller.ch_in, CallMsg(Base.push!, (x,), NamedTuple()))
85+
Base.append!(t::Trajectory{<:Any,<:Any,<:AsyncInsertSampleRatioController}, x) = put!(t.controller.ch_in, CallMsg(Base.append!, (x,), NamedTuple()))
8386

84-
function Base.append!(t::Trajectory, x::AbstractVector)
87+
function Base.append!(t::Trajectory, x)
8588
append!(t.container, x)
8689
on_insert!(t.controller, length(x))
8790
end

test/samplers.jl

Lines changed: 2 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ end
3333
sampler=MetaSampler(policy=BatchSampler(3), critic=BatchSampler(5)),
3434
)
3535

36-
append!(t, (a=rand(Int, 10), b=rand(Bool, 10)))
36+
append!(t, Traces(a=rand(Int, 10), b=rand(Bool, 10)))
3737

3838
batches = collect(t)
3939

@@ -50,37 +50,12 @@ end
5050
sampler=MetaSampler(policy=BatchSampler(3), critic=MultiBatchSampler(BatchSampler(5), 2)),
5151
)
5252

53-
append!(t, (a=rand(Int, 10), b=rand(Bool, 10)))
53+
append!(t, Traces(a=rand(Int, 10), b=rand(Bool, 10)))
5454

5555
batches = collect(t)
5656

5757
@test length(batches) == 10
5858
@test length(batches[1][:policy][:a]) == 3
5959
@test length(batches[1][:critic]) == 2 # we sampled 2 batches for critic
6060
@test length(batches[1][:critic][1][:b]) == 5 #each batch is 5 samples
61-
end
62-
63-
@testset "async trajectories" begin
64-
threshould = 100
65-
ratio = 1 / 4
66-
t = Trajectory(
67-
container=Traces(
68-
a=Int[],
69-
b=Bool[]
70-
),
71-
sampler=BatchSampler(3),
72-
controller=AsyncInsertSampleRatioController(ratio, threshould)
73-
)
74-
75-
n = 100
76-
insert_task = @async for i in 1:n
77-
append!(t, (a=[i, i, i, i], b=[false, true, false, true]))
78-
end
79-
80-
s = 0
81-
sample_task = @async for _ in t
82-
s += 1
83-
end
84-
sleep(1)
85-
@test s == (n - threshould * ratio) + 1
8661
end

test/traces.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
@test t[:a][end] == 3
1212
@test t[:b][end] == true
1313

14-
append!(t, (a=[4, 5], b=[false, false]))
14+
append!(t, Traces(a=[4, 5], b=[false, false]))
1515
@test length(t[:a]) == 5
1616
@test t[:b][end-1:end] == [false, false]
1717

@@ -71,7 +71,7 @@ end
7171
@test length(t3) == 1
7272
@test t3[1] == (a=1, b=false)
7373

74-
append!(t3, (; a=[2, 3], b=[false, true]))
74+
append!(t3, Traces(; a=[2, 3], b=[false, true]))
7575
@test length(t3) == 3
7676

7777
@test t3[:a][1:3] == [1, 2, 3]
@@ -102,7 +102,7 @@ end
102102

103103
empty!(t8)
104104
push!(t8, (a=1, b=false, aa=1, bb=false))
105-
append!(t8, (a=[2, 3], b=[true, true], aa=[2, 3], bb=[true, true]))
105+
append!(t8, Traces(a=[2, 3], b=[true, true], aa=[2, 3], bb=[true, true]))
106106

107107
@test length(t8) == 3
108108

@@ -124,7 +124,7 @@ end
124124
push!(t, (state=1, action=1.0))
125125
@test length(t) == 1
126126

127-
append!(t, (state=[2, 3], action=[2.0, 3.0]))
127+
append!(t, Traces(state=[2, 3], action=[2.0, 3.0]))
128128
@test length(t) == 3
129129

130130
@test t[:state] == [1, 2, 3]

test/trajectories.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ end
3232

3333
@test length(batches) == 0 # threshold not reached yet
3434

35-
append!(t, (a=[1, 2, 3], b=[false, true, false]))
35+
append!(t, Traces(a=[1, 2, 3], b=[false, true, false]))
3636

3737
for batch in t
3838
push!(batches, batch)
@@ -48,7 +48,7 @@ end
4848

4949
@test length(batches) == 1 # 4 inserted, threshold is 4, ratio is 0.25
5050

51-
append!(t, (a=[5, 6, 7], b=[true, true, true]))
51+
append!(t, Traces(a=[5, 6, 7], b=[true, true, true]))
5252

5353
for batch in t
5454
push!(batches, batch)
@@ -66,7 +66,7 @@ end
6666

6767
n = 100
6868
for i in 1:n
69-
append!(t, (a=[i, i, i, i], b=[false, true, false, true]))
69+
append!(t, Traces(a=[i, i, i, i], b=[false, true, false, true]))
7070
end
7171

7272
s = 0
@@ -90,7 +90,7 @@ end
9090

9191
n = 100
9292
insert_task = @async for i in 1:n
93-
append!(t, (a=[i, i, i, i], b=[false, true, false, true]))
93+
append!(t, Traces(a=[i, i, i, i], b=[false, true, false, true]))
9494
end
9595

9696
s = 0

0 commit comments

Comments
 (0)