Skip to content

Commit 2ca9e00

Browse files
authored
LinearAlgebra: copyto! between banded matrix types (#54041)
This specialized `copyto!` for combinations of banded structured matrix types so that the copy may be O(N) instead of the fallback O(N^2) implementation. E.g.: ```julia julia> T = Tridiagonal(zeros(999), zeros(1000), zeros(999)); julia> B = Bidiagonal(ones(1000), fill(2.0, 999), :U); julia> @Btime copyto!($T, $B); 1.927 ms (0 allocations: 0 bytes) # master 229.870 ns (0 allocations: 0 bytes) # PR ``` This also changes the `copyto!` implementation for mismatched matrix sizes, bringing it closer to the docstring. So, the following works on master: ```julia julia> Ddest = Diagonal(zeros(4)); julia> Dsrc = Diagonal(ones(2)); julia> copyto!(Ddest, Dsrc) 4×4 Diagonal{Float64, Vector{Float64}}: 1.0 ⋅ ⋅ ⋅ ⋅ 1.0 ⋅ ⋅ ⋅ ⋅ 0.0 ⋅ ⋅ ⋅ ⋅ 0.0 ``` but this won't work anymore with this PR. This was inconsistent anyway, as materializing the matrices produces a different result, which shouldn't be the case: ```julia julia> copyto!(Matrix(Ddest), Dsrc) 4×4 Matrix{Float64}: 1.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 ``` After this PR, the way to carry out the copy would be ```julia julia> copyto!(Ddest, CartesianIndices(Dsrc), Dsrc, CartesianIndices(Dsrc)) 4×4 Diagonal{Float64, Vector{Float64}}: 1.0 ⋅ ⋅ ⋅ ⋅ 1.0 ⋅ ⋅ ⋅ ⋅ 0.0 ⋅ ⋅ ⋅ ⋅ 0.0 ``` This change fixes https://github.com/JuliaLang/julia/issues/46005. Also fixes https://github.com/JuliaLang/julia/issues/53997 After this, ```julia julia> @Btime copyto!(C, B) setup=(n = 1_000; B = Bidiagonal(randn(n), randn(n-1), :L); C = Bidiagonal(randn(n), randn(n-1), :L)); 158.405 ns (0 allocations: 0 bytes) julia> @Btime copyto!(C, B) setup=(n = 10_000; B = Bidiagonal(randn(n), randn(n-1), :L); C = Bidiagonal(randn(n), randn(n-1), :L)); 4.706 μs (0 allocations: 0 bytes) julia> @Btime copyto!(C, B) setup=(n = 100_000; B = Bidiagonal(randn(n), randn(n-1), :L); C = Bidiagonal(randn(n), randn(n-1), :L)); 120.880 μs (0 allocations: 0 bytes) ``` which is roughly linear scaling. Taken along with #54027, the speed-ups would also apply to the adjoints of banded matrices.
1 parent 2b878f0 commit 2ca9e00

File tree

7 files changed

+406
-3
lines changed

7 files changed

+406
-3
lines changed

stdlib/LinearAlgebra/src/bidiag.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,22 @@ function Base.copy(tB::Transpose{<:Any,<:Bidiagonal})
291291
return Bidiagonal(map(x -> copy.(transpose.(x)), (B.dv, B.ev))..., B.uplo == 'U' ? :L : :U)
292292
end
293293

294+
# copyto! for matching axes
295+
function _copyto_banded!(A::Bidiagonal, B::Bidiagonal)
296+
A.dv .= B.dv
297+
if A.uplo == B.uplo
298+
A.ev .= B.ev
299+
elseif iszero(B.ev) # diagonal source
300+
A.ev .= zero.(A.ev)
301+
else
302+
zeroband = istriu(A) ? "lower" : "upper"
303+
uplo = A.uplo
304+
throw(ArgumentError(string("cannot set the ",
305+
zeroband, " bidiagonal band to a nonzero value for uplo=:", uplo)))
306+
end
307+
return A
308+
end
309+
294310
iszero(M::Bidiagonal) = iszero(M.dv) && iszero(M.ev)
295311
isone(M::Bidiagonal) = all(isone, M.dv) && iszero(M.ev)
296312
function istriu(M::Bidiagonal, k::Integer=0)
@@ -332,6 +348,8 @@ function istril(M::Bidiagonal, k::Integer=0)
332348
end
333349
end
334350
isdiag(M::Bidiagonal) = iszero(M.ev)
351+
issymmetric(M::Bidiagonal) = isdiag(M) && all(issymmetric, M.dv)
352+
ishermitian(M::Bidiagonal) = isdiag(M) && all(ishermitian, M.dv)
335353

336354
function tril!(M::Bidiagonal{T}, k::Integer=0) where T
337355
n = length(M.dv)

stdlib/LinearAlgebra/src/diagonal.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,8 @@ Diagonal{T}(::UndefInitializer, n::Integer) where T = Diagonal(Vector{T}(undef,
136136
similar(D::Diagonal, ::Type{T}) where {T} = Diagonal(similar(D.diag, T))
137137
similar(D::Diagonal, ::Type{T}, dims::Union{Dims{1},Dims{2}}) where {T} = similar(D.diag, T, dims)
138138

139-
copyto!(D1::Diagonal, D2::Diagonal) = (copyto!(D1.diag, D2.diag); D1)
139+
# copyto! for matching axes
140+
_copyto_banded!(D1::Diagonal, D2::Diagonal) = (copyto!(D1.diag, D2.diag); D1)
140141

141142
size(D::Diagonal) = (n = length(D.diag); (n,n))
142143

stdlib/LinearAlgebra/src/special.jl

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,92 @@ isdiag(A::HermOrSym{<:Any,<:Diagonal}) = isdiag(parent(A))
307307
dot(x::AbstractVector, A::RealHermSymComplexSym{<:Real,<:Diagonal}, y::AbstractVector) =
308308
dot(x, A.data, y)
309309

310+
# O(N) implementations using the banded structure
311+
function copyto!(dest::BandedMatrix, src::BandedMatrix)
312+
if axes(dest) == axes(src)
313+
_copyto_banded!(dest, src)
314+
else
315+
@invoke copyto!(dest::AbstractMatrix, src::AbstractMatrix)
316+
end
317+
return dest
318+
end
319+
function _copyto_banded!(T::Tridiagonal, D::Diagonal)
320+
T.d .= D.diag
321+
T.dl .= zero.(T.dl)
322+
T.du .= zero.(T.du)
323+
return T
324+
end
325+
function _copyto_banded!(SymT::SymTridiagonal, D::Diagonal)
326+
issymmetric(D) || throw(ArgumentError("cannot copy a non-symmetric Diagonal matrix to a SymTridiagonal"))
327+
SymT.dv .= D.diag
328+
_ev = _evview(SymT)
329+
_ev .= zero.(_ev)
330+
return SymT
331+
end
332+
function _copyto_banded!(B::Bidiagonal, D::Diagonal)
333+
B.dv .= D.diag
334+
B.ev .= zero.(B.ev)
335+
return B
336+
end
337+
function _copyto_banded!(D::Diagonal, B::Bidiagonal)
338+
isdiag(B) ||
339+
throw(ArgumentError("cannot copy a Bidiagonal with a non-zero off-diagonal band to a Diagonal"))
340+
D.diag .= B.dv
341+
return D
342+
end
343+
function _copyto_banded!(D::Diagonal, T::Tridiagonal)
344+
isdiag(T) ||
345+
throw(ArgumentError("cannot copy a Tridiagonal with a non-zero off-diagonal band to a Diagonal"))
346+
D.diag .= T.d
347+
return D
348+
end
349+
function _copyto_banded!(D::Diagonal, SymT::SymTridiagonal)
350+
isdiag(SymT) ||
351+
throw(ArgumentError("cannot copy a SymTridiagonal with a non-zero off-diagonal band to a Diagonal"))
352+
# we broadcast identity for numbers using the fact that symmetric(x::Number) = x
353+
# this potentially allows us to access faster copyto! paths
354+
_symmetric = eltype(SymT) <: Number ? identity : symmetric
355+
D.diag .= _symmetric.(SymT.dv)
356+
return D
357+
end
358+
function _copyto_banded!(T::Tridiagonal, B::Bidiagonal)
359+
T.d .= B.dv
360+
if B.uplo == 'U'
361+
T.du .= B.ev
362+
T.dl .= zero.(T.dl)
363+
else
364+
T.dl .= B.ev
365+
T.du .= zero.(T.du)
366+
end
367+
return T
368+
end
369+
function _copyto_banded!(SymT::SymTridiagonal, B::Bidiagonal)
370+
issymmetric(B) || throw(ArgumentError("cannot copy a non-symmetric Bidiagonal matrix to a SymTridiagonal"))
371+
SymT.dv .= B.dv
372+
_ev = _evview(SymT)
373+
_ev .= zero.(_ev)
374+
return SymT
375+
end
376+
function _copyto_banded!(B::Bidiagonal, T::Tridiagonal)
377+
if B.uplo == 'U' && !iszero(T.dl)
378+
throw(ArgumentError("cannot copy a Tridiagonal with a non-zero subdiagonal to a Bidiagonal with uplo=:U"))
379+
elseif B.uplo == 'L' && !iszero(T.du)
380+
throw(ArgumentError("cannot copy a Tridiagonal with a non-zero superdiagonal to a Bidiagonal with uplo=:L"))
381+
end
382+
B.dv .= T.d
383+
B.ev .= B.uplo == 'U' ? T.du : T.dl
384+
return B
385+
end
386+
function _copyto_banded!(B::Bidiagonal, SymT::SymTridiagonal)
387+
isdiag(SymT) ||
388+
throw(ArgumentError("cannot copy a SymTridiagonal with a non-zero off-diagonal band to a Bidiagonal"))
389+
# we broadcast identity for numbers using the fact that symmetric(x::Number) = x
390+
# this potentially allows us to access faster copyto! paths
391+
_symmetric = eltype(SymT) <: Number ? identity : symmetric
392+
B.dv .= _symmetric.(SymT.dv)
393+
return B
394+
end
395+
310396
# equals and approx equals methods for structured matrices
311397
# SymTridiagonal == Tridiagonal is already defined in tridiag.jl
312398

stdlib/LinearAlgebra/src/tridiag.jl

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,8 @@ axes(M::SymTridiagonal) = (ax = axes(M.dv, 1); (ax, ax))
157157
similar(S::SymTridiagonal, ::Type{T}) where {T} = SymTridiagonal(similar(S.dv, T), similar(S.ev, T))
158158
similar(S::SymTridiagonal, ::Type{T}, dims::Union{Dims{1},Dims{2}}) where {T} = similar(S.dv, T, dims)
159159

160-
copyto!(dest::SymTridiagonal, src::SymTridiagonal) =
160+
# copyto! for matching axes
161+
_copyto_banded!(dest::SymTridiagonal, src::SymTridiagonal) =
161162
(copyto!(dest.dv, src.dv); copyto!(dest.ev, _evview(src)); dest)
162163

163164
#Elementary operations
@@ -606,7 +607,13 @@ similar(M::Tridiagonal, ::Type{T}) where {T} = Tridiagonal(similar(M.dl, T), sim
606607
similar(M::Tridiagonal, ::Type{T}, dims::Union{Dims{1},Dims{2}}) where {T} = similar(M.d, T, dims)
607608

608609
# Operations on Tridiagonal matrices
609-
copyto!(dest::Tridiagonal, src::Tridiagonal) = (copyto!(dest.dl, src.dl); copyto!(dest.d, src.d); copyto!(dest.du, src.du); dest)
610+
# copyto! for matching axes
611+
function _copyto_banded!(dest::Tridiagonal, src::Tridiagonal)
612+
copyto!(dest.dl, src.dl)
613+
copyto!(dest.d, src.d)
614+
copyto!(dest.du, src.du)
615+
dest
616+
end
610617

611618
#Elementary operations
612619
for func in (:conj, :copy, :real, :imag)
@@ -981,3 +988,21 @@ function ldiv!(A::Tridiagonal, B::AbstractVecOrMat)
981988
end
982989
return B
983990
end
991+
992+
# combinations of Tridiagonal and Symtridiagonal
993+
# copyto! for matching axes
994+
function _copyto_banded!(A::Tridiagonal, B::SymTridiagonal)
995+
Bev = _evview(B)
996+
A.du .= Bev
997+
# Broadcast identity for numbers to access the faster copyto! path
998+
# This uses the fact that transpose(x::Number) = x and symmetric(x::Number) = x
999+
A.dl .= (eltype(B) <: Number ? identity : transpose).(Bev)
1000+
A.d .= (eltype(B) <: Number ? identity : symmetric).(B.dv)
1001+
return A
1002+
end
1003+
function _copyto_banded!(A::SymTridiagonal, B::Tridiagonal)
1004+
issymmetric(B) || throw(ArgumentError("cannot copy a non-symmetric Tridiagonal matrix to a SymTridiagonal"))
1005+
A.dv .= B.d
1006+
_evview(A) .= B.du
1007+
return A
1008+
end

stdlib/LinearAlgebra/test/bidiag.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -834,6 +834,26 @@ end
834834
end
835835
end
836836

837+
@testset "copyto!" begin
838+
ev, dv = [1:4;], [1:5;]
839+
B = Bidiagonal(dv, ev, :U)
840+
B2 = copyto!(zero(B), B)
841+
@test B2 == B
842+
for (ul1, ul2) in ((:U, :L), (:L, :U))
843+
B3 = Bidiagonal(dv, zero(ev), ul1)
844+
B2 = Bidiagonal(zero(dv), zero(ev), ul2)
845+
@test copyto!(B2, B3) == B3
846+
end
847+
848+
@testset "mismatched sizes" begin
849+
dv2 = [4; @view dv[2:end]]
850+
@test copyto!(B, Bidiagonal([4], Int[], :U)) == Bidiagonal(dv2, ev, :U)
851+
@test copyto!(B, Bidiagonal([4], Int[], :L)) == Bidiagonal(dv2, ev, :U)
852+
@test copyto!(B, Bidiagonal(Int[], Int[], :U)) == Bidiagonal(dv, ev, :U)
853+
@test copyto!(B, Bidiagonal(Int[], Int[], :L)) == Bidiagonal(dv, ev, :U)
854+
end
855+
end
856+
837857
@testset "copyto! with UniformScaling" begin
838858
@testset "Fill" begin
839859
for len in (4, InfiniteArrays.Infinity())

0 commit comments

Comments
 (0)