Skip to content
1 change: 1 addition & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ ArrayInterfaceCore.promote_eltype
ArrayInterfaceCore.restructure
ArrayInterfaceCore.safevec
ArrayInterfaceCore.zeromatrix
ArrayInterfaceCore.undefmatrix
```

### Types
Expand Down
15 changes: 15 additions & 0 deletions lib/ArrayInterfaceCore/src/ArrayInterfaceCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,21 @@ function zeromatrix(u::Array{T}) where {T}
fill!(out, false)
end

"""
undefmatrix(u::AbstractVector)

Creates the matrix version of `u` with possibly undefined values. Note that this is unique because
`similar(u,length(u),length(u))` returns a mutable type, so it is not type-matching,
while `fill(zero(eltype(u)),length(u),length(u))` doesn't match the array type,
i.e., you'll get a CPU array from a GPU array. The generic fallback is
`u .* u'`, which works on a surprising number of types, but can be broken
with weird (recursive) broadcast overloads. For higher-order tensors, this
returns the matrix linear operator type which acts on the `vec` of the array.
"""
function undefmatrix(u)
similar(u, length(u), length(u))
end

"""
restructure(x,y)

Expand Down
19 changes: 17 additions & 2 deletions lib/ArrayInterfaceCore/test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using ArrayInterfaceCore
using ArrayInterfaceCore: zeromatrix
using ArrayInterfaceCore: zeromatrix, undefmatrix
import ArrayInterfaceCore: has_sparsestruct, findstructralnz, fast_scalar_indexing, lu_instance,
parent_type, zeromatrix, IndicesInfo
using Base: setindex
Expand All @@ -11,7 +11,22 @@ using Test
using Aqua
Aqua.test_all(ArrayInterfaceCore)

@test zeromatrix(rand(4,4,4)) == zeros(4*4*4,4*4*4)
@testset "zeromatrix and unsafematrix" begin
for T in (Int, Float32, Float64)
for (vectype, mattype) in ((Vector{T}, Matrix{T}), (SparseVector{T}, SparseMatrixCSC{T, Int}))
v = vectype(rand(T, 4))
um = undefmatrix(v)
@test size(um) == (length(v),length(v))
@test typeof(um) == mattype
@test zeromatrix(v) == zeros(T,length(v),length(v))
end
v = rand(T,4,4,4)
um = undefmatrix(v)
@test size(um) == (length(v),length(v))
@test typeof(um) == Matrix{T}
@test zeromatrix(v) == zeros(T,4*4*4,4*4*4)
end
end

@testset "matrix colors" begin
@test ArrayInterfaceCore.fast_matrix_colors(1) == false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,14 @@ import ArrayInterfaceStaticArraysCore

const CanonicalInt = Union{Int,StaticInt}

function ArrayInterface.undefmatrix(::MArray{S, T, N, L}) where {S, T, N, L}
return MMatrix{L, L, T, L*L}(undef)
end
# SArray doesn't have an undef constructor and is going to be small enough that this is fine.
function ArrayInterface.undefmatrix(s::SArray)
v = vec(s)
return v.*v'
end
Comment on lines +12 to +19
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These could probably be staticarrayscore?

ArrayInterface.known_first(::Type{<:StaticArrays.SOneTo}) = 1
ArrayInterface.known_last(::Type{StaticArrays.SOneTo{N}}) where {N} = N
ArrayInterface.known_length(::Type{StaticArrays.SOneTo{N}}) where {N} = N
Expand Down
2 changes: 1 addition & 1 deletion src/ArrayInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using ArrayInterfaceCore
import ArrayInterfaceCore: allowed_getindex, allowed_setindex!, aos_to_soa, buffer,
parent_type, fast_matrix_colors, findstructralnz, has_sparsestruct,
issingular, isstructured, matrix_colors, restructure, lu_instance,
safevec, zeromatrix, ColoringAlgorithm, fast_scalar_indexing, parameterless_type,
safevec, zeromatrix, undefmatrix, ColoringAlgorithm, fast_scalar_indexing, parameterless_type,
ndims_index, ndims_shape, is_splat_index, is_forwarding_wrapper, IndicesInfo, childdims,
parentdims, map_tuple_type, flatten_tuples, GetIndex, SetIndex!, defines_strides,
stride_preserving_index
Expand Down