Skip to content

Commit a3897f4

Browse files
committed
Turn off struct broadcast by default.
1 parent 4ea69d5 commit a3897f4

File tree

4 files changed

+75
-30
lines changed

4 files changed

+75
-30
lines changed

src/StructArrays.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,5 +38,6 @@ function GPUArraysCore.backend(::Type{T}) where {T<:StructArray}
3838
isconsistent || throw(ArgumentError("all component arrays must have the same GPU backend"))
3939
return backend
4040
end
41+
_use_default_bc(::GPUArraysCore.AbstractGPUArrayStyle) = true
4142

4243
end # module

src/staticarrays_support.jl

Lines changed: 2 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -38,35 +38,10 @@ end
3838
# This looks costly, but the compiler should be able to optimize them away
3939
Broadcast._axes(bc::Broadcasted{<:StructStaticArrayStyle}, ::Nothing) = axes(replace_structarray(bc))
4040

41-
to_staticstyle(@nospecialize(x::Type)) = x
42-
to_staticstyle(::Type{StructStaticArrayStyle{N}}) where {N} = StaticArrayStyle{N}
43-
44-
"""
45-
replace_structarray(bc::Broadcasted)
46-
47-
An internal function transforms the `Broadcasted` with `StructArray` into
48-
an equivalent one without it. This is not a must if the root `BroadcastStyle`
49-
supports `AbstractArray`. But some `BroadcastStyle` limits the input array types,
50-
e.g. `StaticArrayStyle`, thus we have to omit all `StructArray`.
51-
"""
52-
function replace_structarray(bc::Broadcasted{Style}) where {Style}
53-
args = replace_structarray_args(bc.args)
54-
return Broadcasted{to_staticstyle(Style)}(bc.f, args, nothing)
55-
end
56-
function replace_structarray(A::StructArray)
57-
f = Instantiator(eltype(A))
58-
args = Tuple(components(A))
59-
return Broadcasted{StaticArrayStyle{ndims(A)}}(f, args, nothing)
60-
end
61-
replace_structarray(@nospecialize(A)) = A
62-
63-
replace_structarray_args(args::Tuple) = (replace_structarray(args[1]), replace_structarray_args(tail(args))...)
64-
replace_structarray_args(::Tuple{}) = ()
65-
6641
# StaticArrayStyle has no similar defined.
6742
# Overload `Base.copy` instead.
68-
@inline function Base.copy(bc::Broadcasted{StructStaticArrayStyle{M}}) where {M}
69-
sa = copy(convert(Broadcasted{StaticArrayStyle{M}}, bc))
43+
@inline function try_struct_copy(bc::Broadcasted{StaticArrayStyle{M}}) where {M}
44+
sa = copy(bc)
7045
ET = eltype(sa)
7146
isnonemptystructtype(ET) || return sa
7247
elements = Tuple(sa)

src/structarray.jl

Lines changed: 57 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -494,7 +494,8 @@ function Base.showarg(io::IO, s::StructArray{T}, toplevel) where T
494494
end
495495

496496
# broadcast
497-
import Base.Broadcast: BroadcastStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle, Unknown
497+
import Base.Broadcast: BroadcastStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle, Unknown, ArrayConflict
498+
using Base.Broadcast: combine_styles
498499

499500
struct StructArrayStyle{S, N} <: AbstractArrayStyle{N} end
500501

