Skip to content

Commit 9095619

Browse files
authored
Merge pull request #21 from findmyway/minor_enhancements
minor updates when adapting this package in RL.jl
2 parents f2239bd + 7974510 commit 9095619

File tree

8 files changed

+49
-63
lines changed

8 files changed

+49
-63
lines changed

Project.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,13 @@ version = "0.1.0"
44

55
[deps]
66
CircularArrayBuffers = "9de3a189-e0c0-4e15-ba3b-b14b9fb0aec1"
7-
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
87
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
98
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
109
StackViews = "cae243ae-269e-4f55-b966-ac2d0dc13c15"
1110

1211
[compat]
1312
CircularArrayBuffers = "0.1"
1413
MacroTools = "0.5"
15-
MLUtils = "0.2"
1614
StackViews = "0.1"
1715
julia = "1.6"
1816

src/patch.jl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,3 @@
1-
import MLUtils
2-
3-
MLUtils.batch(x::AbstractArray{<:Number}) = x
4-
51
#####
62

73
import StackViews: StackView

src/samplers.jl

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,33 @@
11
export BatchSampler, MetaSampler, MultiBatchSampler
22

3-
using MLUtils: batch
4-
53
using Random
64

75
abstract type AbstractSampler end
86

9-
struct BatchSampler <: AbstractSampler
7+
struct BatchSampler{names} <: AbstractSampler
108
batch_size::Int
119
rng::Random.AbstractRNG
1210
transformer::Any
1311
end
1412

1513
"""
16-
BatchSampler(batch_size; rng=Random.GLOBAL_RNG, transformer=identity)
14+
BatchSampler{names}(;batch_size, rng=Random.GLOBAL_RNG, transformer=identity)
1715
18-
Uniformly sample a batch of examples for each trace.
16+
Uniformly sample a batch of examples for each trace specified in `names`. By default, all the traces will be sampled.
1917
2018
See also [`sample`](@ref).
2119
"""
22-
BatchSampler(batch_size; rng=Random.GLOBAL_RNG, transformer=batch) = BatchSampler(batch_size, rng, transformer)
20+
BatchSampler(batch_size; kw...) = BatchSampler(; batch_size=batch_size, kw...)
21+
BatchSampler(; kw...) = BatchSampler{nothing}(; kw...)
22+
BatchSampler{names}(batch_size; kw...) where {names} = BatchSampler{names}(; batch_size=batch_size, kw...)
23+
BatchSampler{names}(; batch_size, rng=Random.GLOBAL_RNG, transformer=identity) where {names} = BatchSampler{names}(batch_size, rng, transformer)
24+
25+
sample(s::BatchSampler{nothing}, t::AbstractTraces) = sample(s, t, keys(t))
26+
sample(s::BatchSampler{names}, t::AbstractTraces) where {names} = sample(s, t, names)
2327

24-
function sample(s::BatchSampler, t::AbstractTraces)
28+
function sample(s::BatchSampler, t::AbstractTraces, names)
2529
inds = rand(s.rng, 1:length(t), s.batch_size)
26-
map(s.transformer, t[inds])
30+
NamedTuple{names}(s.transformer(t[x][inds]) for x in names)
2731
end
2832

