Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/common/CircularArraySARTTraces.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
export CircularArraySARTTraces

const CircularArraySARTTraces = Traces{
SSAART,
SS′AA′RT,
<:Tuple{
<:MultiplexTraces{SS,<:Trace{<:CircularArrayBuffer}},
<:MultiplexTraces{AA,<:Trace{<:CircularArrayBuffer}},
<:MultiplexTraces{SS,<:Trace{<:CircularArrayBuffer}},
<:MultiplexTraces{AA,<:Trace{<:CircularArrayBuffer}},
<:Trace{<:CircularArrayBuffer},
<:Trace{<:CircularArrayBuffer},
}
Expand All @@ -22,8 +22,8 @@ function CircularArraySARTTraces(;
reward_eltype, reward_size = reward
terminal_eltype, terminal_size = terminal

MultiplexTraces{SS}(CircularArrayBuffer{state_eltype}(state_size..., capacity + 1)) +
MultiplexTraces{AA}(CircularArrayBuffer{action_eltype}(action_size..., capacity + 1)) +
MultiplexTraces{SS}(CircularArrayBuffer{state_eltype}(state_size..., capacity + 1)) +
MultiplexTraces{AA}(CircularArrayBuffer{action_eltype}(action_size..., capacity + 1)) +
Traces(
reward=CircularArrayBuffer{reward_eltype}(reward_size..., capacity),
terminal=CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity),
Expand Down
14 changes: 7 additions & 7 deletions src/common/CircularArraySLARTTraces.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
export CircularArraySLARTTraces

const CircularArraySLARTTraces = Traces{
SSLLAART,
SS′LL′AA′RT,
<:Tuple{
<:MultiplexTraces{SS,<:Trace{<:CircularArrayBuffer}},
<:MultiplexTraces{LL,<:Trace{<:CircularArrayBuffer}},
<:MultiplexTraces{AA,<:Trace{<:CircularArrayBuffer}},
<:MultiplexTraces{SS,<:Trace{<:CircularArrayBuffer}},
<:MultiplexTraces{LL,<:Trace{<:CircularArrayBuffer}},
<:MultiplexTraces{AA,<:Trace{<:CircularArrayBuffer}},
<:Trace{<:CircularArrayBuffer},
<:Trace{<:CircularArrayBuffer},
}
Expand All @@ -25,9 +25,9 @@ function CircularArraySLARTTraces(;
reward_eltype, reward_size = reward
terminal_eltype, terminal_size = terminal

MultiplexTraces{SS}(CircularArrayBuffer{state_eltype}(state_size..., capacity + 1)) +
MultiplexTraces{LL}(CircularArrayBuffer{legal_actions_mask_eltype}(legal_actions_mask_size..., capacity + 1)) +
MultiplexTraces{AA}(CircularArrayBuffer{action_eltype}(action_size..., capacity + 1)) +
MultiplexTraces{SS}(CircularArrayBuffer{state_eltype}(state_size..., capacity + 1)) +
MultiplexTraces{LL}(CircularArrayBuffer{legal_actions_mask_eltype}(legal_actions_mask_size..., capacity + 1)) +
MultiplexTraces{AA}(CircularArrayBuffer{action_eltype}(action_size..., capacity + 1)) +
Traces(
reward=CircularArrayBuffer{reward_eltype}(reward_size..., capacity),
terminal=CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity),
Expand Down
16 changes: 8 additions & 8 deletions src/common/common.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
export SS, LL, AA, RT, SSART, SSAART, SSLLAART
export SS, LL, AA, RT, SS′ART, SS′AA′RT, SS′L′ART, SS′LL′AA′RT

using CircularArrayBuffers

const SS = (:state, :next_state)
const LL = (:legal_actions_mask, :next_legal_actions_mask)
const AA = (:action, :next_action)
const SS = (:state, :next_state)
const LL = (:legal_actions_mask, :next_legal_actions_mask)
const AA = (:action, :next_action)
const RT = (:reward, :terminal)
const SSART = (SS..., :action, RT...)
const SSAART = (SS..., AA..., RT...)
const SSLART = (SS..., :legal_actions_mask, :action, RT...)
const SSLLAART = (SS..., LL..., AA..., RT...)
const SS′ART = (SS..., :action, RT...)
const SS′AA′RT = (SS..., AA..., RT...)
const SS′L′ART = (SS..., :next_legal_actions_mask, :action, RT...)
const SS′LL′AA′RT = (SS..., LL..., AA..., RT...)

include("sum_tree.jl")
include("CircularArraySARTTraces.jl")
Expand Down
12 changes: 6 additions & 6 deletions src/samplers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ sample(s::BatchSampler{names}, t::AbstractTraces) where {names} = sample(s, t, n

function sample(s::BatchSampler, t::AbstractTraces, names)
inds = rand(s.rng, 1:length(t), s.batch_size)
NamedTuple{names}(t[x][inds] for x in names)
NamedTuple{names}(map(x -> t[x][inds], names))
end

#####
Expand Down Expand Up @@ -99,7 +99,7 @@ mutable struct NStepBatchSampler{traces}
rng::Any
end

NStepBatchSampler(; kw...) = NStepBatchSampler{SSART}(; kw...)
NStepBatchSampler(; kw...) = NStepBatchSampler{SS′ART}(; kw...)
NStepBatchSampler{names}(; n, γ, batch_size=32, stack_size=nothing, rng=Random.GLOBAL_RNG) where {names} = NStepBatchSampler{names}(n, γ, batch_size, stack_size, rng)

function sample(s::NStepBatchSampler{names}, ts) where {names}
Expand All @@ -108,7 +108,7 @@ function sample(s::NStepBatchSampler{names}, ts) where {names}
sample(s, ts, Val(names), inds)
end

function sample(nbs::NStepBatchSampler, ts, ::Val{SSART}, inds)
function sample(nbs::NStepBatchSampler, ts, ::Val{SS′ART}, inds)
if isnothing(nbs.stack_size)
s = ts[:state][inds]
s′ = ts[:next_state][inds.+(nbs.n-1)]
Expand All @@ -129,11 +129,11 @@ function sample(nbs::NStepBatchSampler, ts, ::Val{SSART}, inds)
foldr(((rr, tt), init) -> rr + nbs.γ * init * (1 - tt), zip(r⃗, t⃗); init=0.0f0)
end

NamedTuple{SSART}((s, s′, a, r, t))
NamedTuple{SS′ART}((s, s′, a, r, t))
end

function sample(s::NStepBatchSampler, ts, ::Val{SSLART}, inds)
function sample(s::NStepBatchSampler, ts, ::Val{SS′L′ART}, inds)
s, s′, a, r, t = sample(s, ts, Val(SSART), inds)
l = consecutive_view(ts[:legal_actions_mask], inds)
l = consecutive_view(ts[:next_legal_actions_mask], inds)
NamedTuple{SSLART}((s, s′, l, a, r, t))
end
2 changes: 1 addition & 1 deletion test/samplers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ end
@test size(xs2.terminal) == (batch_size,)

inds = [3, 5, 7]
xs3 = RLTrajectories.sample(s1, t2, Val(SSART), inds)
xs3 = RLTrajectories.sample(s1, t2, Val(SS′ART), inds)

@test xs3.state == cat(
(
Expand Down