Skip to content

Commit 710e1f9

Browse files
authored
Reland "Reroute (Upper/Lower)Triangular * Diagonal through __muldiag #55984" (#56270)
This relands #55984 which was reverted in #56267. Previously, in #55984, the destination in multiplying triangular matrices with diagonals was also assumed to be triangular, which is not necessarily the case in `mul!`. Tests for this case, however, were being run non-deterministically, so this wasn't caught by the CI runs. This improves performance: ```julia julia> U = UpperTriangular(rand(100,100)); D = Diagonal(rand(size(U,2))); C = similar(U); julia> @Btime mul!($C, $D, $U); 1.517 μs (0 allocations: 0 bytes) # nightly 1.116 μs (0 allocations: 0 bytes) # This PR ```
1 parent 4c076c8 commit 710e1f9

File tree

4 files changed

+184
-84
lines changed

4 files changed

+184
-84
lines changed

stdlib/LinearAlgebra/src/LinearAlgebra.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -656,6 +656,8 @@ matprod_dest(A::StructuredMatrix, B::Diagonal, TS) = _matprod_dest_diag(A, TS)
656656
matprod_dest(A::Diagonal, B::StructuredMatrix, TS) = _matprod_dest_diag(B, TS)
657657
matprod_dest(A::Diagonal, B::Diagonal, TS) = _matprod_dest_diag(B, TS)
658658
_matprod_dest_diag(A, TS) = similar(A, TS)
659+
_matprod_dest_diag(A::UnitUpperTriangular, TS) = UpperTriangular(similar(parent(A), TS))
660+
_matprod_dest_diag(A::UnitLowerTriangular, TS) = LowerTriangular(similar(parent(A), TS))
659661
function _matprod_dest_diag(A::SymTridiagonal, TS)
660662
n = size(A, 1)
661663
ev = similar(A, TS, max(0, n-1))

stdlib/LinearAlgebra/src/diagonal.jl

Lines changed: 108 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -397,89 +397,124 @@ function lmul!(D::Diagonal, T::Tridiagonal)
397397
return T
398398
end
399399

400-
function __muldiag!(out, D::Diagonal, B, _add::MulAddMul{ais1,bis0}) where {ais1,bis0}
401-
require_one_based_indexing(out, B)
402-
alpha, beta = _add.alpha, _add.beta
403-
if iszero(alpha)
404-
_rmul_or_fill!(out, beta)
405-
else
406-
if bis0
407-
@inbounds for j in axes(B, 2)
408-
@simd for i in axes(B, 1)
409-
out[i,j] = D.diag[i] * B[i,j] * alpha
410-
end
411-
end
412-
else
413-
@inbounds for j in axes(B, 2)
414-
@simd for i in axes(B, 1)
415-
out[i,j] = D.diag[i] * B[i,j] * alpha + out[i,j] * beta
416-
end
417-
end
400+
@inline function __muldiag_nonzeroalpha!(out, D::Diagonal, B, _add::MulAddMul)
401+
@inbounds for j in axes(B, 2)
402+
@simd for i in axes(B, 1)
403+
_modify!(_add, D.diag[i] * B[i,j], out, (i,j))
418404
end
419405
end
420-
return out
421-
end
422-
function __muldiag!(out, A, D::Diagonal, _add::MulAddMul{ais1,bis0}) where {ais1,bis0}
423-
require_one_based_indexing(out, A)
424-
alpha, beta = _add.alpha, _add.beta
425-
if iszero(alpha)
426-
_rmul_or_fill!(out, beta)
427-
else
428-
if bis0
429-
@inbounds for j in axes(A, 2)
430-
dja = D.diag[j] * alpha
431-
@simd for i in axes(A, 1)
432-
out[i,j] = A[i,j] * dja
433-
end
434-
end
435-
else
436-
@inbounds for j in axes(A, 2)
437-
dja = D.diag[j] * alpha
438-
@simd for i in axes(A, 1)
439-
out[i,j] = A[i,j] * dja + out[i,j] * beta
440-
end
406+
out
407+
end
408+
_has_matching_zeros(out::UpperOrUnitUpperTriangular, A::UpperOrUnitUpperTriangular) = true
409+
_has_matching_zeros(out::LowerOrUnitLowerTriangular, A::LowerOrUnitLowerTriangular) = true
410+
_has_matching_zeros(out, A) = false
411+
function _rowrange_tri_stored(B::UpperOrUnitUpperTriangular, col)
412+
isunit = B isa UnitUpperTriangular
413+
1:min(col-isunit, size(B,1))
414+
end
415+
function _rowrange_tri_stored(B::LowerOrUnitLowerTriangular, col)
416+
isunit = B isa UnitLowerTriangular
417+
col+isunit:size(B,1)
418+
end
419+
_rowrange_tri_zeros(B::UpperOrUnitUpperTriangular, col) = col+1:size(B,1)
420+
_rowrange_tri_zeros(B::LowerOrUnitLowerTriangular, col) = 1:col-1
421+
function __muldiag_nonzeroalpha!(out, D::Diagonal, B::UpperOrLowerTriangular, _add::MulAddMul)
422+
isunit = B isa UnitUpperOrUnitLowerTriangular
423+
out_maybeparent, B_maybeparent = _has_matching_zeros(out, B) ? (parent(out), parent(B)) : (out, B)
424+
for j in axes(B, 2)
425+
# store the diagonal separately for unit triangular matrices
426+
if isunit
427+
@inbounds _modify!(_add, D.diag[j] * B[j,j], out, (j,j))
428+
end
429+
# The indices of out corresponding to the stored indices of B
430+
rowrange = _rowrange_tri_stored(B, j)
431+
@inbounds @simd for i in rowrange
432+
_modify!(_add, D.diag[i] * B_maybeparent[i,j], out_maybeparent, (i,j))
433+
end
434+
# Fill the indices of out corresponding to the zeros of B
435+
# we only fill these if out and B don't have matching zeros
436+
if !_has_matching_zeros(out, B)
437+
rowrange = _rowrange_tri_zeros(B, j)
438+
@inbounds @simd for i in rowrange
439+
_modify!(_add, D.diag[i] * B[i,j], out, (i,j))
441440
end
442441
end
443442
end
444443
return out
445444
end
446-
function __muldiag!(out::Diagonal, D1::Diagonal, D2::Diagonal, _add::MulAddMul{ais1,bis0}) where {ais1,bis0}
447-
d1 = D1.diag
448-
d2 = D2.diag
449-
alpha, beta = _add.alpha, _add.beta
450-
if iszero(alpha)
451-
_rmul_or_fill!(out.diag, beta)
452-
else
453-
if bis0
454-
@inbounds @simd for i in eachindex(out.diag)
455-
out.diag[i] = d1[i] * d2[i] * alpha
456-
end
457-
else
458-
@inbounds @simd for i in eachindex(out.diag)
459-
out.diag[i] = d1[i] * d2[i] * alpha + out.diag[i] * beta
445+
446+
@inline function __muldiag_nonzeroalpha!(out, A, D::Diagonal, _add::MulAddMul{ais1,bis0}) where {ais1,bis0}
447+
beta = _add.beta
448+
_add_aisone = MulAddMul{true,bis0,Bool,typeof(beta)}(true, beta)
449+
@inbounds for j in axes(A, 2)
450+
dja = _add(D.diag[j])
451+
@simd for i in axes(A, 1)
452+
_modify!(_add_aisone, A[i,j] * dja, out, (i,j))
453+
end
454+
end
455+
out
456+
end
457+
function __muldiag_nonzeroalpha!(out, A::UpperOrLowerTriangular, D::Diagonal, _add::MulAddMul{ais1,bis0}) where {ais1,bis0}
458+
isunit = A isa UnitUpperOrUnitLowerTriangular
459+
beta = _add.beta
460+
# since alpha is multiplied to the diagonal element of D,
461+
# we may skip alpha in the second multiplication by setting ais1 to true
462+
_add_aisone = MulAddMul{true,bis0,Bool,typeof(beta)}(true, beta)
463+
# if both A and out have the same upper/lower triangular structure,
464+
# we may directly read and write from the parents
465+
out_maybeparent, A_maybeparent = _has_matching_zeros(out, A) ? (parent(out), parent(A)) : (out, A)
466+
for j in axes(A, 2)
467+
dja = _add(@inbounds D.diag[j])
468+
# store the diagonal separately for unit triangular matrices
469+
if isunit
470+
@inbounds _modify!(_add_aisone, A[j,j] * dja, out, (j,j))
471+
end
472+
# indices of out corresponding to the stored indices of A
473+
rowrange = _rowrange_tri_stored(A, j)
474+
@inbounds @simd for i in rowrange
475+
_modify!(_add_aisone, A_maybeparent[i,j] * dja, out_maybeparent, (i,j))
476+
end
477+
# Fill the indices of out corresponding to the zeros of A
478+
# we only fill these if out and A don't have matching zeros
479+
if !_has_matching_zeros(out, A)
480+
rowrange = _rowrange_tri_zeros(A, j)
481+
@inbounds @simd for i in rowrange
482+
_modify!(_add_aisone, A[i,j] * dja, out, (i,j))
460483
end
461484
end
462485
end
463-
return out
486+
out
464487
end
465-
function __muldiag!(out, D1::Diagonal, D2::Diagonal, _add::MulAddMul{ais1,bis0}) where {ais1,bis0}
466-
require_one_based_indexing(out)
467-
alpha, beta = _add.alpha, _add.beta
468-
mA = size(D1, 1)
488+
489+
@inline function __muldiag_nonzeroalpha!(out::Diagonal, D1::Diagonal, D2::Diagonal, _add::MulAddMul)
469490
d1 = D1.diag
470491
d2 = D2.diag
471-
_rmul_or_fill!(out, beta)
472-
if !iszero(alpha)
473-
@inbounds @simd for i in 1:mA
474-
out[i,i] += d1[i] * d2[i] * alpha
475-
end
492+
outd = out.diag
493+
@inbounds @simd for i in eachindex(d1, d2, outd)
494+
_modify!(_add, d1[i] * d2[i], outd, i)
476495
end
477-
return out
496+
out
497+
end
498+
499+
# ambiguity resolution
500+
@inline function __muldiag_nonzeroalpha!(out, D1::Diagonal, D2::Diagonal, _add::MulAddMul)
501+
@inbounds for j in axes(D2, 2), i in axes(D2, 1)
502+
_modify!(_add, D1.diag[i] * D2[i,j], out, (i,j))
503+
end
504+
out
478505
end
479506

507+
# muldiag mainly handles the zero-alpha case, so that we need only
508+
# specialize the non-trivial case
480509
function _mul_diag!(out, A, B, _add)
510+
require_one_based_indexing(out, A, B)
481511
_muldiag_size_check(size(out), size(A), size(B))
482-
__muldiag!(out, A, B, _add)
512+
alpha, beta = _add.alpha, _add.beta
513+
if iszero(alpha)
514+
_rmul_or_fill!(out, beta)
515+
else
516+
__muldiag_nonzeroalpha!(out, A, B, _add)
517+
end
483518
return out
484519
end
485520

@@ -659,31 +694,21 @@ for Tri in (:UpperTriangular, :LowerTriangular)
659694
@eval $fun(A::$Tri, D::Diagonal) = $Tri($fun(A.data, D))
660695
@eval $fun(A::$UTri, D::Diagonal) = $Tri(_setdiag!($fun(A.data, D), $f, D.diag))
661696
end
697+
@eval *(A::$Tri{<:Any, <:StridedMaybeAdjOrTransMat}, D::Diagonal) =
698+
@invoke *(A::AbstractMatrix, D::Diagonal)
699+
@eval *(A::$UTri{<:Any, <:StridedMaybeAdjOrTransMat}, D::Diagonal) =
700+
@invoke *(A::AbstractMatrix, D::Diagonal)
662701
for (fun, f) in zip((:*, :lmul!, :ldiv!, :\), (:identity, :identity, :inv, :inv))
663702
@eval $fun(D::Diagonal, A::$Tri) = $Tri($fun(D, A.data))
664703
@eval $fun(D::Diagonal, A::$UTri) = $Tri(_setdiag!($fun(D, A.data), $f, D.diag))
665704
end
705+
@eval *(D::Diagonal, A::$Tri{<:Any, <:StridedMaybeAdjOrTransMat}) =
706+
@invoke *(D::Diagonal, A::AbstractMatrix)
707+
@eval *(D::Diagonal, A::$UTri{<:Any, <:StridedMaybeAdjOrTransMat}) =
708+
@invoke *(D::Diagonal, A::AbstractMatrix)
666709
# 3-arg ldiv!
667710
@eval ldiv!(C::$Tri, D::Diagonal, A::$Tri) = $Tri(ldiv!(C.data, D, A.data))
668711
@eval ldiv!(C::$Tri, D::Diagonal, A::$UTri) = $Tri(_setdiag!(ldiv!(C.data, D, A.data), inv, D.diag))
669-
# 3-arg mul! is disambiguated in special.jl
670-
# 5-arg mul!
671-
@eval _mul!(C::$Tri, D::Diagonal, A::$Tri, _add) = $Tri(mul!(C.data, D, A.data, _add.alpha, _add.beta))
672-
@eval function _mul!(C::$Tri, D::Diagonal, A::$UTri, _add::MulAddMul{ais1,bis0}) where {ais1,bis0}
673-
α, β = _add.alpha, _add.beta
674-
iszero(α) && return _rmul_or_fill!(C, β)
675-
diag′ = bis0 ? nothing : diag(C)
676-
data = mul!(C.data, D, A.data, α, β)
677-
$Tri(_setdiag!(data, _add, D.diag, diag′))
678-
end
679-
@eval _mul!(C::$Tri, A::$Tri, D::Diagonal, _add) = $Tri(mul!(C.data, A.data, D, _add.alpha, _add.beta))
680-
@eval function _mul!(C::$Tri, A::$UTri, D::Diagonal, _add::MulAddMul{ais1,bis0}) where {ais1,bis0}
681-
α, β = _add.alpha, _add.beta
682-
iszero(α) && return _rmul_or_fill!(C, β)
683-
diag′ = bis0 ? nothing : diag(C)
684-
data = mul!(C.data, A.data, D, α, β)
685-
$Tri(_setdiag!(data, _add, D.diag, diag′))
686-
end
687712
end
688713

689714
@inline function kron!(C::AbstractMatrix, A::Diagonal, B::Diagonal)

stdlib/LinearAlgebra/test/addmul.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,4 +239,26 @@ end
239239
end
240240
end
241241

242+
@testset "Diagonal scaling of a triangular matrix with a non-triangular destination" begin
243+
for MT in (UpperTriangular, UnitUpperTriangular, LowerTriangular, UnitLowerTriangular)
244+
U = MT(reshape([1:9;],3,3))
245+
M = Array(U)
246+
D = Diagonal(1:3)
247+
A = reshape([1:9;],3,3)
248+
@test mul!(copy(A), U, D, 2, 3) == M * D * 2 + A * 3
249+
@test mul!(copy(A), D, U, 2, 3) == D * M * 2 + A * 3
250+
251+
# nan values with iszero(alpha)
252+
D = Diagonal(fill(NaN,3))
253+
@test mul!(copy(A), U, D, 0, 3) == A * 3
254+
@test mul!(copy(A), D, U, 0, 3) == A * 3
255+
256+
# nan values with iszero(beta)
257+
A = fill(NaN,3,3)
258+
D = Diagonal(1:3)
259+
@test mul!(copy(A), U, D, 2, 0) == M * D * 2
260+
@test mul!(copy(A), D, U, 2, 0) == D * M * 2
261+
end
262+
end
263+
242264
end # module

stdlib/LinearAlgebra/test/diagonal.jl

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -822,6 +822,19 @@ end
822822
@test @inferred(D[1,2]) isa typeof(s)
823823
@test all(iszero, D[1,2])
824824
end
825+
826+
@testset "mul!" begin
827+
D1 = Diagonal(fill(ones(2,3), 2))
828+
D2 = Diagonal(fill(ones(3,2), 2))
829+
C = similar(D1, size(D1))
830+
mul!(C, D1, D2)
831+
@test all(x -> size(x) == (2,2), C)
832+
@test C == D1 * D2
833+
D = similar(D1)
834+
mul!(D, D1, D2)
835+
@test all(x -> size(x) == (2,2), D)
836+
@test D == D1 * D2
837+
end
825838
end
826839

827840
@testset "Eigensystem for block diagonal (issue #30681)" begin
@@ -1188,7 +1201,7 @@ end
11881201
@test oneunit(D3) isa typeof(D3)
11891202
end
11901203

1191-
@testset "AbstractTriangular" for (Tri, UTri) in ((UpperTriangular, UnitUpperTriangular), (LowerTriangular, UnitLowerTriangular))
1204+
@testset "$Tri" for (Tri, UTri) in ((UpperTriangular, UnitUpperTriangular), (LowerTriangular, UnitLowerTriangular))
11921205
A = randn(4, 4)
11931206
TriA = Tri(A)
11941207
UTriA = UTri(A)
@@ -1218,6 +1231,44 @@ end
12181231
@test outTri === mul!(outTri, D, UTriA, 2, 1)::Tri == mul!(out, D, Matrix(UTriA), 2, 1)
12191232
@test outTri === mul!(outTri, TriA, D, 2, 1)::Tri == mul!(out, Matrix(TriA), D, 2, 1)
12201233
@test outTri === mul!(outTri, UTriA, D, 2, 1)::Tri == mul!(out, Matrix(UTriA), D, 2, 1)
1234+
1235+
# we may write to a Unit triangular if the diagonal is preserved
1236+
ID = Diagonal(ones(size(UTriA,2)))
1237+
@test mul!(copy(UTriA), UTriA, ID) == UTriA
1238+
@test mul!(copy(UTriA), ID, UTriA) == UTriA
1239+
1240+
@testset "partly filled parents" begin
1241+
M = Matrix{BigFloat}(undef, 2, 2)
1242+
M[1,1] = M[2,2] = 3
1243+
isupper = Tri == UpperTriangular
1244+
M[1+!isupper, 1+isupper] = 3
1245+
D = Diagonal(1:2)
1246+
T = Tri(M)
1247+
TA = Array(T)
1248+
@test T * D == TA * D
1249+
@test D * T == D * TA
1250+
@test mul!(copy(T), T, D, 2, 3) == 2T * D + 3T
1251+
@test mul!(copy(T), D, T, 2, 3) == 2D * T + 3T
1252+
1253+
U = UTri(M)
1254+
UA = Array(U)
1255+
@test U * D == UA * D
1256+
@test D * U == D * UA
1257+
@test mul!(copy(T), U, D, 2, 3) == 2 * UA * D + 3TA
1258+
@test mul!(copy(T), D, U, 2, 3) == 2 * D * UA + 3TA
1259+
1260+
M2 = Matrix{BigFloat}(undef, 2, 2)
1261+
M2[1+!isupper, 1+isupper] = 3
1262+
U = UTri(M2)
1263+
UA = Array(U)
1264+
@test U * D == UA * D
1265+
@test D * U == D * UA
1266+
ID = Diagonal(ones(size(U,2)))
1267+
@test mul!(copy(U), U, ID) == U
1268+
@test mul!(copy(U), ID, U) == U
1269+
@test mul!(copy(U), U, ID, 2, -1) == U
1270+
@test mul!(copy(U), ID, U, 2, -1) == U
1271+
end
12211272
end
12221273

12231274
struct SMatrix1{T} <: AbstractArray{T,2}

0 commit comments

Comments
 (0)