Skip to content
4 changes: 2 additions & 2 deletions stdlib/LinearAlgebra/src/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -491,8 +491,8 @@ julia> reshape(kron(v,w), (length(w), length(v)))
```
"""
function kron(A::AbstractVecOrMat{T}, B::AbstractVecOrMat{S}) where {T,S}
R = Matrix{promote_op(*,T,S)}(undef, _kronsize(A, B))
return kron!(R, A, B)
C = Matrix{promote_op(*,T,S)}(undef, _kronsize(A, B))
return kron!(C, A, B)
end
function kron(a::AbstractVector{T}, b::AbstractVector{S}) where {T,S}
c = Vector{promote_op(*,T,S)}(undef, length(a)*length(b))
Expand Down
124 changes: 124 additions & 0 deletions stdlib/LinearAlgebra/src/symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,130 @@ for (T, trans, real) in [(:Symmetric, :transpose, :identity), (:(Hermitian{<:Uni
end
end

function kron(A::Hermitian{T}, B::Hermitian{S}) where {T<:Union{Real,Complex},S<:Union{Real,Complex}}
resultuplo = A.uplo == 'U' || B.uplo == 'U' ? :U : :L
C = Hermitian(Matrix{promote_op(*, T, S)}(undef, _kronsize(A, B)), resultuplo)
return kron!(C, A, B)
end

function kron(A::Symmetric{T}, B::Symmetric{S}) where {T<:Number,S<:Number}
resultuplo = A.uplo == 'U' || B.uplo == 'U' ? :U : :L
C = Symmetric(Matrix{promote_op(*, T, S)}(undef, _kronsize(A, B)), resultuplo)
return kron!(C, A, B)
end

function kron!(C::Hermitian{<:Union{Real,Complex}}, A::Hermitian{<:Union{Real,Complex}}, B::Hermitian{<:Union{Real,Complex}})
size(C) == _kronsize(A, B) || throw(DimensionMismatch("kron!"))
if ((A.uplo == 'U' || B.uplo == 'U') && C.uplo != 'U') || ((A.uplo == 'L' && B.uplo == 'L') && C.uplo != 'L')
throw(ArgumentError("C.uplo must match A.uplo and B.uplo, got $(C.uplo) $(A.uplo) $(B.uplo)"))
end
_hermkron!(C.data, A.data, B.data, conj, real, A.uplo, B.uplo)
return C
end

function kron!(C::Symmetric{<:Number}, A::Symmetric{<:Number}, B::Symmetric{<:Number})
size(C) == _kronsize(A, B) || throw(DimensionMismatch("kron!"))
if ((A.uplo == 'U' || B.uplo == 'U') && C.uplo != 'U') || ((A.uplo == 'L' && B.uplo == 'L') && C.uplo != 'L')
throw(ArgumentError("C.uplo must match A.uplo and B.uplo, got $(C.uplo) $(A.uplo) $(B.uplo)"))
end
_hermkron!(C.data, A.data, B.data, identity, identity, A.uplo, B.uplo)
return C
end

function _hermkron!(C, A, B, conj::TC, real::TR, Auplo, Buplo) where {TC,TR}
n_A = size(A, 1)
n_B = size(B, 1)
@inbounds if Auplo == 'U' && Buplo == 'U'
for j = 1:n_A
jnB = (j - 1) * n_B
for i = 1:(j-1)
Aij = A[i, j]
inB = (i - 1) * n_B
for l = 1:n_B
for k = 1:(l-1)
C[inB+k, jnB+l] = Aij * B[k, l]
C[inB+l, jnB+k] = Aij * conj(B[k, l])
end
C[inB+l, jnB+l] = Aij * real(B[l, l])
end
end
Ajj = real(A[j, j])
for l = 1:n_B
for k = 1:(l-1)
C[jnB+k, jnB+l] = Ajj * B[k, l]
end
C[jnB+l, jnB+l] = Ajj * real(B[l, l])
end
end
elseif Auplo == 'U' && Buplo == 'L'
for j = 1:n_A
jnB = (j - 1) * n_B
for i = 1:(j-1)
Aij = A[i, j]
inB = (i - 1) * n_B
for l = 1:n_B
C[inB+l, jnB+l] = Aij * real(B[l, l])
for k = (l+1):n_B
C[inB+l, jnB+k] = Aij * conj(B[k, l])
C[inB+k, jnB+l] = Aij * B[k, l]
end
end
end
Ajj = real(A[j, j])
for l = 1:n_B
C[jnB+l, jnB+l] = Ajj * real(B[l, l])
for k = (l+1):n_B
C[jnB+l, jnB+k] = Ajj * conj(B[k, l])
end
end
end
elseif Auplo == 'L' && Buplo == 'U'
for j = 1:n_A
jnB = (j - 1) * n_B
Ajj = real(A[j, j])
for l = 1:n_B
for k = 1:(l-1)
C[jnB+k, jnB+l] = Ajj * B[k, l]
end
C[jnB+l, jnB+l] = Ajj * real(B[l, l])
end
for i = (j+1):n_A
conjAij = conj(A[i, j])
inB = (i - 1) * n_B
for l = 1:n_B
for k = 1:(l-1)
C[jnB+k, inB+l] = conjAij * B[k, l]
C[jnB+l, inB+k] = conjAij * conj(B[k, l])
end
C[jnB+l, inB+l] = conjAij * real(B[l, l])
end
end
end
else #if Auplo == 'L' && Buplo == 'L'
for j = 1:n_A
jnB = (j - 1) * n_B
Ajj = real(A[j, j])
for l = 1:n_B
C[jnB+l, jnB+l] = Ajj * real(B[l, l])
for k = (l+1):n_B
C[jnB+k, jnB+l] = Ajj * B[k, l]
end
end
for i = (j+1):n_A
Aij = A[i, j]
inB = (i - 1) * n_B
for l = 1:n_B
C[inB+l, jnB+l] = Aij * real(B[l, l])
for k = (l+1):n_B
C[inB+k, jnB+l] = Aij * B[k, l]
C[inB+l, jnB+k] = Aij * conj(B[k, l])
end
end
end
end
end
end

(-)(A::Symmetric) = Symmetric(parentof_applytri(-, A), sym_uplo(A.uplo))
(-)(A::Hermitian) = Hermitian(parentof_applytri(-, A), sym_uplo(A.uplo))

Expand Down
74 changes: 74 additions & 0 deletions stdlib/LinearAlgebra/src/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -752,6 +752,80 @@ for op in (:+, :-)
end
end

function kron(A::UpperTriangular{T}, B::UpperTriangular{S}) where {T<:Number,S<:Number}
C = UpperTriangular(Matrix{promote_op(*, T, S)}(undef, _kronsize(A, B)))
return kron!(C, A, B)
end

function kron(A::LowerTriangular{T}, B::LowerTriangular{S}) where {T<:Number,S<:Number}
C = LowerTriangular(Matrix{promote_op(*, T, S)}(undef, _kronsize(A, B)))
return kron!(C, A, B)
end

function kron!(C::UpperTriangular{<:Number}, A::UpperTriangular{<:Number}, B::UpperTriangular{<:Number})
size(C) == _kronsize(A, B) || throw(DimensionMismatch("kron!"))
_triukron!(C.data, A.data, B.data)
return C
end

function kron!(C::LowerTriangular{<:Number}, A::LowerTriangular{<:Number}, B::LowerTriangular{<:Number})
size(C) == _kronsize(A, B) || throw(DimensionMismatch("kron!"))
_trilkron!(C.data, A.data, B.data)
return C
end

function _triukron!(C, A, B)
n_A = size(A, 1)
n_B = size(B, 1)
@inbounds for j = 1:n_A
jnB = (j - 1) * n_B
for i = 1:(j-1)
Aij = A[i, j]
inB = (i - 1) * n_B
for l = 1:n_B
for k = 1:l
C[inB+k, jnB+l] = Aij * B[k, l]
end
for k = 1:(l-1)
C[inB+l, jnB+k] = zero(eltype(C))
end
end
end
Ajj = A[j, j]
for l = 1:n_B
for k = 1:l
C[jnB+k, jnB+l] = Ajj * B[k, l]
end
end
end
end

function _trilkron!(C, A, B)
n_A = size(A, 1)
n_B = size(B, 1)
@inbounds for j = 1:n_A
jnB = (j - 1) * n_B
Ajj = A[j, j]
for l = 1:n_B
for k = l:n_B
C[jnB+k, jnB+l] = Ajj * B[k, l]
end
end
for i = (j+1):n_A
Aij = A[i, j]
inB = (i - 1) * n_B
for l = 1:n_B
for k = l:n_B
C[inB+k, jnB+l] = Aij * B[k, l]
end
for k = (l+1):n_B
C[inB+l, jnB+k] = zero(eltype(C))
end
end
end
end
end

######################
# BlasFloat routines #
######################
Expand Down
35 changes: 35 additions & 0 deletions stdlib/LinearAlgebra/test/symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,28 @@ end
@test dot(symblockml, symblockml) ≈ dot(msymblockml, msymblockml)
end
end

@testset "kronecker product of symmetric and Hermitian matrices" begin
for mtype in (Symmetric, Hermitian)
symau = mtype(a, :U)
symal = mtype(a, :L)
msymau = Matrix(symau)
msymal = Matrix(symal)
for eltyc in (Float32, Float64, ComplexF32, ComplexF64, BigFloat, Int)
creal = randn(n, n)/2
cimag = randn(n, n)/2
c = eltya == Int ? rand(1:7, n, n) : convert(Matrix{eltya}, eltya <: Complex ? complex.(creal, cimag) : creal)
symcu = mtype(c, :U)
symcl = mtype(c, :L)
msymcu = Matrix(symcu)
msymcl = Matrix(symcl)
@test kron(symau, symcu) ≈ kron(msymau, msymcu)
@test kron(symau, symcl) ≈ kron(msymau, msymcl)
@test kron(symal, symcu) ≈ kron(msymal, msymcu)
@test kron(symal, symcl) ≈ kron(msymal, msymcl)
end
end
end
end
end

Expand All @@ -487,6 +509,7 @@ end
@test S - S == MS - MS
@test S*2 == 2*S == 2*MS
@test S/2 == MS/2
@test kron(S,S) == kron(MS,MS)
end
@testset "mixed uplo" begin
Mu = Matrix{Complex{BigFloat}}(undef,2,2)
Expand All @@ -502,6 +525,8 @@ end
MSl = Matrix(Sl)
@test Su + Sl == Sl + Su == MSu + MSl
@test Su - Sl == -(Sl - Su) == MSu - MSl
@test kron(Su,Sl) == kron(MSu,MSl)
@test kron(Sl,Su) == kron(MSl,MSu)
end
end
end
Expand All @@ -517,6 +542,16 @@ end
@test dot(A, B) ≈ dot(Symmetric(A), Symmetric(B))
end

# let's make sure the analogous bug will not show up with kronecker products
@testset "kron Hermitian quaternion #52318" begin
A, B = [Quaternion.(randn(3,3), randn(3, 3), randn(3, 3), randn(3,3)) |> t -> t + t' for i in 1:2]
@test A == Hermitian(A) && B == Hermitian(B)
@test kron(A, B) ≈ kron(Hermitian(A), Hermitian(B))
A, B = [Quaternion.(randn(3,3), randn(3, 3), randn(3, 3), randn(3,3)) |> t -> t + transpose(t) for i in 1:2]
@test A == Symmetric(A) && B == Symmetric(B)
@test kron(A, B) ≈ kron(Symmetric(A), Symmetric(B))
end

#Issue #7647: test xsyevr, xheevr, xstevr drivers.
@testset "Eigenvalues in interval for $(typeof(Mi7647))" for Mi7647 in
(Symmetric(diagm(0 => 1.0:3.0)),
Expand Down
3 changes: 3 additions & 0 deletions stdlib/LinearAlgebra/test/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,7 @@ debug && println("Test basic type functionality")
# Binary operations
@test A1 + A2 == M1 + M2
@test A1 - A2 == M1 - M2
@test kron(A1,A2) == kron(M1,M2)

# Triangular-Triangular multiplication and division
@test A1*A2 ≈ M1*M2
Expand Down Expand Up @@ -1014,6 +1015,7 @@ end
@test 2\L == 2\B
@test real(L) == real(B)
@test imag(L) == imag(B)
@test kron(L,L) == kron(B,B)
@test transpose!(MT(copy(A))) == transpose(L) broken=!(A isa Matrix)
@test adjoint!(MT(copy(A))) == adjoint(L) broken=!(A isa Matrix)
end
Expand All @@ -1035,6 +1037,7 @@ end
@test 2\U == 2\B
@test real(U) == real(B)
@test imag(U) == imag(B)
@test kron(U,U) == kron(B,B)
@test transpose!(MT(copy(A))) == transpose(U) broken=!(A isa Matrix)
@test adjoint!(MT(copy(A))) == adjoint(U) broken=!(A isa Matrix)
end
Expand Down