Skip to content

Commit 46184e5

Browse files
committed
Reference variables instead of using strings
1 parent dc99404 commit 46184e5

File tree

10 files changed

+360
-270
lines changed

10 files changed

+360
-270
lines changed

docs/api.rst

Lines changed: 35 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -56,41 +56,41 @@ Variables
5656
.. autosummary::
5757
:toctree: generated/
5858

59-
variables.base_prediction
60-
variables.call_allele_count
61-
variables.call_dosage
62-
variables.call_dosage_mask
63-
variables.call_genotype
64-
variables.call_genotype_mask
65-
variables.call_genotype_phased
66-
variables.call_genotype_probability
67-
variables.call_genotype_probability_mask
68-
variables.covariates
69-
variables.dosage
70-
variables.genotype_counts
71-
variables.loco_prediction
72-
variables.meta_prediction
73-
variables.pc_relate_phi
74-
variables.sample_id
75-
variables.sample_pcs
76-
variables.traits
77-
variables.variant_allele
78-
variables.variant_allele_count
79-
variables.variant_allele_frequency
80-
variables.variant_allele_total
81-
variables.variant_beta
82-
variables.variant_call_rate
83-
variables.variant_contig
84-
variables.variant_hwe_p_value
85-
variables.variant_id
86-
variables.variant_n_called
87-
variables.variant_n_het
88-
variables.variant_n_hom_alt
89-
variables.variant_n_hom_ref
90-
variables.variant_n_non_ref
91-
variables.variant_p_value
92-
variables.variant_position
93-
variables.variant_t_value
59+
variables.base_prediction_spec
60+
variables.call_allele_count_spec
61+
variables.call_dosage_spec
62+
variables.call_dosage_mask_spec
63+
variables.call_genotype_spec
64+
variables.call_genotype_mask_spec
65+
variables.call_genotype_phased_spec
66+
variables.call_genotype_probability_spec
67+
variables.call_genotype_probability_mask_spec
68+
variables.covariates_spec
69+
variables.dosage_spec
70+
variables.genotype_counts_spec
71+
variables.loco_prediction_spec
72+
variables.meta_prediction_spec
73+
variables.pc_relate_phi_spec
74+
variables.sample_id_spec
75+
variables.sample_pcs_spec
76+
variables.traits_spec
77+
variables.variant_allele_spec
78+
variables.variant_allele_count_spec
79+
variables.variant_allele_frequency_spec
80+
variables.variant_allele_total_spec
81+
variables.variant_beta_spec
82+
variables.variant_call_rate_spec
83+
variables.variant_contig_spec
84+
variables.variant_hwe_p_value_spec
85+
variables.variant_id_spec
86+
variables.variant_n_called_spec
87+
variables.variant_n_het_spec
88+
variables.variant_n_hom_alt_spec
89+
variables.variant_n_hom_ref_spec
90+
variables.variant_n_non_ref_spec
91+
variables.variant_p_value_spec
92+
variables.variant_position_spec
93+
variables.variant_t_value_spec
9494

9595
Utilities
9696
=========

sgkit/model.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
from . import variables
77
from .typing import ArrayLike
8-
from .utils import check_array_like
98

109
DIM_VARIANT = "variants"
1110
DIM_SAMPLE = "samples"
@@ -70,13 +69,11 @@ def create_genotype_call_dataset(
7069
),
7170
}
7271
if call_genotype_phased is not None:
73-
check_array_like(call_genotype_phased, kind="b", ndim=2)
7472
data_vars["call_genotype_phased"] = (
7573
[DIM_VARIANT, DIM_SAMPLE],
7674
call_genotype_phased,
7775
)
7876
if variant_id is not None:
79-
check_array_like(variant_id, kind={"U", "O"}, ndim=1)
8077
data_vars["variant_id"] = ([DIM_VARIANT], variant_id)
8178
attrs: Dict[Hashable, Any] = {"contigs": variant_contig_names}
8279
return variables.validate(xr.Dataset(data_vars=data_vars, attrs=attrs))
@@ -145,7 +142,6 @@ def create_genotype_dosage_dataset(
145142
),
146143
}
147144
if variant_id is not None:
148-
check_array_like(variant_id, kind={"U", "O"}, ndim=1)
149145
data_vars["variant_id"] = ([DIM_VARIANT], variant_id)
150146
attrs: Dict[Hashable, Any] = {"contigs": variant_contig_names}
151147
return variables.validate(xr.Dataset(data_vars=data_vars, attrs=attrs))

sgkit/stats/aggregation.py

Lines changed: 31 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def count_alleles(g: ArrayLike, _: ArrayLike, out: ArrayLike) -> None:
5353

5454