@@ -524,6 +525,49 @@ Base.@pure cst(::Type{SA}) where {SA} = combine_style_types(array_types(SA).para
524525

525526
BroadcastStyle(::Type{SA}) where {SA<:StructArray} = StructArrayStyle{typeof(cst(SA)), ndims(SA)}()
526527

528+
_use_default_bc(::Any) = false
529+
_use_default_bc(::DefaultArrayStyle) = true
530+
_use_default_bc(::ArrayConflict) = true
531+
532+
function Base.copy(bc::Broadcasted{StructArrayStyle{S, N}}) where {S, N}
533+
if _use_default_bc(S())
534+
return @invoke copy(bc::Broadcasted)
535+
else
536+
return try_struct_copy(replace_structarray(bc))
537+
end
538+
end
539+
try_struct_copy(bc::Broadcasted) = copy(bc)
540+
541+
"""
542+
replace_structarray(bc::Broadcasted)
543+
544+
An internal function transforms the `Broadcasted` with `StructArray` into
545+
an equivalent one without it. This is not a must if the root `BroadcastStyle`
546+
supports `AbstractArray`. But some `BroadcastStyle` limits the input array types,
547+
e.g. `StaticArrayStyle`, thus we have to omit all `StructArray`.
548+
"""
549+
function replace_structarray(bc::Broadcasted{Style}) where {Style}
550+
args = replace_structarray_args(bc.args)
551+
Style′ = parant_style(Style())
552+
return Broadcasted{Style′}(bc.f, args, nothing)
553+
end
554+
function replace_structarray(A::StructArray)
555+
f = Instantiator(eltype(A))
556+
args = Tuple(components(A))
557+
Style = typeof(combine_styles(args...))
558+
return Broadcasted{Style}(f, args, nothing)
559+
end
560+
replace_structarray(@nospecialize(A)) = A
561+
562+
replace_structarray_args(args::Tuple) = (replace_structarray(args[1]), replace_structarray_args(tail(args))...)
563+
replace_structarray_args(::Tuple{}) = ()
564+
565+
parant_style(@nospecialize(x)) = typeof(x)
566+
parant_style(::StructArrayStyle{S, N}) where {S, N} = S
567+
parant_style(::StructArrayStyle{S, N}) where {N, S<:AbstractArrayStyle{N}} = S
568+
parant_style(::StructArrayStyle{S, N}) where {S<:AbstractArrayStyle{Any}, N} = S
569+
parant_style(::StructArrayStyle{S, N}) where {S<:AbstractArrayStyle, N} = typeof(S(Val(N)))
570+
527571
# Here we use `similar` defined for `S` to build the dest Array.
528572
function Base.similar(bc::Broadcasted{StructArrayStyle{S, N}}, ::Type{ElType}) where {S, N, ElType}
529573
bc′ = convert(Broadcasted{S}, bc)
@@ -532,12 +576,22 @@ end
532576

533577
# Unwrapper to recover the behaviour defined by parent style.
534578
@inline function Base.copyto!(dest::AbstractArray, bc::Broadcasted{StructArrayStyle{S, N}}) where {S, N}
535-
return copyto!(dest, convert(Broadcasted{S}, bc))
579+
bc′ = _use_default_bc(S()) ? convert(Broadcasted{S}, bc) : replace_structarray(bc)
580+
return copyto!(dest, bc′)
536581
end
537582

538583
@inline function Broadcast.materialize!(::StructArrayStyle{S}, dest, bc::Broadcasted) where {S}
539-
return Broadcast.materialize!(S(), dest, bc)
584+
bc′ = _use_default_bc(S()) ? bc : replace_structarray(bc)
585+
return Broadcast.materialize!(S(), dest, bc′)
540586
end
541587

542588
# for aliasing analysis during broadcast
589+
function Broadcast.broadcast_unalias(dest::StructArray, src::AbstractArray)
590+
if dest === src || any(Base.Fix2(===, src), components(dest))
591+
return src
592+
else
593+
return Base.unalias(dest, src)
594+
end
595+
end
596+
543597
Base.dataids(u::StructArray) = mapreduce(Base.dataids, (a, b) -> (a..., b...), values(components(u)), init=())

test/runtests.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1183,6 +1183,7 @@ for S in (1, 2, 3)
11831183
Base.setindex!(A::$MyArray, val, i::Int) = A.A[i] = val
11841184
Base.size(A::$MyArray) = Base.size(A.A)
11851185
Base.BroadcastStyle(::Type{<:$MyArray}) = Broadcast.ArrayStyle{$MyArray}()
1186+
StructArrays._use_default_bc(::Broadcast.ArrayStyle{$MyArray}) = true
11861187
end
11871188
end
11881189
Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray1}}, ::Type{ElType}) where ElType =
@@ -1247,6 +1248,12 @@ Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayS
12471248
@test Base.broadcasted(+, reshape(1:2,1,1,2), s) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{3}}
12481249
@test Base.broadcasted(+, s, MyArray1(rand(2))) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{Any}}
12491250

1251+
# allocation test for overloaded `broadcast_unalias`
1252+
StructArrays._use_default_bc(::Broadcast.ArrayStyle{MyArray1}) = false
1253+
f(s) = s .+= 1
1254+
f(s)
1255+
@test (@allocated f(s)) == 0
1256+
12501257
# issue #185
12511258
A = StructArray(randn(ComplexF64, 3, 3))
12521259
B = randn(ComplexF64, 3, 3)
@@ -1317,6 +1324,14 @@ Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayS
13171324
@test backend(bcmul2(sa)) === backend(sa)
13181325
@test (sa .+= 1) === sa
13191326
end
1327+
1328+
@testset "StructSparseArray" begin
1329+
a = sprand(10, 10, 0.5)
1330+
b = sprand(10, 10, 0.5)
1331+
c = StructArray{ComplexF64}((a, b))
1332+
d = identity.(c)
1333+
@test d isa SparseMatrixCSC
1334+
end
13201335
end
13211336

13221337
@testset "map" begin

0 commit comments

Comments
 (0)