Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@ uuid = "6486599b-a3cd-4e92-a99a-2cea90cc8c3c"
authors = ["Jun Tian <[email protected]> and contributors"]
version = "0.1.0"

[deps]
CircularArrayBuffers = "9de3a189-e0c0-4e15-ba3b-b14b9fb0aec1"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Term = "22787eb5-b846-44ae-b979-8e399b8463ab"

[compat]
julia = "1.6"

Expand Down
55 changes: 55 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,58 @@
[![Build Status](https://github.com/JuliaReinforcementLearning/Trajectories.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/JuliaReinforcementLearning/Trajectories.jl/actions/workflows/CI.yml?query=branch%3Amain)
[![Coverage](https://codecov.io/gh/JuliaReinforcementLearning/Trajectories.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/JuliaReinforcementLearning/Trajectories.jl)
[![PkgEval](https://JuliaCI.github.io/NanosoldierReports/pkgeval_badges/T/Trajectories.svg)](https://JuliaCI.github.io/NanosoldierReports/pkgeval_badges/report.html)

## Design

```
┌────────────────────────────┐
│(state=..., action=..., ...)│
└──────────────┬─────────────┘
push! │ append!
┌───────────────────▼───────────────────┐
│ Trajectory │
│ ┌─────────────────────────────────┐ │
│ │ Traces │ │
│ │ ┌───────────────────┐ │ │
│ │ state: │CircularArrayBuffer│ │ │
│ │ └───────────────────┘ │ │
│ │ ┌───────────────────┐ │ │
│ │ action:│CircularArrayBuffer│ │ │
│ │ └───────────────────┘ │ │
│ │ ...... │ │
│ └─────────────────────────────────┘ │
| Sampler |
└───────────────────┬───────────────────┘
│ batch sampling
┌──────────────▼─────────────┐
│(state=..., action=..., ...)│
└────────────────────────────┘
```

```
┌──────────────┐ ┌──────────────┐
│Single Element│ │Batch Elements│
└──────┬───────┘ └──────┬───────┘
│ │
push! └──────┐ ┌───────┘ append!
│ │
┌─────────────┼────┼─────────────────────────────┐
│ ┌──▼────▼──┐ AsyncTrajectory │
│ │Channel In│ │
│ └─────┬────┘ │
│ take! │ │
│ ┌─────▼─────┐ push! ┌────────────┐ │
│ │RateLimiter├──────────► Trajectory │ │
│ └─────┬─────┘ append! └────*───────┘ │
│ │ * │
│ put! │********************** │
│ │ batch sampling │
│ ┌─────▼─────┐ │
│ │Channel Out│ │
│ └───────────┘ │
└────────────────────────────────────────────────┘
```

## Acknowledgement

This async version is mainly inspired by [deepmind/reverb](https://github.com/deepmind/reverb).
8 changes: 7 additions & 1 deletion src/Trajectories.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
module Trajectories

# Write your package code here.
include("traces.jl")
include("episodes.jl")
include("trajectory.jl")
include("samplers.jl")
include("async_trajectory.jl")
include("rendering.jl")
include("common/common.jl")

end
120 changes: 120 additions & 0 deletions src/async_trajectory.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
export AsyncTrajectory, UPDATE, SAMPLE, NO_OP

struct Update end
const UPDATE = Update()
struct SampleDecision end
const SAMPLE = SampleDecision()
struct NoOp end
const NO_OP = NoOp()


"""
RateLimiter(n_samping_per_update=0.25; min_sampling_length=0, buffer_range=0:0)

- `n_samping_per_update`, a positive number, float number is supported.
- `min_sampling_length`, minimal number of elements to start sampling.
- `buffer_range`, if a single number is provided, it will be transformed into
`-buffer_range:buffer_range`. Once the `trajectory` reaches
`min_sampling_length`, if the length of `trajectory` is greater than the upper
bound plus `min_sampling_length`, the `rate_limiter` always return `SAMPLE`.
While if the length of `trajectory` is less than `min_sampling_length` minus
the lower bound of `buffer_range`, then the `rate_limiter` always return
`UPDATE`. In other cases, whether to `SAMPLE` or `UPDATE` depends on the
availability of `in_channel` or `out_channel`.
"""
struct RateLimiter
n_samping_per_update::Float64
min_sampling_length::UInt
buffer_range::UnitRange{Int}
is_min_sampling_length_reached::Ref{Bool}
end

RateLimiter(n_samping_per_update=0.25; min_sampling_length=0, buffer_range=0) = RateLimiter(n_samping_per_update, min_sampling_length, buffer_range)
RateLimiter(n_samping_per_update, min_sampling_length, buffer_range::Number) = RateLimiter(n_samping_per_update, convert(UInt, min_sampling_length), range(-buffer_range, buffer_range))
RateLimiter(n_samping_per_update, min_sampling_length, buffer_range) = RateLimiter(n_samping_per_update, min_sampling_length, buffer_range, Ref(false))

"""
(r::RateLimiter)(n_update, n_sample, is_in_ready, is_out_ready)

- `n_update`, number of elements inserted into `trajectory`.
- `n_sample`, number of batches sampled from `trajectory`.
- `is_in_ready`, the `in_channel` has elements to put into `trajectory` or not.
- `is_out_ready`, the `out_channel` is ready to consume new samplings or not.
"""
function (r::RateLimiter)(n_update, n_sample, is_in_ready, is_out_ready)
if n_update >= r.min_sampling_length
r.is_min_sampling_length_reached[] = true
end

if r.is_min_sampling_length_reached[]
n_estimated_updates = n_sample / r.n_samping_per_update + r.min_sampling_length
if n_estimated_updates < n_update + r.buffer_range[begin]
SAMPLE
elseif n_estimated_updates > n_update + r.buffer_range[end]
UPDATE
else
if is_out_ready
SAMPLE
elseif is_in_ready
UPDATE
else
NO_OP
end
end
else
UPDATE
end
end

#####

struct CallMsg
f::Any
args::Tuple
kw::Any
end

struct AsyncTrajectory
trajectory
rate_limiter
channel_in
channel_out
task
n_update_ref
n_sample_ref

function AsyncTrajectory(trajectory, rate_limiter; channel_in=Channel(1), channel_out=Channel(1))
n_update_ref = Ref(0)
n_sample_ref = Ref(0)
task = @async while true
decision = rate_limiter(n_update, n_sample, isready(channel_in), Base.n_avail(channel_out) < channel_out_size)
if decision === UPDATE
msg = take!(channel_in)
if msg.f === Base.push!
push!(trajectory, msg.args...; msg.kw...)
n_update_ref[] += 1
elseif msg.f === Base.append!
append!(trajectory, msg.args...; msg.kw...)
n_update_ref[] += length(msg.data)
else
msg.f(trajectory, msg.args...; msg.kw...)
end
elseif decision === SAMPLE
put!(channel_out, rand(trajectory))
n_sample_ref[] += 1
end
end
new(
trajectory,
channel_in,
channel_out,
task,
n_update_ref,
n_sample_ref
)
end
end

Base.push!(t::AsyncTrajectory, args...; kw...) = put!(t.in, CallMsg(Base.push!, args, kw))
Base.append!(t::AsyncTrajectory, args...; kw...) = put!(t.in, CallMsg(Base.append!, args, kw))
Base.take!(t::AsyncTrajectory) = take!(t.out)
45 changes: 45 additions & 0 deletions src/common/CircularArraySARTTraces.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
export CircularArraySARTTraces

const CircularArraySARTTraces = Traces{
SART,
<:Tuple{
<:Trace{<:CircularArrayBuffer},
<:Trace{<:CircularArrayBuffer},
<:Trace{<:CircularArrayBuffer},
<:Trace{<:CircularArrayBuffer}
}
}


function CircularArraySARTTraces(;
capacity::Int,
state=Int => (),
action=Int => (),
reward=Float32 => (),
terminal=Bool => ()
)
state_eltype, state_size = state
action_eltype, action_size = action
reward_eltype, reward_size = reward
terminal_eltype, terminal_size = terminal

Traces(
state=CircularArrayBuffer{state_eltype}(state_size..., capacity + 1), # !!! state is one step longer
action=CircularArrayBuffer{action_eltype}(action_size..., capacity + 1), # !!! action is one step longer
reward=CircularArrayBuffer{reward_eltype}(reward_size..., capacity),
terminal=CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity),
)
end

function Random.rand(s::BatchSampler, t::CircularArraySARTTraces)
inds = rand(s.rng, 1:length(t), s.batch_size)
inds′ = inds .+ 1
(
state=t[:state][inds],
action=t[:action][inds],
reward=t[:reward][inds],
terminal=t[:terminal][inds],
next_state=t[:state][inds′],
next_action=t[:state][inds′]
)
end
51 changes: 51 additions & 0 deletions src/common/CircularArraySLARTTraces.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
export CircularArraySLARTTraces

const CircularArraySLARTTraces = Traces{
SLART,
<:Tuple{
<:Trace{<:CircularArrayBuffer},
<:Trace{<:CircularArrayBuffer},
<:Trace{<:CircularArrayBuffer},
<:Trace{<:CircularArrayBuffer},
<:Trace{<:CircularArrayBuffer}
}
}


function CircularArraySLARTTraces(;
capacity::Int,
state=Int => (),
legal_actions_mask=Bool => (),
action=Int => (),
reward=Float32 => (),
terminal=Bool => ()
)
state_eltype, state_size = state
action_eltype, action_size = action
legal_actions_mask_eltype, legal_actions_mask_size = legal_actions_mask
reward_eltype, reward_size = reward
terminal_eltype, terminal_size = terminal

Traces(
state=CircularArrayBuffer{state_eltype}(state_size..., capacity + 1), # !!! state is one step longer
legal_actions_mask=CircularArrayBuffer{legal_actions_mask_eltype}(legal_actions_mask_size..., capacity + 1), # !!! legal_actions_mask is one step longer
action=CircularArrayBuffer{action_eltype}(action_size..., capacity + 1), # !!! action is one step longer
reward=CircularArrayBuffer{reward_eltype}(reward_size..., capacity),
terminal=CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity),
)
end

function Random.rand(s::BatchSampler, t::CircularArraySLARTTraces)
inds = rand(s.rng, 1:length(t), s.batch_size)
inds′ = inds .+ 1
(
state=t[:state][inds],
legal_actions_mask=t[:legal_actions_mask][inds],
action=t[:action][inds],
reward=t[:reward][inds],
terminal=t[:terminal][inds],
next_state=t[:state][inds′],
next_legal_actions_mask=t[:legal_actions_mask][inds′],
next_action=t[:state][inds′]
)
end
10 changes: 10 additions & 0 deletions src/common/common.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
using CircularArrayBuffers

const SART = (:state, :action, :reward, :terminal)
const SARTSA = (:state, :action, :reward, :terminal, :next_state, :next_action)
const SLART = (:state, :legal_actions_mask, :action, :reward, :terminal)
const SLARTSLA = (:state, :legal_actions_mask, :action, :reward, :terminal, :next_state, :next_legal_actions_mask, :next_action)

include("sum_tree.jl")
include("CircularArraySARTTraces.jl")
include("CircularArraySLARTTraces.jl")
Loading