@@ -404,22 +404,44 @@ end
404404 end
405405 out
406406end
407- _maybe_unwrap_tri (out, A) = out, A
408- _maybe_unwrap_tri (out:: UpperTriangular , A:: UpperOrUnitUpperTriangular ) = parent (out), parent (A)
409- _maybe_unwrap_tri (out:: LowerTriangular , A:: LowerOrUnitLowerTriangular ) = parent (out), parent (A)
407+ _has_matching_storage (out:: UpperOrUnitUpperTriangular , A:: UpperOrUnitUpperTriangular ) = true
408+ _has_matching_storage (out:: LowerOrUnitLowerTriangular , A:: LowerOrUnitLowerTriangular ) = true
409+ _has_matching_storage (out, A) = false
410+ function _rowrange_tri_stored (B:: UpperOrUpperTriangular , col)
411+ isunit = B isa UnitUpperTriangular
412+ 1 : min (col- isunit, size (B,1 ))
413+ end
414+ function _rowrange_tri_stored (B:: LowerOrUnitLowerTriangular , col)
415+ isunit = B isa UnitLowerTriangular
416+ col+ isunit: size (B,1 )
417+ end
418+ _rowrange_tri_nonstored (B:: UpperOrUnitUpperTriangular , col) = col+ 1 : size (B,1 )
419+ _rowrange_tri_nonstored (B:: LowerOrUnitLowerTriangular , col) = 1 : col- 1
410420@inline function __muldiag_nonzeroalpha! (out, D:: Diagonal , B:: UpperOrLowerTriangular , _add:: MulAddMul )
411- isunit = B isa Union{UnitUpperTriangular, UnitLowerTriangular}
412- # if both B and out have the same upper/lower triangular structure,
413- # we may directly read and write from the parents
414- out_maybeparent, B_maybeparent = _maybe_unwrap_tri (out, B)
421+ isunit = B isa UnitUpperOrUnitLowerTriangular
422+ out_maybeparent, B_maybeparent = _has_matching_storage (out, B) ? (parent (out), parent (B)) : (out, B)
415423 for j in axes (B, 2 )
424+ # store the diagonal separately for unit triangular matrices
416425 if isunit
417- _modify! (_add, D. diag[j] * B[j,j], out, (j,j))
426+ @inbounds _modify! (_add, D. diag[j] * B[j,j], out, (j,j))
418427 end
419- rowrange = B isa UpperOrUnitUpperTriangular ? (1 : min (j- isunit, size (B,1 ))) : (j+ isunit: size (B,1 ))
428+ # indices of out corresponding to the stored indices of B
429+ rowrange = _rowrange_tri_stored (B, j)
420430 @inbounds @simd for i in rowrange
421431 _modify! (_add, D. diag[i] * B_maybeparent[i,j], out_maybeparent, (i,j))
422432 end
433+ # indices of out corresponding to the zeros of B
434+ # we only fill these if out and B don't have matching zeros
435+ if ! _has_matching_storage (out, B)
436+ rowrange = _rowrange_tri_nonstored (B, j)
437+ if haszero (eltype (out))
438+ _rmul_or_fill! (@view (out[rowrange,j]), _add. beta)
439+ else
440+ @inbounds @simd for i in rowrange
441+ _modify! (_add, D. diag[i] * B[i,j], out, (i,j))
442+ end
443+ end
444+ end
423445 end
424446 out
425447end
@@ -446,23 +468,37 @@ end
446468 out
447469end
448470@inline function __muldiag_nonzeroalpha! (out, A:: UpperOrLowerTriangular , D:: Diagonal , _add:: MulAddMul{ais1,bis0} ) where {ais1,bis0}
449- isunit = A isa Union{UnitUpperTriangular, UnitLowerTriangular}
471+ isunit = A isa UnitUpperOrUnitLowerTriangular
450472 beta = _add. beta
451473 # since alpha is multiplied to the diagonal element of D,
452474 # we may skip alpha in the second multiplication by setting ais1 to true
453475 _add_aisone = MulAddMul {true,bis0,Bool,typeof(beta)} (true , beta)
454476 # if both A and out have the same upper/lower triangular structure,
455477 # we may directly read and write from the parents
456- out_maybeparent, A_maybeparent = _maybe_unwrap_tri (out, A)
457- @inbounds for j in axes (A, 2 )
458- dja = _add (D. diag[j])
478+ out_maybeparent, A_maybeparent = _has_matching_storage (out, A) ? (parent (out), parent (A)) : (out, A)
479+ for j in axes (A, 2 )
480+ dja = _add (@inbounds D. diag[j])
481+ # store the diagonal separately for unit triangular matrices
459482 if isunit
460- _modify! (_add_aisone, A[j,j] * dja, out, (j,j))
483+ @inbounds _modify! (_add_aisone, A[j,j] * dja, out, (j,j))
461484 end
462- rowrange = A isa UpperOrUnitUpperTriangular ? (1 : min (j- isunit, size (A,1 ))) : (j+ isunit: size (A,1 ))
463- @simd for i in rowrange
485+ # indices of out corresponding to the stored indices of A
486+ rowrange = _rowrange_tri_stored (A, j)
487+ @inbounds @simd for i in rowrange
464488 _modify! (_add_aisone, A_maybeparent[i,j] * dja, out_maybeparent, (i,j))
465489 end
490+ # indices of out corresponding to the zeros of A
491+ # we only fill these if out and A don't have matching zeros
492+ if ! _has_matching_storage (out, A)
493+ rowrange = _rowrange_tri_nonstored (A, j)
494+ if haszero (eltype (out))
495+ _rmul_or_fill! (@view (out[rowrange,j]), _add. beta)
496+ else
497+ @inbounds @simd for i in rowrange
498+ _modify! (_add, A[i,j] * dja, out, (i,j))
499+ end
500+ end
501+ end
466502 end
467503 out
468504end
0 commit comments