@@ -494,7 +494,8 @@ function Base.showarg(io::IO, s::StructArray{T}, toplevel) where T
494
494
end
495
495
496
496
# broadcast
497
- import Base. Broadcast: BroadcastStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle, Unknown
497
+ import Base. Broadcast: BroadcastStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle, Unknown, ArrayConflict
498
+ using Base. Broadcast: combine_styles
498
499
499
500
struct StructArrayStyle{S, N} <: AbstractArrayStyle{N} end
500
501
@@ -524,6 +525,49 @@ Base.@pure cst(::Type{SA}) where {SA} = combine_style_types(array_types(SA).para
524
525
525
526
BroadcastStyle (:: Type{SA} ) where {SA<: StructArray } = StructArrayStyle {typeof(cst(SA)), ndims(SA)} ()
526
527
528
+ _use_default_bc (:: Any ) = false
529
+ _use_default_bc (:: DefaultArrayStyle ) = true
530
+ _use_default_bc (:: ArrayConflict ) = true
531
+
532
+ function Base. copy (bc:: Broadcasted{StructArrayStyle{S, N}} ) where {S, N}
533
+ if _use_default_bc (S ())
534
+ return @invoke copy (bc:: Broadcasted )
535
+ else
536
+ return try_struct_copy (replace_structarray (bc))
537
+ end
538
+ end
539
+ try_struct_copy (bc:: Broadcasted ) = copy (bc)
540
+
541
+ """
542
+ replace_structarray(bc::Broadcasted)
543
+
544
+ An internal function transforms the `Broadcasted` with `StructArray` into
545
+ an equivalent one without it. This is not a must if the root `BroadcastStyle`
546
+ supports `AbstractArray`. But some `BroadcastStyle` limits the input array types,
547
+ e.g. `StaticArrayStyle`, thus we have to omit all `StructArray`.
548
+ """
549
+ function replace_structarray (bc:: Broadcasted{Style} ) where {Style}
550
+ args = replace_structarray_args (bc. args)
551
+ Style′ = parant_style (Style ())
552
+ return Broadcasted {Style′} (bc. f, args, nothing )
553
+ end
554
+ function replace_structarray (A:: StructArray )
555
+ f = Instantiator (eltype (A))
556
+ args = Tuple (components (A))
557
+ Style = typeof (combine_styles (args... ))
558
+ return Broadcasted {Style} (f, args, nothing )
559
+ end
560
+ replace_structarray (@nospecialize (A)) = A
561
+
562
+ replace_structarray_args (args:: Tuple ) = (replace_structarray (args[1 ]), replace_structarray_args (tail (args))... )
563
+ replace_structarray_args (:: Tuple{} ) = ()
564
+
565
+ parant_style (@nospecialize (x)) = typeof (x)
566
+ parant_style (:: StructArrayStyle{S, N} ) where {S, N} = S
567
+ parant_style (:: StructArrayStyle{S, N} ) where {N, S<: AbstractArrayStyle{N} } = S
568
+ parant_style (:: StructArrayStyle{S, N} ) where {S<: AbstractArrayStyle{Any} , N} = S
569
+ parant_style (:: StructArrayStyle{S, N} ) where {S<: AbstractArrayStyle , N} = typeof (S (Val (N)))
570
+
527
571
# Here we use `similar` defined for `S` to build the dest Array.
528
572
function Base. similar (bc:: Broadcasted{StructArrayStyle{S, N}} , :: Type{ElType} ) where {S, N, ElType}
529
573
bc′ = convert (Broadcasted{S}, bc)
@@ -532,12 +576,22 @@ end
532
576
533
577
# Unwrapper to recover the behaviour defined by parent style.
534
578
@inline function Base. copyto! (dest:: AbstractArray , bc:: Broadcasted{StructArrayStyle{S, N}} ) where {S, N}
535
- return copyto! (dest, convert (Broadcasted{S}, bc))
579
+ bc′ = _use_default_bc (S ()) ? convert (Broadcasted{S}, bc) : replace_structarray (bc)
580
+ return copyto! (dest, bc′)
536
581
end
537
582
538
583
@inline function Broadcast. materialize! (:: StructArrayStyle{S} , dest, bc:: Broadcasted ) where {S}
539
- return Broadcast. materialize! (S (), dest, bc)
584
+ bc′ = _use_default_bc (S ()) ? bc : replace_structarray (bc)
585
+ return Broadcast. materialize! (S (), dest, bc′)
540
586
end
541
587
542
588
# for aliasing analysis during broadcast
589
+ function Broadcast. broadcast_unalias (dest:: StructArray , src:: AbstractArray )
590
+ if dest === src || any (Base. Fix2 (=== , src), components (dest))
591
+ return src
592
+ else
593
+ return Base. unalias (dest, src)
594
+ end
595
+ end
596
+
543
597
Base. dataids (u:: StructArray ) = mapreduce (Base. dataids, (a, b) -> (a... , b... ), values (components (u)), init= ())
0 commit comments