@@ -396,120 +396,82 @@ function lmul!(D::Diagonal, T::Tridiagonal)
396396 return T
397397end
398398
399- @inline function __muldiag_nonzeroalpha! (out, D:: Diagonal , B, _add:: MulAddMul )
400- @inbounds for j in axes (B, 2 )
401- @simd for i in axes (B, 1 )
402- _modify! (_add, D. diag[i] * B[i,j], out, (i,j))
403- end
404- end
405- out
406- end
407- _maybe_unwrap_tri (out, A) = out, A
408- _maybe_unwrap_tri (out:: UpperTriangular , A:: UpperOrUnitUpperTriangular ) = parent (out), parent (A)
409- _maybe_unwrap_tri (out:: LowerTriangular , A:: LowerOrUnitLowerTriangular ) = parent (out), parent (A)
410- @inline function __muldiag_nonzeroalpha! (out, D:: Diagonal , B:: UpperOrLowerTriangular , _add:: MulAddMul )
411- isunit = B isa Union{UnitUpperTriangular, UnitLowerTriangular}
412- # if both B and out have the same upper/lower triangular structure,
413- # we may directly read and write from the parents
414- out_maybeparent, B_maybeparent = _maybe_unwrap_tri (out, B)
415- for j in axes (B, 2 )
416- if isunit
417- _modify! (_add, D. diag[j] * B[j,j], out, (j,j))
418- end
419- rowrange = B isa UpperOrUnitUpperTriangular ? (1 : min (j- isunit, size (B,1 ))) : (j+ isunit: size (B,1 ))
420- @inbounds @simd for i in rowrange
421- _modify! (_add, D. diag[i] * B_maybeparent[i,j], out_maybeparent, (i,j))
422- end
423- end
424- out
425- end
426- function __muldiag! (out, D:: Diagonal , B, _add:: MulAddMul )
399+ function __muldiag! (out, D:: Diagonal , B, _add:: MulAddMul{ais1,bis0} ) where {ais1,bis0}
427400 require_one_based_indexing (out, B)
428401 alpha, beta = _add. alpha, _add. beta
429402 if iszero (alpha)
430403 _rmul_or_fill! (out, beta)
431404 else
432- __muldiag_nonzeroalpha! (out, D, B, _add)
433- end
434- return out
435- end
436-
437- @inline function __muldiag_nonzeroalpha! (out, A, D:: Diagonal , _add:: MulAddMul{ais1,bis0} ) where {ais1,bis0}
438- beta = _add. beta
439- _add_aisone = MulAddMul {true,bis0,Bool,typeof(beta)} (true , beta)
440- @inbounds for j in axes (A, 2 )
441- dja = _add (D. diag[j])
442- @simd for i in axes (A, 1 )
443- _modify! (_add_aisone, A[i,j] * dja, out, (i,j))
444- end
445- end
446- out
447- end
448- @inline function __muldiag_nonzeroalpha! (out, A:: UpperOrLowerTriangular , D:: Diagonal , _add:: MulAddMul{ais1,bis0} ) where {ais1,bis0}
449- isunit = A isa Union{UnitUpperTriangular, UnitLowerTriangular}
450- beta = _add. beta
451- # since alpha is multiplied to the diagonal element of D,
452- # we may skip alpha in the second multiplication by setting ais1 to true
453- _add_aisone = MulAddMul {true,bis0,Bool,typeof(beta)} (true , beta)
454- # if both A and out have the same upper/lower triangular structure,
455- # we may directly read and write from the parents
456- out_maybeparent, A_maybeparent = _maybe_unwrap_tri (out, A)
457- @inbounds for j in axes (A, 2 )
458- dja = _add (D. diag[j])
459- if isunit
460- _modify! (_add_aisone, A[j,j] * dja, out, (j,j))
461- end
462- rowrange = A isa UpperOrUnitUpperTriangular ? (1 : min (j- isunit, size (A,1 ))) : (j+ isunit: size (A,1 ))
463- @simd for i in rowrange
464- _modify! (_add_aisone, A_maybeparent[i,j] * dja, out_maybeparent, (i,j))
405+ if bis0
406+ @inbounds for j in axes (B, 2 )
407+ @simd for i in axes (B, 1 )
408+ out[i,j] = D. diag[i] * B[i,j] * alpha
409+ end
410+ end
411+ else
412+ @inbounds for j in axes (B, 2 )
413+ @simd for i in axes (B, 1 )
414+ out[i,j] = D. diag[i] * B[i,j] * alpha + out[i,j] * beta
415+ end
416+ end
465417 end
466418 end
467- out
419+ return out
468420end
469- function __muldiag! (out, A, D:: Diagonal , _add:: MulAddMul )
421+ function __muldiag! (out, A, D:: Diagonal , _add:: MulAddMul{ais1,bis0} ) where {ais1,bis0}
470422 require_one_based_indexing (out, A)
471423 alpha, beta = _add. alpha, _add. beta
472424 if iszero (alpha)
473425 _rmul_or_fill! (out, beta)
474426 else
475- __muldiag_nonzeroalpha! (out, A, D, _add)
427+ if bis0
428+ @inbounds for j in axes (A, 2 )
429+ dja = D. diag[j] * alpha
430+ @simd for i in axes (A, 1 )
431+ out[i,j] = A[i,j] * dja
432+ end
433+ end
434+ else
435+ @inbounds for j in axes (A, 2 )
436+ dja = D. diag[j] * alpha
437+ @simd for i in axes (A, 1 )
438+ out[i,j] = A[i,j] * dja + out[i,j] * beta
439+ end
440+ end
441+ end
476442 end
477443 return out
478444end
479-
480- @inline function __muldiag_nonzeroalpha! (out:: Diagonal , D1:: Diagonal , D2:: Diagonal , _add:: MulAddMul )
445+ function __muldiag! (out:: Diagonal , D1:: Diagonal , D2:: Diagonal , _add:: MulAddMul{ais1,bis0} ) where {ais1,bis0}
481446 d1 = D1. diag
482447 d2 = D2. diag
483- outd = out. diag
484- @inbounds @simd for i in eachindex (d1, d2, outd)
485- _modify! (_add, d1[i] * d2[i], outd, i)
486- end
487- out
488- end
489- function __muldiag! (out:: Diagonal , D1:: Diagonal , D2:: Diagonal , _add:: MulAddMul )
490448 alpha, beta = _add. alpha, _add. beta
491449 if iszero (alpha)
492450 _rmul_or_fill! (out. diag, beta)
493451 else
494- __muldiag_nonzeroalpha! (out, D1, D2, _add)
452+ if bis0
453+ @inbounds @simd for i in eachindex (out. diag)
454+ out. diag[i] = d1[i] * d2[i] * alpha
455+ end
456+ else
457+ @inbounds @simd for i in eachindex (out. diag)
458+ out. diag[i] = d1[i] * d2[i] * alpha + out. diag[i] * beta
459+ end
460+ end
495461 end
496462 return out
497463end
498- @inline function __muldiag_nonzeroalpha! (out, D1:: Diagonal , D2:: Diagonal , _add:: MulAddMul )
499- d1 = D1. diag
500- d2 = D2. diag
501- @inbounds @simd for i in eachindex (d1, d2)
502- _modify! (_add, d1[i] * d2[i], out, (i,i))
503- end
504- out
505- end
506- function __muldiag! (out, D1:: Diagonal , D2:: Diagonal , _add:: MulAddMul{ais1} ) where {ais1}
464+ function __muldiag! (out, D1:: Diagonal , D2:: Diagonal , _add:: MulAddMul{ais1,bis0} ) where {ais1,bis0}
507465 require_one_based_indexing (out)
508466 alpha, beta = _add. alpha, _add. beta
467+ mA = size (D1, 1 )
468+ d1 = D1. diag
469+ d2 = D2. diag
509470 _rmul_or_fill! (out, beta)
510471 if ! iszero (alpha)
511- _add_bis1 = MulAddMul {ais1,false,typeof(alpha),Bool} (alpha,true )
512- __muldiag_nonzeroalpha! (out, D1, D2, _add_bis1)
472+ @inbounds @simd for i in 1 : mA
473+ out[i,i] += d1[i] * d2[i] * alpha
474+ end
513475 end
514476 return out
515477end
@@ -696,21 +658,31 @@ for Tri in (:UpperTriangular, :LowerTriangular)
696658 @eval $ fun (A:: $Tri , D:: Diagonal ) = $ Tri ($ fun (A. data, D))
697659 @eval $ fun (A:: $UTri , D:: Diagonal ) = $ Tri (_setdiag! ($ fun (A. data, D), $ f, D. diag))
698660 end
699- @eval * (A:: $Tri{<:Any, <:StridedMaybeAdjOrTransMat} , D:: Diagonal ) =
700- @invoke * (A:: AbstractMatrix , D:: Diagonal )
701- @eval * (A:: $UTri{<:Any, <:StridedMaybeAdjOrTransMat} , D:: Diagonal ) =
702- @invoke * (A:: AbstractMatrix , D:: Diagonal )
703661 for (fun, f) in zip ((:* , :lmul! , :ldiv! , :\ ), (:identity , :identity , :inv , :inv ))
704662 @eval $ fun (D:: Diagonal , A:: $Tri ) = $ Tri ($ fun (D, A. data))
705663 @eval $ fun (D:: Diagonal , A:: $UTri ) = $ Tri (_setdiag! ($ fun (D, A. data), $ f, D. diag))
706664 end
707- @eval * (D:: Diagonal , A:: $Tri{<:Any, <:StridedMaybeAdjOrTransMat} ) =
708- @invoke * (D:: Diagonal , A:: AbstractMatrix )
709- @eval * (D:: Diagonal , A:: $UTri{<:Any, <:StridedMaybeAdjOrTransMat} ) =
710- @invoke * (D:: Diagonal , A:: AbstractMatrix )
711665 # 3-arg ldiv!
712666 @eval ldiv! (C:: $Tri , D:: Diagonal , A:: $Tri ) = $ Tri (ldiv! (C. data, D, A. data))
713667 @eval ldiv! (C:: $Tri , D:: Diagonal , A:: $UTri ) = $ Tri (_setdiag! (ldiv! (C. data, D, A. data), inv, D. diag))
668+ # 3-arg mul! is disambiguated in special.jl
669+ # 5-arg mul!
670+ @eval _mul! (C:: $Tri , D:: Diagonal , A:: $Tri , _add) = $ Tri (mul! (C. data, D, A. data, _add. alpha, _add. beta))
671+ @eval function _mul! (C:: $Tri , D:: Diagonal , A:: $UTri , _add:: MulAddMul{ais1,bis0} ) where {ais1,bis0}
672+ α, β = _add. alpha, _add. beta
673+ iszero (α) && return _rmul_or_fill! (C, β)
674+ diag′ = bis0 ? nothing : diag (C)
675+ data = mul! (C. data, D, A. data, α, β)
676+ $ Tri (_setdiag! (data, _add, D. diag, diag′))
677+ end
678+ @eval _mul! (C:: $Tri , A:: $Tri , D:: Diagonal , _add) = $ Tri (mul! (C. data, A. data, D, _add. alpha, _add. beta))
679+ @eval function _mul! (C:: $Tri , A:: $UTri , D:: Diagonal , _add:: MulAddMul{ais1,bis0} ) where {ais1,bis0}
680+ α, β = _add. alpha, _add. beta
681+ iszero (α) && return _rmul_or_fill! (C, β)
682+ diag′ = bis0 ? nothing : diag (C)
683+ data = mul! (C. data, A. data, D, α, β)
684+ $ Tri (_setdiag! (data, _add, D. diag, diag′))
685+ end
714686end
715687
716688@inline function kron! (C:: AbstractMatrix , A:: Diagonal , B:: Diagonal )
0 commit comments