Skip to content

Commit 786ba2f

Browse files
committed
Fix multiplying a triangular matrix and a Diagonal
1 parent cba1cc0 commit 786ba2f

File tree

4 files changed

+190
-64
lines changed

4 files changed

+190
-64
lines changed

stdlib/LinearAlgebra/src/LinearAlgebra.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -655,6 +655,8 @@ matprod_dest(A::StructuredMatrix, B::Diagonal, TS) = _matprod_dest_diag(A, TS)
655655
matprod_dest(A::Diagonal, B::StructuredMatrix, TS) = _matprod_dest_diag(B, TS)
656656
matprod_dest(A::Diagonal, B::Diagonal, TS) = _matprod_dest_diag(B, TS)
657657
_matprod_dest_diag(A, TS) = similar(A, TS)
658+
_matprod_dest_diag(A::UnitUpperTriangular, TS) = UpperTriangular(similar(parent(A), TS))
659+
_matprod_dest_diag(A::UnitLowerTriangular, TS) = LowerTriangular(similar(parent(A), TS))
658660
function _matprod_dest_diag(A::SymTridiagonal, TS)
659661
n = size(A, 1)
660662
ev = similar(A, TS, max(0, n-1))

stdlib/LinearAlgebra/src/diagonal.jl

Lines changed: 127 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -396,82 +396,156 @@ function lmul!(D::Diagonal, T::Tridiagonal)
396396
return T
397397
end
398398

