Skip to content

Commit 9ceb747

Browse files
authored
[NDTensors] GradedAxes library (#1271)
1 parent 1722ba4 commit 9ceb747

File tree

26 files changed

+328
-37
lines changed

26 files changed

+328
-37
lines changed

NDTensors/src/NDTensors.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ for lib in [
3131
:SparseArrayDOKs,
3232
:DiagonalArrays,
3333
:BlockSparseArrays,
34+
:GradedAxes,
3435
:NamedDimsArrays,
3536
:SmallVectors,
3637
:SortedSets,

NDTensors/src/arraystorage/diagonalarray/storage/contract.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
using .SparseArrayInterface: densearray
2+
using .DiagonalArrays: DiagIndex, diaglength
3+
14
# TODO: Move to a different file.
25
Unwrap.parenttype(::Type{<:DiagonalArray{<:Any,<:Any,P}}) where {P} = P
36

@@ -99,7 +102,7 @@ function contract!(
99102
coffset += ii * custride[i]
100103
end
101104
c = zero(eltype(C))
102-
for j in 1:DiagonalArrays.diaglength(A)
105+
for j in 1:diaglength(A)
103106
# With α == 0 && β == 1
104107
C[cstart + j * c_cstride + coffset] +=
105108
A[DiagIndex(j)] * B[bstart + j * b_cstride + boffset]

NDTensors/src/dims.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
using .DiagonalArrays: DiagonalArrays
2+
13
export dense, dims, dim, mindim, diaglength
24

35
# dim and dims are used in the Tensor interface, overload
@@ -26,7 +28,7 @@ mindim(inds::Tuple) = minimum(dims(inds))
2628

2729
mindim(::Tuple{}) = 1
2830

29-
diaglength(inds::Tuple) = mindim(inds)
31+
DiagonalArrays.diaglength(inds::Tuple) = mindim(inds)
3032

3133
"""
3234
dim_to_strides(ds)
@@ -94,4 +96,3 @@ dim(T::Tensor) = dim(inds(T))
9496
dim(T::Tensor, i::Int) = dim(inds(T), i)
9597
maxdim(T::Tensor) = maxdim(inds(T))
9698
mindim(T::Tensor) = mindim(inds(T))
97-
diaglength(T::Tensor) = mindim(T)

NDTensors/src/lib/DiagonalArrays/README.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33
A Julia `DiagonalArray` type.
44

55
````julia
6-
using NDTensors.DiagonalArrays: DiagonalArray, DiagonalMatrix, DiagIndex, DiagIndices, isdiagindex
6+
using NDTensors.DiagonalArrays:
7+
DiagonalArray, DiagonalMatrix, DiagIndex, DiagIndices, diaglength, isdiagindex
78
using Test
89

910
function main()
1011
d = DiagonalMatrix([1.0, 2.0, 3.0])
1112
@test eltype(d) == Float64
13+
@test diaglength(d) == 3
1214
@test size(d) == (3, 3)
1315
@test d[1, 1] == 1
1416
@test d[2, 2] == 2
@@ -17,6 +19,7 @@ function main()
1719

1820
d = DiagonalArray([1.0, 2.0, 3.0], 3, 4, 5)
1921
@test eltype(d) == Float64
22+
@test diaglength(d) == 3
2023
@test d[1, 1, 1] == 1
2124
@test d[2, 2, 2] == 2
2225
@test d[3, 3, 3] == 3

NDTensors/src/lib/DiagonalArrays/examples/README.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@
33
# A Julia `DiagonalArray` type.
44

55
using NDTensors.DiagonalArrays:
6-
DiagonalArray, DiagonalMatrix, DiagIndex, DiagIndices, isdiagindex
6+
DiagonalArray, DiagonalMatrix, DiagIndex, DiagIndices, diaglength, isdiagindex
77
using Test
88

99
function main()
1010
d = DiagonalMatrix([1.0, 2.0, 3.0])
1111
@test eltype(d) == Float64
12+
@test diaglength(d) == 3
1213
@test size(d) == (3, 3)
1314
@test d[1, 1] == 1
1415
@test d[2, 2] == 2
@@ -17,6 +18,7 @@ function main()
1718

1819
d = DiagonalArray([1.0, 2.0, 3.0], 3, 4, 5)
1920
@test eltype(d) == Float64
21+
@test diaglength(d) == 3
2022
@test d[1, 1, 1] == 1
2123
@test d[2, 2, 2] == 2
2224
@test d[3, 3, 3] == 3

NDTensors/src/lib/DiagonalArrays/src/diaginterface.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
using Compat: allequal
22

3+
diaglength(a::AbstractArray{<:Any,0}) = 1
4+
5+
function diaglength(a::AbstractArray)
6+
return minimum(size(a))
7+
end
8+
39
function isdiagindex(a::AbstractArray{<:Any,N}, I::CartesianIndex{N}) where {N}
410
@boundscheck checkbounds(a, I)
511
return allequal(Tuple(I))
@@ -16,8 +22,7 @@ function diagstride(a::AbstractArray)
1622
end
1723

1824
function diagindices(a::AbstractArray)
19-
diaglength = minimum(size(a))
20-
maxdiag = LinearIndices(a)[CartesianIndex(ntuple(Returns(diaglength), ndims(a)))]
25+
maxdiag = LinearIndices(a)[CartesianIndex(ntuple(Returns(diaglength(a)), ndims(a)))]
2126
return 1:diagstride(a):maxdiag
2227
end
2328

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
[deps]
2+
NDTensors = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
using Test
2-
using NDTensors.DiagonalArrays
3-
1+
@eval module $(gensym())
2+
using Test: @test, @testset
3+
using NDTensors.DiagonalArrays: DiagonalArrays
44
@testset "Test NDTensors.DiagonalArrays" begin
55
@testset "README" begin
66
@test include(
@@ -9,4 +9,12 @@ using NDTensors.DiagonalArrays
99
),
1010
) isa Any
1111
end
12+
@testset "Basics" begin
13+
using NDTensors.DiagonalArrays: diaglength
14+
a = fill(1.0, 2, 3)
15+
@test diaglength(a) == 2
16+
a = fill(1.0)
17+
@test diaglength(a) == 1
18+
end
19+
end
1220
end
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
module GradedAxes
2+
include("groupsortperm.jl")
3+
include("tensor_product.jl")
4+
include("abstractgradedunitrange.jl")
5+
include("gradedunitrange.jl")
6+
end
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
using BlockArrays:
2+
BlockArrays,
3+
Block,
4+
BlockRange,
5+
BlockedUnitRange,
6+
blockaxes,
7+
blockedrange,
8+
blockfirsts,
9+
blocklasts,
10+
blocklengths,
11+
findblock
12+
using Dictionaries: Dictionary
13+
14+
# Fuse two symmetry labels
15+
fuse(l1, l2) = error("Not implemented")
16+
17+
abstract type AbstractGradedUnitRange{T,G} <: AbstractUnitRange{Int} end
18+
19+
BlockArrays.blockedrange(a::AbstractGradedUnitRange) = error("Not implemented")
20+
sectors(a::AbstractGradedUnitRange) = error("Not implemented")
21+
scale_factor(a::AbstractGradedUnitRange) = error("Not implemented")
22+
23+
# BlockArrays block axis interface
24+
BlockArrays.blockaxes(a::AbstractGradedUnitRange) = blockaxes(blockedrange(a))
25+
Base.getindex(a::AbstractGradedUnitRange, b::Block{1}) = blockedrange(a)[b]
26+
BlockArrays.blockfirsts(a::AbstractGradedUnitRange) = blockfirsts(blockedrange(a))
27+
BlockArrays.blocklasts(a::AbstractGradedUnitRange) = blocklasts(blockedrange(a))
28+
function BlockArrays.findblock(a::AbstractGradedUnitRange, k::Integer)
29+
return findblock(blockedrange(a), k)
30+
end
31+
32+
# Base axis interface
33+
Base.getindex(a::AbstractGradedUnitRange, I::Integer) = blockedrange(a)[I]
34+
Base.first(a::AbstractGradedUnitRange) = first(blockedrange(a))
35+
Base.last(a::AbstractGradedUnitRange) = last(blockedrange(a))
36+
Base.length(a::AbstractGradedUnitRange) = length(blockedrange(a))
37+
Base.step(a::AbstractGradedUnitRange) = step(blockedrange(a))
38+
Base.unitrange(b::AbstractGradedUnitRange) = first(b):last(b)
39+
40+
sector(a::AbstractGradedUnitRange, b::Block{1}) = sectors(a)[only(b.n)]
41+
sector(a::AbstractGradedUnitRange, I::Integer) = sector(a, findblock(a, I))
42+
43+
# Tensor product, no sorting
44+
function tensor_product(a1::AbstractGradedUnitRange, a2::AbstractGradedUnitRange)
45+
a = tensor_product(blockedrange(a1), blockedrange(a2))
46+
sectors_a = map(Iterators.product(sectors(a1), sectors(a2))) do (l1, l2)
47+
return fuse(scale_factor(a1) * l1, scale_factor(a2) * l2)
48+
end
49+
return gradedrange(a, vec(sectors_a))
50+
end
51+
52+
function Base.show(io::IO, mimetype::MIME"text/plain", a::AbstractGradedUnitRange)
53+
show(io, mimetype, sectors(a))
54+
println(io)
55+
println(io, "Scale factor = ", scale_factor(a))
56+
return show(io, mimetype, blockedrange(a))
57+
end
58+
59+
function blockmerge(a::AbstractGradedUnitRange, grouped_perm::Vector{Vector{Int}})
60+
merged_sectors = map(group -> sector(a, Block(first(group))), grouped_perm)
61+
lengths = blocklengths(a)
62+
merged_lengths = map(group -> sum(@view(lengths[group])), grouped_perm)
63+
return gradedrange(merged_sectors, merged_lengths)
64+
end
65+
66+
# Sort and merge by the grade of the blocks.
67+
function blockmergesort(a::AbstractGradedUnitRange)
68+
grouped_perm = blockmergesortperm(a)
69+
return blockmerge(a, grouped_perm)
70+
end
71+
72+
# Get the permutation for sorting, then group by common elements.
73+
# groupsortperm([2, 1, 2, 3]) == [[2], [1, 3], [4]]
74+
function blockmergesortperm(a::AbstractGradedUnitRange)
75+
return groupsortperm(sectors(a))
76+
end
77+
78+
function sub_axis(a::AbstractGradedUnitRange, blocks)
79+
a_sub = sub_axis(blockedrange(a), blocks)
80+
sectors_sub = map(b -> sector(a, b), Indices(blocks))
81+
return AbstractGradedUnitRange(a_sub, sectors_sub)
82+
end
83+
84+
function fuse(a1::AbstractGradedUnitRange, a2::AbstractGradedUnitRange)
85+
a = tensor_product(a1, a2)
86+
return blockmergesort(a)
87+
end
88+
89+
## TODO: Add this back.
90+
## # Slicing
91+
## ## using BlockArrays: BlockRange, _BlockedUnitRange
92+
## Base.@propagate_inbounds function Base.getindex(
93+
## b::AbstractGradedUnitRange, KR::BlockRange{1}
94+
## )
95+
## cs = blocklasts(b)
96+
## isempty(KR) && return _BlockedUnitRange(1, cs[1:0])
97+
## K, J = first(KR), last(KR)
98+
## k, j = Integer(K), Integer(J)
99+
## bax = blockaxes(b, 1)
100+
## @boundscheck K in bax || throw(BlockBoundsError(b, K))
101+
## @boundscheck J in bax || throw(BlockBoundsError(b, J))
102+
## K == first(bax) && return _BlockedUnitRange(first(b), cs[k:j])
103+
## return _BlockedUnitRange(cs[k - 1] + 1, cs[k:j])
104+
## end

0 commit comments

Comments
 (0)