Skip to content

Commit 3f0dbda

Browse files
committed
Improve performance of variant_stats sgkit-dev#1116
* Add count_variant_alleles option to calculate directly from calls * Improve performance of variant_stats using gufuncs * Raise error is variant_stats used on mixed-ploidy data * Document behavior of variant_stats with partial genotype calls
1 parent cc04858 commit 3f0dbda

File tree

3 files changed

+261
-55
lines changed

3 files changed

+261
-55
lines changed

sgkit/stats/aggregation.py

Lines changed: 107 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,9 @@ def count_call_alleles(
9696
def count_variant_alleles(
9797
ds: Dataset,
9898
*,
99+
call_genotype: Hashable = variables.call_genotype,
99100
call_allele_count: Hashable = variables.call_allele_count,
101+
from_call_allele_count: bool = True,
100102
merge: bool = True,
101103
) -> Dataset:
102104
"""Compute allele count from per-sample allele counts, or genotype calls.
@@ -105,11 +107,22 @@ def count_variant_alleles(
105107
----------
106108
ds
107109
Dataset containing genotype calls.
110+
call_genotype
111+
Input variable name holding call_genotype as defined by
112+
:data:`sgkit.variables.call_genotype_spec`.
113+
Must be present in ``ds`` unless from_call_allele_count is True.
108114
call_allele_count
109115
Input variable name holding call_allele_count as defined by
110116
:data:`sgkit.variables.call_allele_count_spec`.
111117
If the variable is not present in ``ds``, it will be computed
112118
using :func:`count_call_alleles`.
119+
This variable is only used if from_call_allele_count is True.
120+
from_call_allele_count
121+
if True (the default), the result will be calculated from the
122+
call_allele_count variable rather than the call_genotype variable.
123+
If False, the result will be calculated directly from the
124+
call_genotype variable without computing the call_allele_count
125+
variable as an intermediate.
113126
merge
114127
If True (the default), merge the input dataset and the computed
115128
output variables into a single dataset, otherwise return only
@@ -141,14 +154,25 @@ def count_variant_alleles(
141154
[2, 2],
142155
[4, 0]], dtype=uint64)
143156
"""
144-
ds = define_variable_if_absent(
145-
ds, variables.call_allele_count, call_allele_count, count_call_alleles
146-
)
147-
variables.validate(ds, {call_allele_count: variables.call_allele_count_spec})
148-
149-
new_ds = create_dataset(
150-
{variables.variant_allele_count: ds[call_allele_count].sum(dim="samples")}
151-
)
157+
if from_call_allele_count:
158+
ds = define_variable_if_absent(
159+
ds, variables.call_allele_count, call_allele_count, count_call_alleles
160+
)
161+
variables.validate(ds, {call_allele_count: variables.call_allele_count_spec})
162+
AC = ds[call_allele_count].sum(dim="samples")
163+
else:
164+
from .aggregation_numba_fns import count_alleles
165+
166+
variables.validate(ds, {call_genotype: variables.call_genotype_spec})
167+
n_alleles = ds.dims["alleles"]
168+
n_variant = ds.dims["variants"]
169+
G = da.asarray(ds[call_genotype]).reshape((n_variant, -1))
170+
shape = (G.chunks[0], n_alleles)
171+
# use uint64 dummy array to return uin64 counts array
172+
N = np.empty(n_alleles, dtype=np.uint64)
173+
AC = da.map_blocks(count_alleles, G, N, chunks=shape, drop_axis=1, new_axis=1)
174+
AC = xr.DataArray(AC, dims=["variants", "alleles"])
175+
new_ds = create_dataset({variables.variant_allele_count: AC})
152176
return conditional_merge_datasets(ds, new_ds, merge)
153177

154178

@@ -629,7 +653,6 @@ def allele_frequency(
629653
def variant_stats(
630654
ds: Dataset,
631655
*,
632-
call_genotype_mask: Hashable = variables.call_genotype_mask,
633656
call_genotype: Hashable = variables.call_genotype,
634657
variant_allele_count: Hashable = variables.variant_allele_count,
635658
merge: bool = True,
@@ -644,10 +667,6 @@ def variant_stats(
644667
Input variable name holding call_genotype.
645668
Defined by :data:`sgkit.variables.call_genotype_spec`.
646669
Must be present in ``ds``.
647-
call_genotype_mask
648-
Input variable name holding call_genotype_mask.
649-
Defined by :data:`sgkit.variables.call_genotype_mask_spec`
650-
Must be present in ``ds``.
651670
variant_allele_count
652671
Input variable name holding variant_allele_count,
653672
as defined by :data:`sgkit.variables.variant_allele_count_spec`.
@@ -681,31 +700,85 @@ def variant_stats(
681700
The number of occurrences of all alleles.
682701
- :data:`sgkit.variables.variant_allele_frequency_spec` (variants, alleles):
683702
The frequency of occurrence of each allele.
703+
704+
Note
705+
----
706+
If the dataset contains partial genotype calls (i.e., genotype calls with
707+
a mixture of called and missing alleles), these genotypes will be ignored
708+
when counting the number of homozygous, heterozygous or total genotype calls.
709+
However, the called alleles will be counted when calculating allele counts
710+
and frequencies using :func:`count_variant_alleles`.
711+
712+
Note
713+
----
714+
When used on autopolyploid genotypes, this method treats genotypes calls
715+
with any level of heterozygosity as 'heterozygous'. Only fully homozygous
716+
genotype calls (e.g. 0/0/0/0) will be classified as 'homozygous'.
717+
718+
Warnings
719+
--------
720+
This method does not support mixed-ploidy datasets.
721+
722+
Raises
723+
------
724+
ValueError
725+
If the dataset contains mixed-ploidy genotype calls.
726+
727+
See Also
728+
--------
729+
:func:`count_variant_genotypes`
684730
"""
685-
variables.validate(
731+
from .aggregation_numba_fns import count_hom
732+
733+
variables.validate(ds, {call_genotype: variables.call_genotype_spec})
734+
mixed_ploidy = ds[call_genotype].attrs.get("mixed_ploidy", False)
735+
if mixed_ploidy:
736+
raise ValueError("Mixed-ploidy dataset")
737+
AC = define_variable_if_absent(
686738
ds,
687-
{
688-
call_genotype: variables.call_genotype_spec,
689-
call_genotype_mask: variables.call_genotype_mask_spec,
690-
},
739+
variables.variant_allele_count,
740+
variant_allele_count,
741+
count_variant_alleles,
742+
from_call_allele_count=False,
743+
merge=False,
744+
)[variant_allele_count]
745+
G = da.array(ds[call_genotype].data)
746+
H = xr.DataArray(
747+
da.map_blocks(
748+
count_hom,
749+
G,
750+
np.zeros(3, np.uint64),
751+
drop_axis=(1, 2),
752+
new_axis=1,
753+
dtype=np.int64,
754+
chunks=(G.chunks[0], 3),
755+
),
756+
dims=["variants", "categories"],
691757
)
692-
new_ds = xr.merge(
693-
[
694-
call_rate(ds, dim="samples", call_genotype_mask=call_genotype_mask),
695-
count_genotypes(
696-
ds,
697-
dim="samples",
698-
call_genotype=call_genotype,
699-
call_genotype_mask=call_genotype_mask,
700-
merge=False,
701-
),
702-
allele_frequency(
703-
ds,
704-
call_genotype_mask=call_genotype_mask,
705-
variant_allele_count=variant_allele_count,
706-
),
707-
]
758+
_, n_sample, _ = G.shape
759+
n_called = H.sum(axis=-1)
760+
call_rate = n_called / n_sample
761+
n_hom_ref = H[:, 0]
762+
n_hom_alt = H[:, 1]
763+
n_het = H[:, 2]
764+
n_non_ref = n_called - n_hom_ref
765+
allele_total = AC.sum(axis=-1).astype(int) # backwards compatibility
766+
new_ds = xr.Dataset(
767+
{
768+
variables.variant_n_called: n_called,
769+
variables.variant_call_rate: call_rate,
770+
variables.variant_n_het: n_het,
771+
variables.variant_n_hom_ref: n_hom_ref,
772+
variables.variant_n_hom_alt: n_hom_alt,
773+
variables.variant_n_non_ref: n_non_ref,
774+
variables.variant_allele_count: AC,
775+
variables.variant_allele_total: allele_total,
776+
variables.variant_allele_frequency: AC / allele_total,
777+
}
708778
)
779+
# for backwards compatible behavior
780+
if (variant_allele_count in ds) and merge:
781+
new_ds = new_ds.drop_vars(variant_allele_count)
709782
return conditional_merge_datasets(ds, variables.validate(new_ds), merge)
710783

711784

sgkit/stats/aggregation_numba_fns.py

Lines changed: 63 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# in a separate file here, and imported dynamically to avoid
33
# initial compilation overhead.
44

5-
from sgkit.accelerate import numba_guvectorize
5+
from sgkit.accelerate import numba_guvectorize, numba_jit
66
from sgkit.typing import ArrayLike
77

88

@@ -12,6 +12,10 @@
1212
"void(int16[:], uint8[:], uint8[:])",
1313
"void(int32[:], uint8[:], uint8[:])",
1414
"void(int64[:], uint8[:], uint8[:])",
15+
"void(int8[:], uint64[:], uint64[:])",
16+
"void(int16[:], uint64[:], uint64[:])",
17+
"void(int32[:], uint64[:], uint64[:])",
18+
"void(int64[:], uint64[:], uint64[:])",
1519
],
1620
"(k),(n)->(n)",
1721
)
@@ -26,9 +30,10 @@ def count_alleles(
2630
Genotype call of shape (ploidy,) containing alleles encoded as
2731
type `int` with values < 0 indicating a missing allele.
2832
_
29-
Dummy variable of type `uint8` and shape (alleles,) used to
30-
define the number of unique alleles to be counted in the
31-
return value.
33+
Dummy variable of type `uint8` or `uint64` and shape (alleles,)
34+
used to define the number of unique alleles to be counted in the
35+
return value. The dtype of this array determines the dtype of the
36+
returned array.
3237
3338
Returns
3439
-------
@@ -43,3 +48,57 @@ def count_alleles(
4348
a = g[i]
4449
if a >= 0:
4550
out[a] += 1
51+
52+
53+
@numba_jit(nogil=True)
54+
def _classify_hom(genotype: ArrayLike) -> int:
55+
a0 = genotype[0]
56+
cat = min(a0, 1) # -1, 0, 1
57+
for i in range(1, len(genotype)):
58+
if cat < 0:
59+
break
60+
a = genotype[i]
61+
if a != a0:
62+
cat = 2 # het
63+
if a < 0:
64+
cat = -1
65+
return cat
66+
67+
68+
@numba_guvectorize( # type: ignore
69+
[
70+
"void(int8[:,:], uint64[:], int64[:])",
71+
"void(int16[:,:], uint64[:], int64[:])",
72+
"void(int32[:,:], uint64[:], int64[:])",
73+
"void(int64[:,:], uint64[:], int64[:])",
74+
],
75+
"(n, k),(c)->(c)",
76+
)
77+
def count_hom(
78+
genotypes: ArrayLike, _: ArrayLike, out: ArrayLike
79+
) -> None: # pragma: no cover
80+
"""Generalized U-function for counting homozygous and heterozygous genotypes.
81+
82+
Parameters
83+
----------
84+
g
85+
Genotype call of shape (ploidy,) containing alleles encoded as
86+
type `int` with values < 0 indicating a missing allele.
87+
_
88+
Dummy variable of type `uint64` with length 3 which determines the
89+
number of categories returned (this should always be 3).
90+
91+
Note
92+
----
93+
This method is not suitable for mixed-ploidy genotypes.
94+
95+
Returns
96+
-------
97+
counts : ndarray
98+
Counts of homozygous reference, homozygous alternate, and heterozygous genotypes.
99+
"""
100+
out[:] = 0
101+
for i in range(len(genotypes)):
102+
index = _classify_hom(genotypes[i])
103+
if index >= 0:
104+
out[index] += 1

0 commit comments

Comments
 (0)