diff --git a/base/reduce.jl b/base/reduce.jl index 425d8d6f28f2f..00088027eacba 100644 --- a/base/reduce.jl +++ b/base/reduce.jl @@ -1,5 +1,10 @@ # This file is a part of Julia. License is MIT: https://julialang.org/license +## Helpers +eltype_or_default_eltype(itr::T) where T = eltype_or_default_eltype(itr, IteratorEltype(T)) +eltype_or_default_eltype(itr::T, ::HasEltype) where T = eltype(T) +eltype_or_default_eltype(itr, ::EltypeUnknown) = @default_eltype(itr) + ## reductions ## ###### Generic (map)reduce functions ###### @@ -354,10 +359,102 @@ julia> reduce(*, [2; 3; 4]; init=-1) -24 ``` """ -reduce(op, itr; kw...) = mapreduce(identity, op, itr; kw...) +function reduce(op, itr::T; kw...) where T + # Redispatch, adding traits + reduce(op, itr, eltype_or_default_eltype(itr), IteratorSize(T); kw...) +end + +function reduce(op, itr, et, isize; kw...) + # Fallback: if nothing interesting is being done with the traits + # or the operation + return mapreduce(identity, op, itr; kw...) +end reduce(op, a::Number) = a # Do we want this? +##### Operation specific reduce optimisations + +## vcat + +function reduce(::typeof(vcat), xs, T::Type{<:AbstractVector{V}}, isize) where V + x_state = iterate(xs) + x_state === nothing && return reduce_empty(vcat, T) + x1, state = x_state + + ret = Vector(x1) + + hinted_size = 0 + if !(isize isa SizeUnknown) + # Assume first element has representitive size, unless that would make this too large + SIZEHINT_CAP = 10^5 + hinted_size = min(SIZEHINT_CAP, length(xs)*length(x1)) + sizehint!(ret, hinted_size) + end + + x_state = iterate(xs, state) + while(x_state !== nothing) + x, state = x_state + append!(ret, x) + x_state = iterate(xs, state) + end + + if length(ret) < hinted_size/2 # it is only allowable to be at most 2x too much memory + sizehint!(ret, length(ret)) + end + + return ret +end + +## hcat + +function reduce(::typeof(hcat), xs, T::Type{<:AbstractVector{V}}, isize::SizeUnknown) where V + x_state = iterate(xs) + x_state === nothing && return reduce_empty(hcat, T) + x1, state = x_state + + dim1_size = length(x1) + dim2_size = 1 + ret_vec = Vector(x1) + + x_state = iterate(xs, state) + + while(x_state !== nothing) + x, state = x_state + append!(ret_vec, x) + dim2_size += 1 + x_state = iterate(xs, state) + end + + # Reshape will throw errors if anything was the wrong size + return reshape(ret_vec, (dim1_size, dim2_size)) +end + +function reduce(::typeof(hcat), xs, T::Type{<:AbstractVector{V}}, isize) where V + # Size is known + x_state = iterate(xs) + x_state === nothing && return reduce_empty(hcat, T) + x1, state = x_state + + dim1_size = size(x1,1) + dim2_size = length(xs) + + ret = similar(x1, eltype(x1), (dim1_size, dim2_size)) + copyto!(ret, 1, x1, 1) + + x_state = iterate(xs, state) + offset = length(x1)+1 + while(x_state !== nothing) + x, state = x_state + length(x)==dim1_size || throw(DimensionMismatch("hcat")) + copyto!(ret, offset, x, 1) + offset += length(x) + x_state = iterate(xs, state) + end + + return ret +end + + ###### Specific reduction functions ###### ## sum diff --git a/base/reducedim.jl b/base/reducedim.jl index 996729be8bc4c..6e11e41fb1290 100644 --- a/base/reducedim.jl +++ b/base/reducedim.jl @@ -349,7 +349,7 @@ julia> reduce(max, a, dims=1) 4 8 12 16 ``` """ -reduce(op, A::AbstractArray; kw...) = mapreduce(identity, op, A; kw...) +reduce(op, A::AbstractArray; kw...) ##### Specific reduction functions ##### """ diff --git a/test/reduce.jl b/test/reduce.jl index eb585e8a630f1..b6d216fccc883 100644 --- a/test/reduce.jl +++ b/test/reduce.jl @@ -526,6 +526,27 @@ test18695(r) = sum( t^2 for t in r ) end end +@testset "optimized reduce(vcat/hcat, A) for iterators" begin + v_v_same = [rand(128) for ii in 1:100] + # the following 2 are not optimized, but we want to make sure + # that they still hit the normal reduce methods + v_v_diff = [rand(128), rand(Float32,128), rand(Int, 128)] + v_v_diff_typed = Union{Vector{Float64},Vector{Float32},Vector{Int}}[rand(128), rand(Float32,128), rand(Int, 128)] + + for v_v in (v_v_same, v_v_diff, v_v_diff_typed) + # Cover all combinations of iterator traits. + g_v = (x for x in v_v) + f_g_v = Iterators.filter(x->true, g_v) + f_v_v = Iterators.filter(x->true, v_v); + hcat_expected = hcat(v_v...) + vcat_expected = vcat(v_v...) + @testset "$(typeof(data))" for data in (v_v, g_v, f_g_v, f_v_v) + @test reduce(hcat, data) == hcat_expected + @test reduce(vcat, data) == vcat_expected + end + end +end + # offset axes i = Base.Slice(-3:3) x = [j^2 for j in i]