5555
def count_call_alleles(
56-
ds: Dataset, *, call_genotype: str = "call_genotype", merge: bool = True
56+
ds: Dataset, *, call_genotype: str = variables.call_genotype, merge: bool = True
5757
) -> Dataset:
5858
"""Compute per sample allele counts from genotype calls.
5959
@@ -64,7 +64,7 @@ def count_call_alleles(
6464
:func:`sgkit.create_genotype_call_dataset`.
6565
call_genotype
6666
Input variable name holding call_genotype as defined by
67-
:data:`sgkit.variables.call_genotype`
67+
:data:`sgkit.variables.call_genotype_spec`
6868
merge
6969
If True (the default), merge the input dataset and the computed
7070
output variables into a single dataset, otherwise return only
@@ -104,14 +104,14 @@ def count_call_alleles(
104104
[[2, 0],
105105
[2, 0]]], dtype=uint8)
106106
"""
107-
variables.validate(ds, {call_genotype: variables.call_genotype})
107+
variables.validate(ds, {call_genotype: variables.call_genotype_spec})
108108
n_alleles = ds.dims["alleles"]
109109
G = da.asarray(ds[call_genotype])
110110
shape = (G.chunks[0], G.chunks[1], n_alleles)
111111
N = da.empty(n_alleles, dtype=np.uint8)
112112
new_ds = Dataset(
113113
{
114-
"call_allele_count": (
114+
variables.call_allele_count: (
115115
("variants", "samples", "alleles"),
116116
da.map_blocks(
117117
count_alleles, G, N, chunks=shape, drop_axis=2, new_axis=2
@@ -123,7 +123,7 @@ def count_call_alleles(
123123

124124

125125
def count_variant_alleles(
126-
ds: Dataset, *, call_genotype: str = "call_genotype", merge: bool = True
126+
ds: Dataset, *, call_genotype: str = variables.call_genotype, merge: bool = True
127127
) -> Dataset:
128128
"""Compute allele count from genotype calls.
129129
@@ -134,7 +134,7 @@ def count_variant_alleles(
134134
:func:`sgkit.create_genotype_call_dataset`.
135135
call_genotype
136136
Input variable name holding call_genotype as defined by
137-
:data:`sgkit.variables.call_genotype`
137+
:data:`sgkit.variables.call_genotype_spec`
138138
merge
139139
If True (the default), merge the input dataset and the computed
140140
output variables into a single dataset, otherwise return only
@@ -169,10 +169,10 @@ def count_variant_alleles(
169169
"""
170170
new_ds = Dataset(
171171
{
172-
"variant_allele_count": (
172+
variables.variant_allele_count: (
173173
("variants", "alleles"),
174174
count_call_alleles(ds, call_genotype=call_genotype)[
175-
"call_allele_count"
175+
variables.call_allele_count
176176
].sum(dim="samples"),
177177
)
178178
}
@@ -222,28 +222,30 @@ def allele_frequency(
222222
data_vars: Dict[Hashable, Any] = {}
223223
# only compute variant allele count if not already in dataset
224224
if variant_allele_count is not None:
225-
variables.validate(ds, {variant_allele_count: variables.variant_allele_count})
225+
variables.validate(
226+
ds, {variant_allele_count: variables.variant_allele_count_spec}
227+
)
226228
AC = ds[variant_allele_count]
227229
else:
228230
AC = count_variant_alleles(ds, merge=False, call_genotype=call_genotype)[
229-
"variant_allele_count"
231+
variables.variant_allele_count
230232
]
231-
data_vars["variant_allele_count"] = AC
233+
data_vars[variables.variant_allele_count] = AC
232234

233235
M = ds[call_genotype_mask].stack(calls=("samples", "ploidy"))
234236
AN = (~M).sum(dim="calls") # type: ignore
235237
assert AN.shape == (ds.dims["variants"],)
236238

237-
data_vars["variant_allele_total"] = AN
238-
data_vars["variant_allele_frequency"] = AC / AN
239+
data_vars[variables.variant_allele_total] = AN
240+
data_vars[variables.variant_allele_frequency] = AC / AN
239241
return Dataset(data_vars)
240242

241243

242244
def variant_stats(
243245
ds: Dataset,
244246
*,
245-
call_genotype_mask: str = "call_genotype_mask",
246-
call_genotype: str = "call_genotype",
247+
call_genotype_mask: str = variables.call_genotype_mask,
248+
call_genotype: str = variables.call_genotype,
247249
variant_allele_count: Optional[str] = None,
248250
merge: bool = True,
249251
) -> Dataset:
@@ -256,13 +258,13 @@ def variant_stats(
256258
:func:`sgkit.create_genotype_call_dataset`.
257259
call_genotype
258260
Input variable name holding call_genotype.
259-
Defined by :data:`sgkit.variables.call_genotype`.
261+
Defined by :data:`sgkit.variables.call_genotype_spec`.
260262
call_genotype_mask
261263
Input variable name holding call_genotype_mask.
262-
Defined by :data:`sgkit.variables.call_genotype_mask`
264+
Defined by :data:`sgkit.variables.call_genotype_mask_spec`
263265
variant_allele_count
264266
Optional name of the input variable holding variant_allele_count,
265-
as defined by :data:`sgkit.variables.variant_allele_count`.
267+
as defined by :data:`sgkit.variables.variant_allele_count_spec`.
266268
merge
267269
If True (the default), merge the input dataset and the computed
268270
output variables into a single dataset, otherwise return only
@@ -273,30 +275,30 @@ def variant_stats(
273275
-------
274276
A dataset containing the following variables:
275277
276-
- :data:`sgkit.variables.variant_n_called` (variants):
278+
- :data:`sgkit.variables.variant_n_called_spec` (variants):
277279
The number of samples with called genotypes.
278-
- :data:`sgkit.variables.variant_call_rate` (variants):
280+
- :data:`sgkit.variables.variant_call_rate_spec` (variants):
279281
The fraction of samples with called genotypes.
280-
- :data:`sgkit.variables.variant_n_het` (variants):
282+
- :data:`sgkit.variables.variant_n_het_spec` (variants):
281283
The number of samples with heterozygous calls.
282-
- :data:`sgkit.variables.variant_n_hom_ref` (variants):
284+
- :data:`sgkit.variables.variant_n_hom_ref_spec` (variants):
283285
The number of samples with homozygous reference calls.
284-
- :data:`sgkit.variables.variant_n_hom_alt` (variants):
286+
- :data:`sgkit.variables.variant_n_hom_alt_spec` (variants):
285287
The number of samples with homozygous alternate calls.
286-
- :data:`sgkit.variables.variant_n_non_ref` (variants):
288+
- :data:`sgkit.variables.variant_n_non_ref_spec` (variants):
287289
The number of samples that are not homozygous reference calls.
288-
- :data:`sgkit.variables.variant_allele_count` (variants, alleles):
290+
- :data:`sgkit.variables.variant_allele_count_spec` (variants, alleles):
289291
The number of occurrences of each allele.
290-
- :data:`sgkit.variables.variant_allele_total` (variants):
292+
- :data:`sgkit.variables.variant_allele_total_spec` (variants):
291293
The number of occurrences of all alleles.
292-
- :data:`sgkit.variables.variant_allele_frequency` (variants, alleles):
294+
- :data:`sgkit.variables.variant_allele_frequency_spec` (variants, alleles):
293295
The frequency of occurrence of each allele.
294296
"""
295297
variables.validate(
296298
ds,
297299
{
298-
call_genotype: variables.call_genotype,
299-
call_genotype_mask: variables.call_genotype_mask,
300+
call_genotype: variables.call_genotype_spec,
301+
call_genotype_mask: variables.call_genotype_mask_spec,
300302
},
301303
)
302304
new_ds = xr.merge(

sgkit/stats/association.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -104,11 +104,13 @@ def linear_regression(
104104
return LinearRegressionResult(beta=B, t_value=T, p_value=P)
105105

106106

107-
def _get_loop_covariates(ds: Dataset, dosage: Optional[str] = None) -> Array:
107+
def _get_loop_covariates(
108+
ds: Dataset, call_genotype: str, dosage: Optional[str] = None
109+
) -> Array:
108110
if dosage is None:
109111
# TODO: This should be (probably gwas-specific) allele
110112
# count with sex chromosome considerations
111-
G = ds["call_genotype"].sum(dim="ploidy") # pragma: no cover
113+
G = ds[call_genotype].sum(dim="ploidy") # pragma: no cover
112114
else:
113115
G = ds[dosage]
114116
return da.asarray(G.data)
@@ -121,6 +123,7 @@ def gwas_linear_regression(
121123
covariates: Union[str, Sequence[str]],
122124
traits: Union[str, Sequence[str]],
123125
add_intercept: bool = True,
126+
call_genotype: str = variables.call_genotype,
124127
merge: bool = True,
125128
) -> Dataset:
126129
"""Run linear regression to identify continuous trait associations with genetic variants.
@@ -138,15 +141,18 @@ def gwas_linear_regression(
138141
Dataset containing necessary dependent and independent variables.
139142
dosage
140143
Name of genetic dosage variable.
141-
Defined by :data:`sgkit.variables.dosage`.
144+
Defined by :data:`sgkit.variables.dosage_spec`.
142145
covariates
143146
Names of covariate variables (1D or 2D).
144-
Defined by :data:`sgkit.variables.covariates`.
147+
Defined by :data:`sgkit.variables.covariates_spec`.
145148
traits
146149
Names of trait variables (1D or 2D).
147-
Defined by :data:`sgkit.variables.traits`.
150+
Defined by :data:`sgkit.variables.traits_spec`.
148151
add_intercept
149152
Add intercept term to covariate set, by default True.
153+
call_genotype
154+
Input variable name holding call_genotype.
155+
Defined by :data:`sgkit.variables.call_genotype_spec`.
150156
merge
151157
If True (the default), merge the input dataset and the computed
152158
output variables into a single dataset, otherwise return only
@@ -193,12 +199,12 @@ def gwas_linear_regression(
193199

194200
variables.validate(
195201
ds,
196-
{dosage: variables.dosage},
197-
{c: variables.covariates for c in covariates},
198-
{t: variables.traits for t in traits},
202+
{dosage: variables.dosage_spec},
203+
{c: variables.covariates_spec for c in covariates},
204+
{t: variables.traits_spec for t in traits},
199205
)
200206

201-
G = _get_loop_covariates(ds, dosage=dosage)
207+
G = _get_loop_covariates(ds, dosage=dosage, call_genotype=call_genotype)
202208

203209
X = da.asarray(concat_2d(ds[list(covariates)], dims=("samples", "covariates")))
204210
if add_intercept:
@@ -216,9 +222,9 @@ def gwas_linear_regression(
216222
res = linear_regression(G.T, X, Y)
217223
new_ds = xr.Dataset(
218224
{
219-
"variant_beta": (("variants", "traits"), res.beta),
220-
"variant_t_value": (("variants", "traits"), res.t_value),
221-
"variant_p_value": (("variants", "traits"), res.p_value),
225+
variables.variant_beta: (("variants", "traits"), res.beta),
226+
variables.variant_t_value: (("variants", "traits"), res.t_value),
227+
variables.variant_p_value: (("variants", "traits"), res.p_value),
222228
}
223229
)
224230
return conditional_merge_datasets(ds, variables.validate(new_ds), merge)

sgkit/stats/hwe.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,8 @@ def hardy_weinberg_test(
127127
ds: Dataset,
128128
*,
129129
genotype_counts: Optional[Hashable] = None,
130-
call_genotype: str = "call_genotype",
131-
call_genotype_mask: str = "call_genotype_mask",
130+
call_genotype: str = variables.call_genotype,
131+
call_genotype_mask: str = variables.call_genotype_mask,
132132
merge: bool = True,
133133
) -> Dataset:
134134
"""Exact test for HWE as described in Wigginton et al. 2005 [1].
@@ -146,10 +146,10 @@ def hardy_weinberg_test(
146146
(in that order) across all samples for a variant.
147147
call_genotype
148148
Input variable name holding call_genotype.
149-
Defined by :data:`sgkit.variables.call_genotype`.
149+
Defined by :data:`sgkit.variables.call_genotype_spec`.
150150
call_genotype_mask
151151
Input variable name holding call_genotype_mask.
152-
Defined by :data:`sgkit.variables.call_genotype_mask`
152+
Defined by :data:`sgkit.variables.call_genotype_mask_spec`
153153
merge
154154
If True (the default), merge the input dataset and the computed
155155
output variables into a single dataset, otherwise return only
@@ -185,15 +185,15 @@ def hardy_weinberg_test(
185185
raise NotImplementedError("HWE test only implemented for biallelic genotypes")
186186
# Use precomputed genotype counts if provided
187187
if genotype_counts is not None:
188-
variables.validate(ds, {genotype_counts: variables.genotype_counts})
188+
variables.validate(ds, {genotype_counts: variables.genotype_counts_spec})
189189
obs = list(da.asarray(ds[genotype_counts]).T)
190190
# Otherwise compute genotype counts from calls
191191
else:
192192
variables.validate(
193193
ds,
194194
{
195-
call_genotype_mask: variables.call_genotype_mask,
196-
call_genotype: variables.call_genotype,
195+
call_genotype_mask: variables.call_genotype_mask_spec,
196+
call_genotype: variables.call_genotype_spec,
197197
},
198198
)
199199
# TODO: Use API genotype counting function instead, e.g.
@@ -203,5 +203,5 @@ def hardy_weinberg_test(
203203
cts = [1, 0, 2] # arg order: hets, hom1, hom2
204204
obs = [da.asarray((AC == ct).sum(dim="samples")) for ct in cts]
205205
p = da.map_blocks(hardy_weinberg_p_value_vec_jit, *obs)
206-
new_ds = xr.Dataset({"variant_hwe_p_value": ("variants", p)})
206+
new_ds = xr.Dataset({variables.variant_hwe_p_value: ("variants", p)})
207207
return conditional_merge_datasets(ds, variables.validate(new_ds), merge)

0 commit comments

Comments
 (0)