@@ -396,82 +396,156 @@ function lmul!(D::Diagonal, T::Tridiagonal)
396396 return T
397397end
398398
399- function __muldiag! (out, D:: Diagonal , B, _add:: MulAddMul{ais1,bis0} ) where {ais1,bis0}
399+ @inline function __muldiag_nonzeroalpha! (out, D:: Diagonal , B, _add:: MulAddMul )
400+ @inbounds for j in axes (B, 2 )
401+ @simd for i in axes (B, 1 )
402+ _modify! (_add, D. diag[i] * B[i,j], out, (i,j))
403+ end
404+ end
405+ out
406+ end
407+ _has_matching_storage (out:: UpperOrUnitUpperTriangular , A:: UpperOrUnitUpperTriangular ) = true
408+ _has_matching_storage (out:: LowerOrUnitLowerTriangular , A:: LowerOrUnitLowerTriangular ) = true
409+ _has_matching_storage (out, A) = false
410+ function _rowrange_tri_stored (B:: UpperOrUnitUpperTriangular , col)
411+ isunit = B isa UnitUpperTriangular
412+ 1 : min (col- isunit, size (B,1 ))
413+ end
414+ function _rowrange_tri_stored (B:: LowerOrUnitLowerTriangular , col)
415+ isunit = B isa UnitLowerTriangular
416+ col+ isunit: size (B,1 )
417+ end
418+ _rowrange_tri_nonstored (B:: UpperOrUnitUpperTriangular , col) = col+ 1 : size (B,1 )
419+ _rowrange_tri_nonstored (B:: LowerOrUnitLowerTriangular , col) = 1 : col- 1
420+ function __muldiag_nonzeroalpha! (out, D:: Diagonal , B:: UpperOrLowerTriangular , _add:: MulAddMul )
421+ isunit = B isa UnitUpperOrUnitLowerTriangular
422+ out_maybeparent, B_maybeparent = _has_matching_storage (out, B) ? (parent (out), parent (B)) : (out, B)
423+ for j in axes (B, 2 )
424+ # store the diagonal separately for unit triangular matrices
425+ if isunit
426+ @inbounds _modify! (_add, D. diag[j] * B[j,j], out, (j,j))
427+ end
428+ # indices of out corresponding to the stored indices of B
429+ rowrange = _rowrange_tri_stored (B, j)
430+ @inbounds @simd for i in rowrange
431+ _modify! (_add, D. diag[i] * B_maybeparent[i,j], out_maybeparent, (i,j))
432+ end
433+ # indices of out corresponding to the zeros of B
434+ # we only fill these if out and B don't have matching zeros
435+ if ! _has_matching_storage (out, B)
436+ rowrange = _rowrange_tri_nonstored (B, j)
437+ if haszero (eltype (out))
438+ _rmul_or_fill! (@view (out[rowrange,j]), _add. beta)
439+ else
440+ @inbounds @simd for i in rowrange
441+ _modify! (_add, D. diag[i] * B[i,j], out, (i,j))
442+ end
443+ end
444+ end
445+ end
446+ out
447+ end
448+ function __muldiag! (out, D:: Diagonal , B, _add:: MulAddMul )
400449 require_one_based_indexing (out, B)
401450 alpha, beta = _add. alpha, _add. beta
402451 if iszero (alpha)
403452 _rmul_or_fill! (out, beta)
404453 else
405- if bis0
406- @inbounds for j in axes (B, 2 )
407- @simd for i in axes (B, 1 )
408- out[i,j] = D. diag[i] * B[i,j] * alpha
409- end
410- end
411- else
412- @inbounds for j in axes (B, 2 )
413- @simd for i in axes (B, 1 )
414- out[i,j] = D. diag[i] * B[i,j] * alpha + out[i,j] * beta
454+ __muldiag_nonzeroalpha! (out, D, B, _add)
455+ end
456+ return out
457+ end
458+
459+ @inline function __muldiag_nonzeroalpha! (out, A, D:: Diagonal , _add:: MulAddMul{ais1,bis0} ) where {ais1,bis0}
460+ beta = _add. beta
461+ _add_aisone = MulAddMul {true,bis0,Bool,typeof(beta)} (true , beta)
462+ @inbounds for j in axes (A, 2 )
463+ dja = _add (D. diag[j])
464+ @simd for i in axes (A, 1 )
465+ _modify! (_add_aisone, A[i,j] * dja, out, (i,j))
466+ end
467+ end
468+ out
469+ end
470+ function __muldiag_nonzeroalpha! (out, A:: UpperOrLowerTriangular , D:: Diagonal , _add:: MulAddMul{ais1,bis0} ) where {ais1,bis0}
471+ isunit = A isa UnitUpperOrUnitLowerTriangular
472+ beta = _add. beta
473+ # since alpha is multiplied to the diagonal element of D,
474+ # we may skip alpha in the second multiplication by setting ais1 to true
475+ _add_aisone = MulAddMul {true,bis0,Bool,typeof(beta)} (true , beta)
476+ # if both A and out have the same upper/lower triangular structure,
477+ # we may directly read and write from the parents
478+ out_maybeparent, A_maybeparent = _has_matching_storage (out, A) ? (parent (out), parent (A)) : (out, A)
479+ for j in axes (A, 2 )
480+ dja = _add (@inbounds D. diag[j])
481+ # store the diagonal separately for unit triangular matrices
482+ if isunit
483+ @inbounds _modify! (_add_aisone, A[j,j] * dja, out, (j,j))
484+ end
485+ # indices of out corresponding to the stored indices of A
486+ rowrange = _rowrange_tri_stored (A, j)
487+ @inbounds @simd for i in rowrange
488+ _modify! (_add_aisone, A_maybeparent[i,j] * dja, out_maybeparent, (i,j))
489+ end
490+ # indices of out corresponding to the zeros of A
491+ # we only fill these if out and A don't have matching zeros
492+ if ! _has_matching_storage (out, A)
493+ rowrange = _rowrange_tri_nonstored (A, j)
494+ if haszero (eltype (out))
495+ _rmul_or_fill! (@view (out[rowrange,j]), _add. beta)
496+ else
497+ @inbounds @simd for i in rowrange
498+ _modify! (_add, A[i,j] * dja, out, (i,j))
415499 end
416500 end
417501 end
418502 end
419- return out
503+ out
420504end
421- function __muldiag! (out, A, D:: Diagonal , _add:: MulAddMul{ais1,bis0} ) where {ais1,bis0}
505+ function __muldiag! (out, A, D:: Diagonal , _add:: MulAddMul )
422506 require_one_based_indexing (out, A)
423507 alpha, beta = _add. alpha, _add. beta
424508 if iszero (alpha)
425509 _rmul_or_fill! (out, beta)
426510 else
427- if bis0
428- @inbounds for j in axes (A, 2 )
429- dja = D. diag[j] * alpha
430- @simd for i in axes (A, 1 )
431- out[i,j] = A[i,j] * dja
432- end
433- end
434- else
435- @inbounds for j in axes (A, 2 )
436- dja = D. diag[j] * alpha
437- @simd for i in axes (A, 1 )
438- out[i,j] = A[i,j] * dja + out[i,j] * beta
439- end
440- end
441- end
511+ __muldiag_nonzeroalpha! (out, A, D, _add)
442512 end
443513 return out
444514end
445- function __muldiag! (out:: Diagonal , D1:: Diagonal , D2:: Diagonal , _add:: MulAddMul{ais1,bis0} ) where {ais1,bis0}
515+
516+ @inline function __muldiag_nonzeroalpha! (out:: Diagonal , D1:: Diagonal , D2:: Diagonal , _add:: MulAddMul )
446517 d1 = D1. diag
447518 d2 = D2. diag
519+ outd = out. diag
520+ @inbounds @simd for i in eachindex (d1, d2, outd)
521+ _modify! (_add, d1[i] * d2[i], outd, i)
522+ end
523+ out
524+ end
525+ function __muldiag! (out:: Diagonal , D1:: Diagonal , D2:: Diagonal , _add:: MulAddMul )
448526 alpha, beta = _add. alpha, _add. beta
449527 if iszero (alpha)
450528 _rmul_or_fill! (out. diag, beta)
451529 else
452- if bis0
453- @inbounds @simd for i in eachindex (out. diag)
454- out. diag[i] = d1[i] * d2[i] * alpha
455- end
456- else
457- @inbounds @simd for i in eachindex (out. diag)
458- out. diag[i] = d1[i] * d2[i] * alpha + out. diag[i] * beta
459- end
460- end
530+ __muldiag_nonzeroalpha! (out, D1, D2, _add)
461531 end
462532 return out
463533end
464- function __muldiag! (out, D1:: Diagonal , D2:: Diagonal , _add:: MulAddMul{ais1,bis0} ) where {ais1,bis0}
465- require_one_based_indexing (out)
466- alpha, beta = _add. alpha, _add. beta
467- mA = size (D1, 1 )
534+ @inline function __muldiag_nonzeroalpha! (out, D1:: Diagonal , D2:: Diagonal , _add:: MulAddMul )
468535 d1 = D1. diag
469536 d2 = D2. diag
537+ @inbounds @simd for i in eachindex (d1, d2)
538+ _modify! (_add, d1[i] * d2[i], out, (i,i))
539+ end
540+ out
541+ end
542+ function __muldiag! (out, D1:: Diagonal , D2:: Diagonal , _add:: MulAddMul{ais1} ) where {ais1}
543+ require_one_based_indexing (out)
544+ alpha, beta = _add. alpha, _add. beta
470545 _rmul_or_fill! (out, beta)
471546 if ! iszero (alpha)
472- @inbounds @simd for i in 1 : mA
473- out[i,i] += d1[i] * d2[i] * alpha
474- end
547+ _add_bis1 = MulAddMul {ais1,false,typeof(alpha),Bool} (alpha,true )
548+ __muldiag_nonzeroalpha! (out, D1, D2, _add_bis1)
475549 end
476550 return out
477551end
@@ -658,31 +732,21 @@ for Tri in (:UpperTriangular, :LowerTriangular)
658732 @eval $ fun (A:: $Tri , D:: Diagonal ) = $ Tri ($ fun (A. data, D))
659733 @eval $ fun (A:: $UTri , D:: Diagonal ) = $ Tri (_setdiag! ($ fun (A. data, D), $ f, D. diag))
660734 end
735+ @eval * (A:: $Tri{<:Any, <:StridedMaybeAdjOrTransMat} , D:: Diagonal ) =
736+ @invoke * (A:: AbstractMatrix , D:: Diagonal )
737+ @eval * (A:: $UTri{<:Any, <:StridedMaybeAdjOrTransMat} , D:: Diagonal ) =
738+ @invoke * (A:: AbstractMatrix , D:: Diagonal )
661739 for (fun, f) in zip ((:* , :lmul! , :ldiv! , :\ ), (:identity , :identity , :inv , :inv ))
662740 @eval $ fun (D:: Diagonal , A:: $Tri ) = $ Tri ($ fun (D, A. data))
663741 @eval $ fun (D:: Diagonal , A:: $UTri ) = $ Tri (_setdiag! ($ fun (D, A. data), $ f, D. diag))
664742 end
743+ @eval * (D:: Diagonal , A:: $Tri{<:Any, <:StridedMaybeAdjOrTransMat} ) =
744+ @invoke * (D:: Diagonal , A:: AbstractMatrix )
745+ @eval * (D:: Diagonal , A:: $UTri{<:Any, <:StridedMaybeAdjOrTransMat} ) =
746+ @invoke * (D:: Diagonal , A:: AbstractMatrix )
665747 # 3-arg ldiv!
666748 @eval ldiv! (C:: $Tri , D:: Diagonal , A:: $Tri ) = $ Tri (ldiv! (C. data, D, A. data))
667749 @eval ldiv! (C:: $Tri , D:: Diagonal , A:: $UTri ) = $ Tri (_setdiag! (ldiv! (C. data, D, A. data), inv, D. diag))
668- # 3-arg mul! is disambiguated in special.jl
669- # 5-arg mul!
670- @eval _mul! (C:: $Tri , D:: Diagonal , A:: $Tri , _add) = $ Tri (mul! (C. data, D, A. data, _add. alpha, _add. beta))
671- @eval function _mul! (C:: $Tri , D:: Diagonal , A:: $UTri , _add:: MulAddMul{ais1,bis0} ) where {ais1,bis0}
672- α, β = _add. alpha, _add. beta
673- iszero (α) && return _rmul_or_fill! (C, β)
674- diag′ = bis0 ? nothing : diag (C)
675- data = mul! (C. data, D, A. data, α, β)
676- $ Tri (_setdiag! (data, _add, D. diag, diag′))
677- end
678- @eval _mul! (C:: $Tri , A:: $Tri , D:: Diagonal , _add) = $ Tri (mul! (C. data, A. data, D, _add. alpha, _add. beta))
679- @eval function _mul! (C:: $Tri , A:: $UTri , D:: Diagonal , _add:: MulAddMul{ais1,bis0} ) where {ais1,bis0}
680- α, β = _add. alpha, _add. beta
681- iszero (α) && return _rmul_or_fill! (C, β)
682- diag′ = bis0 ? nothing : diag (C)
683- data = mul! (C. data, A. data, D, α, β)
684- $ Tri (_setdiag! (data, _add, D. diag, diag′))
685- end
686750end
687751
688752@inline function kron! (C:: AbstractMatrix , A:: Diagonal , B:: Diagonal )
0 commit comments