Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ version = "0.1.0"

[deps]
CircularArrayBuffers = "9de3a189-e0c0-4e15-ba3b-b14b9fb0aec1"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Term = "22787eb5-b846-44ae-b979-8e399b8463ab"

Expand Down
60 changes: 14 additions & 46 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,54 +6,22 @@

## Design

```
┌────────────────────────────┐
│(state=..., action=..., ...)│
└──────────────┬─────────────┘
push! │ append!
┌───────────────────▼───────────────────┐
│ Trajectory │
│ ┌─────────────────────────────────┐ │
│ │ Traces │ │
│ │ ┌───────────────────┐ │ │
│ │ state: │CircularArrayBuffer│ │ │
│ │ └───────────────────┘ │ │
│ │ ┌───────────────────┐ │ │
│ │ action:│CircularArrayBuffer│ │ │
│ │ └───────────────────┘ │ │
│ │ ...... │ │
│ └─────────────────────────────────┘ │
| Sampler |
└───────────────────┬───────────────────┘
│ batch sampling
┌──────────────▼─────────────┐
│(state=..., action=..., ...)│
└────────────────────────────┘
```
A typical example of `Trajectory`:

![](https://user-images.githubusercontent.com/5612003/167291629-0e2d4f0f-7c54-460c-a94f-9eb4148cdca0.png)

Exported APIs are:

```julia
push!(trajectory; [trace_name=value]...)
append!(trajectory; [trace_name=value]...)

for sample in trajectory
# consume samples from the trajectory
end
```
┌──────────────┐ ┌──────────────┐
│Single Element│ │Batch Elements│
└──────┬───────┘ └──────┬───────┘
│ │
push! └──────┐ ┌───────┘ append!
│ │
┌─────────────┼────┼─────────────────────────────┐
│ ┌──▼────▼──┐ AsyncTrajectory │
│ │Channel In│ │
│ └─────┬────┘ │
│ take! │ │
│ ┌─────▼─────┐ push! ┌────────────┐ │
│ │RateLimiter├──────────► Trajectory │ │
│ └─────┬─────┘ append! └────*───────┘ │
│ │ * │
│ put! │********************** │
│ │ batch sampling │
│ ┌─────▼─────┐ │
│ │Channel Out│ │
│ └───────────┘ │
└────────────────────────────────────────────────┘
```

A wide variety of `container`s, `sampler`s, and `controler`s are provided. For the full list, please read the doc.

## Acknowledgement

Expand Down
4 changes: 2 additions & 2 deletions src/Trajectories.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
module Trajectories

include("samplers.jl")
include("controlers.jl")
include("traces.jl")
include("episodes.jl")
include("trajectory.jl")
include("samplers.jl")
include("async_trajectory.jl")
include("rendering.jl")
include("common/common.jl")

Expand Down
120 changes: 0 additions & 120 deletions src/async_trajectory.jl

This file was deleted.

11 changes: 10 additions & 1 deletion src/common/CircularArraySARTTraces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,5 +41,14 @@ function Random.rand(s::BatchSampler, t::CircularArraySARTTraces)
terminal=t[:terminal][inds],
next_state=t[:state][inds′],
next_action=t[:state][inds′]
)
) |> s.transformer
end

function Base.push!(t::CircularArraySARTTraces, x::NamedTuple{SA})
if length(t[:state]) == length(t[:terminal]) + 1
pop!(t[:state])
pop!(t[:action])
end
push!(t[:state], x[:state])
push!(t[:action], x[:action])
end
17 changes: 14 additions & 3 deletions src/common/CircularArraySLARTTraces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ function CircularArraySLARTTraces(;
)
end

function Random.rand(s::BatchSampler, t::CircularArraySLARTTraces)
function sample(s::BatchSampler, t::CircularArraySLARTTraces)
inds = rand(s.rng, 1:length(t), s.batch_size)
inds′ = inds .+ 1
(
Expand All @@ -47,5 +47,16 @@ function Random.rand(s::BatchSampler, t::CircularArraySLARTTraces)
next_state=t[:state][inds′],
next_legal_actions_mask=t[:legal_actions_mask][inds′],
next_action=t[:state][inds′]
)
end
) |> s.transformer
end

function Base.push!(t::CircularArraySLARTTraces, x::NamedTuple{SLA})
if length(t[:state]) == length(t[:terminal]) + 1
pop!(t[:state])
pop!(t[:legal_actions_mask])
pop!(t[:action])
end
push!(t[:state], x[:state])
push!(t[:legal_actions_mask], x[:legal_actions_mask])
push!(t[:action], x[:action])
end
3 changes: 3 additions & 0 deletions src/common/common.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
using CircularArrayBuffers

const SA = (:state, :action)
const SLA = (:state, :legal_actions_mask, :action)
const RT = (:reward, :terminal)
const SART = (:state, :action, :reward, :terminal)
const SARTSA = (:state, :action, :reward, :terminal, :next_state, :next_action)
const SLART = (:state, :legal_actions_mask, :action, :reward, :terminal)
Expand Down
61 changes: 61 additions & 0 deletions src/controlers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
export InsertSampleRatioControler, AsyncInsertSampleRatioControler

mutable struct InsertSampleRatioControler
ratio::Float64
threshold::Int
n_inserted::Int
n_sampled::Int
end

"""
InsertSampleRatioControler(ratio, threshold)

Used in [`Trajectory`](@ref). The `threshold` means the minimal number of
insertings before sampling. The `ratio` balances the number of insertings and
the number of samplings.
"""
InsertSampleRatioControler(ratio, threshold) = InsertSampleRatioControler(ratio, threshold, 0, 0)

function on_insert!(c::InsertSampleRatioControler, n::Int)
if n > 0
c.n_inserted += n
end
end

function on_sample!(c::InsertSampleRatioControler)
if c.n_inserted >= c.threshold
if c.n_sampled <= (c.n_inserted - c.threshold) * c.ratio
c.n_sampled += 1
true
end
end
end

#####

mutable struct AsyncInsertSampleRatioControler
ratio::Float64
threshold::Int
n_inserted::Int
n_sampled::Int
ch_in::Channel
ch_out::Channel
end

function AsyncInsertSampleRatioControler(
ratio,
threshold,
; ch_in_sz=1,
ch_out_sz=1,
n_inserted=0,
n_sampled=0
)
AsyncInsertSampleRatioControler(
ratio,
threshold,
n_inserted,
n_sampled,
Channel(ch_in_sz),
Channel(ch_out_sz)
)
end
Loading