From 555f1de1115cbe23dd61a83c9c0a2c3263e4b3f6 Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Tue, 3 May 2022 22:36:44 +0800 Subject: [PATCH 01/11] sync --- src/Trajectories.jl | 2 +- src/common/CircularArraySARTTraces.jl | 9 +++++++++ src/common/CircularArraySLARTTraces.jl | 13 ++++++++++++- src/common/common.jl | 3 +++ src/samplers.jl | 26 ++++++++++++++++++++++++++ src/traces.jl | 16 ++++++++++++++++ src/trajectory.jl | 15 ++++++--------- test/runtests.jl | 1 + test/trajectories.jl | 18 ++++++++++++++++++ 9 files changed, 92 insertions(+), 11 deletions(-) create mode 100644 test/trajectories.jl diff --git a/src/Trajectories.jl b/src/Trajectories.jl index 38f2c05..ce05556 100644 --- a/src/Trajectories.jl +++ b/src/Trajectories.jl @@ -1,9 +1,9 @@ module Trajectories +include("samplers.jl") include("traces.jl") include("episodes.jl") include("trajectory.jl") -include("samplers.jl") include("async_trajectory.jl") include("rendering.jl") include("common/common.jl") diff --git a/src/common/CircularArraySARTTraces.jl b/src/common/CircularArraySARTTraces.jl index 84cf42b..607db2d 100644 --- a/src/common/CircularArraySARTTraces.jl +++ b/src/common/CircularArraySARTTraces.jl @@ -43,3 +43,12 @@ function Random.rand(s::BatchSampler, t::CircularArraySARTTraces) next_action=t[:state][inds′] ) end + +function Base.push!(t::CircularArraySARTTraces, x::NamedTuple{SA}) + if length(t[:state]) == length(t[:terminal]) + 1 + pop!(t[:state]) + pop!(t[:action]) + end + push!(t[:state], x[:state]) + push!(t[:action], x[:action]) +end diff --git a/src/common/CircularArraySLARTTraces.jl b/src/common/CircularArraySLARTTraces.jl index 5b03b78..c004b34 100644 --- a/src/common/CircularArraySLARTTraces.jl +++ b/src/common/CircularArraySLARTTraces.jl @@ -48,4 +48,15 @@ function Random.rand(s::BatchSampler, t::CircularArraySLARTTraces) next_legal_actions_mask=t[:legal_actions_mask][inds′], next_action=t[:state][inds′] ) -end \ No newline at end of file +end + +function Base.push!(t::CircularArraySLARTTraces, x::NamedTuple{SLA}) + if length(t[:state]) == length(t[:terminal]) + 1 + pop!(t[:state]) + pop!(t[:legal_actions_mask]) + pop!(t[:action]) + end + push!(t[:state], x[:state]) + push!(t[:legal_actions_mask], x[:legal_actions_mask]) + push!(t[:action], x[:action]) +end diff --git a/src/common/common.jl b/src/common/common.jl index 7ba3ad8..271b149 100644 --- a/src/common/common.jl +++ b/src/common/common.jl @@ -1,5 +1,8 @@ using CircularArrayBuffers +const SA = (:state, :action) +const SLA = (:state, :legal_actions_mask, :action) +const RT = (:reward, :terminal) const SART = (:state, :action, :reward, :terminal) const SARTSA = (:state, :action, :reward, :terminal, :next_state, :next_action) const SLART = (:state, :legal_actions_mask, :action, :reward, :terminal) diff --git a/src/samplers.jl b/src/samplers.jl index 0a8ea34..12af508 100644 --- a/src/samplers.jl +++ b/src/samplers.jl @@ -1,3 +1,5 @@ +export BatchSampler + using Random struct BatchSampler @@ -5,4 +7,28 @@ struct BatchSampler rng::Random.AbstractRNG end +""" + BatchSampler(batch_size; rng=Random.GLOBAL_RNG) + +Uniformly sample a batch of examples for each trace. +""" BatchSampler(batch_size; rng=Random.GLOBAL_RNG) = BatchSampler(batch_size, rng) + +""" + MinLengthSampler(min_length, sampler) + +A wrapper of `sampler`. When the `length` of traces is less than `min_length`, +`nothing` is returned. Otherwise, apply the `sampler` to the traces. +""" +struct MinLengthSampler{S} + min_length::Int + sampler::S +end + +function Random.rand(s::MinLengthSampler, t) + if length(t) < s.min_length + nothing + else + rand(s.sampler, t) + end +end \ No newline at end of file diff --git a/src/traces.jl b/src/traces.jl index 6ba56fe..29f2344 100644 --- a/src/traces.jl +++ b/src/traces.jl @@ -32,6 +32,13 @@ Base.pop!(t::Trace) = pop!(t.x) Base.popfirst!(t::Trace) = popfirst!(t.x) Base.empty!(t::Trace) = empty!(t.x) +## + +function Random.rand(s::BatchSampler, t::Trace) + inds = rand(s.rng, 1:length(t), s.batch_size) + t[inds] +end + ##### """ @@ -50,6 +57,7 @@ end Base.keys(t::Traces) = keys(t.traces) Base.haskey(t::Traces, s::Symbol) = haskey(t.traces, s) Base.getindex(t::Traces, x) = getindex(t.traces, x) +Base.length(t::Traces) = mapreduce(length, min, t.traces) Base.push!(t::Traces; kw...) = push!(t, values(kw)) @@ -70,3 +78,11 @@ end Base.pop!(t::Traces) = map(pop!, t.traces) Base.popfirst!(t::Traces) = map(popfirst!, t.traces) Base.empty!(t::Traces) = map(empty!, t.traces) + +## +function Random.rand(s::BatchSampler, t::Traces) + inds = rand(s.rng, 1:length(t), s.batch_size) + map(t.traces) do x + x[inds] + end +end \ No newline at end of file diff --git a/src/trajectory.jl b/src/trajectory.jl index 4590ec7..2bd241c 100644 --- a/src/trajectory.jl +++ b/src/trajectory.jl @@ -1,18 +1,15 @@ -struct Trajectory{T,S} +export Trajectory + +Base.@kwdef struct Trajectory{T,S} traces::T sampler::S end Base.rand(t::Trajectory) = rand(t.sampler, t.traces) -Base.push!(t::Trajectory; kw...) = push!(t, values(kw)) -Base.push!(t::Trajectory, x) = push!(t.traces, x) - -Base.append!(t::Trajectory; kw...) = append!(t, values(kw)) -Base.append!(t::Trajectory, x) = append(t.traces, x) - -Base.pop!(t::Trajectory) = pop!(t.traces) -Base.empty!(t::Trajectory) = empty!(t.traces) +Base.push!(t::Trajectory; kw...) = push!(t.traces; kw...) +Base.append!(t::Trajectory; kw...) = append!(t.traces; kw...) Base.getindex(t::Trajectory, k) = getindex(t.traces, k) +Base.setindex!(t::Trajectory, v, ks...) = setindex!(t.traces, v, ks...) Base.length(t::Trajectory) = length(t.traces) diff --git a/test/runtests.jl b/test/runtests.jl index 10fed17..91680af 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,4 +3,5 @@ using Test @testset "Trajectories.jl" begin include("traces.jl") + include("trajectories.jl") end diff --git a/test/trajectories.jl b/test/trajectories.jl new file mode 100644 index 0000000..64695ca --- /dev/null +++ b/test/trajectories.jl @@ -0,0 +1,18 @@ +@testset "trajectories" begin + t = Trajectory( + traces=Traces( + a=Int[], + b=Bool[] + ), + sampler=BatchSampler(3) + ) + + for i in 1:10 + push!(t; a=i, b=isodd(i)) + end + + batch = rand(t) + + @test size(batch[:a]) == (3,) + @test size(batch[:b]) == (3,) +end \ No newline at end of file From a72a473f384b149ab6dd00419d5019a634095b3e Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Wed, 4 May 2022 12:18:26 +0800 Subject: [PATCH 02/11] adjust to RL.jl --- src/async_trajectory.jl | 16 +++++++++++----- src/trajectory.jl | 1 + 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/src/async_trajectory.jl b/src/async_trajectory.jl index c685844..1d4b291 100644 --- a/src/async_trajectory.jl +++ b/src/async_trajectory.jl @@ -76,6 +76,7 @@ end struct AsyncTrajectory trajectory + sampler rate_limiter channel_in channel_out @@ -83,29 +84,34 @@ struct AsyncTrajectory n_update_ref n_sample_ref - function AsyncTrajectory(trajectory, rate_limiter; channel_in=Channel(1), channel_out=Channel(1)) + function AsyncTrajectory(trajectory, sampler, rate_limiter; channel_in=Channel(1), channel_out=Channel(1)) n_update_ref = Ref(0) n_sample_ref = Ref(0) task = @async while true - decision = rate_limiter(n_update, n_sample, isready(channel_in), Base.n_avail(channel_out) < channel_out_size) + decision = rate_limiter(n_update, n_sample, isready(channel_in), Base.n_avail(channel_out) < length(channel_out.data)) if decision === UPDATE msg = take!(channel_in) if msg.f === Base.push! + n_pre = length(trajectory) push!(trajectory, msg.args...; msg.kw...) - n_update_ref[] += 1 + n_post = length(trajectory) + n_update_ref[] += n_post - n_pre elseif msg.f === Base.append! + n_pre = length(trajectory) append!(trajectory, msg.args...; msg.kw...) - n_update_ref[] += length(msg.data) + n_post = length(trajectory) + n_update_ref[] += n_post - n_pre else msg.f(trajectory, msg.args...; msg.kw...) end elseif decision === SAMPLE - put!(channel_out, rand(trajectory)) + put!(channel_out, rand(sampler, trajectory)) n_sample_ref[] += 1 end end new( trajectory, + sampler, channel_in, channel_out, task, diff --git a/src/trajectory.jl b/src/trajectory.jl index 2bd241c..6bc8e52 100644 --- a/src/trajectory.jl +++ b/src/trajectory.jl @@ -3,6 +3,7 @@ export Trajectory Base.@kwdef struct Trajectory{T,S} traces::T sampler::S + rate_limiter::RateLimiter end Base.rand(t::Trajectory) = rand(t.sampler, t.traces) From ec7ac653a9acf744476d1a3017bad84883e7380f Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Wed, 4 May 2022 23:27:18 +0800 Subject: [PATCH 03/11] sync local changes --- Project.toml | 1 + src/episodes.jl | 5 ++- src/trajectory.jl | 82 +++++++++++++++++++++++++++++++++++++++++------ 3 files changed, 76 insertions(+), 12 deletions(-) diff --git a/Project.toml b/Project.toml index d4818b6..61f2c96 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "0.1.0" [deps] CircularArrayBuffers = "9de3a189-e0c0-4e15-ba3b-b14b9fb0aec1" +MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Term = "22787eb5-b846-44ae-b979-8e399b8463ab" diff --git a/src/episodes.jl b/src/episodes.jl index e419c30..933ae92 100644 --- a/src/episodes.jl +++ b/src/episodes.jl @@ -56,9 +56,8 @@ struct Episodes episodes::Vector{Episode} end -Base.lastindex(e::Episodes) = lastindex(e.episodes) -Base.length(e::Episodes) = length(e.episodes) -Base.getindex(e::Episodes, I...) = getindex(e.episodes, I...) +# TODO: we may need a range map +# https://github.com/google/guava/wiki/NewCollectionTypesExplained#rangemap Base.push!(e::Episodes, x::Episode) = push!(e.episodes, x) Base.append!(e::Episodes, x::AbstractVector{<:Episode}) = append!(e.episodes, x) diff --git a/src/trajectory.jl b/src/trajectory.jl index 6bc8e52..4e9f6bb 100644 --- a/src/trajectory.jl +++ b/src/trajectory.jl @@ -1,16 +1,80 @@ export Trajectory -Base.@kwdef struct Trajectory{T,S} - traces::T +""" + Trajectory(container, sampler, controler) + +The `container` is used to store experiences. Common ones are [`Traces`](@ref) +or [`Episodes`](@ref). The `sampler` is used to sample experience batches from +the `container`. The `controler` controls whether it is time to sample a batch +or not. + +Supported methoes are: + +- `push!(t::Trajectory, experience)`, add one experience into the trajectory. +- `append!(t::Trajectory, batch)`, add a batch of experiences into the trajectory. +- `take!(t::Trajectory)`, take a batch of experiences from the trajectory. Note + that `nothing` may be returned, indicating that it's not ready to sample yet. +""" +Base.@kwdef struct Trajectory{C,S,T} + container::C sampler::S - rate_limiter::RateLimiter + controler::T +end + +function Base.push!(t::Trajectory, x) + n_pre = length(t) + push!(t.container, x) + n_post = length(t) + on_insert!(t.controler, n_post - n_pre) +end + +function Base.append!(t::Trajectory, x) + n_pre = length(t) + append!(t.container, x) + n_post = length(t) + on_insert!(t.controler, n_post - n_pre) end -Base.rand(t::Trajectory) = rand(t.sampler, t.traces) +function Base.take!(t::Trajectory) + res = on_take!(t.controler) + if isnothing(res) + nothing + else + rand(t.sampler, t.container) + end +end -Base.push!(t::Trajectory; kw...) = push!(t.traces; kw...) -Base.append!(t::Trajectory; kw...) = append!(t.traces; kw...) +function Base.iterate(t::Trajectory) + x = take!(t) + if isnothing(x) + nothing + else + x, true + end +end + +Base.iterate(t::Trajectory, state) = iterate(t) + +##### + +mutable struct InsertSampleRatioControler + n_inserted::Int + n_sampled::Int + threshold::Int + ratio::Float64 +end + +function on_insert!(c::InsertSampleRatioControler, n::Int) + if n > 0 + c.n_inserted += n + end +end -Base.getindex(t::Trajectory, k) = getindex(t.traces, k) -Base.setindex!(t::Trajectory, v, ks...) = setindex!(t.traces, v, ks...) -Base.length(t::Trajectory) = length(t.traces) +function on_take!(c::InsertSampleRatioControler) + if c.n_inserted >= c.threshold + if c.n_sampled <= c.n_inserted * c.ratio + c.n_sampled += 1 + true + end + end +end \ No newline at end of file From 98775e0b5e6d9cfc87706833fd144b433b040f74 Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Thu, 5 May 2022 00:16:53 +0800 Subject: [PATCH 04/11] add tests for trajectory --- src/trajectory.jl | 33 +++++++++++++++++-------- test/trajectories.jl | 57 ++++++++++++++++++++++++++++++++++++++------ 2 files changed, 73 insertions(+), 17 deletions(-) diff --git a/src/trajectory.jl b/src/trajectory.jl index 4e9f6bb..cc35e19 100644 --- a/src/trajectory.jl +++ b/src/trajectory.jl @@ -1,4 +1,4 @@ -export Trajectory +export Trajectory, InsertSampleRatioControler """ Trajectory(container, sampler, controler) @@ -21,22 +21,26 @@ Base.@kwdef struct Trajectory{C,S,T} controler::T end +Base.push!(t::Trajectory; kw...) = push!(t, values(kw)) + function Base.push!(t::Trajectory, x) - n_pre = length(t) + n_pre = length(t.container) push!(t.container, x) - n_post = length(t) + n_post = length(t.container) on_insert!(t.controler, n_post - n_pre) end +Base.append!(t::Trajectory; kw...) = append!(t, values(kw)) + function Base.append!(t::Trajectory, x) - n_pre = length(t) + n_pre = length(t.container) append!(t.container, x) - n_post = length(t) + n_post = length(t.container) on_insert!(t.controler, n_post - n_pre) end function Base.take!(t::Trajectory) - res = on_take!(t.controler) + res = on_sample!(t.controler) if isnothing(res) nothing else @@ -58,21 +62,30 @@ Base.iterate(t::Trajectory, state) = iterate(t) ##### mutable struct InsertSampleRatioControler + ratio::Float64 + threshold::Int n_inserted::Int n_sampled::Int - threshold::Int - ratio::Float64 end +""" + InsertSampleRatioControler(ratio, threshold) + +Used in [`Trajectory`](@ref). The `threshold` means the minimal number of +insertings before sampling. The `ratio` balances the number of insertings and +the number of samplings. +""" +InsertSampleRatioControler(ratio, threshold) = InsertSampleRatioControler(ratio, threshold, 0, 0) + function on_insert!(c::InsertSampleRatioControler, n::Int) if n > 0 c.n_inserted += n end end -function on_take!(c::InsertSampleRatioControler) +function on_sample!(c::InsertSampleRatioControler) if c.n_inserted >= c.threshold - if c.n_sampled <= c.n_inserted * c.ratio + if c.n_sampled < c.n_inserted * c.ratio c.n_sampled += 1 true end diff --git a/test/trajectories.jl b/test/trajectories.jl index 64695ca..fe41227 100644 --- a/test/trajectories.jl +++ b/test/trajectories.jl @@ -1,18 +1,61 @@ @testset "trajectories" begin t = Trajectory( - traces=Traces( + container=Traces( a=Int[], b=Bool[] ), - sampler=BatchSampler(3) + sampler=BatchSampler(3), + controler=InsertSampleRatioControler(0.25, 4) ) - for i in 1:10 - push!(t; a=i, b=isodd(i)) + batches = [] + + for batch in t + push!(batches, batch) + end + + @test length(batches) == 0 # threshold not reached yet + + append!(t; a=[1, 2, 3], b=[false, true, false]) + + for batch in t + push!(batches, batch) + end + + @test length(batches) == 0 # threshold not reached yet + + push!(t; a=4, b=true) + + for batch in t + push!(batches, batch) + end + + @test length(batches) == 1 # 4 inserted, ratio is 0.25 + + append!(t; a=[5, 6, 7], b=[true, true, true]) + + for batch in t + push!(batches, batch) + end + + @test length(batches) == 2 # 7 inserted, ratio is 0.25 + + push!(t; a=8, b=true) + + for batch in t + push!(batches, batch) end - batch = rand(t) + @test length(batches) == 2 # 8 inserted, ratio is 0.25 - @test size(batch[:a]) == (3,) - @test size(batch[:b]) == (3,) + n = 100 + for i in 1:n + append!(t; a=[i, i, i, i], b=[false, true, false, true]) + end + + s = 0 + for _ in t + s += 1 + end + @test s == n end \ No newline at end of file From f462e8807def28d192d958f13c38c48bedf0e467 Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Fri, 6 May 2022 01:02:06 +0800 Subject: [PATCH 05/11] sync local change --- src/async_trajectory.jl | 2 +- src/common/CircularArraySLARTTraces.jl | 2 +- src/episodes.jl | 43 +++++++++++++++++++++----- src/samplers.jl | 21 ++----------- src/traces.jl | 6 ++-- src/trajectory.jl | 2 +- test/traces.jl | 10 ++++++ 7 files changed, 54 insertions(+), 32 deletions(-) diff --git a/src/async_trajectory.jl b/src/async_trajectory.jl index 1d4b291..19ae167 100644 --- a/src/async_trajectory.jl +++ b/src/async_trajectory.jl @@ -105,7 +105,7 @@ struct AsyncTrajectory msg.f(trajectory, msg.args...; msg.kw...) end elseif decision === SAMPLE - put!(channel_out, rand(sampler, trajectory)) + put!(channel_out, sample(sampler, trajectory)) n_sample_ref[] += 1 end end diff --git a/src/common/CircularArraySLARTTraces.jl b/src/common/CircularArraySLARTTraces.jl index c004b34..3410dfd 100644 --- a/src/common/CircularArraySLARTTraces.jl +++ b/src/common/CircularArraySLARTTraces.jl @@ -35,7 +35,7 @@ function CircularArraySLARTTraces(; ) end -function Random.rand(s::BatchSampler, t::CircularArraySLARTTraces) +function sample(s::BatchSampler, t::CircularArraySLARTTraces) inds = rand(s.rng, 1:length(t), s.batch_size) inds′ = inds .+ 1 ( diff --git a/src/episodes.jl b/src/episodes.jl index 933ae92..d6abb10 100644 --- a/src/episodes.jl +++ b/src/episodes.jl @@ -11,9 +11,14 @@ struct Episode{T} is_done::Ref{Bool} end +Base.getindex(e::Episode, s::Symbol) = getindex(e.traces, s) +Base.keys(e::Episode) = keys(e.traces) + Base.getindex(e::Episode) = getindex(e.is_done) Base.setindex!(e::Episode, x::Bool) = setindex!(e.is_done, x) +Base.length(e::Episode) = length(e.traces) + Episode(t::Traces) = Episode(t, Ref(false)) function Base.push!(t::Episode, x) @@ -54,16 +59,23 @@ A container for multiple [`Episode`](@ref)s. `init` is a parameterness function struct Episodes init::Any episodes::Vector{Episode} + inds::Vector{Tuple{Int,Int}} end -# TODO: we may need a range map -# https://github.com/google/guava/wiki/NewCollectionTypesExplained#rangemap +Base.length(e::Episodes) = length(e.inds) -Base.push!(e::Episodes, x::Episode) = push!(e.episodes, x) -Base.append!(e::Episodes, x::AbstractVector{<:Episode}) = append!(e.episodes, x) -Base.pop!(e::Episodes) = pop!(e.episodes) -Base.popfirst!(e::Episodes) = popfirst!(e.episodes) -Base.empty!(e::Episodes) = empty!(e.episodes) +function Base.push!(e::Episodes, x::Episode) + push!(e.episodes, x) + for i in 1:length(x) + push!(e.inds, (length(e.episodes), i)) + end +end + +function Base.append!(e::Episodes, xs::AbstractVector{<:Episode}) + for x in xs + push!(e, x) + end +end function Base.push!(e::Episodes, x) if isempty(e.episodes) || e.episodes[end][] @@ -72,5 +84,22 @@ function Base.push!(e::Episodes, x) push!(e.episodes, episode) else push!(e.episodes[end], x) + push!(e.inds, (length(e.episodes), length(e.episodes[end]))) end +end + +function Base.append!(e::Episodes, x) + n_pre = length(e.episodes[end]) + append!(e.episodes[end], x) + n_post = length(e.episodes[end]) + for i in n_pre:n_post + push!(e.inds, (lengthe.episodes, i)) + end +end + +## + +function sample(s::BatchSampler, e::Episodes) + inds = rand(s.rng, 1:length(t), s.batch_size) + # TODO: batch end \ No newline at end of file diff --git a/src/samplers.jl b/src/samplers.jl index 12af508..ea5d229 100644 --- a/src/samplers.jl +++ b/src/samplers.jl @@ -11,24 +11,7 @@ end BatchSampler(batch_size; rng=Random.GLOBAL_RNG) Uniformly sample a batch of examples for each trace. -""" -BatchSampler(batch_size; rng=Random.GLOBAL_RNG) = BatchSampler(batch_size, rng) - -""" - MinLengthSampler(min_length, sampler) -A wrapper of `sampler`. When the `length` of traces is less than `min_length`, -`nothing` is returned. Otherwise, apply the `sampler` to the traces. +See also [`sample`](@ref). """ -struct MinLengthSampler{S} - min_length::Int - sampler::S -end - -function Random.rand(s::MinLengthSampler, t) - if length(t) < s.min_length - nothing - else - rand(s.sampler, t) - end -end \ No newline at end of file +BatchSampler(batch_size; rng=Random.GLOBAL_RNG) = BatchSampler(batch_size, rng) diff --git a/src/traces.jl b/src/traces.jl index 29f2344..a4f5750 100644 --- a/src/traces.jl +++ b/src/traces.jl @@ -1,4 +1,4 @@ -export Trace, Traces +export Trace, Traces, sample """ Trace(data) @@ -34,7 +34,7 @@ Base.empty!(t::Trace) = empty!(t.x) ## -function Random.rand(s::BatchSampler, t::Trace) +function sample(s::BatchSampler, t::Trace) inds = rand(s.rng, 1:length(t), s.batch_size) t[inds] end @@ -80,7 +80,7 @@ Base.popfirst!(t::Traces) = map(popfirst!, t.traces) Base.empty!(t::Traces) = map(empty!, t.traces) ## -function Random.rand(s::BatchSampler, t::Traces) +function sample(s::BatchSampler, t::Traces) inds = rand(s.rng, 1:length(t), s.batch_size) map(t.traces) do x x[inds] diff --git a/src/trajectory.jl b/src/trajectory.jl index cc35e19..5c5722a 100644 --- a/src/trajectory.jl +++ b/src/trajectory.jl @@ -44,7 +44,7 @@ function Base.take!(t::Trajectory) if isnothing(res) nothing else - rand(t.sampler, t.container) + sample(t.sampler, t.container) end end diff --git a/test/traces.jl b/test/traces.jl index 4307f74..cfcd72e 100644 --- a/test/traces.jl +++ b/test/traces.jl @@ -13,8 +13,12 @@ pop!(t) @test length(t) == 2 + s = BatchSampler(2) + @test size(sample(s, t)) == (2,) + empty!(t) @test length(t) == 0 + end @testset "Trace 2d" begin @@ -25,6 +29,9 @@ end @test length(t) == 3 @test t[1] == [1, 4] @test @view(t[2:3]) == [2 3; 5 6] + + s = BatchSampler(5) + @test size(sample(s, t)) == (2, 5) end @testset "Traces" begin @@ -44,4 +51,7 @@ end append!(t; a=[4, 5], b=[false, false]) @test length(t[:a]) == 5 @test t[:b][end-1:end] == [false, false] + + s = BatchSampler(5) + @test size(sample(s, t)[:a]) == (5,) end \ No newline at end of file From 9fb2994432f5cc56fe596b88d720b64b78da8dda Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Sat, 7 May 2022 00:21:56 +0800 Subject: [PATCH 06/11] import MLUtils --- Project.toml | 1 + src/episodes.jl | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 61f2c96..f374c82 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "0.1.0" [deps] CircularArrayBuffers = "9de3a189-e0c0-4e15-ba3b-b14b9fb0aec1" +MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Term = "22787eb5-b846-44ae-b979-8e399b8463ab" diff --git a/src/episodes.jl b/src/episodes.jl index d6abb10..bc3a690 100644 --- a/src/episodes.jl +++ b/src/episodes.jl @@ -1,5 +1,7 @@ export Episode, Episodes +using MLUtils: batch + """ Episode(traces) @@ -101,5 +103,5 @@ end function sample(s::BatchSampler, e::Episodes) inds = rand(s.rng, 1:length(t), s.batch_size) - # TODO: batch + batch([@view(s.episodes[e.inds[i][1]][e.inds[i][2]]) for i in inds]) end \ No newline at end of file From d63fd7570949775cc7928c36d1523be3cb2ef566 Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Sat, 7 May 2022 23:32:32 +0800 Subject: [PATCH 07/11] remove AsyncTrajectory --- src/async_trajectory.jl | 126 ------------------------------------ src/trajectory.jl | 140 +++++++++++++++++++++++++++++++--------- 2 files changed, 108 insertions(+), 158 deletions(-) delete mode 100644 src/async_trajectory.jl diff --git a/src/async_trajectory.jl b/src/async_trajectory.jl deleted file mode 100644 index 19ae167..0000000 --- a/src/async_trajectory.jl +++ /dev/null @@ -1,126 +0,0 @@ -export AsyncTrajectory, UPDATE, SAMPLE, NO_OP - -struct Update end -const UPDATE = Update() -struct SampleDecision end -const SAMPLE = SampleDecision() -struct NoOp end -const NO_OP = NoOp() - - -""" - RateLimiter(n_samping_per_update=0.25; min_sampling_length=0, buffer_range=0:0) - -- `n_samping_per_update`, a positive number, float number is supported. -- `min_sampling_length`, minimal number of elements to start sampling. -- `buffer_range`, if a single number is provided, it will be transformed into - `-buffer_range:buffer_range`. Once the `trajectory` reaches - `min_sampling_length`, if the length of `trajectory` is greater than the upper - bound plus `min_sampling_length`, the `rate_limiter` always return `SAMPLE`. - While if the length of `trajectory` is less than `min_sampling_length` minus - the lower bound of `buffer_range`, then the `rate_limiter` always return - `UPDATE`. In other cases, whether to `SAMPLE` or `UPDATE` depends on the - availability of `in_channel` or `out_channel`. -""" -struct RateLimiter - n_samping_per_update::Float64 - min_sampling_length::UInt - buffer_range::UnitRange{Int} - is_min_sampling_length_reached::Ref{Bool} -end - -RateLimiter(n_samping_per_update=0.25; min_sampling_length=0, buffer_range=0) = RateLimiter(n_samping_per_update, min_sampling_length, buffer_range) -RateLimiter(n_samping_per_update, min_sampling_length, buffer_range::Number) = RateLimiter(n_samping_per_update, convert(UInt, min_sampling_length), range(-buffer_range, buffer_range)) -RateLimiter(n_samping_per_update, min_sampling_length, buffer_range) = RateLimiter(n_samping_per_update, min_sampling_length, buffer_range, Ref(false)) - -""" - (r::RateLimiter)(n_update, n_sample, is_in_ready, is_out_ready) - -- `n_update`, number of elements inserted into `trajectory`. -- `n_sample`, number of batches sampled from `trajectory`. -- `is_in_ready`, the `in_channel` has elements to put into `trajectory` or not. -- `is_out_ready`, the `out_channel` is ready to consume new samplings or not. -""" -function (r::RateLimiter)(n_update, n_sample, is_in_ready, is_out_ready) - if n_update >= r.min_sampling_length - r.is_min_sampling_length_reached[] = true - end - - if r.is_min_sampling_length_reached[] - n_estimated_updates = n_sample / r.n_samping_per_update + r.min_sampling_length - if n_estimated_updates < n_update + r.buffer_range[begin] - SAMPLE - elseif n_estimated_updates > n_update + r.buffer_range[end] - UPDATE - else - if is_out_ready - SAMPLE - elseif is_in_ready - UPDATE - else - NO_OP - end - end - else - UPDATE - end -end - -##### - -struct CallMsg - f::Any - args::Tuple - kw::Any -end - -struct AsyncTrajectory - trajectory - sampler - rate_limiter - channel_in - channel_out - task - n_update_ref - n_sample_ref - - function AsyncTrajectory(trajectory, sampler, rate_limiter; channel_in=Channel(1), channel_out=Channel(1)) - n_update_ref = Ref(0) - n_sample_ref = Ref(0) - task = @async while true - decision = rate_limiter(n_update, n_sample, isready(channel_in), Base.n_avail(channel_out) < length(channel_out.data)) - if decision === UPDATE - msg = take!(channel_in) - if msg.f === Base.push! - n_pre = length(trajectory) - push!(trajectory, msg.args...; msg.kw...) - n_post = length(trajectory) - n_update_ref[] += n_post - n_pre - elseif msg.f === Base.append! - n_pre = length(trajectory) - append!(trajectory, msg.args...; msg.kw...) - n_post = length(trajectory) - n_update_ref[] += n_post - n_pre - else - msg.f(trajectory, msg.args...; msg.kw...) - end - elseif decision === SAMPLE - put!(channel_out, sample(sampler, trajectory)) - n_sample_ref[] += 1 - end - end - new( - trajectory, - sampler, - channel_in, - channel_out, - task, - n_update_ref, - n_sample_ref - ) - end -end - -Base.push!(t::AsyncTrajectory, args...; kw...) = put!(t.in, CallMsg(Base.push!, args, kw)) -Base.append!(t::AsyncTrajectory, args...; kw...) = put!(t.in, CallMsg(Base.append!, args, kw)) -Base.take!(t::AsyncTrajectory) = take!(t.out) diff --git a/src/trajectory.jl b/src/trajectory.jl index 5c5722a..1aa324b 100644 --- a/src/trajectory.jl +++ b/src/trajectory.jl @@ -1,5 +1,72 @@ export Trajectory, InsertSampleRatioControler +using Base.Threads + + +##### + +mutable struct InsertSampleRatioControler + ratio::Float64 + threshold::Int + n_inserted::Int + n_sampled::Int +end + +""" + InsertSampleRatioControler(ratio, threshold) + +Used in [`Trajectory`](@ref). The `threshold` means the minimal number of +insertings before sampling. The `ratio` balances the number of insertings and +the number of samplings. +""" +InsertSampleRatioControler(ratio, threshold) = InsertSampleRatioControler(ratio, threshold, 0, 0) + +function on_insert!(c::InsertSampleRatioControler, n::Int) + if n > 0 + c.n_inserted += n + end +end + +function on_sample!(c::InsertSampleRatioControler) + if c.n_inserted >= c.threshold + if c.n_sampled < (c.n_inserted - n.threshold) * c.ratio + c.n_sampled += 1 + true + end + end +end + +##### + +mutable struct AsyncInsertSampleRatioControler + ratio::Float64 + threshold::Int + n_inserted::Int + n_sampled::Int + ch_in::Channel + ch_out::Channel +end + +function AsyncInsertSampleRatioControler( + ratio, + threshold, + ; ch_in_sz=1, + ch_out_sz=1, + n_inserted=0, + n_sampled=0 +) + AsyncInsertSampleRatioControler( + ratio, + threshold, + n_inserted, + n_sampled, + Channel(ch_in_sz), + Channel(ch_out_sz) + ) +end + +##### + """ Trajectory(container, sampler, controler) @@ -19,8 +86,38 @@ Base.@kwdef struct Trajectory{C,S,T} container::C sampler::S controler::T + + Trajectory(c::C, s::S, t::T) where {C,S,T} = new{C,S,T}(c, s, t) + + function Trajectory(container::C, sampler::S, controler::T) where {C,S,T<:AsyncInsertSampleRatioControler} + t = Threads.@spawn while true + for msg in controler.in + if msg.f === Base.push! || msg.f === Base.append! + n_pre = length(trajectory) + msg.f(trajectory, msg.args...; msg.kw...) + n_post = length(trajectory) + controler.n_inserted += n_post - n_pre + else + msg.f(trajectory, msg.args...; msg.kw...) + end + + if controler.n_inserted >= controler.threshold + if controler.n_sampled < (controler.n_inserted - controler.threshold) * controler.ratio + batch = sample(sampler, container) + put!(controler.ch_out, batch) + controler.n_sampled += 1 + end + end + end + end + + bind(controler.in, t) + bind(controler.out, t) + new{C,S,T}(container, sampler, controler) + end end + Base.push!(t::Trajectory; kw...) = push!(t, values(kw)) function Base.push!(t::Trajectory, x) @@ -30,6 +127,15 @@ function Base.push!(t::Trajectory, x) on_insert!(t.controler, n_post - n_pre) end +struct CallMsg + f::Any + args::Tuple + kw::Any +end + +Base.push!(t::Trajectory{<:Any,<:Any,<:AsyncInsertSampleRatioControler}, args...; kw...) = put!(t.controler.ch_in, CallMsg(Base.push!, args, kw)) +Base.append!(t::Trajectory{<:Any,<:Any,<:AsyncInsertSampleRatioControler}, args...; kw...) = append!(t.controler.ch_in, CallMsg(Base.push!, args, kw)) + Base.append!(t::Trajectory; kw...) = append!(t, values(kw)) function Base.append!(t::Trajectory, x) @@ -59,35 +165,5 @@ end Base.iterate(t::Trajectory, state) = iterate(t) -##### - -mutable struct InsertSampleRatioControler - ratio::Float64 - threshold::Int - n_inserted::Int - n_sampled::Int -end - -""" - InsertSampleRatioControler(ratio, threshold) - -Used in [`Trajectory`](@ref). The `threshold` means the minimal number of -insertings before sampling. The `ratio` balances the number of insertings and -the number of samplings. -""" -InsertSampleRatioControler(ratio, threshold) = InsertSampleRatioControler(ratio, threshold, 0, 0) - -function on_insert!(c::InsertSampleRatioControler, n::Int) - if n > 0 - c.n_inserted += n - end -end - -function on_sample!(c::InsertSampleRatioControler) - if c.n_inserted >= c.threshold - if c.n_sampled < c.n_inserted * c.ratio - c.n_sampled += 1 - true - end - end -end \ No newline at end of file +Base.iterate(t::Trajectory{<:Any,<:Any,<:AsyncInsertSampleRatioControler}, args...) = iterate(t.controler.ch_out, args...) +Base.take!(t::Trajectory{<:Any,<:Any,<:AsyncInsertSampleRatioControler}) = take!(t.controler.ch_out) \ No newline at end of file From f574b3d5d52d541ff354bac891c5c510c18811e1 Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Sun, 8 May 2022 15:06:51 +0800 Subject: [PATCH 08/11] fix tests --- src/Trajectories.jl | 2 +- src/common/CircularArraySARTTraces.jl | 2 +- src/common/CircularArraySLARTTraces.jl | 2 +- src/controlers.jl | 61 +++++++++++++++++++ src/episodes.jl | 2 +- src/samplers.jl | 5 +- src/traces.jl | 4 +- src/trajectory.jl | 84 +++----------------------- test/trajectories.jl | 29 ++++++++- 9 files changed, 107 insertions(+), 84 deletions(-) create mode 100644 src/controlers.jl diff --git a/src/Trajectories.jl b/src/Trajectories.jl index ce05556..40e876b 100644 --- a/src/Trajectories.jl +++ b/src/Trajectories.jl @@ -1,10 +1,10 @@ module Trajectories include("samplers.jl") +include("controlers.jl") include("traces.jl") include("episodes.jl") include("trajectory.jl") -include("async_trajectory.jl") include("rendering.jl") include("common/common.jl") diff --git a/src/common/CircularArraySARTTraces.jl b/src/common/CircularArraySARTTraces.jl index 607db2d..f140b00 100644 --- a/src/common/CircularArraySARTTraces.jl +++ b/src/common/CircularArraySARTTraces.jl @@ -41,7 +41,7 @@ function Random.rand(s::BatchSampler, t::CircularArraySARTTraces) terminal=t[:terminal][inds], next_state=t[:state][inds′], next_action=t[:state][inds′] - ) + ) |> s.transformer end function Base.push!(t::CircularArraySARTTraces, x::NamedTuple{SA}) diff --git a/src/common/CircularArraySLARTTraces.jl b/src/common/CircularArraySLARTTraces.jl index 3410dfd..83e5d0d 100644 --- a/src/common/CircularArraySLARTTraces.jl +++ b/src/common/CircularArraySLARTTraces.jl @@ -47,7 +47,7 @@ function sample(s::BatchSampler, t::CircularArraySLARTTraces) next_state=t[:state][inds′], next_legal_actions_mask=t[:legal_actions_mask][inds′], next_action=t[:state][inds′] - ) + ) |> s.transformer end function Base.push!(t::CircularArraySLARTTraces, x::NamedTuple{SLA}) diff --git a/src/controlers.jl b/src/controlers.jl new file mode 100644 index 0000000..bc3852b --- /dev/null +++ b/src/controlers.jl @@ -0,0 +1,61 @@ +export InsertSampleRatioControler, AsyncInsertSampleRatioControler + +mutable struct InsertSampleRatioControler + ratio::Float64 + threshold::Int + n_inserted::Int + n_sampled::Int +end + +""" + InsertSampleRatioControler(ratio, threshold) + +Used in [`Trajectory`](@ref). The `threshold` means the minimal number of +insertings before sampling. The `ratio` balances the number of insertings and +the number of samplings. +""" +InsertSampleRatioControler(ratio, threshold) = InsertSampleRatioControler(ratio, threshold, 0, 0) + +function on_insert!(c::InsertSampleRatioControler, n::Int) + if n > 0 + c.n_inserted += n + end +end + +function on_sample!(c::InsertSampleRatioControler) + if c.n_inserted >= c.threshold + if c.n_sampled <= (c.n_inserted - c.threshold) * c.ratio + c.n_sampled += 1 + true + end + end +end + +##### + +mutable struct AsyncInsertSampleRatioControler + ratio::Float64 + threshold::Int + n_inserted::Int + n_sampled::Int + ch_in::Channel + ch_out::Channel +end + +function AsyncInsertSampleRatioControler( + ratio, + threshold, + ; ch_in_sz=1, + ch_out_sz=1, + n_inserted=0, + n_sampled=0 +) + AsyncInsertSampleRatioControler( + ratio, + threshold, + n_inserted, + n_sampled, + Channel(ch_in_sz), + Channel(ch_out_sz) + ) +end diff --git a/src/episodes.jl b/src/episodes.jl index bc3a690..dcd5712 100644 --- a/src/episodes.jl +++ b/src/episodes.jl @@ -103,5 +103,5 @@ end function sample(s::BatchSampler, e::Episodes) inds = rand(s.rng, 1:length(t), s.batch_size) - batch([@view(s.episodes[e.inds[i][1]][e.inds[i][2]]) for i in inds]) + batch([@view(s.episodes[e.inds[i][1]][e.inds[i][2]]) for i in inds]) |> s.transformer end \ No newline at end of file diff --git a/src/samplers.jl b/src/samplers.jl index ea5d229..b4c073c 100644 --- a/src/samplers.jl +++ b/src/samplers.jl @@ -5,13 +5,14 @@ using Random struct BatchSampler batch_size::Int rng::Random.AbstractRNG + transformer::Any end """ - BatchSampler(batch_size; rng=Random.GLOBAL_RNG) + BatchSampler(batch_size; rng=Random.GLOBAL_RNG, transformer=identity) Uniformly sample a batch of examples for each trace. See also [`sample`](@ref). """ -BatchSampler(batch_size; rng=Random.GLOBAL_RNG) = BatchSampler(batch_size, rng) +BatchSampler(batch_size; rng=Random.GLOBAL_RNG, transformer=identity) = BatchSampler(batch_size, rng, identity) diff --git a/src/traces.jl b/src/traces.jl index a4f5750..02a96ab 100644 --- a/src/traces.jl +++ b/src/traces.jl @@ -36,7 +36,7 @@ Base.empty!(t::Trace) = empty!(t.x) function sample(s::BatchSampler, t::Trace) inds = rand(s.rng, 1:length(t), s.batch_size) - t[inds] + t[inds] |> s.transformer end ##### @@ -84,5 +84,5 @@ function sample(s::BatchSampler, t::Traces) inds = rand(s.rng, 1:length(t), s.batch_size) map(t.traces) do x x[inds] - end + end |> s.transformer end \ No newline at end of file diff --git a/src/trajectory.jl b/src/trajectory.jl index 1aa324b..fe7373e 100644 --- a/src/trajectory.jl +++ b/src/trajectory.jl @@ -1,72 +1,8 @@ -export Trajectory, InsertSampleRatioControler +export Trajectory using Base.Threads -##### - -mutable struct InsertSampleRatioControler - ratio::Float64 - threshold::Int - n_inserted::Int - n_sampled::Int -end - -""" - InsertSampleRatioControler(ratio, threshold) - -Used in [`Trajectory`](@ref). The `threshold` means the minimal number of -insertings before sampling. The `ratio` balances the number of insertings and -the number of samplings. -""" -InsertSampleRatioControler(ratio, threshold) = InsertSampleRatioControler(ratio, threshold, 0, 0) - -function on_insert!(c::InsertSampleRatioControler, n::Int) - if n > 0 - c.n_inserted += n - end -end - -function on_sample!(c::InsertSampleRatioControler) - if c.n_inserted >= c.threshold - if c.n_sampled < (c.n_inserted - n.threshold) * c.ratio - c.n_sampled += 1 - true - end - end -end - -##### - -mutable struct AsyncInsertSampleRatioControler - ratio::Float64 - threshold::Int - n_inserted::Int - n_sampled::Int - ch_in::Channel - ch_out::Channel -end - -function AsyncInsertSampleRatioControler( - ratio, - threshold, - ; ch_in_sz=1, - ch_out_sz=1, - n_inserted=0, - n_sampled=0 -) - AsyncInsertSampleRatioControler( - ratio, - threshold, - n_inserted, - n_sampled, - Channel(ch_in_sz), - Channel(ch_out_sz) - ) -end - -##### - """ Trajectory(container, sampler, controler) @@ -91,18 +27,18 @@ Base.@kwdef struct Trajectory{C,S,T} function Trajectory(container::C, sampler::S, controler::T) where {C,S,T<:AsyncInsertSampleRatioControler} t = Threads.@spawn while true - for msg in controler.in + for msg in controler.ch_in if msg.f === Base.push! || msg.f === Base.append! - n_pre = length(trajectory) - msg.f(trajectory, msg.args...; msg.kw...) - n_post = length(trajectory) + n_pre = length(container) + msg.f(container, msg.args...; msg.kw...) + n_post = length(container) controler.n_inserted += n_post - n_pre else - msg.f(trajectory, msg.args...; msg.kw...) + msg.f(container, msg.args...; msg.kw...) end if controler.n_inserted >= controler.threshold - if controler.n_sampled < (controler.n_inserted - controler.threshold) * controler.ratio + if controler.n_sampled <= (controler.n_inserted - controler.threshold) * controler.ratio batch = sample(sampler, container) put!(controler.ch_out, batch) controler.n_sampled += 1 @@ -111,8 +47,8 @@ Base.@kwdef struct Trajectory{C,S,T} end end - bind(controler.in, t) - bind(controler.out, t) + bind(controler.ch_in, t) + bind(controler.ch_out, t) new{C,S,T}(container, sampler, controler) end end @@ -134,7 +70,7 @@ struct CallMsg end Base.push!(t::Trajectory{<:Any,<:Any,<:AsyncInsertSampleRatioControler}, args...; kw...) = put!(t.controler.ch_in, CallMsg(Base.push!, args, kw)) -Base.append!(t::Trajectory{<:Any,<:Any,<:AsyncInsertSampleRatioControler}, args...; kw...) = append!(t.controler.ch_in, CallMsg(Base.push!, args, kw)) +Base.append!(t::Trajectory{<:Any,<:Any,<:AsyncInsertSampleRatioControler}, args...; kw...) = put!(t.controler.ch_in, CallMsg(Base.append!, args, kw)) Base.append!(t::Trajectory; kw...) = append!(t, values(kw)) diff --git a/test/trajectories.jl b/test/trajectories.jl index fe41227..c6b1d9e 100644 --- a/test/trajectories.jl +++ b/test/trajectories.jl @@ -30,7 +30,7 @@ push!(batches, batch) end - @test length(batches) == 1 # 4 inserted, ratio is 0.25 + @test length(batches) == 1 # 4 inserted, threshold is 4, ratio is 0.25 append!(t; a=[5, 6, 7], b=[true, true, true]) @@ -38,7 +38,7 @@ push!(batches, batch) end - @test length(batches) == 2 # 7 inserted, ratio is 0.25 + @test length(batches) == 1 # 7 inserted, threshold is 4, ratio is 0.25 push!(t; a=8, b=true) @@ -58,4 +58,29 @@ s += 1 end @test s == n +end + +@testset "async trajectories" begin + threshould = 100 + ratio = 1 / 4 + t = Trajectory( + container=Traces( + a=Int[], + b=Bool[] + ), + sampler=BatchSampler(3), + controler=AsyncInsertSampleRatioControler(ratio, threshould) + ) + + n = 100 + insert_task = @async for i in 1:n + append!(t; a=[i, i, i, i], b=[false, true, false, true]) + end + + s = 0 + sample_task = @async for _ in t + s += 1 + end + sleep(1) + @test s == (n - threshould * ratio) + 1 end \ No newline at end of file From 5722d214a93463dd28a942f5d7cf5937f68cbc7b Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Sun, 8 May 2022 17:39:49 +0800 Subject: [PATCH 09/11] minor modification --- src/rendering.jl | 39 +++++++++++++++++++++++++++++++++++---- 1 file changed, 35 insertions(+), 4 deletions(-) diff --git a/src/rendering.jl b/src/rendering.jl index 39d0515..1de5afa 100644 --- a/src/rendering.jl +++ b/src/rendering.jl @@ -2,7 +2,7 @@ using Term const TRACE_COLORS = ("bright_green", "hot_pink", "bright_blue", "light_coral", "bright_cyan", "sandy_brown", "violet") -Base.show(io::IO, ::MIME"text/plain", t::Union{Trace,Traces,Episode,Episodes}) = tprint(io, convert(Term.AbstractRenderable, t; width=displaysize(io)[2]) |> string) +Base.show(io::IO, ::MIME"text/plain", t::Union{Trace,Traces,Episode,Episodes,Trajectory}) = tprint(io, convert(Term.AbstractRenderable, t; width=displaysize(io)[2]) |> string) inner_convert(::Type{Term.AbstractRenderable}, s::String; style="gray1", width=88) = Panel(s, width=width, style=style, justify=:center) inner_convert(t::Type{Term.AbstractRenderable}, x::Union{Symbol,Number}; kw...) = inner_convert(t, string(x); kw...) @@ -63,7 +63,8 @@ function Base.convert(::Type{Term.AbstractRenderable}, t::Traces; width=88) style="yellow3", subtitle="$N traces in total", subtitle_justify=:right, - width=width + width=width, + fit=true ) end @@ -74,7 +75,8 @@ function Base.convert(::Type{Term.AbstractRenderable}, e::Episode; width=88) style="green_yellow", subtitle=e[] ? "Episode END" : "Episode growing...", subtitle_justify=:right, - width=width + width=width, + fit=true ) end @@ -99,6 +101,35 @@ function Base.convert(::Type{Term.AbstractRenderable}, e::Episodes; width=88) subtitle="$n episodes in total", subtitle_justify=:right, width=width, + fit=true, style="wheat1" ) -end \ No newline at end of file +end + +function Base.convert(r::Type{Term.AbstractRenderable}, t::Trajectory; width=88) + Panel( + convert(r, t.container; width=width - 8) / + Panel(convert(Term.Tree, t.sampler); title="sampler", style="yellow3", fit=true, width=width - 8) / + Panel(convert(Term.Tree, t.controler); title="controler", style="yellow3", fit=true, width=width - 8); + title="Trajectory", + style="yellow3", + width=width, + fit=true + ) +end + +# general converter + +Base.convert(::Type{Term.Tree}, x) = Tree(to_tree_body(x); title=to_tree_title(x)) +Base.convert(::Type{Term.Tree}, x::Tree) = x + +function to_tree_body(x) + pts = propertynames(x) + if length(pts) > 0 + Dict("$p => $(summary(getproperty(x, p)))" => to_tree_body(getproperty(x, p)) for p in pts) + else + x + end +end + +to_tree_title(x) = "$(summary(x))" \ No newline at end of file From 016b57e4da27e3370692d34c336fc496b0b54f9e Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Sun, 8 May 2022 18:09:19 +0800 Subject: [PATCH 10/11] update readme --- README.md | 99 ++++++++++++++++++++++++++++++------------------------- 1 file changed, 55 insertions(+), 44 deletions(-) diff --git a/README.md b/README.md index 8dab63d..0858da5 100644 --- a/README.md +++ b/README.md @@ -6,54 +6,65 @@ ## Design +A typical example of `Trajectory`: + ``` - ┌────────────────────────────┐ - │(state=..., action=..., ...)│ - └──────────────┬─────────────┘ - push! │ append! - ┌───────────────────▼───────────────────┐ - │ Trajectory │ - │ ┌─────────────────────────────────┐ │ - │ │ Traces │ │ - │ │ ┌───────────────────┐ │ │ - │ │ state: │CircularArrayBuffer│ │ │ - │ │ └───────────────────┘ │ │ - │ │ ┌───────────────────┐ │ │ - │ │ action:│CircularArrayBuffer│ │ │ - │ │ └───────────────────┘ │ │ - │ │ ...... │ │ - │ └─────────────────────────────────┘ │ - | Sampler | - └───────────────────┬───────────────────┘ - │ batch sampling - ┌──────────────▼─────────────┐ - │(state=..., action=..., ...)│ - └────────────────────────────┘ +╭──── Trajectory ──────────────────────────────────────────────────────────────────────────────────────────────────╮ +│ ╭──── Traces ────────────────────────────────────────────────────────────────────────────────────────────────╮ │ +│ │ ╭──── state: Trac... ────╮╭──── action: Tra... ────╮╭──── reward: Tra... ────╮╭──── terminal: T... ────╮ │ │ +│ │ │ ╭──────────────────╮ ││ ╭──────────────────╮ ││ ╭──────────────────╮ ││ ╭──────────────────╮ │ │ │ +│ │ │ │ SubArray │ ││ │ 1 │ ││ │ 1.0 │ ││ │ false │ │ │ │ +│ │ │ │ (2,3) │ ││ ╰──────────────────╯ ││ ╰──────────────────╯ ││ ╰──────────────────╯ │ │ │ +│ │ │ ╰──────────────────╯ ││ ╭──────────────────╮ ││ ╭──────────────────╮ ││ ╭──────────────────╮ │ │ │ +│ │ │ ╭──────────────────╮ ││ │ 2 │ ││ │ 2.0 │ ││ │ false │ │ │ │ +│ │ │ │ SubArray │ ││ ╰──────────────────╯ ││ ╰──────────────────╯ ││ ╰──────────────────╯ │ │ │ +│ │ │ │ (2,3) │ ││ ╭──────────────────╮ ││ ╭──────────────────╮ ││ ╭──────────────────╮ │ │ │ +│ │ │ ╰──────────────────╯ ││ │ ... │ ││ │ ... │ ││ │ ... │ │ │ │ +│ │ │ ╭──────────────────╮ ││ ╰──────────────────╯ ││ ╰──────────────────╯ ││ ╰──────────────────╯ │ │ │ +│ │ │ │ ... │ ││ ╭──────────────────╮ ││ ╭──────────────────╮ ││ ╭──────────────────╮ │ │ │ +│ │ │ ╰──────────────────╯ ││ │ 3 │ ││ │ 3.0 │ ││ │ true │ │ │ │ +│ │ │ ╭──────────────────╮ ││ ╰──────────────────╯ ││ ╰──────────────────╯ ││ ╰──────────────────╯ │ │ │ +│ │ │ │ SubArray │ ││ ╭──────────────────╮ │╰───────── size: (4,) ───╯╰───────── size: (4,) ───╯ │ │ +│ │ │ │ (2,3) │ ││ │ 3 │ │ │ │ +│ │ │ ╰──────────────────╯ ││ ╰──────────────────╯ │ │ │ +│ │ │ ╭──────────────────╮ │╰───────── size: (5,) ───╯ │ │ +│ │ │ │ SubArray │ │ │ │ +│ │ │ │ (2,3) │ │ │ │ +│ │ │ ╰──────────────────╯ │ │ │ +│ │ ╰──── size: (2, 3, 5) ───╯ │ │ +│ ╰────────────────────────────────────────────────────────────────────────────────────── 4 traces in total ───╯ │ +│ ╭──── sampler ───────────────────────────────────────────────────────╮ │ +│ │ BatchSampler │ │ +│ │ ━━━━━━━━━━━━━━ │ │ +│ │ │ │ │ +│ │ ├── transformer => identity (generic function...: identity │ │ +│ │ ├── rng => Random._GLOBAL_RNG: Random._GLOBAL_RNG() │ │ +│ │ └── batch_size => Int64: 5 │ │ +│ ╰────────────────────────────────────────────────────────────────────╯ │ +│ ╭──── controler ────────────────────────────╮ │ +│ │ InsertSampleRatioControler │ │ +│ │ ━━━━━━━━━━━━━━━━━━━━━━━━━━━━ │ │ +│ │ │ │ │ +│ │ ├── threshold => Int64: 4 │ │ +│ │ ├── n_sampled => Int64: 0 │ │ +│ │ ├── ratio => Float64: 0.25 │ │ +│ │ └── n_inserted => Int64: 4 │ │ +│ ╰───────────────────────────────────────────╯ │ +╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ ``` +Exported APIs are: + +```julia +push!(trajectory; [trace_name=value]...) +append!(trajectory; [trace_name=value]...) + +for sample in trajectory + # consume samples from the trajectory +end ``` - ┌──────────────┐ ┌──────────────┐ - │Single Element│ │Batch Elements│ - └──────┬───────┘ └──────┬───────┘ - │ │ - push! └──────┐ ┌───────┘ append! - │ │ - ┌─────────────┼────┼─────────────────────────────┐ - │ ┌──▼────▼──┐ AsyncTrajectory │ - │ │Channel In│ │ - │ └─────┬────┘ │ - │ take! │ │ - │ ┌─────▼─────┐ push! ┌────────────┐ │ - │ │RateLimiter├──────────► Trajectory │ │ - │ └─────┬─────┘ append! └────*───────┘ │ - │ │ * │ - │ put! │********************** │ - │ │ batch sampling │ - │ ┌─────▼─────┐ │ - │ │Channel Out│ │ - │ └───────────┘ │ - └────────────────────────────────────────────────┘ -``` + +A wide variety of `container`s, `sampler`s, and `controler`s are provided. For the full list, please read the doc. ## Acknowledgement From edd4c52e74768f61efc0367dd05e153bb0391891 Mon Sep 17 00:00:00 2001 From: Jun Tian Date: Sun, 8 May 2022 18:17:28 +0800 Subject: [PATCH 11/11] update readme --- README.md | 45 +-------------------------------------------- 1 file changed, 1 insertion(+), 44 deletions(-) diff --git a/README.md b/README.md index 0858da5..f650d21 100644 --- a/README.md +++ b/README.md @@ -8,50 +8,7 @@ A typical example of `Trajectory`: -``` -╭──── Trajectory ──────────────────────────────────────────────────────────────────────────────────────────────────╮ -│ ╭──── Traces ────────────────────────────────────────────────────────────────────────────────────────────────╮ │ -│ │ ╭──── state: Trac... ────╮╭──── action: Tra... ────╮╭──── reward: Tra... ────╮╭──── terminal: T... ────╮ │ │ -│ │ │ ╭──────────────────╮ ││ ╭──────────────────╮ ││ ╭──────────────────╮ ││ ╭──────────────────╮ │ │ │ -│ │ │ │ SubArray │ ││ │ 1 │ ││ │ 1.0 │ ││ │ false │ │ │ │ -│ │ │ │ (2,3) │ ││ ╰──────────────────╯ ││ ╰──────────────────╯ ││ ╰──────────────────╯ │ │ │ -│ │ │ ╰──────────────────╯ ││ ╭──────────────────╮ ││ ╭──────────────────╮ ││ ╭──────────────────╮ │ │ │ -│ │ │ ╭──────────────────╮ ││ │ 2 │ ││ │ 2.0 │ ││ │ false │ │ │ │ -│ │ │ │ SubArray │ ││ ╰──────────────────╯ ││ ╰──────────────────╯ ││ ╰──────────────────╯ │ │ │ -│ │ │ │ (2,3) │ ││ ╭──────────────────╮ ││ ╭──────────────────╮ ││ ╭──────────────────╮ │ │ │ -│ │ │ ╰──────────────────╯ ││ │ ... │ ││ │ ... │ ││ │ ... │ │ │ │ -│ │ │ ╭──────────────────╮ ││ ╰──────────────────╯ ││ ╰──────────────────╯ ││ ╰──────────────────╯ │ │ │ -│ │ │ │ ... │ ││ ╭──────────────────╮ ││ ╭──────────────────╮ ││ ╭──────────────────╮ │ │ │ -│ │ │ ╰──────────────────╯ ││ │ 3 │ ││ │ 3.0 │ ││ │ true │ │ │ │ -│ │ │ ╭──────────────────╮ ││ ╰──────────────────╯ ││ ╰──────────────────╯ ││ ╰──────────────────╯ │ │ │ -│ │ │ │ SubArray │ ││ ╭──────────────────╮ │╰───────── size: (4,) ───╯╰───────── size: (4,) ───╯ │ │ -│ │ │ │ (2,3) │ ││ │ 3 │ │ │ │ -│ │ │ ╰──────────────────╯ ││ ╰──────────────────╯ │ │ │ -│ │ │ ╭──────────────────╮ │╰───────── size: (5,) ───╯ │ │ -│ │ │ │ SubArray │ │ │ │ -│ │ │ │ (2,3) │ │ │ │ -│ │ │ ╰──────────────────╯ │ │ │ -│ │ ╰──── size: (2, 3, 5) ───╯ │ │ -│ ╰────────────────────────────────────────────────────────────────────────────────────── 4 traces in total ───╯ │ -│ ╭──── sampler ───────────────────────────────────────────────────────╮ │ -│ │ BatchSampler │ │ -│ │ ━━━━━━━━━━━━━━ │ │ -│ │ │ │ │ -│ │ ├── transformer => identity (generic function...: identity │ │ -│ │ ├── rng => Random._GLOBAL_RNG: Random._GLOBAL_RNG() │ │ -│ │ └── batch_size => Int64: 5 │ │ -│ ╰────────────────────────────────────────────────────────────────────╯ │ -│ ╭──── controler ────────────────────────────╮ │ -│ │ InsertSampleRatioControler │ │ -│ │ ━━━━━━━━━━━━━━━━━━━━━━━━━━━━ │ │ -│ │ │ │ │ -│ │ ├── threshold => Int64: 4 │ │ -│ │ ├── n_sampled => Int64: 0 │ │ -│ │ ├── ratio => Float64: 0.25 │ │ -│ │ └── n_inserted => Int64: 4 │ │ -│ ╰───────────────────────────────────────────╯ │ -╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ -``` +![](https://user-images.githubusercontent.com/5612003/167291629-0e2d4f0f-7c54-460c-a94f-9eb4148cdca0.png) Exported APIs are: