1- export BatchSampler, MetaSampler, MultiBatchSampler
2-
31using Random
42
53abstract type AbstractSampler end
64
5+ # ####
6+ # BatchSampler
7+ # ####
8+
9+ export BatchSampler
710struct BatchSampler{names} <: AbstractSampler
811 batch_size:: Int
912 rng:: Random.AbstractRNG
10- transformer:: Any
1113end
1214
1315"""
14- BatchSampler{names}(;batch_size, rng=Random.GLOBAL_RNG, transformer=identity )
15- BatchSampler{names}(batch_size ;rng=Random.GLOBAL_RNG, transformer=identity )
16+ BatchSampler{names}(;batch_size, rng=Random.GLOBAL_RNG)
17+ BatchSampler{names}(batch_size ;rng=Random.GLOBAL_RNG)
1618
1719Uniformly sample a batch of examples for each trace specified in `names`.
1820By default, all the traces will be sampled.
@@ -22,16 +24,22 @@ See also [`sample`](@ref).
2224BatchSampler (batch_size; kw... ) = BatchSampler (; batch_size= batch_size, kw... )
2325BatchSampler (; kw... ) = BatchSampler {nothing} (; kw... )
2426BatchSampler {names} (batch_size; kw... ) where {names} = BatchSampler {names} (; batch_size= batch_size, kw... )
25- BatchSampler {names} (; batch_size, rng= Random. GLOBAL_RNG, transformer = identity ) where {names} = BatchSampler {names} (batch_size, rng, transformer )
27+ BatchSampler {names} (; batch_size, rng= Random. GLOBAL_RNG) where {names} = BatchSampler {names} (batch_size, rng)
2628
2729sample (s:: BatchSampler{nothing} , t:: AbstractTraces ) = sample (s, t, keys (t))
2830sample (s:: BatchSampler{names} , t:: AbstractTraces ) where {names} = sample (s, t, names)
2931
3032function sample (s:: BatchSampler , t:: AbstractTraces , names)
3133 inds = rand (s. rng, 1 : length (t), s. batch_size)
32- NamedTuple {names} (s . transformer ( t[x][inds] for x in names) )
34+ NamedTuple {names} (t[x][inds] for x in names)
3335end
3436
37+ # ####
38+ # MetaSampler
39+ # ####
40+
41+ export MetaSampler
42+
3543"""
3644 MetaSampler(::NamedTuple)
3745
@@ -52,6 +60,11 @@ MetaSampler(; kw...) = MetaSampler(NamedTuple(kw))
5260
5361sample (s:: MetaSampler , t) = map (x -> sample (x, t), s. samplers)
5462
63+ # ####
64+ # MultiBatchSampler
65+ # ####
66+
67+ export MultiBatchSampler
5568
5669"""
5770 MultiBatchSampler(sampler, n)
@@ -71,3 +84,56 @@ struct MultiBatchSampler{S<:AbstractSampler} <: AbstractSampler
7184end
7285
7386sample (m:: MultiBatchSampler , t) = [sample (m. sampler, t) for _ in 1 : m. n]
87+
88+ # ####
89+ # NStepBatchSampler
90+ # ####
91+
92+ export NStepBatchSampler
93+
94+ Base. @kwdef mutable struct NStepBatchSampler{traces}
95+ n:: Int # !!! n starts from 1
96+ γ:: Float32
97+ batch_size:: Int = 32
98+ stack_size:: Union{Nothing,Int} = nothing
99+ rng:: Any = Random. GLOBAL_RNG
100+ end
101+
102+ select_last_dim (xs:: AbstractArray{T,N} , inds) where {T,N} = @views xs[ntuple (_ -> (:), Val (N - 1 ))... , inds]
103+ select_last_frame (xs:: AbstractArray{T,N} ) where {T,N} = select_last_dim (xs, size (xs, N))
104+
105+ consecutive_view (cb, inds; n_stack= nothing , n_horizon= nothing ) = consecutive_view (cb, inds, n_stack, n_horizon)
106+ consecutive_view (cb, inds, :: Nothing , :: Nothing ) = select_last_dim (cb, inds)
107+ consecutive_view (cb, inds, n_stack:: Int , :: Nothing ) = select_last_dim (cb, [x + i for i in - n_stack+ 1 : 0 , x in inds])
108+ consecutive_view (cb, inds, :: Nothing , n_horizon:: Int ) = select_last_dim (cb, [x + j for j in 0 : n_horizon- 1 , x in inds])
109+ consecutive_view (cb, inds, n_stack:: Int , n_horizon:: Int ) = select_last_dim (cb, [x + i + j for i in - n_stack+ 1 : 0 , j in 0 : n_horizon- 1 , x in inds])
110+
111+ function sample (s:: NStepBatchSampler{names} , ts) where {names}
112+ valid_range = isnothing (s. stack_size) ? (1 : (length (ts)- s. n+ 1 )) : (s. stack_size: (length (ts)- s. n+ 1 ))# think about the exteme case where s.stack_size == 1 and s.n == 1
113+ inds = rand (s. rng, valid_range, s. batch_size)
114+ sample (s, ts, Val (names), inds)
115+ end
116+
117+ function sample (s:: NStepBatchSampler , ts, :: Val{SSART} , inds)
118+ s = consecutive_view (ts[:state ], inds; n_stack= s. stack_size)
119+ s′ = consecutive_view (ts[:next_state ], inds .+ (s. n - 1 ); n_stack= s. stack_size)
120+ a = consecutive_view (ts[:action ], inds)
121+ t_horizon = consecutive_view (ts[:terminal ], inds; n_horizon= s. n)
122+ r_horizon = consecutive_view (ts[:reward ], inds; n_horizon= s. n)
123+
124+ @assert ndims (t_horizon) == 2
125+ t = any (t_horizon, dims= 1 )
126+
127+ @assert ndims (r_horizon) == 2
128+ r = map (eachcol (r_horizon), eachcol (t_horizon)) do r⃗, t⃗
129+ foldr ((init, (rr, tt)) -> rr + f. γ * init * (1 - tt), zip (r⃗, t⃗); init= 0.0f0 )
130+ end
131+
132+ NamedTuple {names} (s, s′, a, r, t)
133+ end
134+
135+ function sample (s:: NStepBatchSampler , ts, :: Val{SSLART} , inds)
136+ s, s′, a, r, t = sample (s, ts, Val (SSART), inds),
137+ l = consecutive_view (ts[:legal_actions_mask ], inds)
138+ NamedTuple {SSLART} (s, s′, l, a, r, t)
139+ end
0 commit comments