@@ -31,7 +31,7 @@ sample(s::BatchSampler{names}, t::AbstractTraces) where {names} = sample(s, t, n
3131
3232function sample (s:: BatchSampler , t:: AbstractTraces , names)
3333 inds = rand (s. rng, 1 : length (t), s. batch_size)
34- NamedTuple {names} (t[x][inds] for x in names)
34+ NamedTuple {names} (map (x -> t[x][inds], names) )
3535end
3636
3737# ####
@@ -99,7 +99,7 @@ mutable struct NStepBatchSampler{traces}
9999 rng:: Any
100100end
101101
102- NStepBatchSampler (; kw... ) = NStepBatchSampler {SSART } (; kw... )
102+ NStepBatchSampler (; kw... ) = NStepBatchSampler {SS′ART } (; kw... )
103103NStepBatchSampler {names} (; n, γ, batch_size= 32 , stack_size= nothing , rng= Random. GLOBAL_RNG) where {names} = NStepBatchSampler {names} (n, γ, batch_size, stack_size, rng)
104104
105105function sample (s:: NStepBatchSampler{names} , ts) where {names}
@@ -108,7 +108,7 @@ function sample(s::NStepBatchSampler{names}, ts) where {names}
108108 sample (s, ts, Val (names), inds)
109109end
110110
111- function sample (nbs:: NStepBatchSampler , ts, :: Val{SSART } , inds)
111+ function sample (nbs:: NStepBatchSampler , ts, :: Val{SS′ART } , inds)
112112 if isnothing (nbs. stack_size)
113113 s = ts[:state ][inds]
114114 s′ = ts[:next_state ][inds.+ (nbs. n- 1 )]
@@ -129,11 +129,11 @@ function sample(nbs::NStepBatchSampler, ts, ::Val{SSART}, inds)
129129 foldr (((rr, tt), init) -> rr + nbs. γ * init * (1 - tt), zip (r⃗, t⃗); init= 0.0f0 )
130130 end
131131
132- NamedTuple {SSART } ((s, s′, a, r, t))
132+ NamedTuple {SS′ART } ((s, s′, a, r, t))
133133end
134134
135- function sample (s:: NStepBatchSampler , ts, :: Val{SSLART } , inds)
135+ function sample (s:: NStepBatchSampler , ts, :: Val{SS′L′ART } , inds)
136136 s, s′, a, r, t = sample (s, ts, Val (SSART), inds)
137- l = consecutive_view (ts[:legal_actions_mask ], inds)
137+ l = consecutive_view (ts[:next_legal_actions_mask ], inds)
138138 NamedTuple {SSLART} ((s, s′, l, a, r, t))
139139end
0 commit comments