Skip to content

Commit a36ddc3

Browse files
Tweak code
1 parent 31ce648 commit a36ddc3

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

src/samplers.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ end
186186

187187
function StatsBase.sample(s::NStepBatchSampler{names}, ts::EpisodesBuffer) where {names}
188188
valid_range = valid_range_nbatchsampler(s, ts)
189-
valid_range = valid_range[valid_range findall(ts.sampleable_inds)] # Ensure that the valid range is within the sampleable indices
189+
valid_range = valid_range[valid_range .∈ (findall(ts.sampleable_inds),)] # Ensure that the valid range is within the sampleable indices, probably could be done more efficiently by refactoring `valid_range_nbatchsampler`
190190
inds = rand(s.rng, valid_range, s.batch_size)
191191
StatsBase.sample(s, ts, Val(names), inds)
192192
end

0 commit comments

Comments
 (0)