399-
function __muldiag!(out, D::Diagonal, B, _add::MulAddMul{ais1,bis0}) where {ais1,bis0}
399+
@inline function __muldiag_nonzeroalpha!(out, D::Diagonal, B, _add::MulAddMul)
400+
@inbounds for j in axes(B, 2)
401+
@simd for i in axes(B, 1)
402+
_modify!(_add, D.diag[i] * B[i,j], out, (i,j))
403+
end
404+
end
405+
out
406+
end
407+
_has_matching_storage(out::UpperOrUnitUpperTriangular, A::UpperOrUnitUpperTriangular) = true
408+
_has_matching_storage(out::LowerOrUnitLowerTriangular, A::LowerOrUnitLowerTriangular) = true
409+
_has_matching_storage(out, A) = false
410+
function _rowrange_tri_stored(B::UpperOrUnitUpperTriangular, col)
411+
isunit = B isa UnitUpperTriangular
412+
1:min(col-isunit, size(B,1))
413+
end
414+
function _rowrange_tri_stored(B::LowerOrUnitLowerTriangular, col)
415+
isunit = B isa UnitLowerTriangular
416+
col+isunit:size(B,1)
417+
end
418+
_rowrange_tri_nonstored(B::UpperOrUnitUpperTriangular, col) = col+1:size(B,1)
419+
_rowrange_tri_nonstored(B::LowerOrUnitLowerTriangular, col) = 1:col-1
420+
function __muldiag_nonzeroalpha!(out, D::Diagonal, B::UpperOrLowerTriangular, _add::MulAddMul)
421+
isunit = B isa UnitUpperOrUnitLowerTriangular
422+
out_maybeparent, B_maybeparent = _has_matching_storage(out, B) ? (parent(out), parent(B)) : (out, B)
423+
for j in axes(B, 2)
424+
# store the diagonal separately for unit triangular matrices
425+
if isunit
426+
@inbounds _modify!(_add, D.diag[j] * B[j,j], out, (j,j))
427+
end
428+
# indices of out corresponding to the stored indices of B
429+
rowrange = _rowrange_tri_stored(B, j)
430+
@inbounds @simd for i in rowrange
431+
_modify!(_add, D.diag[i] * B_maybeparent[i,j], out_maybeparent, (i,j))
432+
end
433+
# indices of out corresponding to the zeros of B
434+
# we only fill these if out and B don't have matching zeros
435+
if !_has_matching_storage(out, B)
436+
rowrange = _rowrange_tri_nonstored(B, j)
437+
if haszero(eltype(out))
438+
_rmul_or_fill!(@view(out[rowrange,j]), _add.beta)
439+
else
440+
@inbounds @simd for i in rowrange
441+
_modify!(_add, D.diag[i] * B[i,j], out, (i,j))
442+
end
443+
end
444+
end
445+
end
446+
out
447+
end
448+
function __muldiag!(out, D::Diagonal, B, _add::MulAddMul)
400449
require_one_based_indexing(out, B)
401450
alpha, beta = _add.alpha, _add.beta
402451
if iszero(alpha)
403452
_rmul_or_fill!(out, beta)
404453
else
405-
if bis0
406-
@inbounds for j in axes(B, 2)
407-
@simd for i in axes(B, 1)
408-
out[i,j] = D.diag[i] * B[i,j] * alpha
409-
end
410-
end
411-
else
412-
@inbounds for j in axes(B, 2)
413-
@simd for i in axes(B, 1)
414-
out[i,j] = D.diag[i] * B[i,j] * alpha + out[i,j] * beta
454+
__muldiag_nonzeroalpha!(out, D, B, _add)
455+
end
456+
return out
457+
end
458+
459+
@inline function __muldiag_nonzeroalpha!(out, A, D::Diagonal, _add::MulAddMul{ais1,bis0}) where {ais1,bis0}
460+
beta = _add.beta
461+
_add_aisone = MulAddMul{true,bis0,Bool,typeof(beta)}(true, beta)
462+
@inbounds for j in axes(A, 2)
463+
dja = _add(D.diag[j])
464+
@simd for i in axes(A, 1)
465+
_modify!(_add_aisone, A[i,j] * dja, out, (i,j))
466+
end
467+
end
468+
out
469+
end
470+
function __muldiag_nonzeroalpha!(out, A::UpperOrLowerTriangular, D::Diagonal, _add::MulAddMul{ais1,bis0}) where {ais1,bis0}
471+
isunit = A isa UnitUpperOrUnitLowerTriangular
472+
beta = _add.beta
473+
# since alpha is multiplied to the diagonal element of D,
474+
# we may skip alpha in the second multiplication by setting ais1 to true
475+
_add_aisone = MulAddMul{true,bis0,Bool,typeof(beta)}(true, beta)
476+
# if both A and out have the same upper/lower triangular structure,
477+
# we may directly read and write from the parents
478+
out_maybeparent, A_maybeparent = _has_matching_storage(out, A) ? (parent(out), parent(A)) : (out, A)
479+
for j in axes(A, 2)
480+
dja = _add(@inbounds D.diag[j])
481+
# store the diagonal separately for unit triangular matrices
482+
if isunit
483+
@inbounds _modify!(_add_aisone, A[j,j] * dja, out, (j,j))
484+
end
485+
# indices of out corresponding to the stored indices of A
486+
rowrange = _rowrange_tri_stored(A, j)
487+
@inbounds @simd for i in rowrange
488+
_modify!(_add_aisone, A_maybeparent[i,j] * dja, out_maybeparent, (i,j))
489+
end
490+
# indices of out corresponding to the zeros of A
491+
# we only fill these if out and A don't have matching zeros
492+
if !_has_matching_storage(out, A)
493+
rowrange = _rowrange_tri_nonstored(A, j)
494+
if haszero(eltype(out))
495+
_rmul_or_fill!(@view(out[rowrange,j]), _add.beta)
496+
else
497+
@inbounds @simd for i in rowrange
498+
_modify!(_add, A[i,j] * dja, out, (i,j))
415499
end
416500
end
417501
end
418502
end
419-
return out
503+
out
420504
end
421-
function __muldiag!(out, A, D::Diagonal, _add::MulAddMul{ais1,bis0}) where {ais1,bis0}
505+
function __muldiag!(out, A, D::Diagonal, _add::MulAddMul)
422506
require_one_based_indexing(out, A)
423507
alpha, beta = _add.alpha, _add.beta
424508
if iszero(alpha)
425509
_rmul_or_fill!(out, beta)
426510
else
427-
if bis0
428-
@inbounds for j in axes(A, 2)
429-
dja = D.diag[j] * alpha
430-
@simd for i in axes(A, 1)
431-
out[i,j] = A[i,j] * dja
432-
end
433-
end
434-
else
435-
@inbounds for j in axes(A, 2)
436-
dja = D.diag[j] * alpha
437-
@simd for i in axes(A, 1)
438-
out[i,j] = A[i,j] * dja + out[i,j] * beta
439-
end
440-
end
441-
end
511+
__muldiag_nonzeroalpha!(out, A, D, _add)
442512
end
443513
return out
444514
end
445-
function __muldiag!(out::Diagonal, D1::Diagonal, D2::Diagonal, _add::MulAddMul{ais1,bis0}) where {ais1,bis0}
515+
516+
@inline function __muldiag_nonzeroalpha!(out::Diagonal, D1::Diagonal, D2::Diagonal, _add::MulAddMul)
446517
d1 = D1.diag
447518
d2 = D2.diag
519+
outd = out.diag
520+
@inbounds @simd for i in eachindex(d1, d2, outd)
521+
_modify!(_add, d1[i] * d2[i], outd, i)
522+
end
523+
out
524+
end
525+
function __muldiag!(out::Diagonal, D1::Diagonal, D2::Diagonal, _add::MulAddMul)
448526
alpha, beta = _add.alpha, _add.beta
449527
if iszero(alpha)
450528
_rmul_or_fill!(out.diag, beta)
451529
else
452-
if bis0
453-
@inbounds @simd for i in eachindex(out.diag)
454-
out.diag[i] = d1[i] * d2[i] * alpha
455-
end
456-
else
457-
@inbounds @simd for i in eachindex(out.diag)
458-
out.diag[i] = d1[i] * d2[i] * alpha + out.diag[i] * beta
459-
end
460-
end
530+
__muldiag_nonzeroalpha!(out, D1, D2, _add)
461531
end
462532
return out
463533
end
464-
function __muldiag!(out, D1::Diagonal, D2::Diagonal, _add::MulAddMul{ais1,bis0}) where {ais1,bis0}
465-
require_one_based_indexing(out)
466-
alpha, beta = _add.alpha, _add.beta
467-
mA = size(D1, 1)
534+
@inline function __muldiag_nonzeroalpha!(out, D1::Diagonal, D2::Diagonal, _add::MulAddMul)
468535
d1 = D1.diag
469536
d2 = D2.diag
537+
@inbounds @simd for i in eachindex(d1, d2)
538+
_modify!(_add, d1[i] * d2[i], out, (i,i))
539+
end
540+
out
541+
end
542+
function __muldiag!(out, D1::Diagonal, D2::Diagonal, _add::MulAddMul{ais1}) where {ais1}
543+
require_one_based_indexing(out)
544+
alpha, beta = _add.alpha, _add.beta
470545
_rmul_or_fill!(out, beta)
471546
if !iszero(alpha)
472-
@inbounds @simd for i in 1:mA
473-
out[i,i] += d1[i] * d2[i] * alpha
474-
end
547+
_add_bis1 = MulAddMul{ais1,false,typeof(alpha),Bool}(alpha,true)
548+
__muldiag_nonzeroalpha!(out, D1, D2, _add_bis1)
475549
end
476550
return out
477551
end
@@ -658,31 +732,21 @@ for Tri in (:UpperTriangular, :LowerTriangular)
658732
@eval $fun(A::$Tri, D::Diagonal) = $Tri($fun(A.data, D))
659733
@eval $fun(A::$UTri, D::Diagonal) = $Tri(_setdiag!($fun(A.data, D), $f, D.diag))
660734
end
735+
@eval *(A::$Tri{<:Any, <:StridedMaybeAdjOrTransMat}, D::Diagonal) =
736+
@invoke *(A::AbstractMatrix, D::Diagonal)
737+
@eval *(A::$UTri{<:Any, <:StridedMaybeAdjOrTransMat}, D::Diagonal) =
738+
@invoke *(A::AbstractMatrix, D::Diagonal)
661739
for (fun, f) in zip((:*, :lmul!, :ldiv!, :\), (:identity, :identity, :inv, :inv))
662740
@eval $fun(D::Diagonal, A::$Tri) = $Tri($fun(D, A.data))
663741
@eval $fun(D::Diagonal, A::$UTri) = $Tri(_setdiag!($fun(D, A.data), $f, D.diag))
664742
end
743+
@eval *(D::Diagonal, A::$Tri{<:Any, <:StridedMaybeAdjOrTransMat}) =
744+
@invoke *(D::Diagonal, A::AbstractMatrix)
745+
@eval *(D::Diagonal, A::$UTri{<:Any, <:StridedMaybeAdjOrTransMat}) =
746+
@invoke *(D::Diagonal, A::AbstractMatrix)
665747
# 3-arg ldiv!
666748
@eval ldiv!(C::$Tri, D::Diagonal, A::$Tri) = $Tri(ldiv!(C.data, D, A.data))
667749
@eval ldiv!(C::$Tri, D::Diagonal, A::$UTri) = $Tri(_setdiag!(ldiv!(C.data, D, A.data), inv, D.diag))
668-
# 3-arg mul! is disambiguated in special.jl
669-
# 5-arg mul!
670-
@eval _mul!(C::$Tri, D::Diagonal, A::$Tri, _add) = $Tri(mul!(C.data, D, A.data, _add.alpha, _add.beta))
671-
@eval function _mul!(C::$Tri, D::Diagonal, A::$UTri, _add::MulAddMul{ais1,bis0}) where {ais1,bis0}
672-
α, β = _add.alpha, _add.beta
673-
iszero(α) && return _rmul_or_fill!(C, β)
674-
diag′ = bis0 ? nothing : diag(C)
675-
data = mul!(C.data, D, A.data, α, β)
676-
$Tri(_setdiag!(data, _add, D.diag, diag′))
677-
end
678-
@eval _mul!(C::$Tri, A::$Tri, D::Diagonal, _add) = $Tri(mul!(C.data, A.data, D, _add.alpha, _add.beta))
679-
@eval function _mul!(C::$Tri, A::$UTri, D::Diagonal, _add::MulAddMul{ais1,bis0}) where {ais1,bis0}
680-
α, β = _add.alpha, _add.beta
681-
iszero(α) && return _rmul_or_fill!(C, β)
682-
diag′ = bis0 ? nothing : diag(C)
683-
data = mul!(C.data, A.data, D, α, β)
684-
$Tri(_setdiag!(data, _add, D.diag, diag′))
685-
end
686750
end
687751

