Skip to content

Commit 5229375

Browse files
committed
Fix multiplying a triangular matrix and a Diagonal
1 parent 04259da commit 5229375

File tree

2 files changed

+74
-16
lines changed

2 files changed

+74
-16
lines changed

stdlib/LinearAlgebra/src/diagonal.jl

Lines changed: 52 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -404,22 +404,44 @@ end
404404
end
405405
out
406406
end
407-
_maybe_unwrap_tri(out, A) = out, A
408-
_maybe_unwrap_tri(out::UpperTriangular, A::UpperOrUnitUpperTriangular) = parent(out), parent(A)
409-
_maybe_unwrap_tri(out::LowerTriangular, A::LowerOrUnitLowerTriangular) = parent(out), parent(A)
407+
_has_matching_storage(out::UpperOrUnitUpperTriangular, A::UpperOrUnitUpperTriangular) = true
408+
_has_matching_storage(out::LowerOrUnitLowerTriangular, A::LowerOrUnitLowerTriangular) = true
409+
_has_matching_storage(out, A) = false
410+
function _rowrange_tri_stored(B::UpperOrUpperTriangular, col)
411+
isunit = B isa UnitUpperTriangular
412+
1:min(col-isunit, size(B,1))
413+
end
414+
function _rowrange_tri_stored(B::LowerOrUnitLowerTriangular, col)
415+
isunit = B isa UnitLowerTriangular
416+
col+isunit:size(B,1)
417+
end
418+
_rowrange_tri_nonstored(B::UpperOrUnitUpperTriangular, col) = col+1:size(B,1)
419+
_rowrange_tri_nonstored(B::LowerOrUnitLowerTriangular, col) = 1:col-1
410420
@inline function __muldiag_nonzeroalpha!(out, D::Diagonal, B::UpperOrLowerTriangular, _add::MulAddMul)
411-
isunit = B isa Union{UnitUpperTriangular, UnitLowerTriangular}
412-
# if both B and out have the same upper/lower triangular structure,
413-
# we may directly read and write from the parents
414-
out_maybeparent, B_maybeparent = _maybe_unwrap_tri(out, B)
421+
isunit = B isa UnitUpperOrUnitLowerTriangular
422+
out_maybeparent, B_maybeparent = _has_matching_storage(out, B) ? (parent(out), parent(B)) : (out, B)
415423
for j in axes(B, 2)
424+
# store the diagonal separately for unit triangular matrices
416425
if isunit
417-
_modify!(_add, D.diag[j] * B[j,j], out, (j,j))
426+
@inbounds _modify!(_add, D.diag[j] * B[j,j], out, (j,j))
418427
end
419-
rowrange = B isa UpperOrUnitUpperTriangular ? (1:min(j-isunit, size(B,1))) : (j+isunit:size(B,1))
428+
# indices of out corresponding to the stored indices of B
429+
rowrange = _rowrange_tri_stored(B, j)
420430
@inbounds @simd for i in rowrange
421431
_modify!(_add, D.diag[i] * B_maybeparent[i,j], out_maybeparent, (i,j))
422432
end
433+
# indices of out corresponding to the zeros of B
434+
# we only fill these if out and B don't have matching zeros
435+
if !_has_matching_storage(out, B)
436+
rowrange = _rowrange_tri_nonstored(B, j)
437+
if haszero(eltype(out))
438+
_rmul_or_fill!(@view(out[rowrange,j]), _add.beta)
439+
else
440+
@inbounds @simd for i in rowrange
441+
_modify!(_add, D.diag[i] * B[i,j], out, (i,j))
442+
end
443+
end
444+
end
423445
end
424446
out
425447
end
@@ -446,23 +468,37 @@ end
446468
out
447469
end
448470
@inline function __muldiag_nonzeroalpha!(out, A::UpperOrLowerTriangular, D::Diagonal, _add::MulAddMul{ais1,bis0}) where {ais1,bis0}
449-
isunit = A isa Union{UnitUpperTriangular, UnitLowerTriangular}
471+
isunit = A isa UnitUpperOrUnitLowerTriangular
450472
beta = _add.beta
451473
# since alpha is multiplied to the diagonal element of D,
452474
# we may skip alpha in the second multiplication by setting ais1 to true
453475
_add_aisone = MulAddMul{true,bis0,Bool,typeof(beta)}(true, beta)
454476
# if both A and out have the same upper/lower triangular structure,
455477
# we may directly read and write from the parents
456-
out_maybeparent, A_maybeparent = _maybe_unwrap_tri(out, A)
457-
@inbounds for j in axes(A, 2)
458-
dja = _add(D.diag[j])
478+
out_maybeparent, A_maybeparent = _has_matching_storage(out, A) ? (parent(out), parent(A)) : (out, A)
479+
for j in axes(A, 2)
480+
dja = _add(@inbounds D.diag[j])
481+
# store the diagonal separately for unit triangular matrices
459482
if isunit
460-
_modify!(_add_aisone, A[j,j] * dja, out, (j,j))
483+
@inbounds _modify!(_add_aisone, A[j,j] * dja, out, (j,j))
461484
end
462-
rowrange = A isa UpperOrUnitUpperTriangular ? (1:min(j-isunit, size(A,1))) : (j+isunit:size(A,1))
463-
@simd for i in rowrange
485+
# indices of out corresponding to the stored indices of A
486+
rowrange = _rowrange_tri_stored(A, j)
487+
@inbounds @simd for i in rowrange
464488
_modify!(_add_aisone, A_maybeparent[i,j] * dja, out_maybeparent, (i,j))
465489
end
490+
# indices of out corresponding to the zeros of A
491+
# we only fill these if out and A don't have matching zeros
492+
if !_has_matching_storage(out, A)
493+
rowrange = _rowrange_tri_nonstored(A, j)
494+
if haszero(eltype(out))
495+
_rmul_or_fill!(@view(out[rowrange,j]), _add.beta)
496+
else
497+
@inbounds @simd for i in rowrange
498+
_modify!(_add, A[i,j] * dja, out, (i,j))
499+
end
500+
end
501+
end
466502
end
467503
out
468504
end

stdlib/LinearAlgebra/test/addmul.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,4 +239,26 @@ end
239239
end
240240
end
241241

242+
@testset "Diagonal scaling of a triangular matrix with a non-triangular destination" begin
243+
for MT in (UpperTriangular, UnitUpperTriangular, LowerTriangular, UnitLowerTriangular)
244+
U = MT(reshape([1:9;],3,3))
245+
M = Array(U)
246+
D = Diagonal(1:3)
247+
A = reshape([1:9;],3,3)
248+
@test mul!(copy(A), U, D, 2, 3) == M * D * 2 + A * 3
249+
@test mul!(copy(A), D, U, 2, 3) == D * M * 2 + A * 3
250+
251+
# nan values with iszero(alpha)
252+
D = Diagonal(fill(NaN,3))
253+
@test mul!(copy(A), U, D, 0, 3) == A * 3
254+
@test mul!(copy(A), D, U, 0, 3) == A * 3
255+
256+
# nan values with iszero(beta)
257+
A = fill(NaN,3,3)
258+
D = Diagonal(1:3)
259+
@test mul!(copy(A), U, D, 2, 0) == M * D * 2
260+
@test mul!(copy(A), D, U, 2, 0) == D * M * 2
261+
end
262+
end
263+
242264
end # module

0 commit comments

Comments
 (0)