Skip to content

Commit 1f254d3

Browse files
Fix Broadcast.broadcast_shape inference
1 parent ab51dbe commit 1f254d3

File tree

2 files changed

+17
-3
lines changed

2 files changed

+17
-3
lines changed

src/blockbroadcast.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,14 @@ BroadcastStyle(::PseudoBlockStyle{M}, ::BlockStyle{N}) where {M,N} = BlockStyle(
3030

3131
# sortedunion can assume inputs are already sorted so this could be improved
3232
sortedunion(a,b) = sort!(union(a,b))
33+
sortedunion(a::Tuple, b::Tuple) = (a..., b...)
3334
sortedunion(a::Base.OneTo, b::Base.OneTo) = Base.OneTo(max(last(a),last(b)))
3435
sortedunion(a::AbstractUnitRange, b::AbstractUnitRange) = min(first(a),first(b)):max(last(a),last(b))
3536
combine_blockaxes(a, b) = _BlockedUnitRange(sortedunion(blocklasts(a), blocklasts(b)))
3637

37-
Base.Broadcast.axistype(a::BlockedUnitRange, b::BlockedUnitRange) = length(b) == 1 ? a : combine_blockaxes(a, b)
38-
Base.Broadcast.axistype(a::BlockedUnitRange, b) = length(b) == 1 ? a : combine_blockaxes(a, b)
39-
Base.Broadcast.axistype(a, b::BlockedUnitRange) = length(b) == 1 ? a : combine_blockaxes(a, b)
38+
Base.Broadcast.axistype(a::BlockedUnitRange, b::BlockedUnitRange) = combine_blockaxes(a, b)
39+
Base.Broadcast.axistype(a::BlockedUnitRange, b) = combine_blockaxes(a, b)
40+
Base.Broadcast.axistype(a, b::BlockedUnitRange) = combine_blockaxes(a, b)
4041

4142

4243
similar(bc::Broadcasted{<:AbstractBlockStyle{N}}, ::Type{T}) where {T,N} =

test/test_blockbroadcast.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,19 @@ import BlockArrays: SubBlockIterator, BlockIndexRange, Diagonal
182182
u = BlockArray(randn(5), [2,3]);
183183
@inferred(copyto!(similar(u), Base.broadcasted(exp, u)))
184184
@test exp.(u) == exp.(Vector(u))
185+
186+
function test_allocation!(shape1, shape2)
187+
x = Base.Broadcast.broadcast_shape(shape1, shape2)
188+
return nothing
189+
end
190+
shape1 = (BlockArrays._BlockedUnitRange((2,)),);
191+
shape2 = (BlockArrays._BlockedUnitRange((2,)),);
192+
@inferred Base.Broadcast.axistype(shape1[1], shape2[1])
193+
@inferred BlockArrays.combine_blockaxes(shape1[1], shape2[1])
194+
@inferred Base.Broadcast.broadcast_shape(shape1, shape2)
195+
test_allocation!(shape1, shape2) # compile first
196+
p = @allocated test_allocation!(shape1, shape2)
197+
@test p == 0
185198
end
186199

187200
@testset "adjtrans" begin

0 commit comments

Comments
 (0)