Skip to content

Commit 7bca709

Browse files
committed
Fix cat_shape for julia < 1.6
1 parent 772d3be commit 7bca709

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

src/fillcat.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ function Base.cat_t(::Type{T}, fs::Fill...; dims) where T
99
# There might be some cases when it does not get padded which are not considered here
1010
allvals[] !== zero(T) && sum(catdims) > 1 && return Base._cat_t(dims, T, fs...)
1111

12-
shape = Base.cat_shape(catdims, map(Base.cat_size, fs)::Tuple{Vararg{Union{Int,Dims}}})::Dims
12+
shape = cat_shape_fill(catdims, fs)
1313
return Fill(convert(T, fs[1].value), shape)
1414
end
1515

@@ -19,7 +19,7 @@ Base.hcat(vs::Fill...) = cat(vs...;dims=Val(2))
1919

2020
function Base.cat_t(::Type{T}, fs::Zeros...; dims) where T
2121
catdims = Base.dims2cat(dims)
22-
shape = Base.cat_shape(catdims, map(Base.cat_size, fs)::Tuple{Vararg{Union{Int,Dims}}})::Dims
22+
shape = cat_shape_fill(catdims, fs)
2323
return Zeros{T}(shape)
2424
end
2525

@@ -34,10 +34,16 @@ function Base.cat_t(::Type{T}, fs::Ones...; dims) where T
3434
# There might be some cases when it does not get padded which are not considered here
3535
sum(catdims) > 1 && return Base._cat_t(dims, T, fs...)
3636

37-
shape = Base.cat_shape(catdims, map(Base.cat_size, fs)::Tuple{Vararg{Union{Int,Dims}}})::Dims
37+
shape = cat_shape_fill(catdims, fs)
3838
return Ones{T}(shape)
3939
end
4040

4141
Base.vcat(vs::Ones...) = cat(vs...;dims=Val(1))
4242
Base.hcat(vs::Ones...) = cat(vs...;dims=Val(2))
4343

44+
45+
if VERSION < v"1.6-"
46+
cat_shape_fill(catdims, fs) = Base.cat_shape(catdims, (), map(Base.cat_size, fs)...)
47+
else
48+
cat_shape_fill(catdims, fs) = Base.cat_shape(catdims, map(Base.cat_size, fs)::Tuple{Vararg{Union{Int,Dims}}})::Dims
49+
end

0 commit comments

Comments
 (0)