Skip to content

Commit dc220ed

Browse files
authored
Merge pull request #1 from findmyway/main
init
2 parents 7bfe22d + 04edcc3 commit dc220ed

15 files changed

+793
-2
lines changed

Project.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,11 @@ uuid = "6486599b-a3cd-4e92-a99a-2cea90cc8c3c"
33
authors = ["Jun Tian <[email protected]> and contributors"]
44
version = "0.1.0"
55

6+
[deps]
7+
CircularArrayBuffers = "9de3a189-e0c0-4e15-ba3b-b14b9fb0aec1"
8+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
9+
Term = "22787eb5-b846-44ae-b979-8e399b8463ab"
10+
611
[compat]
712
julia = "1.6"
813

README.md

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,58 @@
33
[![Build Status](https://github.com/JuliaReinforcementLearning/Trajectories.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/JuliaReinforcementLearning/Trajectories.jl/actions/workflows/CI.yml?query=branch%3Amain)
44
[![Coverage](https://codecov.io/gh/JuliaReinforcementLearning/Trajectories.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/JuliaReinforcementLearning/Trajectories.jl)
55
[![PkgEval](https://JuliaCI.github.io/NanosoldierReports/pkgeval_badges/T/Trajectories.svg)](https://JuliaCI.github.io/NanosoldierReports/pkgeval_badges/report.html)
6+
7+
## Design
8+
9+
```
10+
┌────────────────────────────┐
11+
│(state=..., action=..., ...)│
12+
└──────────────┬─────────────┘
13+
push! │ append!
14+
┌───────────────────▼───────────────────┐
15+
│ Trajectory │
16+
│ ┌─────────────────────────────────┐ │
17+
│ │ Traces │ │
18+
│ │ ┌───────────────────┐ │ │
19+
│ │ state: │CircularArrayBuffer│ │ │
20+
│ │ └───────────────────┘ │ │
21+
│ │ ┌───────────────────┐ │ │
22+
│ │ action:│CircularArrayBuffer│ │ │
23+
│ │ └───────────────────┘ │ │
24+
│ │ ...... │ │
25+
│ └─────────────────────────────────┘ │
26+
| Sampler |
27+
└───────────────────┬───────────────────┘
28+
│ batch sampling
29+
┌──────────────▼─────────────┐
30+
│(state=..., action=..., ...)│
31+
└────────────────────────────┘
32+
```
33+
34+
```
35+
┌──────────────┐ ┌──────────────┐
36+
│Single Element│ │Batch Elements│
37+
└──────┬───────┘ └──────┬───────┘
38+
│ │
39+
push! └──────┐ ┌───────┘ append!
40+
│ │
41+
┌─────────────┼────┼─────────────────────────────┐
42+
│ ┌──▼────▼──┐ AsyncTrajectory │
43+
│ │Channel In│ │
44+
│ └─────┬────┘ │
45+
│ take! │ │
46+
│ ┌─────▼─────┐ push! ┌────────────┐ │
47+
│ │RateLimiter├──────────► Trajectory │ │
48+
│ └─────┬─────┘ append! └────*───────┘ │
49+
│ │ * │
50+
│ put! │********************** │
51+
│ │ batch sampling │
52+
│ ┌─────▼─────┐ │
53+
│ │Channel Out│ │
54+
│ └───────────┘ │
55+
└────────────────────────────────────────────────┘
56+
```
57+
58+
## Acknowledgement
59+
60+
This async version is mainly inspired by [deepmind/reverb](https://github.com/deepmind/reverb).

src/Trajectories.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
module Trajectories
22

3-
# Write your package code here.
3+
include("traces.jl")
4+
include("episodes.jl")
5+
include("trajectory.jl")
6+
include("samplers.jl")
7+
include("async_trajectory.jl")
8+
include("rendering.jl")
9+
include("common/common.jl")
410

511
end

src/async_trajectory.jl

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
export AsyncTrajectory, UPDATE, SAMPLE, NO_OP
2+
3+
struct Update end
4+
const UPDATE = Update()
5+
struct SampleDecision end
6+
const SAMPLE = SampleDecision()
7+
struct NoOp end
8+
const NO_OP = NoOp()
9+
10+
11+
"""
12+
RateLimiter(n_samping_per_update=0.25; min_sampling_length=0, buffer_range=0:0)
13+
14+
- `n_samping_per_update`, a positive number, float number is supported.
15+
- `min_sampling_length`, minimal number of elements to start sampling.
16+
- `buffer_range`, if a single number is provided, it will be transformed into
17+
`-buffer_range:buffer_range`. Once the `trajectory` reaches
18+
`min_sampling_length`, if the length of `trajectory` is greater than the upper
19+
bound plus `min_sampling_length`, the `rate_limiter` always return `SAMPLE`.
20+
While if the length of `trajectory` is less than `min_sampling_length` minus
21+
the lower bound of `buffer_range`, then the `rate_limiter` always return
22+
`UPDATE`. In other cases, whether to `SAMPLE` or `UPDATE` depends on the
23+
availability of `in_channel` or `out_channel`.
24+
"""
25+
struct RateLimiter
26+
n_samping_per_update::Float64
27+
min_sampling_length::UInt
28+
buffer_range::UnitRange{Int}
29+
is_min_sampling_length_reached::Ref{Bool}
30+
end
31+
32+
RateLimiter(n_samping_per_update=0.25; min_sampling_length=0, buffer_range=0) = RateLimiter(n_samping_per_update, min_sampling_length, buffer_range)
33+
RateLimiter(n_samping_per_update, min_sampling_length, buffer_range::Number) = RateLimiter(n_samping_per_update, convert(UInt, min_sampling_length), range(-buffer_range, buffer_range))
34+
RateLimiter(n_samping_per_update, min_sampling_length, buffer_range) = RateLimiter(n_samping_per_update, min_sampling_length, buffer_range, Ref(false))
35+
36+
"""
37+
(r::RateLimiter)(n_update, n_sample, is_in_ready, is_out_ready)
38+
39+
- `n_update`, number of elements inserted into `trajectory`.
40+
- `n_sample`, number of batches sampled from `trajectory`.
41+
- `is_in_ready`, the `in_channel` has elements to put into `trajectory` or not.
42+
- `is_out_ready`, the `out_channel` is ready to consume new samplings or not.
43+
"""
44+
function (r::RateLimiter)(n_update, n_sample, is_in_ready, is_out_ready)
45+
if n_update >= r.min_sampling_length
46+
r.is_min_sampling_length_reached[] = true
47+
end
48+
49+
if r.is_min_sampling_length_reached[]
50+
n_estimated_updates = n_sample / r.n_samping_per_update + r.min_sampling_length
51+
if n_estimated_updates < n_update + r.buffer_range[begin]
52+
SAMPLE
53+
elseif n_estimated_updates > n_update + r.buffer_range[end]
54+
UPDATE
55+
else
56+
if is_out_ready
57+
SAMPLE
58+
elseif is_in_ready
59+
UPDATE
60+
else
61+
NO_OP
62+
end
63+
end
64+
else
65+
UPDATE
66+
end
67+
end
68+
69+
#####
70+
71+
struct CallMsg
72+
f::Any
73+
args::Tuple
74+
kw::Any
75+
end
76+
77+
struct AsyncTrajectory
78+
trajectory
79+
rate_limiter
80+
channel_in
81+
channel_out
82+
task
83+
n_update_ref
84+
n_sample_ref
85+
86+
function AsyncTrajectory(trajectory, rate_limiter; channel_in=Channel(1), channel_out=Channel(1))
87+
n_update_ref = Ref(0)
88+
n_sample_ref = Ref(0)
89+
task = @async while true
90+
decision = rate_limiter(n_update, n_sample, isready(channel_in), Base.n_avail(channel_out) < channel_out_size)
91+
if decision === UPDATE
92+
msg = take!(channel_in)
93+
if msg.f === Base.push!
94+
push!(trajectory, msg.args...; msg.kw...)
95+
n_update_ref[] += 1
96+
elseif msg.f === Base.append!
97+
append!(trajectory, msg.args...; msg.kw...)
98+
n_update_ref[] += length(msg.data)
99+
else
100+
msg.f(trajectory, msg.args...; msg.kw...)
101+
end
102+
elseif decision === SAMPLE
103+
put!(channel_out, rand(trajectory))
104+
n_sample_ref[] += 1
105+
end
106+
end
107+
new(
108+
trajectory,
109+
channel_in,
110+
channel_out,
111+
task,
112+
n_update_ref,
113+
n_sample_ref
114+
)
115+
end
116+
end
117+
118+
Base.push!(t::AsyncTrajectory, args...; kw...) = put!(t.in, CallMsg(Base.push!, args, kw))
119+
Base.append!(t::AsyncTrajectory, args...; kw...) = put!(t.in, CallMsg(Base.append!, args, kw))
120+
Base.take!(t::AsyncTrajectory) = take!(t.out)
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
export CircularArraySARTTraces
2+
3+
const CircularArraySARTTraces = Traces{
4+
SART,
5+
<:Tuple{
6+
<:Trace{<:CircularArrayBuffer},
7+
<:Trace{<:CircularArrayBuffer},
8+
<:Trace{<:CircularArrayBuffer},
9+
<:Trace{<:CircularArrayBuffer}
10+
}
11+
}
12+
13+
14+
function CircularArraySARTTraces(;
15+
capacity::Int,
16+
state=Int => (),
17+
action=Int => (),
18+
reward=Float32 => (),
19+
terminal=Bool => ()
20+
)
21+
state_eltype, state_size = state
22+
action_eltype, action_size = action
23+
reward_eltype, reward_size = reward
24+
terminal_eltype, terminal_size = terminal
25+
26+
Traces(
27+
state=CircularArrayBuffer{state_eltype}(state_size..., capacity + 1), # !!! state is one step longer
28+
action=CircularArrayBuffer{action_eltype}(action_size..., capacity + 1), # !!! action is one step longer
29+
reward=CircularArrayBuffer{reward_eltype}(reward_size..., capacity),
30+
terminal=CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity),
31+
)
32+
end
33+
34+
function Random.rand(s::BatchSampler, t::CircularArraySARTTraces)
35+
inds = rand(s.rng, 1:length(t), s.batch_size)
36+
inds′ = inds .+ 1
37+
(
38+
state=t[:state][inds],
39+
action=t[:action][inds],
40+
reward=t[:reward][inds],
41+
terminal=t[:terminal][inds],
42+
next_state=t[:state][inds′],
43+
next_action=t[:state][inds′]
44+
)
45+
end
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
export CircularArraySLARTTraces
2+
3+
const CircularArraySLARTTraces = Traces{
4+
SLART,
5+
<:Tuple{
6+
<:Trace{<:CircularArrayBuffer},
7+
<:Trace{<:CircularArrayBuffer},
8+
<:Trace{<:CircularArrayBuffer},
9+
<:Trace{<:CircularArrayBuffer},
10+
<:Trace{<:CircularArrayBuffer}
11+
}
12+
}
13+
14+
15+
function CircularArraySLARTTraces(;
16+
capacity::Int,
17+
state=Int => (),
18+
legal_actions_mask=Bool => (),
19+
action=Int => (),
20+
reward=Float32 => (),
21+
terminal=Bool => ()
22+
)
23+
state_eltype, state_size = state
24+
action_eltype, action_size = action
25+
legal_actions_mask_eltype, legal_actions_mask_size = legal_actions_mask
26+
reward_eltype, reward_size = reward
27+
terminal_eltype, terminal_size = terminal
28+
29+
Traces(
30+
state=CircularArrayBuffer{state_eltype}(state_size..., capacity + 1), # !!! state is one step longer
31+
legal_actions_mask=CircularArrayBuffer{legal_actions_mask_eltype}(legal_actions_mask_size..., capacity + 1), # !!! legal_actions_mask is one step longer
32+
action=CircularArrayBuffer{action_eltype}(action_size..., capacity + 1), # !!! action is one step longer
33+
reward=CircularArrayBuffer{reward_eltype}(reward_size..., capacity),
34+
terminal=CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity),
35+
)
36+
end
37+
38+
function Random.rand(s::BatchSampler, t::CircularArraySLARTTraces)
39+
inds = rand(s.rng, 1:length(t), s.batch_size)
40+
inds′ = inds .+ 1
41+
(
42+
state=t[:state][inds],
43+
legal_actions_mask=t[:legal_actions_mask][inds],
44+
action=t[:action][inds],
45+
reward=t[:reward][inds],
46+
terminal=t[:terminal][inds],
47+
next_state=t[:state][inds′],
48+
next_legal_actions_mask=t[:legal_actions_mask][inds′],
49+
next_action=t[:state][inds′]
50+
)
51+
end

src/common/common.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
using CircularArrayBuffers
2+
3+
const SART = (:state, :action, :reward, :terminal)
4+
const SARTSA = (:state, :action, :reward, :terminal, :next_state, :next_action)
5+
const SLART = (:state, :legal_actions_mask, :action, :reward, :terminal)
6+
const SLARTSLA = (:state, :legal_actions_mask, :action, :reward, :terminal, :next_state, :next_legal_actions_mask, :next_action)
7+
8+
include("sum_tree.jl")
9+
include("CircularArraySARTTraces.jl")
10+
include("CircularArraySLARTTraces.jl")

0 commit comments

Comments
 (0)