Skip to content

Commit e7fafc7

Browse files
jishnubKristofferC
authored andcommitted
Fix triu/tril for partly initialized matrices (#55312)
This fixes ```julia julia> using LinearAlgebra, StaticArrays julia> M = Matrix{BigInt}(undef, 2, 2); M[1,1] = M[2,2] = M[1,2] = 3; julia> S = SizedMatrix{2,2}(M) 2×2 SizedMatrix{2, 2, BigInt, 2, Matrix{BigInt}} with indices SOneTo(2)×SOneTo(2): 3 3 #undef 3 julia> triu(S) ERROR: UndefRefError: access to undefined reference Stacktrace: [1] getindex @ ./essentials.jl:907 [inlined] [2] getindex @ ~/.julia/packages/StaticArrays/MSJcA/src/SizedArray.jl:92 [inlined] [3] copyto_unaliased! @ ./abstractarray.jl:1086 [inlined] [4] copyto!(dest::SizedMatrix{2, 2, BigInt, 2, Matrix{BigInt}}, src::SizedMatrix{2, 2, BigInt, 2, Matrix{BigInt}}) @ Base ./abstractarray.jl:1066 [5] copymutable @ ./abstractarray.jl:1200 [inlined] [6] triu(M::SizedMatrix{2, 2, BigInt, 2, Matrix{BigInt}}) @ LinearAlgebra ~/.julia/juliaup/julia-nightly/share/julia/stdlib/v1.12/LinearAlgebra/src/generic.jl:413 [7] top-level scope @ REPL[11]:1 ``` After this PR: ```julia julia> triu(S) 2×2 SizedMatrix{2, 2, BigInt, 2, Matrix{BigInt}} with indices SOneTo(2)×SOneTo(2): 3 3 0 3 ``` Only the indices that need to be copied are accessed, and the others are written to without being read.
1 parent 763c225 commit e7fafc7

File tree

2 files changed

+83
-52
lines changed

2 files changed

+83
-52
lines changed

stdlib/LinearAlgebra/src/generic.jl

Lines changed: 28 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -389,55 +389,7 @@ function cross(a::AbstractVector, b::AbstractVector)
389389
end
390390

391391
"""
392-
triu(M)
393-
394-
Upper triangle of a matrix.
395-
396-
# Examples
397-
```jldoctest
398-
julia> a = fill(1.0, (4,4))
399-
4×4 Matrix{Float64}:
400-
1.0 1.0 1.0 1.0
401-
1.0 1.0 1.0 1.0
402-
1.0 1.0 1.0 1.0
403-
1.0 1.0 1.0 1.0
404-
405-
julia> triu(a)
406-
4×4 Matrix{Float64}:
407-
1.0 1.0 1.0 1.0
408-
0.0 1.0 1.0 1.0
409-
0.0 0.0 1.0 1.0
410-
0.0 0.0 0.0 1.0
411-
```
412-
"""
413-
triu(M::AbstractMatrix) = triu!(copymutable(M))
414-
415-
"""
416-
tril(M)
417-
418-
Lower triangle of a matrix.
419-
420-
# Examples
421-
```jldoctest
422-
julia> a = fill(1.0, (4,4))
423-
4×4 Matrix{Float64}:
424-
1.0 1.0 1.0 1.0
425-
1.0 1.0 1.0 1.0
426-
1.0 1.0 1.0 1.0
427-
1.0 1.0 1.0 1.0
428-
429-
julia> tril(a)
430-
4×4 Matrix{Float64}:
431-
1.0 0.0 0.0 0.0
432-
1.0 1.0 0.0 0.0
433-
1.0 1.0 1.0 0.0
434-
1.0 1.0 1.0 1.0
435-
```
436-
"""
437-
tril(M::AbstractMatrix) = tril!(copymutable(M))
438-
439-
"""
440-
triu(M, k::Integer)
392+
triu(M, k::Integer = 0)
441393
442394
Return the upper triangle of `M` starting from the `k`th superdiagonal.
443395
@@ -465,10 +417,22 @@ julia> triu(a,-3)
465417
1.0 1.0 1.0 1.0
466418
```
467419
"""
468-
triu(M::AbstractMatrix,k::Integer) = triu!(copymutable(M),k)
420+
function triu(M::AbstractMatrix, k::Integer = 0)
421+
d = similar(M)
422+
A = triu!(d,k)
423+
if iszero(k)
424+
copytrito!(A, M, 'U')
425+
else
426+
for col in axes(A,2)
427+
rows = firstindex(A,1):min(col-k, lastindex(A,1))
428+
A[rows, col] = @view M[rows, col]
429+
end
430+
end
431+
return A
432+
end
469433

470434
"""
471-
tril(M, k::Integer)
435+
tril(M, k::Integer = 0)
472436
473437
Return the lower triangle of `M` starting from the `k`th superdiagonal.
474438
@@ -496,7 +460,19 @@ julia> tril(a,-3)
496460
1.0 0.0 0.0 0.0
497461
```
498462
"""
499-
tril(M::AbstractMatrix,k::Integer) = tril!(copymutable(M),k)
463+
function tril(M::AbstractMatrix,k::Integer=0)
464+
d = similar(M)
465+
A = tril!(d,k)
466+
if iszero(k)
467+
copytrito!(A, M, 'L')
468+
else
469+
for col in axes(A,2)
470+
rows = max(firstindex(A,1),col-k):lastindex(A,1)
471+
A[rows, col] = @view M[rows, col]
472+
end
473+
end
474+
return A
475+
end
500476

501477
"""
502478
triu!(M)

stdlib/LinearAlgebra/test/generic.jl

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ using .Main.DualNumbers
1818
isdefined(Main, :FillArrays) || @eval Main include(joinpath($(BASE_TEST_PATH), "testhelpers", "FillArrays.jl"))
1919
using .Main.FillArrays
2020

21+
isdefined(Main, :SizedArrays) || @eval Main include(joinpath($(BASE_TEST_PATH), "testhelpers", "SizedArrays.jl"))
22+
using .Main.SizedArrays
23+
2124
Random.seed!(123)
2225

2326
n = 5 # should be odd
@@ -725,4 +728,56 @@ end
725728
@test det(A) == det(M)
726729
end
727730

731+
@testset "tril/triu" begin
732+
@testset "with partly initialized matrices" begin
733+
function test_triu(M, k=nothing)
734+
M[1,1] = M[2,2] = M[1,2] = M[1,3] = M[2,3] = 3
735+
if isnothing(k)
736+
MU = triu(M)
737+
else
738+
MU = triu(M, k)
739+
end
740+
@test iszero(MU[2,1])
741+
@test MU[1,1] == MU[2,2] == MU[1,2] == MU[1,3] == MU[2,3] == 3
742+
end
743+
test_triu(Matrix{BigInt}(undef, 2, 3))
744+
test_triu(Matrix{BigInt}(undef, 2, 3), 0)
745+
test_triu(SizedArrays.SizedArray{(2,3)}(Matrix{BigInt}(undef, 2, 3)))
746+
test_triu(SizedArrays.SizedArray{(2,3)}(Matrix{BigInt}(undef, 2, 3)), 0)
747+
748+
function test_tril(M, k=nothing)
749+
M[1,1] = M[2,2] = M[2,1] = 3
750+
if isnothing(k)
751+
ML = tril(M)
752+
else
753+
ML = tril(M, k)
754+
end
755+
@test ML[1,2] == ML[1,3] == ML[2,3] == 0
756+
@test ML[1,1] == ML[2,2] == ML[2,1] == 3
757+
end
758+
test_tril(Matrix{BigInt}(undef, 2, 3))
759+
test_tril(Matrix{BigInt}(undef, 2, 3), 0)
760+
test_tril(SizedArrays.SizedArray{(2,3)}(Matrix{BigInt}(undef, 2, 3)))
761+
test_tril(SizedArrays.SizedArray{(2,3)}(Matrix{BigInt}(undef, 2, 3)), 0)
762+
end
763+
764+
@testset "block arrays" begin
765+
for nrows in 0:3, ncols in 0:3
766+
M = [randn(2,2) for _ in 1:nrows, _ in 1:ncols]
767+
Mu = triu(M)
768+
for col in axes(M,2)
769+
rowcutoff = min(col, size(M,1))
770+
@test @views Mu[1:rowcutoff, col] == M[1:rowcutoff, col]
771+
@test @views Mu[rowcutoff+1:end, col] == zero.(M[rowcutoff+1:end, col])
772+
end
773+
Ml = tril(M)
774+
for col in axes(M,2)
775+
@test @views Ml[col:end, col] == M[col:end, col]
776+
rowcutoff = min(col-1, size(M,1))
777+
@test @views Ml[1:rowcutoff, col] == zero.(M[1:rowcutoff, col])
778+
end
779+
end
780+
end
781+
end
782+
728783
end # module TestGeneric

0 commit comments

Comments
 (0)