688752
@inline function kron!(C::AbstractMatrix, A::Diagonal, B::Diagonal)

stdlib/LinearAlgebra/test/addmul.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,4 +239,26 @@ end
239239
end
240240
end
241241

242+
@testset "Diagonal scaling of a triangular matrix with a non-triangular destination" begin
243+
for MT in (UpperTriangular, UnitUpperTriangular, LowerTriangular, UnitLowerTriangular)
244+
U = MT(reshape([1:9;],3,3))
245+
M = Array(U)
246+
D = Diagonal(1:3)
247+
A = reshape([1:9;],3,3)
248+
@test mul!(copy(A), U, D, 2, 3) == M * D * 2 + A * 3
249+
@test mul!(copy(A), D, U, 2, 3) == D * M * 2 + A * 3
250+
251+
# nan values with iszero(alpha)
252+
D = Diagonal(fill(NaN,3))
253+
@test mul!(copy(A), U, D, 0, 3) == A * 3
254+
@test mul!(copy(A), D, U, 0, 3) == A * 3
255+
256+
# nan values with iszero(beta)
257+
A = fill(NaN,3,3)
258+
D = Diagonal(1:3)
259+
@test mul!(copy(A), U, D, 2, 0) == M * D * 2
260+
@test mul!(copy(A), D, U, 2, 0) == D * M * 2
261+
end
262+
end
263+
242264
end # module

