Skip to content

Commit 68b7cd0

Browse files
committed
add NStepSampler
1 parent e868ba2 commit 68b7cd0

File tree

5 files changed

+83
-13
lines changed

5 files changed

+83
-13
lines changed

src/ReinforcementLearningTrajectories.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@ export RLTrajectories
55

66
include("patch.jl")
77
include("traces.jl")
8+
include("common/common.jl")
89
include("samplers.jl")
910
include("controllers.jl")
1011
include("trajectory.jl")
1112
include("normalization.jl")
12-
include("common/common.jl")
1313

1414
end

src/common/CircularArraySLARTTraces.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,4 +32,4 @@ function CircularArraySLARTTraces(;
3232
reward=CircularArrayBuffer{reward_eltype}(reward_size..., capacity),
3333
terminal=CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity),
3434
)
35-
end
35+
end

src/common/common.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1+
export SS, LL, AA, RT, SSART, SSAART, SSLLAART
2+
13
using CircularArrayBuffers
24

35
const SS = (:state, :next_state)
46
const LL = (:legal_actions_mask, :next_legal_actions_mask)
57
const AA = (:action, :next_action)
68
const RT = (:reward, :terminal)
9+
const SSART = (SS..., :action, RT...)
710
const SSAART = (SS..., AA..., RT...)
811
const SSLLAART = (SS..., LL..., AA..., RT...)
912

src/samplers.jl

Lines changed: 73 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,20 @@
1-
export BatchSampler, MetaSampler, MultiBatchSampler
2-
31
using Random
42

53
abstract type AbstractSampler end
64

5+
#####
6+
# BatchSampler
7+
#####
8+
9+
export BatchSampler
710
struct BatchSampler{names} <: AbstractSampler
811
batch_size::Int
912
rng::Random.AbstractRNG
10-
transformer::Any
1113
end
1214

1315
"""
14-
BatchSampler{names}(;batch_size, rng=Random.GLOBAL_RNG, transformer=identity)
15-
BatchSampler{names}(batch_size ;rng=Random.GLOBAL_RNG, transformer=identity)
16+
BatchSampler{names}(;batch_size, rng=Random.GLOBAL_RNG)
17+
BatchSampler{names}(batch_size ;rng=Random.GLOBAL_RNG)
1618
1719
Uniformly sample a batch of examples for each trace specified in `names`.
1820
By default, all the traces will be sampled.
@@ -22,16 +24,22 @@ See also [`sample`](@ref).
2224
BatchSampler(batch_size; kw...) = BatchSampler(; batch_size=batch_size, kw...)
2325
BatchSampler(; kw...) = BatchSampler{nothing}(; kw...)
2426
BatchSampler{names}(batch_size; kw...) where {names} = BatchSampler{names}(; batch_size=batch_size, kw...)
25-
BatchSampler{names}(; batch_size, rng=Random.GLOBAL_RNG, transformer=identity) where {names} = BatchSampler{names}(batch_size, rng, transformer)
27+
BatchSampler{names}(; batch_size, rng=Random.GLOBAL_RNG) where {names} = BatchSampler{names}(batch_size, rng)
2628

2729
sample(s::BatchSampler{nothing}, t::AbstractTraces) = sample(s, t, keys(t))
2830
sample(s::BatchSampler{names}, t::AbstractTraces) where {names} = sample(s, t, names)
2931

3032
function sample(s::BatchSampler, t::AbstractTraces, names)
3133
inds = rand(s.rng, 1:length(t), s.batch_size)
32-
NamedTuple{names}(s.transformer(t[x][inds] for x in names))
34+
NamedTuple{names}(t[x][inds] for x in names)
3335
end
3436

37+
#####
38+
# MetaSampler
39+
#####
40+
41+
export MetaSampler
42+
3543
"""
3644
MetaSampler(::NamedTuple)
3745
@@ -52,6 +60,11 @@ MetaSampler(; kw...) = MetaSampler(NamedTuple(kw))
5260

5361
sample(s::MetaSampler, t) = map(x -> sample(x, t), s.samplers)
5462

63+
#####
64+
# MultiBatchSampler
65+
#####
66+
67+
export MultiBatchSampler
5568

