|
| 1 | +using BlockArrays: |
| 2 | + BlockArrays, |
| 3 | + Block, |
| 4 | + BlockRange, |
| 5 | + BlockedUnitRange, |
| 6 | + blockaxes, |
| 7 | + blockedrange, |
| 8 | + blockfirsts, |
| 9 | + blocklasts, |
| 10 | + blocklengths, |
| 11 | + findblock |
| 12 | +using Dictionaries: Dictionary |
| 13 | + |
| 14 | +# Fuse two symmetry labels |
| 15 | +fuse(l1, l2) = error("Not implemented") |
| 16 | + |
| 17 | +abstract type AbstractGradedUnitRange{T,G} <: AbstractUnitRange{Int} end |
| 18 | + |
| 19 | +BlockArrays.blockedrange(a::AbstractGradedUnitRange) = error("Not implemented") |
| 20 | +sectors(a::AbstractGradedUnitRange) = error("Not implemented") |
| 21 | +scale_factor(a::AbstractGradedUnitRange) = error("Not implemented") |
| 22 | + |
| 23 | +# BlockArrays block axis interface |
| 24 | +BlockArrays.blockaxes(a::AbstractGradedUnitRange) = blockaxes(blockedrange(a)) |
| 25 | +Base.getindex(a::AbstractGradedUnitRange, b::Block{1}) = blockedrange(a)[b] |
| 26 | +BlockArrays.blockfirsts(a::AbstractGradedUnitRange) = blockfirsts(blockedrange(a)) |
| 27 | +BlockArrays.blocklasts(a::AbstractGradedUnitRange) = blocklasts(blockedrange(a)) |
| 28 | +function BlockArrays.findblock(a::AbstractGradedUnitRange, k::Integer) |
| 29 | + return findblock(blockedrange(a), k) |
| 30 | +end |
| 31 | + |
| 32 | +# Base axis interface |
| 33 | +Base.getindex(a::AbstractGradedUnitRange, I::Integer) = blockedrange(a)[I] |
| 34 | +Base.first(a::AbstractGradedUnitRange) = first(blockedrange(a)) |
| 35 | +Base.last(a::AbstractGradedUnitRange) = last(blockedrange(a)) |
| 36 | +Base.length(a::AbstractGradedUnitRange) = length(blockedrange(a)) |
| 37 | +Base.step(a::AbstractGradedUnitRange) = step(blockedrange(a)) |
| 38 | +Base.unitrange(b::AbstractGradedUnitRange) = first(b):last(b) |
| 39 | + |
| 40 | +sector(a::AbstractGradedUnitRange, b::Block{1}) = sectors(a)[only(b.n)] |
| 41 | +sector(a::AbstractGradedUnitRange, I::Integer) = sector(a, findblock(a, I)) |
| 42 | + |
| 43 | +# Tensor product, no sorting |
| 44 | +function tensor_product(a1::AbstractGradedUnitRange, a2::AbstractGradedUnitRange) |
| 45 | + a = tensor_product(blockedrange(a1), blockedrange(a2)) |
| 46 | + sectors_a = map(Iterators.product(sectors(a1), sectors(a2))) do (l1, l2) |
| 47 | + return fuse(scale_factor(a1) * l1, scale_factor(a2) * l2) |
| 48 | + end |
| 49 | + return gradedrange(a, vec(sectors_a)) |
| 50 | +end |
| 51 | + |
| 52 | +function Base.show(io::IO, mimetype::MIME"text/plain", a::AbstractGradedUnitRange) |
| 53 | + show(io, mimetype, sectors(a)) |
| 54 | + println(io) |
| 55 | + println(io, "Scale factor = ", scale_factor(a)) |
| 56 | + return show(io, mimetype, blockedrange(a)) |
| 57 | +end |
| 58 | + |
| 59 | +function blockmerge(a::AbstractGradedUnitRange, grouped_perm::Vector{Vector{Int}}) |
| 60 | + merged_sectors = map(group -> sector(a, Block(first(group))), grouped_perm) |
| 61 | + lengths = blocklengths(a) |
| 62 | + merged_lengths = map(group -> sum(@view(lengths[group])), grouped_perm) |
| 63 | + return gradedrange(merged_sectors, merged_lengths) |
| 64 | +end |
| 65 | + |
| 66 | +# Sort and merge by the grade of the blocks. |
| 67 | +function blockmergesort(a::AbstractGradedUnitRange) |
| 68 | + grouped_perm = blockmergesortperm(a) |
| 69 | + return blockmerge(a, grouped_perm) |
| 70 | +end |
| 71 | + |
| 72 | +# Get the permutation for sorting, then group by common elements. |
| 73 | +# groupsortperm([2, 1, 2, 3]) == [[2], [1, 3], [4]] |
| 74 | +function blockmergesortperm(a::AbstractGradedUnitRange) |
| 75 | + return groupsortperm(sectors(a)) |
| 76 | +end |
| 77 | + |
| 78 | +function sub_axis(a::AbstractGradedUnitRange, blocks) |
| 79 | + a_sub = sub_axis(blockedrange(a), blocks) |
| 80 | + sectors_sub = map(b -> sector(a, b), Indices(blocks)) |
| 81 | + return AbstractGradedUnitRange(a_sub, sectors_sub) |
| 82 | +end |
| 83 | + |
| 84 | +function fuse(a1::AbstractGradedUnitRange, a2::AbstractGradedUnitRange) |
| 85 | + a = tensor_product(a1, a2) |
| 86 | + return blockmergesort(a) |
| 87 | +end |
| 88 | + |
| 89 | +## TODO: Add this back. |
| 90 | +## # Slicing |
| 91 | +## ## using BlockArrays: BlockRange, _BlockedUnitRange |
| 92 | +## Base.@propagate_inbounds function Base.getindex( |
| 93 | +## b::AbstractGradedUnitRange, KR::BlockRange{1} |
| 94 | +## ) |
| 95 | +## cs = blocklasts(b) |
| 96 | +## isempty(KR) && return _BlockedUnitRange(1, cs[1:0]) |
| 97 | +## K, J = first(KR), last(KR) |
| 98 | +## k, j = Integer(K), Integer(J) |
| 99 | +## bax = blockaxes(b, 1) |
| 100 | +## @boundscheck K in bax || throw(BlockBoundsError(b, K)) |
| 101 | +## @boundscheck J in bax || throw(BlockBoundsError(b, J)) |
| 102 | +## K == first(bax) && return _BlockedUnitRange(first(b), cs[k:j]) |
| 103 | +## return _BlockedUnitRange(cs[k - 1] + 1, cs[k:j]) |
| 104 | +## end |
0 commit comments