diff --git a/src/common/CircularArraySARTTraces.jl b/src/common/CircularArraySARTTraces.jl index f29093a..3fd601b 100644 --- a/src/common/CircularArraySARTTraces.jl +++ b/src/common/CircularArraySARTTraces.jl @@ -1,10 +1,10 @@ export CircularArraySARTTraces const CircularArraySARTTraces = Traces{ - SSAART, + SS′AA′RT, <:Tuple{ - <:MultiplexTraces{SS,<:Trace{<:CircularArrayBuffer}}, - <:MultiplexTraces{AA,<:Trace{<:CircularArrayBuffer}}, + <:MultiplexTraces{SS′,<:Trace{<:CircularArrayBuffer}}, + <:MultiplexTraces{AA′,<:Trace{<:CircularArrayBuffer}}, <:Trace{<:CircularArrayBuffer}, <:Trace{<:CircularArrayBuffer}, } @@ -22,8 +22,8 @@ function CircularArraySARTTraces(; reward_eltype, reward_size = reward terminal_eltype, terminal_size = terminal - MultiplexTraces{SS}(CircularArrayBuffer{state_eltype}(state_size..., capacity + 1)) + - MultiplexTraces{AA}(CircularArrayBuffer{action_eltype}(action_size..., capacity + 1)) + + MultiplexTraces{SS′}(CircularArrayBuffer{state_eltype}(state_size..., capacity + 1)) + + MultiplexTraces{AA′}(CircularArrayBuffer{action_eltype}(action_size..., capacity + 1)) + Traces( reward=CircularArrayBuffer{reward_eltype}(reward_size..., capacity), terminal=CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity), diff --git a/src/common/CircularArraySLARTTraces.jl b/src/common/CircularArraySLARTTraces.jl index dc5e5ec..db9ac03 100644 --- a/src/common/CircularArraySLARTTraces.jl +++ b/src/common/CircularArraySLARTTraces.jl @@ -1,11 +1,11 @@ export CircularArraySLARTTraces const CircularArraySLARTTraces = Traces{ - SSLLAART, + SS′LL′AA′RT, <:Tuple{ - <:MultiplexTraces{SS,<:Trace{<:CircularArrayBuffer}}, - <:MultiplexTraces{LL,<:Trace{<:CircularArrayBuffer}}, - <:MultiplexTraces{AA,<:Trace{<:CircularArrayBuffer}}, + <:MultiplexTraces{SS′,<:Trace{<:CircularArrayBuffer}}, + <:MultiplexTraces{LL′,<:Trace{<:CircularArrayBuffer}}, + <:MultiplexTraces{AA′,<:Trace{<:CircularArrayBuffer}}, <:Trace{<:CircularArrayBuffer}, <:Trace{<:CircularArrayBuffer}, } @@ -25,9 +25,9 @@ function CircularArraySLARTTraces(; reward_eltype, reward_size = reward terminal_eltype, terminal_size = terminal - MultiplexTraces{SS}(CircularArrayBuffer{state_eltype}(state_size..., capacity + 1)) + - MultiplexTraces{LL}(CircularArrayBuffer{legal_actions_mask_eltype}(legal_actions_mask_size..., capacity + 1)) + - MultiplexTraces{AA}(CircularArrayBuffer{action_eltype}(action_size..., capacity + 1)) + + MultiplexTraces{SS′}(CircularArrayBuffer{state_eltype}(state_size..., capacity + 1)) + + MultiplexTraces{LL′}(CircularArrayBuffer{legal_actions_mask_eltype}(legal_actions_mask_size..., capacity + 1)) + + MultiplexTraces{AA′}(CircularArrayBuffer{action_eltype}(action_size..., capacity + 1)) + Traces( reward=CircularArrayBuffer{reward_eltype}(reward_size..., capacity), terminal=CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity), diff --git a/src/common/common.jl b/src/common/common.jl index f4c6fad..78d28b3 100644 --- a/src/common/common.jl +++ b/src/common/common.jl @@ -1,15 +1,15 @@ -export SS, LL, AA, RT, SSART, SSAART, SSLLAART +export SS′, LL′, AA′, RT, SS′ART, SS′AA′RT, SS′L′ART, SS′LL′AA′RT using CircularArrayBuffers -const SS = (:state, :next_state) -const LL = (:legal_actions_mask, :next_legal_actions_mask) -const AA = (:action, :next_action) +const SS′ = (:state, :next_state) +const LL′ = (:legal_actions_mask, :next_legal_actions_mask) +const AA′ = (:action, :next_action) const RT = (:reward, :terminal) -const SSART = (SS..., :action, RT...) -const SSAART = (SS..., AA..., RT...) -const SSLART = (SS..., :legal_actions_mask, :action, RT...) -const SSLLAART = (SS..., LL..., AA..., RT...) +const SS′ART = (SS′..., :action, RT...) +const SS′AA′RT = (SS′..., AA′..., RT...) +const SS′L′ART = (SS′..., :next_legal_actions_mask, :action, RT...) +const SS′LL′AA′RT = (SS′..., LL′..., AA′..., RT...) include("sum_tree.jl") include("CircularArraySARTTraces.jl") diff --git a/src/samplers.jl b/src/samplers.jl index fef23bb..927d06e 100644 --- a/src/samplers.jl +++ b/src/samplers.jl @@ -31,7 +31,7 @@ sample(s::BatchSampler{names}, t::AbstractTraces) where {names} = sample(s, t, n function sample(s::BatchSampler, t::AbstractTraces, names) inds = rand(s.rng, 1:length(t), s.batch_size) - NamedTuple{names}(t[x][inds] for x in names) + NamedTuple{names}(map(x -> t[x][inds], names)) end ##### @@ -99,7 +99,7 @@ mutable struct NStepBatchSampler{traces} rng::Any end -NStepBatchSampler(; kw...) = NStepBatchSampler{SSART}(; kw...) +NStepBatchSampler(; kw...) = NStepBatchSampler{SS′ART}(; kw...) NStepBatchSampler{names}(; n, γ, batch_size=32, stack_size=nothing, rng=Random.GLOBAL_RNG) where {names} = NStepBatchSampler{names}(n, γ, batch_size, stack_size, rng) function sample(s::NStepBatchSampler{names}, ts) where {names} @@ -108,7 +108,7 @@ function sample(s::NStepBatchSampler{names}, ts) where {names} sample(s, ts, Val(names), inds) end -function sample(nbs::NStepBatchSampler, ts, ::Val{SSART}, inds) +function sample(nbs::NStepBatchSampler, ts, ::Val{SS′ART}, inds) if isnothing(nbs.stack_size) s = ts[:state][inds] s′ = ts[:next_state][inds.+(nbs.n-1)] @@ -129,11 +129,11 @@ function sample(nbs::NStepBatchSampler, ts, ::Val{SSART}, inds) foldr(((rr, tt), init) -> rr + nbs.γ * init * (1 - tt), zip(r⃗, t⃗); init=0.0f0) end - NamedTuple{SSART}((s, s′, a, r, t)) + NamedTuple{SS′ART}((s, s′, a, r, t)) end -function sample(s::NStepBatchSampler, ts, ::Val{SSLART}, inds) +function sample(s::NStepBatchSampler, ts, ::Val{SS′L′ART}, inds) s, s′, a, r, t = sample(s, ts, Val(SSART), inds) - l = consecutive_view(ts[:legal_actions_mask], inds) + l = consecutive_view(ts[:next_legal_actions_mask], inds) NamedTuple{SSLART}((s, s′, l, a, r, t)) end \ No newline at end of file diff --git a/test/samplers.jl b/test/samplers.jl index b6ac3b6..8f43fa5 100644 --- a/test/samplers.jl +++ b/test/samplers.jl @@ -106,7 +106,7 @@ end @test size(xs2.terminal) == (batch_size,) inds = [3, 5, 7] - xs3 = RLTrajectories.sample(s1, t2, Val(SSART), inds) + xs3 = RLTrajectories.sample(s1, t2, Val(SS′ART), inds) @test xs3.state == cat( (