@@ -397,89 +397,124 @@ function lmul!(D::Diagonal, T::Tridiagonal)
397397 return T
398398end
399399
400- function __muldiag! (out, D:: Diagonal , B, _add:: MulAddMul{ais1,bis0} ) where {ais1,bis0}
401- require_one_based_indexing (out, B)
402- alpha, beta = _add. alpha, _add. beta
403- if iszero (alpha)
404- _rmul_or_fill! (out, beta)
405- else
406- if bis0
407- @inbounds for j in axes (B, 2 )
408- @simd for i in axes (B, 1 )
409- out[i,j] = D. diag[i] * B[i,j] * alpha
410- end
411- end
412- else
413- @inbounds for j in axes (B, 2 )
414- @simd for i in axes (B, 1 )
415- out[i,j] = D. diag[i] * B[i,j] * alpha + out[i,j] * beta
416- end
417- end
400+ @inline function __muldiag_nonzeroalpha! (out, D:: Diagonal , B, _add:: MulAddMul )
401+ @inbounds for j in axes (B, 2 )
402+ @simd for i in axes (B, 1 )
403+ _modify! (_add, D. diag[i] * B[i,j], out, (i,j))
418404 end
419405 end
420- return out
421- end
422- function __muldiag! (out, A, D:: Diagonal , _add:: MulAddMul{ais1,bis0} ) where {ais1,bis0}
423- require_one_based_indexing (out, A)
424- alpha, beta = _add. alpha, _add. beta
425- if iszero (alpha)
426- _rmul_or_fill! (out, beta)
427- else
428- if bis0
429- @inbounds for j in axes (A, 2 )
430- dja = D. diag[j] * alpha
431- @simd for i in axes (A, 1 )
432- out[i,j] = A[i,j] * dja
433- end
434- end
435- else
436- @inbounds for j in axes (A, 2 )
437- dja = D. diag[j] * alpha
438- @simd for i in axes (A, 1 )
439- out[i,j] = A[i,j] * dja + out[i,j] * beta
440- end
406+ out
407+ end
408+ _has_matching_zeros (out:: UpperOrUnitUpperTriangular , A:: UpperOrUnitUpperTriangular ) = true
409+ _has_matching_zeros (out:: LowerOrUnitLowerTriangular , A:: LowerOrUnitLowerTriangular ) = true
410+ _has_matching_zeros (out, A) = false
411+ function _rowrange_tri_stored (B:: UpperOrUnitUpperTriangular , col)
412+ isunit = B isa UnitUpperTriangular
413+ 1 : min (col- isunit, size (B,1 ))
414+ end
415+ function _rowrange_tri_stored (B:: LowerOrUnitLowerTriangular , col)
416+ isunit = B isa UnitLowerTriangular
417+ col+ isunit: size (B,1 )
418+ end
419+ _rowrange_tri_zeros (B:: UpperOrUnitUpperTriangular , col) = col+ 1 : size (B,1 )
420+ _rowrange_tri_zeros (B:: LowerOrUnitLowerTriangular , col) = 1 : col- 1
421+ function __muldiag_nonzeroalpha! (out, D:: Diagonal , B:: UpperOrLowerTriangular , _add:: MulAddMul )
422+ isunit = B isa UnitUpperOrUnitLowerTriangular
423+ out_maybeparent, B_maybeparent = _has_matching_zeros (out, B) ? (parent (out), parent (B)) : (out, B)
424+ for j in axes (B, 2 )
425+ # store the diagonal separately for unit triangular matrices
426+ if isunit
427+ @inbounds _modify! (_add, D. diag[j] * B[j,j], out, (j,j))
428+ end
429+ # The indices of out corresponding to the stored indices of B
430+ rowrange = _rowrange_tri_stored (B, j)
431+ @inbounds @simd for i in rowrange
432+ _modify! (_add, D. diag[i] * B_maybeparent[i,j], out_maybeparent, (i,j))
433+ end
434+ # Fill the indices of out corresponding to the zeros of B
435+ # we only fill these if out and B don't have matching zeros
436+ if ! _has_matching_zeros (out, B)
437+ rowrange = _rowrange_tri_zeros (B, j)
438+ @inbounds @simd for i in rowrange
439+ _modify! (_add, D. diag[i] * B[i,j], out, (i,j))
441440 end
442441 end
443442 end
444443 return out
445444end
446- function __muldiag! (out:: Diagonal , D1:: Diagonal , D2:: Diagonal , _add:: MulAddMul{ais1,bis0} ) where {ais1,bis0}
447- d1 = D1. diag
448- d2 = D2. diag
449- alpha, beta = _add. alpha, _add. beta
450- if iszero (alpha)
451- _rmul_or_fill! (out. diag, beta)
452- else
453- if bis0
454- @inbounds @simd for i in eachindex (out. diag)
455- out. diag[i] = d1[i] * d2[i] * alpha
456- end
457- else
458- @inbounds @simd for i in eachindex (out. diag)
459- out. diag[i] = d1[i] * d2[i] * alpha + out. diag[i] * beta
445+
446+ @inline function __muldiag_nonzeroalpha! (out, A, D:: Diagonal , _add:: MulAddMul{ais1,bis0} ) where {ais1,bis0}
447+ beta = _add. beta
448+ _add_aisone = MulAddMul {true,bis0,Bool,typeof(beta)} (true , beta)
449+ @inbounds for j in axes (A, 2 )
450+ dja = _add (D. diag[j])
451+ @simd for i in axes (A, 1 )
452+ _modify! (_add_aisone, A[i,j] * dja, out, (i,j))
453+ end
454+ end
455+ out
456+ end
457+ function __muldiag_nonzeroalpha! (out, A:: UpperOrLowerTriangular , D:: Diagonal , _add:: MulAddMul{ais1,bis0} ) where {ais1,bis0}
458+ isunit = A isa UnitUpperOrUnitLowerTriangular
459+ beta = _add. beta
460+ # since alpha is multiplied to the diagonal element of D,
461+ # we may skip alpha in the second multiplication by setting ais1 to true
462+ _add_aisone = MulAddMul {true,bis0,Bool,typeof(beta)} (true , beta)
463+ # if both A and out have the same upper/lower triangular structure,
464+ # we may directly read and write from the parents
465+ out_maybeparent, A_maybeparent = _has_matching_zeros (out, A) ? (parent (out), parent (A)) : (out, A)
466+ for j in axes (A, 2 )
467+ dja = _add (@inbounds D. diag[j])
468+ # store the diagonal separately for unit triangular matrices
469+ if isunit
470+ @inbounds _modify! (_add_aisone, A[j,j] * dja, out, (j,j))
471+ end
472+ # indices of out corresponding to the stored indices of A
473+ rowrange = _rowrange_tri_stored (A, j)
474+ @inbounds @simd for i in rowrange
475+ _modify! (_add_aisone, A_maybeparent[i,j] * dja, out_maybeparent, (i,j))
476+ end
477+ # Fill the indices of out corresponding to the zeros of A
478+ # we only fill these if out and A don't have matching zeros
479+ if ! _has_matching_zeros (out, A)
480+ rowrange = _rowrange_tri_zeros (A, j)
481+ @inbounds @simd for i in rowrange
482+ _modify! (_add_aisone, A[i,j] * dja, out, (i,j))
460483 end
461484 end
462485 end
463- return out
486+ out
464487end
465- function __muldiag! (out, D1:: Diagonal , D2:: Diagonal , _add:: MulAddMul{ais1,bis0} ) where {ais1,bis0}
466- require_one_based_indexing (out)
467- alpha, beta = _add. alpha, _add. beta
468- mA = size (D1, 1 )
488+
489+ @inline function __muldiag_nonzeroalpha! (out:: Diagonal , D1:: Diagonal , D2:: Diagonal , _add:: MulAddMul )
469490 d1 = D1. diag
470491 d2 = D2. diag
471- _rmul_or_fill! (out, beta)
472- if ! iszero (alpha)
473- @inbounds @simd for i in 1 : mA
474- out[i,i] += d1[i] * d2[i] * alpha
475- end
492+ outd = out. diag
493+ @inbounds @simd for i in eachindex (d1, d2, outd)
494+ _modify! (_add, d1[i] * d2[i], outd, i)
476495 end
477- return out
496+ out
497+ end
498+
499+ # ambiguity resolution
500+ @inline function __muldiag_nonzeroalpha! (out, D1:: Diagonal , D2:: Diagonal , _add:: MulAddMul )
501+ @inbounds for j in axes (D2, 2 ), i in axes (D2, 1 )
502+ _modify! (_add, D1. diag[i] * D2[i,j], out, (i,j))
503+ end
504+ out
478505end
479506
507+ # muldiag mainly handles the zero-alpha case, so that we need only
508+ # specialize the non-trivial case
480509function _mul_diag! (out, A, B, _add)
510+ require_one_based_indexing (out, A, B)
481511 _muldiag_size_check (size (out), size (A), size (B))
482- __muldiag! (out, A, B, _add)
512+ alpha, beta = _add. alpha, _add. beta
513+ if iszero (alpha)
514+ _rmul_or_fill! (out, beta)
515+ else
516+ __muldiag_nonzeroalpha! (out, A, B, _add)
517+ end
483518 return out
484519end
485520
@@ -659,31 +694,21 @@ for Tri in (:UpperTriangular, :LowerTriangular)
659694 @eval $ fun (A:: $Tri , D:: Diagonal ) = $ Tri ($ fun (A. data, D))
660695 @eval $ fun (A:: $UTri , D:: Diagonal ) = $ Tri (_setdiag! ($ fun (A. data, D), $ f, D. diag))
661696 end
697+ @eval * (A:: $Tri{<:Any, <:StridedMaybeAdjOrTransMat} , D:: Diagonal ) =
698+ @invoke * (A:: AbstractMatrix , D:: Diagonal )
699+ @eval * (A:: $UTri{<:Any, <:StridedMaybeAdjOrTransMat} , D:: Diagonal ) =
700+ @invoke * (A:: AbstractMatrix , D:: Diagonal )
662701 for (fun, f) in zip ((:* , :lmul! , :ldiv! , :\ ), (:identity , :identity , :inv , :inv ))
663702 @eval $ fun (D:: Diagonal , A:: $Tri ) = $ Tri ($ fun (D, A. data))
664703 @eval $ fun (D:: Diagonal , A:: $UTri ) = $ Tri (_setdiag! ($ fun (D, A. data), $ f, D. diag))
665704 end
705+ @eval * (D:: Diagonal , A:: $Tri{<:Any, <:StridedMaybeAdjOrTransMat} ) =
706+ @invoke * (D:: Diagonal , A:: AbstractMatrix )
707+ @eval * (D:: Diagonal , A:: $UTri{<:Any, <:StridedMaybeAdjOrTransMat} ) =
708+ @invoke * (D:: Diagonal , A:: AbstractMatrix )
666709 # 3-arg ldiv!
667710 @eval ldiv! (C:: $Tri , D:: Diagonal , A:: $Tri ) = $ Tri (ldiv! (C. data, D, A. data))
668711 @eval ldiv! (C:: $Tri , D:: Diagonal , A:: $UTri ) = $ Tri (_setdiag! (ldiv! (C. data, D, A. data), inv, D. diag))
669- # 3-arg mul! is disambiguated in special.jl
670- # 5-arg mul!
671- @eval _mul! (C:: $Tri , D:: Diagonal , A:: $Tri , _add) = $ Tri (mul! (C. data, D, A. data, _add. alpha, _add. beta))
672- @eval function _mul! (C:: $Tri , D:: Diagonal , A:: $UTri , _add:: MulAddMul{ais1,bis0} ) where {ais1,bis0}
673- α, β = _add. alpha, _add. beta
674- iszero (α) && return _rmul_or_fill! (C, β)
675- diag′ = bis0 ? nothing : diag (C)
676- data = mul! (C. data, D, A. data, α, β)
677- $ Tri (_setdiag! (data, _add, D. diag, diag′))
678- end
679- @eval _mul! (C:: $Tri , A:: $Tri , D:: Diagonal , _add) = $ Tri (mul! (C. data, A. data, D, _add. alpha, _add. beta))
680- @eval function _mul! (C:: $Tri , A:: $UTri , D:: Diagonal , _add:: MulAddMul{ais1,bis0} ) where {ais1,bis0}
681- α, β = _add. alpha, _add. beta
682- iszero (α) && return _rmul_or_fill! (C, β)
683- diag′ = bis0 ? nothing : diag (C)
684- data = mul! (C. data, A. data, D, α, β)
685- $ Tri (_setdiag! (data, _add, D. diag, diag′))
686- end
687712end
688713
689714@inline function kron! (C:: AbstractMatrix , A:: Diagonal , B:: Diagonal )
0 commit comments