diff --git a/Project.toml b/Project.toml index ee1afdfc..8065d156 100644 --- a/Project.toml +++ b/Project.toml @@ -15,7 +15,7 @@ ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" -StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" [extensions] @@ -27,7 +27,7 @@ ArrayInterfaceChainRulesExt = "ChainRules" ArrayInterfaceGPUArraysCoreExt = "GPUArraysCore" ArrayInterfaceReverseDiffExt = "ReverseDiff" ArrayInterfaceSparseArraysExt = "SparseArrays" -ArrayInterfaceStaticArraysCoreExt = "StaticArraysCore" +ArrayInterfaceStaticArraysExt = "StaticArrays" ArrayInterfaceTrackerExt = "Tracker" [compat] @@ -66,4 +66,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" [targets] -test = ["SafeTestsets", "Pkg", "Test", "Aqua", "Random", "SparseArrays", "SuiteSparse", "BandedMatrices", "BlockBandedMatrices", "GPUArraysCore", "StaticArrays", "StaticArraysCore", "Tracker", "ReverseDiff", "ChainRules", "FillArrays", "ComponentArrays"] +test = ["SafeTestsets", "Pkg", "Test", "Aqua", "Random", "SparseArrays", "SuiteSparse", "BandedMatrices", "BlockBandedMatrices", "GPUArraysCore", "StaticArrays", "Tracker", "ReverseDiff", "ChainRules", "FillArrays", "ComponentArrays"] diff --git a/ext/ArrayInterfaceStaticArraysCoreExt.jl b/ext/ArrayInterfaceStaticArraysCoreExt.jl deleted file mode 100644 index acf06adf..00000000 --- a/ext/ArrayInterfaceStaticArraysCoreExt.jl +++ /dev/null @@ -1,30 +0,0 @@ -module ArrayInterfaceStaticArraysCoreExt - -import ArrayInterface -using LinearAlgebra -import StaticArraysCore - -function ArrayInterface.undefmatrix(::StaticArraysCore.MArray{S, T, N, L}) where {S, T, N, L} - return StaticArraysCore.MMatrix{L, L, T, L*L}(undef) -end -# SArray doesn't have an undef constructor and is going to be small enough that this is fine. -function ArrayInterface.undefmatrix(s::StaticArraysCore.SArray) - v = vec(s) - return v.*v' -end - -ArrayInterface.ismutable(::Type{<:StaticArraysCore.StaticArray}) = false -ArrayInterface.ismutable(::Type{<:StaticArraysCore.MArray}) = true -ArrayInterface.ismutable(::Type{<:StaticArraysCore.SizedArray}) = true - -ArrayInterface.can_setindex(::Type{<:StaticArraysCore.StaticArray}) = false -ArrayInterface.can_setindex(::Type{<:StaticArraysCore.MArray}) = true -ArrayInterface.buffer(A::Union{StaticArraysCore.SArray,StaticArraysCore.MArray}) = getfield(A, :data) - -function ArrayInterface.lu_instance(_A::StaticArraysCore.StaticMatrix{N,N}) where {N} - lu(one(_A)) -end - -ArrayInterface.restructure(x::StaticArraysCore.SArray{S}, y) where {S} = StaticArraysCore.SArray{S}(y) - -end diff --git a/ext/ArrayInterfaceStaticArraysExt.jl b/ext/ArrayInterfaceStaticArraysExt.jl new file mode 100644 index 00000000..80e8e4a7 --- /dev/null +++ b/ext/ArrayInterfaceStaticArraysExt.jl @@ -0,0 +1,34 @@ +module ArrayInterfaceStaticArraysExt + +import ArrayInterface +using LinearAlgebra +import StaticArrays: SArray, SMatrix, SVector, StaticMatrix, StaticArray, SizedArray, MArray, MMatrix, LU + +function ArrayInterface.undefmatrix(::MArray{S, T, N, L}) where {S, T, N, L} + return MMatrix{L, L, T, L*L}(undef) +end +# SArray doesn't have an undef constructor and is going to be small enough that this is fine. +function ArrayInterface.undefmatrix(s::SArray) + v = vec(s) + return v.*v' +end + +ArrayInterface.ismutable(::Type{<:StaticArray}) = false +ArrayInterface.ismutable(::Type{<:MArray}) = true +ArrayInterface.ismutable(::Type{<:SizedArray}) = true + +ArrayInterface.can_setindex(::Type{<:StaticArray}) = false +ArrayInterface.can_setindex(::Type{<:MArray}) = true +ArrayInterface.buffer(A::Union{SArray, MArray}) = getfield(A, :data) + +function ArrayInterface.lu_instance(A::SMatrix{N,N}) where {N} + LU(LowerTriangular(A), UpperTriangular(A), SVector{N}(1:N)) +end + +function ArrayInterface.lu_instance(A::StaticMatrix{N,N}) where {N} + lu(one(A)) +end + +ArrayInterface.restructure(x::SArray{S}, y) where {S} = SArray{S}(y) + +end diff --git a/test/runtests.jl b/test/runtests.jl index 7999dbec..8a5d7b36 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -14,7 +14,7 @@ end @time @safetestset "BlockBandedMatrices" begin include("blockbandedmatrices.jl") end @time @safetestset "Core" begin include("core.jl") end @time @safetestset "AD Integration" begin include("ad.jl") end - @time @safetestset "StaticArraysCore" begin include("staticarrayscore.jl") end + @time @safetestset "StaticArrays" begin include("staticarrays.jl") end @time @safetestset "ChainRules" begin include("chainrules.jl") end end diff --git a/test/staticarrayscore.jl b/test/staticarrays.jl similarity index 100% rename from test/staticarrayscore.jl rename to test/staticarrays.jl