Skip to content

Commit 6170c4b

Browse files
authored
Improve type-stability in SymTridiagonal triu!/tril! (#55646)
Changing the final `elseif` branch to an `else` makes it clear that the method definite returns a value, and the returned type is now a `Tridiagonal` instead of a `Union{Nothing, Tridiagonal}`
1 parent 58d5263 commit 6170c4b

File tree

2 files changed

+34
-18
lines changed

2 files changed

+34
-18
lines changed

stdlib/LinearAlgebra/src/tridiag.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,7 @@ function tril!(M::SymTridiagonal{T}, k::Integer=0) where T
372372
return Tridiagonal(M.ev,M.dv,zero(M.ev))
373373
elseif k == 0
374374
return Tridiagonal(M.ev,M.dv,zero(M.ev))
375-
elseif k >= 1
375+
else # if k >= 1
376376
return Tridiagonal(M.ev,M.dv,copy(M.ev))
377377
end
378378
end
@@ -391,7 +391,7 @@ function triu!(M::SymTridiagonal{T}, k::Integer=0) where T
391391
return Tridiagonal(zero(M.ev),M.dv,M.ev)
392392
elseif k == 0
393393
return Tridiagonal(zero(M.ev),M.dv,M.ev)
394-
elseif k <= -1
394+
else # if k <= -1
395395
return Tridiagonal(M.ev,M.dv,copy(M.ev))
396396
end
397397
end

stdlib/LinearAlgebra/test/tridiag.jl

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -135,27 +135,43 @@ end
135135
@test_throws ArgumentError tril!(SymTridiagonal(d, dl), n)
136136
@test_throws ArgumentError tril!(Tridiagonal(dl, d, du), -n - 2)
137137
@test_throws ArgumentError tril!(Tridiagonal(dl, d, du), n)
138-
@test tril(SymTridiagonal(d,dl)) == Tridiagonal(dl,d,zerosdl)
139-
@test tril(SymTridiagonal(d,dl),1) == Tridiagonal(dl,d,dl)
140-
@test tril(SymTridiagonal(d,dl),-1) == Tridiagonal(dl,zerosd,zerosdl)
141-
@test tril(SymTridiagonal(d,dl),-2) == Tridiagonal(zerosdl,zerosd,zerosdl)
142-
@test tril(Tridiagonal(dl,d,du)) == Tridiagonal(dl,d,zerosdu)
143-
@test tril(Tridiagonal(dl,d,du),1) == Tridiagonal(dl,d,du)
144-
@test tril(Tridiagonal(dl,d,du),-1) == Tridiagonal(dl,zerosd,zerosdu)
145-
@test tril(Tridiagonal(dl,d,du),-2) == Tridiagonal(zerosdl,zerosd,zerosdu)
138+
@test @inferred(tril(SymTridiagonal(d,dl))) == Tridiagonal(dl,d,zerosdl)
139+
@test @inferred(tril(SymTridiagonal(d,dl),1)) == Tridiagonal(dl,d,dl)
140+
@test @inferred(tril(SymTridiagonal(d,dl),-1)) == Tridiagonal(dl,zerosd,zerosdl)
141+
@test @inferred(tril(SymTridiagonal(d,dl),-2)) == Tridiagonal(zerosdl,zerosd,zerosdl)
142+
@test @inferred(tril(Tridiagonal(dl,d,du))) == Tridiagonal(dl,d,zerosdu)
143+
@test @inferred(tril(Tridiagonal(dl,d,du),1)) == Tridiagonal(dl,d,du)
144+
@test @inferred(tril(Tridiagonal(dl,d,du),-1)) == Tridiagonal(dl,zerosd,zerosdu)
145+
@test @inferred(tril(Tridiagonal(dl,d,du),-2)) == Tridiagonal(zerosdl,zerosd,zerosdu)
146+
@test @inferred(tril!(copy(SymTridiagonal(d,dl)))) == Tridiagonal(dl,d,zerosdl)
147+
@test @inferred(tril!(copy(SymTridiagonal(d,dl)),1)) == Tridiagonal(dl,d,dl)
148+
@test @inferred(tril!(copy(SymTridiagonal(d,dl)),-1)) == Tridiagonal(dl,zerosd,zerosdl)
149+
@test @inferred(tril!(copy(SymTridiagonal(d,dl)),-2)) == Tridiagonal(zerosdl,zerosd,zerosdl)
150+
@test @inferred(tril!(copy(Tridiagonal(dl,d,du)))) == Tridiagonal(dl,d,zerosdu)
151+
@test @inferred(tril!(copy(Tridiagonal(dl,d,du)),1)) == Tridiagonal(dl,d,du)
152+
@test @inferred(tril!(copy(Tridiagonal(dl,d,du)),-1)) == Tridiagonal(dl,zerosd,zerosdu)
153+
@test @inferred(tril!(copy(Tridiagonal(dl,d,du)),-2)) == Tridiagonal(zerosdl,zerosd,zerosdu)
146154

147155
@test_throws ArgumentError triu!(SymTridiagonal(d, dl), -n)
148156
@test_throws ArgumentError triu!(SymTridiagonal(d, dl), n + 2)
149157
@test_throws ArgumentError triu!(Tridiagonal(dl, d, du), -n)
150158
@test_throws ArgumentError triu!(Tridiagonal(dl, d, du), n + 2)
151-
@test triu(SymTridiagonal(d,dl)) == Tridiagonal(zerosdl,d,dl)
152-
@test triu(SymTridiagonal(d,dl),-1) == Tridiagonal(dl,d,dl)
153-
@test triu(SymTridiagonal(d,dl),1) == Tridiagonal(zerosdl,zerosd,dl)
154-
@test triu(SymTridiagonal(d,dl),2) == Tridiagonal(zerosdl,zerosd,zerosdl)
155-
@test triu(Tridiagonal(dl,d,du)) == Tridiagonal(zerosdl,d,du)
156-
@test triu(Tridiagonal(dl,d,du),-1) == Tridiagonal(dl,d,du)
157-
@test triu(Tridiagonal(dl,d,du),1) == Tridiagonal(zerosdl,zerosd,du)
158-
@test triu(Tridiagonal(dl,d,du),2) == Tridiagonal(zerosdl,zerosd,zerosdu)
159+
@test @inferred(triu(SymTridiagonal(d,dl))) == Tridiagonal(zerosdl,d,dl)
160+
@test @inferred(triu(SymTridiagonal(d,dl),-1)) == Tridiagonal(dl,d,dl)
161+
@test @inferred(triu(SymTridiagonal(d,dl),1)) == Tridiagonal(zerosdl,zerosd,dl)
162+
@test @inferred(triu(SymTridiagonal(d,dl),2)) == Tridiagonal(zerosdl,zerosd,zerosdl)
163+
@test @inferred(triu(Tridiagonal(dl,d,du))) == Tridiagonal(zerosdl,d,du)
164+
@test @inferred(triu(Tridiagonal(dl,d,du),-1)) == Tridiagonal(dl,d,du)
165+
@test @inferred(triu(Tridiagonal(dl,d,du),1)) == Tridiagonal(zerosdl,zerosd,du)
166+
@test @inferred(triu(Tridiagonal(dl,d,du),2)) == Tridiagonal(zerosdl,zerosd,zerosdu)
167+
@test @inferred(triu!(copy(SymTridiagonal(d,dl)))) == Tridiagonal(zerosdl,d,dl)
168+
@test @inferred(triu!(copy(SymTridiagonal(d,dl)),-1)) == Tridiagonal(dl,d,dl)
169+
@test @inferred(triu!(copy(SymTridiagonal(d,dl)),1)) == Tridiagonal(zerosdl,zerosd,dl)
170+
@test @inferred(triu!(copy(SymTridiagonal(d,dl)),2)) == Tridiagonal(zerosdl,zerosd,zerosdl)
171+
@test @inferred(triu!(copy(Tridiagonal(dl,d,du)))) == Tridiagonal(zerosdl,d,du)
172+
@test @inferred(triu!(copy(Tridiagonal(dl,d,du)),-1)) == Tridiagonal(dl,d,du)
173+
@test @inferred(triu!(copy(Tridiagonal(dl,d,du)),1)) == Tridiagonal(zerosdl,zerosd,du)
174+
@test @inferred(triu!(copy(Tridiagonal(dl,d,du)),2)) == Tridiagonal(zerosdl,zerosd,zerosdu)
159175

160176
@test !istril(SymTridiagonal(d,dl))
161177
@test istril(SymTridiagonal(d,zerosdl))

0 commit comments

Comments
 (0)