Skip to content

Commit 5fcf427

Browse files
committed
feat: split out non-generator changes from #1642
1 parent 7f15e65 commit 5fcf427

File tree

4 files changed

+16
-3
lines changed

4 files changed

+16
-3
lines changed

src/Reactant.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@ use_overlayed_version(::TracedRArray) = true
193193
use_overlayed_version(::TracedRNumber) = true
194194
use_overlayed_version(::Number) = false
195195
use_overlayed_version(::MissingTracedValue) = true
196+
use_overlayed_version(::Vector{<:AnyTracedRArray}) = true
196197
use_overlayed_version(::AbstractArray{<:TracedRNumber}) = true
197198
use_overlayed_version(rng::ReactantRNG) = use_overlayed_version(rng.seed)
198199
function use_overlayed_version(x::AbstractArray)

src/TracedRArray.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1509,6 +1509,9 @@ end
15091509

15101510
struct BroadcastIterator{F}
15111511
f::F
1512+
1513+
BroadcastIterator{F}(f::F) where {F} = new{F}(f)
1514+
BroadcastIterator(f::F) where {F} = new{F}(f)
15121515
end
15131516

15141517
(fn::BroadcastIterator)(args...) = fn.f((args...,))
@@ -1536,15 +1539,15 @@ function unrolled_map(f::F, itr) where {F}
15361539
y === nothing && return []
15371540

15381541
first, state = y
1539-
res_first = Reactant.call_with_reactant(f, first)
1542+
res_first = @opcall call(f, first)
15401543
result = [res_first]
15411544

15421545
while true
15431546
y = Reactant.call_with_reactant(iterate, itr, state)
15441547
y === nothing && break
15451548

15461549
val, state = y
1547-
res = Reactant.call_with_reactant(f, val)
1550+
res = @opcall call(f, val)
15481551
push!(result, res)
15491552
end
15501553

test/basic.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1496,6 +1496,15 @@ end
14961496

14971497
@test @jit(nested_mapreduce_hcat(x_ra, y_ra)) nested_mapreduce_hcat(x, y)
14981498
end
1499+
1500+
@testset "mapreduce vector" begin
1501+
x = [rand(Float32, 2, 3) for _ in 1:10]
1502+
x_ra = Reactant.to_rarray(x)
1503+
1504+
@test @jit(mapreduce_vector(x_ra)) mapreduce_vector(x)
1505+
hlo = repr(@code_hlo optimize = false mapreduce_vector(x_ra))
1506+
@test contains(hlo, "call")
1507+
end
14991508
end
15001509

15011510
@testset "Base.Generator" begin

test/integration/fillarrays.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,5 @@ end
2525
x = OneElement(3.4f0, (3, 4), (32, 32))
2626
rx = Reactant.to_rarray(x)
2727

28-
@test @jit(fn(rx, rx)) fn(x, x)
28+
@test @jit(fn(rx, rx)) fn(x, x) atol = 1e-3 rtol = 1e-3
2929
end

0 commit comments

Comments
 (0)