diff --git a/src/Overlay.jl b/src/Overlay.jl index ec92063dd7..951ad3c7c1 100644 --- a/src/Overlay.jl +++ b/src/Overlay.jl @@ -156,7 +156,10 @@ end end @reactant_overlay @noinline function Base.mapreduce( - f, op, A::Union{AbstractArray,Base.Iterators.Zip,Base.Iterators.Enumerate}; kwargs... + f, + op, + A::Union{AbstractArray,Base.Iterators.Zip,Base.Iterators.Enumerate,Base.Generator}; + kwargs..., ) if use_overlayed_version(A) return TracedRArrayOverrides.overloaded_mapreduce(f, op, A; kwargs...) diff --git a/src/Reactant.jl b/src/Reactant.jl index 25d62f6464..da18116d15 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -60,6 +60,7 @@ function _parent end _parent_type(::Type{Array}) = Array _parent_type(::Type{Array{T}}) where {T} = Array{T} _parent_type(::Type{Array{T,N}}) where {T,N} = Array{T,N} +_parent_type(::Type{<:Slices{P}}) where {P} = P include("accelerators/Accelerators.jl") @@ -179,15 +180,19 @@ include("TracedRArray.jl") include("ConcreteRArray.jl") use_overlayed_version(x) = false -use_overlayed_version(x::Base.Iterators.Zip) = any(use_overlayed_version, x.is) +function use_overlayed_version(x::F) where {F<:Function} + return use_overlayed_version(getfield.(Ref(x), fieldnames(F))) +end +use_overlayed_version(x::Base.Generator) = use_overlayed_version((x.f, x.iter)) +use_overlayed_version(x::Base.Iterators.Zip) = use_overlayed_version(x.is) use_overlayed_version(x::Base.Iterators.Enumerate) = use_overlayed_version(x.itr) -use_overlayed_version(iter::Tuple) = any(use_overlayed_version, iter) -use_overlayed_version(iter::NamedTuple) = any(use_overlayed_version, values(iter)) +use_overlayed_version(x::Vector) = looped_any(use_overlayed_version, x) +use_overlayed_version(iter::Tuple) = looped_any(use_overlayed_version, iter) +use_overlayed_version(iter::NamedTuple) = looped_any(use_overlayed_version, values(iter)) use_overlayed_version(::TracedRArray) = true use_overlayed_version(::TracedRNumber) = true use_overlayed_version(::Number) = false use_overlayed_version(::MissingTracedValue) = true -use_overlayed_version(::Vector{<:AnyTracedRArray}) = true use_overlayed_version(::AbstractArray{<:TracedRNumber}) = true use_overlayed_version(rng::ReactantRNG) = use_overlayed_version(rng.seed) function use_overlayed_version(x::AbstractArray) @@ -196,6 +201,14 @@ function use_overlayed_version(x::AbstractArray) return use_overlayed_version(a) end +## We avoid calling into `any` to avoid triggering the `any` overlay +function looped_any(f::F, itr) where {F} + @inbounds for x in itr + f(x) && return true + end + return false +end + # StdLib Overloads include("stdlibs/LinearAlgebra.jl") include("stdlibs/Random.jl") diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 033329e6df..519115f05d 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -1527,6 +1527,10 @@ function unwrapped_broadcast(f::F, x::Base.Iterators.Enumerate) where {F} end end +function unwrapped_broadcast(f::F, x::Base.Generator) where {F} + return unwrapped_broadcast_with_iterate(f, Base.Generator(TracedCall(x.f), x.iter)) +end + unwrapped_broadcast(f::F, xs) where {F} = unwrapped_broadcast_with_iterate(f, xs) function unwrapped_broadcast_with_iterate(f::F, itr) where {F} diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index 554791af8f..98b83ea006 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -660,7 +660,7 @@ function finalize_mlir_fn( skipped_results = Reactant.TracedType[] for (k, v) in seen_results v isa Reactant.TracedType || continue - if any(Base.Fix1(===, k), skipped_args) + if Reactant.looped_any(Base.Fix1(===, k), skipped_args) push!(skipped_results, v) _, argpath = get_argidx(v, argprefix) diff --git a/test/basic.jl b/test/basic.jl index 03778fa6c0..c27e91d88b 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -1598,3 +1598,16 @@ end x_ra = Reactant.to_rarray(x) @test @jit(clamp!(x_ra, 0.5, Inf32)) ≈ clamp!(x, 0.5, Inf32) end + +@testset "Base.Generator" begin + points = eachcol(rand(Float32, 2, 6)) + params = rand(Float32, 4, 2) + points_ra = Reactant.to_rarray(points) + params_ra = Reactant.to_rarray(params) + + function f_generator(points, params) + return sum(params * point for point in points) + end + + @test @jit(f_generator(points_ra, params_ra)) ≈ f_generator(points, params) +end