5669
"""
5770
MultiBatchSampler(sampler, n)
@@ -71,3 +84,56 @@ struct MultiBatchSampler{S<:AbstractSampler} <: AbstractSampler
7184
end
7285

7386
sample(m::MultiBatchSampler, t) = [sample(m.sampler, t) for _ in 1:m.n]
87+
88+
#####
89+
# NStepBatchSampler
90+
#####
91+
92+
export NStepBatchSampler
93+
94+
Base.@kwdef mutable struct NStepBatchSampler{traces}
95+
n::Int # !!! n starts from 1
96+
γ::Float32
97+
batch_size::Int = 32
98+
stack_size::Union{Nothing,Int} = nothing
99+
rng::Any = Random.GLOBAL_RNG
100+
end
101+
102+
select_last_dim(xs::AbstractArray{T,N}, inds) where {T,N} = @views xs[ntuple(_ -> (:), Val(N - 1))..., inds]
103+
select_last_frame(xs::AbstractArray{T,N}) where {T,N} = select_last_dim(xs, size(xs, N))
104+
105+
consecutive_view(cb, inds; n_stack=nothing, n_horizon=nothing) = consecutive_view(cb, inds, n_stack, n_horizon)
106+
consecutive_view(cb, inds, ::Nothing, ::Nothing) = select_last_dim(cb, inds)
107+
consecutive_view(cb, inds, n_stack::Int, ::Nothing) = select_last_dim(cb, [x + i for i in -n_stack+1:0, x in inds])
108+
consecutive_view(cb, inds, ::Nothing, n_horizon::Int) = select_last_dim(cb, [x + j for j in 0:n_horizon-1, x in inds])
109+
consecutive_view(cb, inds, n_stack::Int, n_horizon::Int) = select_last_dim(cb, [x + i + j for i in -n_stack+1:0, j in 0:n_horizon-1, x in inds])
110+
111+
function sample(s::NStepBatchSampler{names}, ts) where {names}
112+
valid_range = isnothing(s.stack_size) ? (1:(length(ts)-s.n+1)) : (s.stack_size:(length(ts)-s.n+1))# think about the exteme case where s.stack_size == 1 and s.n == 1
113+
inds = rand(s.rng, valid_range, s.batch_size)
114+
sample(s, ts, Val(names), inds)
115+
end
116+
117+
function sample(s::NStepBatchSampler, ts, ::Val{SSART}, inds)
118+
s = consecutive_view(ts[:state], inds; n_stack=s.stack_size)
119+
s′ = consecutive_view(ts[:next_state], inds .+ (s.n - 1); n_stack=s.stack_size)
120+
a = consecutive_view(ts[:action], inds)
121+
t_horizon = consecutive_view(ts[:terminal], inds; n_horizon=s.n)
122+
r_horizon = consecutive_view(ts[:reward], inds; n_horizon=s.n)
123+
124+
@assert ndims(t_horizon) == 2
125+
t = any(t_horizon, dims=1)
126+
127+
@assert ndims(r_horizon) == 2
128+
r = map(eachcol(r_horizon), eachcol(t_horizon)) do r⃗, t⃗
129+
foldr((init, (rr, tt)) -> rr + f.γ * init * (1 - tt), zip(r⃗, t⃗); init=0.0f0)
130+
end
131+
132+
NamedTuple{names}(s, s′, a, r, t)
133+
end
134+
135+
function sample(s::NStepBatchSampler, ts, ::Val{SSLART}, inds)
136+
s, s′, a, r, t = sample(s, ts, Val(SSART), inds),
137+
l = consecutive_view(ts[:legal_actions_mask], inds)
138+
NamedTuple{SSLART}(s, s′, l, a, r, t)
139+
end

src/trajectory.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,11 @@ Base.@kwdef struct Trajectory{C,S,T}
2424
container::C
2525
sampler::S
2626
controller::T = InsertSampleRatioController()
27+
transformer::Any = identity
2728

28-
Trajectory(c::C, s::S, t::T=InsertSampleRatioController()) where {C,S,T} = new{C,S,T}(c, s, t)
29+
Trajectory(c::C, s::S, t::T=InsertSampleRatioController(), f=identity) where {C,S,T} = new{C,S,T}(c, s, t, f)
2930

30-
function Trajectory(container::C, sampler::S, controller::T) where {C,S,T<:AsyncInsertSampleRatioController}
31+
function Trajectory(container::C, sampler::S, controller::T, transformer) where {C,S,T<:AsyncInsertSampleRatioController}
3132
t = Threads.@spawn while true
3233
for msg in controller.ch_in
3334
if msg.f === Base.push!
@@ -54,7 +55,7 @@ Base.@kwdef struct Trajectory{C,S,T}
5455

5556
bind(controller.ch_in, t)
5657
bind(controller.ch_out, t)
57-
new{C,S,T}(container, sampler, controller)
58+
new{C,S,T}(container, sampler, controller, transformer)
5859
end
5960
end
6061

@@ -97,7 +98,7 @@ function Base.take!(t::Trajectory)
9798
if isnothing(res)
9899
nothing
99100
else
100-
sample(t.sampler, t.container)
101+
sample(t.sampler, t.container) |> t.transformer
101102
end
102103
end
103104

0 commit comments

Comments
 (0)