Skip to content

Commit 86ab75e

Browse files
authored
Add upper/lowertriangular functions and use in applytri (#53573)
We may use the fact that a `Diagonal` is already triangular to avoid adding a wrapper. Fixes the specific example in https://github.com/JuliaLang/julia/issues/53564, although not the broader issue. This is because it changes the operation from a `UpperTriangular + UpperTriangular` to a `UpperTriangular + Diagonal`, which uses broadcasting. The latter operation may also allow one to define more efficient methods.
1 parent 3265387 commit 86ab75e

File tree

6 files changed

+54
-6
lines changed

6 files changed

+54
-6
lines changed

stdlib/LinearAlgebra/src/diagonal.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -947,3 +947,6 @@ end
947947
function Base.muladd(A::Diagonal, B::Diagonal, z::Diagonal)
948948
Diagonal(A.diag .* B.diag .+ z.diag)
949949
end
950+
951+
uppertriangular(D::Diagonal) = D
952+
lowertriangular(D::Diagonal) = D

stdlib/LinearAlgebra/src/symmetric.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -281,21 +281,21 @@ diag(A::Hermitian) = hermitian.(diag(parent(A)), sym_uplo(A.uplo))
281281

282282
function applytri(f, A::HermOrSym)
283283
if A.uplo == 'U'
284-
f(UpperTriangular(A.data))
284+
f(uppertriangular(A.data))
285285
else
286-
f(LowerTriangular(A.data))
286+
f(lowertriangular(A.data))
287287
end
288288
end
289289

290290
function applytri(f, A::HermOrSym, B::HermOrSym)
291291
if A.uplo == B.uplo == 'U'
292-
f(UpperTriangular(A.data), UpperTriangular(B.data))
292+
f(uppertriangular(A.data), uppertriangular(B.data))
293293
elseif A.uplo == B.uplo == 'L'
294-
f(LowerTriangular(A.data), LowerTriangular(B.data))
294+
f(lowertriangular(A.data), lowertriangular(B.data))
295295
elseif A.uplo == 'U'
296-
f(UpperTriangular(A.data), UpperTriangular(_conjugation(B)(B.data)))
296+
f(uppertriangular(A.data), uppertriangular(_conjugation(B)(B.data)))
297297
else # A.uplo == 'L'
298-
f(UpperTriangular(_conjugation(A)(A.data)), UpperTriangular(B.data))
298+
f(uppertriangular(_conjugation(A)(A.data)), uppertriangular(B.data))
299299
end
300300
end
301301
parentof_applytri(f, args...) = applytri(parent f, args...)

stdlib/LinearAlgebra/src/triangular.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,12 @@ const UpperOrUnitUpperTriangular{T,S} = Union{UpperTriangular{T,S}, UnitUpperTri
154154
const LowerOrUnitLowerTriangular{T,S} = Union{LowerTriangular{T,S}, UnitLowerTriangular{T,S}}
155155
const UpperOrLowerTriangular{T,S} = Union{UpperOrUnitUpperTriangular{T,S}, LowerOrUnitLowerTriangular{T,S}}
156156

157+
uppertriangular(M) = UpperTriangular(M)
158+
lowertriangular(M) = LowerTriangular(M)
159+
160+
uppertriangular(U::UpperOrUnitUpperTriangular) = U
161+
lowertriangular(U::LowerOrUnitLowerTriangular) = U
162+
157163
Base.dataids(A::UpperOrLowerTriangular) = Base.dataids(A.data)
158164

159165
imag(A::UpperTriangular) = UpperTriangular(imag(A.data))

stdlib/LinearAlgebra/test/diagonal.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1281,6 +1281,12 @@ end
12811281
@test c == Diagonal([2,2,2,2])
12821282
end
12831283

1284+
@testset "uppertriangular/lowertriangular" begin
1285+
D = Diagonal([1,2])
1286+
@test LinearAlgebra.uppertriangular(D) === D
1287+
@test LinearAlgebra.lowertriangular(D) === D
1288+
end
1289+
12841290
@testset "mul/div with an adjoint vector" begin
12851291
A = [1.0;;]
12861292
x = [1.0]

stdlib/LinearAlgebra/test/symmetric.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -532,6 +532,31 @@ end
532532
@test kron(Sl,Su) == kron(MSl,MSu)
533533
end
534534
end
535+
@testset "non-strided" begin
536+
@testset "diagonal" begin
537+
for ST1 in (Symmetric, Hermitian), uplo1 in (:L, :U)
538+
m = ST1(Matrix{BigFloat}(undef,2,2), uplo1)
539+
m.data[1,1] = 1
540+
m.data[2,2] = 3
541+
m.data[1+(uplo1==:L), 1+(uplo1==:U)] = 2
542+
A = Array(m)
543+
for ST2 in (Symmetric, Hermitian), uplo2 in (:L, :U)
544+
id = ST2(I(2), uplo2)
545+
@test m + id == id + m == A + id
546+
end
547+
end
548+
end
549+
@testset "unit triangular" begin
550+
for ST1 in (Symmetric, Hermitian), uplo1 in (:L, :U)
551+
H1 = ST1(UnitUpperTriangular(big.(rand(Int8,4,4))), uplo1)
552+
M1 = Matrix(H1)
553+
for ST2 in (Symmetric, Hermitian), uplo2 in (:L, :U)
554+
H2 = ST2(UnitUpperTriangular(big.(rand(Int8,4,4))), uplo2)
555+
@test H1 + H2 == M1 + Matrix(H2)
556+
end
557+
end
558+
end
559+
end
535560
end
536561

537562
# bug identified in PR #52318: dot products of quaternionic Hermitian matrices,

stdlib/LinearAlgebra/test/triangular.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -998,6 +998,14 @@ end
998998
end
999999
end
10001000

1001+
@testset "uppertriangular/lowertriangular" begin
1002+
M = rand(2,2)
1003+
@test LinearAlgebra.uppertriangular(M) === UpperTriangular(M)
1004+
@test LinearAlgebra.lowertriangular(M) === LowerTriangular(M)
1005+
@test LinearAlgebra.uppertriangular(UnitUpperTriangular(M)) === UnitUpperTriangular(M)
1006+
@test LinearAlgebra.lowertriangular(UnitLowerTriangular(M)) === UnitLowerTriangular(M)
1007+
end
1008+
10011009
@testset "arithmetic with partly uninitialized matrices" begin
10021010
@testset "$(typeof(A))" for A in (Matrix{BigFloat}(undef,2,2), Matrix{Complex{BigFloat}}(undef,2,2)')
10031011
A[2,1] = eltype(A) <: Complex ? 4 + 3im : 4

0 commit comments

Comments
 (0)