Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/Overlay.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,10 @@ end
end

@reactant_overlay @noinline function Base.mapreduce(
f, op, A::Union{AbstractArray,Base.Iterators.Zip,Base.Iterators.Enumerate}; kwargs...
f,
op,
A::Union{AbstractArray,Base.Iterators.Zip,Base.Iterators.Enumerate,Base.Generator};
kwargs...,
)
if use_overlayed_version(A)
return TracedRArrayOverrides.overloaded_mapreduce(f, op, A; kwargs...)
Expand Down
21 changes: 17 additions & 4 deletions src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ function _parent end
_parent_type(::Type{Array}) = Array
_parent_type(::Type{Array{T}}) where {T} = Array{T}
_parent_type(::Type{Array{T,N}}) where {T,N} = Array{T,N}
_parent_type(::Type{<:Slices{P}}) where {P} = P

include("accelerators/Accelerators.jl")

Expand Down Expand Up @@ -179,15 +180,19 @@ include("TracedRArray.jl")
include("ConcreteRArray.jl")

use_overlayed_version(x) = false
use_overlayed_version(x::Base.Iterators.Zip) = any(use_overlayed_version, x.is)
function use_overlayed_version(x::F) where {F<:Function}
return use_overlayed_version(getfield.(Ref(x), fieldnames(F)))
end
use_overlayed_version(x::Base.Generator) = use_overlayed_version((x.f, x.iter))
use_overlayed_version(x::Base.Iterators.Zip) = use_overlayed_version(x.is)
use_overlayed_version(x::Base.Iterators.Enumerate) = use_overlayed_version(x.itr)
use_overlayed_version(iter::Tuple) = any(use_overlayed_version, iter)
use_overlayed_version(iter::NamedTuple) = any(use_overlayed_version, values(iter))
use_overlayed_version(x::Vector) = looped_any(use_overlayed_version, x)
use_overlayed_version(iter::Tuple) = looped_any(use_overlayed_version, iter)
use_overlayed_version(iter::NamedTuple) = looped_any(use_overlayed_version, values(iter))
use_overlayed_version(::TracedRArray) = true
use_overlayed_version(::TracedRNumber) = true
use_overlayed_version(::Number) = false
use_overlayed_version(::MissingTracedValue) = true
use_overlayed_version(::Vector{<:AnyTracedRArray}) = true
use_overlayed_version(::AbstractArray{<:TracedRNumber}) = true
use_overlayed_version(rng::ReactantRNG) = use_overlayed_version(rng.seed)
function use_overlayed_version(x::AbstractArray)
Expand All @@ -196,6 +201,14 @@ function use_overlayed_version(x::AbstractArray)
return use_overlayed_version(a)
end

## We avoid calling into `any` to avoid triggering the `any` overlay
function looped_any(f::F, itr) where {F}
@inbounds for x in itr
f(x) && return true
end
return false
end

# StdLib Overloads
include("stdlibs/LinearAlgebra.jl")
include("stdlibs/Random.jl")
Expand Down
4 changes: 4 additions & 0 deletions src/TracedRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1527,6 +1527,10 @@ function unwrapped_broadcast(f::F, x::Base.Iterators.Enumerate) where {F}
end
end

function unwrapped_broadcast(f::F, x::Base.Generator) where {F}
return unwrapped_broadcast_with_iterate(f, Base.Generator(TracedCall(x.f), x.iter))
end

unwrapped_broadcast(f::F, xs) where {F} = unwrapped_broadcast_with_iterate(f, xs)

function unwrapped_broadcast_with_iterate(f::F, itr) where {F}
Expand Down
2 changes: 1 addition & 1 deletion src/TracedUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -660,7 +660,7 @@ function finalize_mlir_fn(
skipped_results = Reactant.TracedType[]
for (k, v) in seen_results
v isa Reactant.TracedType || continue
if any(Base.Fix1(===, k), skipped_args)
if Reactant.looped_any(Base.Fix1(===, k), skipped_args)
push!(skipped_results, v)

_, argpath = get_argidx(v, argprefix)
Expand Down
13 changes: 13 additions & 0 deletions test/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1598,3 +1598,16 @@ end
x_ra = Reactant.to_rarray(x)
@test @jit(clamp!(x_ra, 0.5, Inf32)) ≈ clamp!(x, 0.5, Inf32)
end

@testset "Base.Generator" begin
points = eachcol(rand(Float32, 2, 6))
params = rand(Float32, 4, 2)
points_ra = Reactant.to_rarray(points)
params_ra = Reactant.to_rarray(params)

function f_generator(points, params)
return sum(params * point for point in points)
end

@test @jit(f_generator(points_ra, params_ra)) ≈ f_generator(points, params)
end
Loading