2933
"""

src/traces.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ end
298298

299299
Base.size(t::Traces) = (mapreduce(length, min, t.traces),)
300300

301-
for f in (:push!, :pushfirst!, :append!, :prepend!)
301+
for f in (:push!, :pushfirst!)
302302
@eval function Base.$f(ts::Traces, xs::NamedTuple)
303303
for (k, v) in pairs(xs)
304304
t = ts.traces[ts.inds[k]]
@@ -311,6 +311,15 @@ for f in (:push!, :pushfirst!, :append!, :prepend!)
311311
end
312312
end
313313

314+
for f in (:append!, :prepend!)
315+
@eval function Base.$f(ts::Traces, xs::Traces)
316+
for k in keys(xs)
317+
t = ts.traces[ts.inds[k]]
318+
$f(t, xs[k])
319+
end
320+
end
321+
end
322+
314323
for f in (:pop!, :popfirst!, :empty!)
315324
@eval function Base.$f(ts::Traces)
316325
for t in ts.traces

src/trajectory.jl

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,14 @@ Base.@kwdef struct Trajectory{C,S,T}
3030
function Trajectory(container::C, sampler::S, controller::T) where {C,S,T<:AsyncInsertSampleRatioController}
3131
t = Threads.@spawn while true
3232
for msg in controller.ch_in
33-
if msg.f === Base.push! || msg.f === Base.append!
34-
n_pre = length(container)
35-
msg.f(container, msg.args...; msg.kw...)
36-
n_post = length(container)
37-
controller.n_inserted += n_post - n_pre
33+
if msg.f === Base.push!
34+
x, = msg.args
35+
msg.f(container, x)
36+
controller.n_inserted += 1
37+
elseif msg.f === Base.append!
38+
x, = msg.args
39+
msg.f(container, x)
40+
controller.n_inserted += length(x)
3841
else
3942
msg.f(container, msg.args...; msg.kw...)
4043
end
@@ -65,11 +68,11 @@ function Base.bind(t::Trajectory{<:Any,<:Any,<:AsyncInsertSampleRatioController}
6568
bind(t.controler.ch_out, task)
6669
end
6770

71+
# !!! by default we assume `x` is a complete example which contains all the traces
72+
# When doing partial inserting, the result of undefined
6873
function Base.push!(t::Trajectory, x)
69-
n_pre = length(t.container)
7074
push!(t.container, x)
71-
n_post = length(t.container)
72-
on_insert!(t.controller, n_post - n_pre)
75+
on_insert!(t.controller, 1)
7376
end
7477

7578
struct CallMsg
@@ -78,16 +81,17 @@ struct CallMsg
7881
kw::Any
7982
end
8083

81-
Base.push!(t::Trajectory{<:Any,<:Any,<:AsyncInsertSampleRatioController}, args...; kw...) = put!(t.controller.ch_in, CallMsg(Base.push!, args, kw))
82-
Base.append!(t::Trajectory{<:Any,<:Any,<:AsyncInsertSampleRatioController}, args...; kw...) = put!(t.controller.ch_in, CallMsg(Base.append!, args, kw))
84+
Base.push!(t::Trajectory{<:Any,<:Any,<:AsyncInsertSampleRatioController}, x) = put!(t.controller.ch_in, CallMsg(Base.push!, (x,), NamedTuple()))
85+
Base.append!(t::Trajectory{<:Any,<:Any,<:AsyncInsertSampleRatioController}, x) = put!(t.controller.ch_in, CallMsg(Base.append!, (x,), NamedTuple()))
8386

8487
function Base.append!(t::Trajectory, x)
85-
n_pre = length(t.container)
8688
append!(t.container, x)
87-
n_post = length(t.container)
88-
on_insert!(t.controller, n_post - n_pre)
89+
on_insert!(t.controller, length(x))
8990
end
9091

92+
# !!! bypass the controller
93+
sample(t::Trajectory) = sample(t.sampler, t.container)
94+
9195
function Base.take!(t::Trajectory)
9296
res = on_sample!(t.controller)
9397
if isnothing(res)

test/samplers.jl

Lines changed: 2 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ end
3333
sampler=MetaSampler(policy=BatchSampler(3), critic=BatchSampler(5)),
3434
)
3535

36-
append!(t, (a=rand(Int, 10), b=rand(Bool, 10)))
36+
append!(t, Traces(a=rand(Int, 10), b=rand(Bool, 10)))
3737

3838
batches = collect(t)
3939

@@ -50,37 +50,12 @@ end
5050
sampler=MetaSampler(policy=BatchSampler(3), critic=MultiBatchSampler(BatchSampler(5), 2)),
5151
)
5252

53-
append!(t, (a=rand(Int, 10), b=rand(Bool, 10)))
53+
append!(t, Traces(a=rand(Int, 10), b=rand(Bool, 10)))
5454

5555
batches = collect(t)
5656

5757
@test length(batches) == 10
5858
@test length(batches[1][:policy][:a]) == 3
5959
@test length(batches[1][:critic]) == 2 # we sampled 2 batches for critic
6060
@test length(batches[1][:critic][1][:b]) == 5 #each batch is 5 samples
61-
end
62-
63-
@testset "async trajectories" begin
64-
threshould = 100
65-
ratio = 1 / 4
66-
t = Trajectory(
67-
container=Traces(
68-
a=Int[],
69-
b=Bool[]
70-
),
71-
sampler=BatchSampler(3),
72-
controller=AsyncInsertSampleRatioController(ratio, threshould)
73-
)
74-
75-
n = 100
76-
insert_task = @async for i in 1:n
77-
append!(t, (a=[i, i, i, i], b=[false, true, false, true]))
78-
end
79-
80-
s = 0
81-
sample_task = @async for _ in t
82-
s += 1
83-
end
84-
sleep(1)
85-
@test s == (n - threshould * ratio) + 1
8661
end

test/traces.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
@test t[:a][end] == 3
1212
@test t[:b][end] == true
1313

14-
append!(t, (a=[4, 5], b=[false, false]))
14+
append!(t, Traces(a=[4, 5], b=[false, false]))
1515
@test length(t[:a]) == 5
1616
@test t[:b][end-1:end] == [false, false]
1717

@@ -71,7 +71,7 @@ end
7171
@test length(t3) == 1
7272
@test t3[1] == (a=1, b=false)
7373

74-
append!(t3, (; a=[2, 3], b=[false, true]))
74+
append!(t3, Traces(; a=[2, 3], b=[false, true]))
7575
@test length(t3) == 3
7676

7777
@test t3[:a][1:3] == [1, 2, 3]
@@ -102,7 +102,7 @@ end
102102

103103
empty!(t8)
104104
push!(t8, (a=1, b=false, aa=1, bb=false))
105-
append!(t8, (a=[2, 3], b=[true, true], aa=[2, 3], bb=[true, true]))
105+
append!(t8, Traces(a=[2, 3], b=[true, true], aa=[2, 3], bb=[true, true]))
106106

107107
@test length(t8) == 3
108108

@@ -124,7 +124,7 @@ end
124124
push!(t, (state=1, action=1.0))
125125
@test length(t) == 1
126126

127-
append!(t, (state=[2, 3], action=[2.0, 3.0]))
127+
append!(t, Traces(state=[2, 3], action=[2.0, 3.0]))
128128
@test length(t) == 3
129129

130130
@test t[:state] == [1, 2, 3]

test/trajectories.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ end
3232

3333
@test length(batches) == 0 # threshold not reached yet
3434

35-
append!(t, (a=[1, 2, 3], b=[false, true, false]))
35+
append!(t, Traces(a=[1, 2, 3], b=[false, true, false]))
3636

3737
for batch in t
3838
push!(batches, batch)
@@ -48,7 +48,7 @@ end
4848

4949
@test length(batches) == 1 # 4 inserted, threshold is 4, ratio is 0.25
5050

51-
append!(t, (a=[5, 6, 7], b=[true, true, true]))
51+
append!(t, Traces(a=[5, 6, 7], b=[true, true, true]))
5252

5353
for batch in t
5454
push!(batches, batch)
@@ -66,7 +66,7 @@ end
6666

6767
n = 100
6868
for i in 1:n
69-
append!(t, (a=[i, i, i, i], b=[false, true, false, true]))
69+
append!(t, Traces(a=[i, i, i, i], b=[false, true, false, true]))
7070
end
7171

7272
s = 0
@@ -90,7 +90,7 @@ end
9090

9191
n = 100
9292
insert_task = @async for i in 1:n
93-
append!(t, (a=[i, i, i, i], b=[false, true, false, true]))
93+
append!(t, Traces(a=[i, i, i, i], b=[false, true, false, true]))
9494
end
9595

9696
s = 0

0 commit comments

Comments
 (0)