stdlib/LinearAlgebra/test/diagonal.jl

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1188,7 +1188,7 @@ end
11881188
@test oneunit(D3) isa typeof(D3)
11891189
end
11901190

1191-
@testset "AbstractTriangular" for (Tri, UTri) in ((UpperTriangular, UnitUpperTriangular), (LowerTriangular, UnitLowerTriangular))
1191+
@testset "$Tri" for (Tri, UTri) in ((UpperTriangular, UnitUpperTriangular), (LowerTriangular, UnitLowerTriangular))
11921192
A = randn(4, 4)
11931193
TriA = Tri(A)
11941194
UTriA = UTri(A)
@@ -1218,6 +1218,44 @@ end
12181218
@test outTri === mul!(outTri, D, UTriA, 2, 1)::Tri == mul!(out, D, Matrix(UTriA), 2, 1)
12191219
@test outTri === mul!(outTri, TriA, D, 2, 1)::Tri == mul!(out, Matrix(TriA), D, 2, 1)
12201220
@test outTri === mul!(outTri, UTriA, D, 2, 1)::Tri == mul!(out, Matrix(UTriA), D, 2, 1)
1221+
1222+
# we may write to a Unit triangular if the diagonal is preserved
1223+
ID = Diagonal(ones(size(UTriA,2)))
1224+
@test mul!(copy(UTriA), UTriA, ID) == UTriA
1225+
@test mul!(copy(UTriA), ID, UTriA) == UTriA
1226+
1227+
@testset "partly filled parents" begin
1228+
M = Matrix{BigFloat}(undef, 2, 2)
1229+
M[1,1] = M[2,2] = 3
1230+
isupper = Tri == UpperTriangular
1231+
M[1+!isupper, 1+isupper] = 3
1232+
D = Diagonal(1:2)
1233+
T = Tri(M)
1234+
TA = Array(T)
1235+
@test T * D == TA * D
1236+
@test D * T == D * TA
1237+
@test mul!(copy(T), T, D, 2, 3) == 2T * D + 3T
1238+
@test mul!(copy(T), D, T, 2, 3) == 2D * T + 3T
1239+
1240+
U = UTri(M)
1241+
UA = Array(U)
1242+
@test U * D == UA * D
1243+
@test D * U == D * UA
1244+
@test mul!(copy(T), U, D, 2, 3) == 2 * UA * D + 3TA
1245+
@test mul!(copy(T), D, U, 2, 3) == 2 * D * UA + 3TA
1246+
1247+
M2 = Matrix{BigFloat}(undef, 2, 2)
1248+
M2[1+!isupper, 1+isupper] = 3
1249+
U = UTri(M2)
1250+
UA = Array(U)
1251+
@test U * D == UA * D
1252+
@test D * U == D * UA
1253+
ID = Diagonal(ones(size(U,2)))
1254+
@test mul!(copy(U), U, ID) == U
1255+
@test mul!(copy(U), ID, U) == U
1256+
@test mul!(copy(U), U, ID, 2, -1) == U
1257+
@test mul!(copy(U), ID, U, 2, -1) == U
1258+
end
12211259
end
12221260

12231261
struct SMatrix1{T} <: AbstractArray{T,2}

0 commit comments

Comments
 (0)