Skip to content

Commit 6fcb0ce

Browse files
committed
Use numpy.typing.DTypeLike
1 parent 70a91ba commit 6fcb0ce

File tree

9 files changed

+36
-29
lines changed

9 files changed

+36
-29
lines changed

sgkit/io/bgen/bgen_reader.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,13 @@
2323
import xarray as xr
2424
import zarr
2525
from cbgen import bgen_file, bgen_metafile
26+
from numpy.typing import DTypeLike
2627
from rechunker import api as rechunker_api
2728
from xarray import Dataset
2829

2930
from sgkit import create_genotype_dosage_dataset
3031
from sgkit.io.utils import dataframe_to_dict, encode_contigs
31-
from sgkit.typing import ArrayLike, DType, PathType
32+
from sgkit.typing import ArrayLike, PathType
3233

3334
logger = logging.getLogger(__name__)
3435

@@ -60,7 +61,7 @@ def __init__(
6061
self,
6162
path: PathType,
6263
metafile_path: Optional[PathType] = None,
63-
dtype: DType = "float32",
64+
dtype: DTypeLike = "float32",
6465
) -> None:
6566
self.path = Path(path)
6667
self.metafile_path = (
@@ -202,8 +203,8 @@ def read_bgen(
202203
chunks: Union[str, int, Tuple[int, int, int]] = "auto",
203204
lock: bool = False,
204205
persist: bool = True,
205-
contig_dtype: DType = "str",
206-
gp_dtype: DType = "float32",
206+
contig_dtype: DTypeLike = "str",
207+
gp_dtype: DTypeLike = "float32",
207208
) -> Dataset:
208209
"""Read BGEN dataset.
209210
@@ -394,7 +395,7 @@ def pack_variables(ds: Dataset) -> Dataset:
394395
return ds
395396

396397

397-
def unpack_variables(ds: Dataset, dtype: DType = "float32") -> Dataset:
398+
def unpack_variables(ds: Dataset, dtype: DTypeLike = "float32") -> Dataset:
398399
# Restore homozygous reference GP
399400
gp = ds["call_genotype_probability"].astype(dtype)
400401
if gp.sizes["genotypes"] != 2:
@@ -423,7 +424,7 @@ def rechunk_bgen(
423424
chunk_length: int = 10_000,
424425
chunk_width: int = 1_000,
425426
compressor: Optional[Any] = zarr.Blosc(cname="zstd", clevel=7, shuffle=2),
426-
probability_dtype: Optional[DType] = "uint8",
427+
probability_dtype: Optional[DTypeLike] = "uint8",
427428
max_mem: str = "4GB",
428429
pack: bool = True,
429430
tempdir: Optional[PathType] = None,
@@ -533,7 +534,7 @@ def bgen_to_zarr(
533534
chunk_width: int = 1_000,
534535
temp_chunk_length: int = 100,
535536
compressor: Optional[Any] = zarr.Blosc(cname="zstd", clevel=7, shuffle=2),
536-
probability_dtype: Optional[DType] = "uint8",
537+
probability_dtype: Optional[DTypeLike] = "uint8",
537538
max_mem: str = "4GB",
538539
pack: bool = True,
539540
tempdir: Optional[PathType] = None,

sgkit/io/utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,14 @@
66
import numpy as np
77
import xarray as xr
88
import zarr
9+
from numpy.typing import DTypeLike
910

10-
from ..typing import ArrayLike, DType
11+
from ..typing import ArrayLike
1112
from ..utils import encode_array, max_str_len
1213

1314

1415
def dataframe_to_dict(
15-
df: dd.DataFrame, dtype: Optional[Mapping[str, DType]] = None
16+
df: dd.DataFrame, dtype: Optional[Mapping[str, DTypeLike]] = None
1617
) -> Mapping[str, ArrayLike]:
1718
""" Convert dask dataframe to dictionary of arrays """
1819
arrs = {}
@@ -110,7 +111,7 @@ def zarrs_to_dataset(
110111
def concatenate_and_rechunk(
111112
zarrs: Sequence[zarr.Array],
112113
chunks: Optional[Tuple[int, ...]] = None,
113-
dtype: DType = None,
114+
dtype: DTypeLike = None,
114115
) -> da.Array:
115116
"""Perform a concatenate and rechunk operation on a collection of Zarr arrays
116117
to produce an array with a uniform chunking, suitable for saving as

sgkit/io/vcf/vcf_reader.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,14 @@
2020
import numpy as np
2121
import xarray as xr
2222
from cyvcf2 import VCF, Variant
23+
from numpy.typing import DTypeLike
2324

2425
from sgkit.io.utils import zarrs_to_dataset
2526
from sgkit.io.vcf import partition_into_regions
2627
from sgkit.io.vcf.utils import build_url, chunks, temporary_directory, url_filename
2728
from sgkit.io.vcfzarr_reader import vcf_number_to_dimension_and_size
2829
from sgkit.model import DIM_SAMPLE, DIM_VARIANT, create_genotype_call_dataset
29-
from sgkit.typing import ArrayLike, DType, PathType
30+
from sgkit.typing import ArrayLike, PathType
3031
from sgkit.utils import max_str_len
3132

3233
DEFAULT_MAX_ALT_ALLELES = (
@@ -104,7 +105,7 @@ def _normalize_fields(vcf: VCF, fields: Sequence[str]) -> Sequence[str]:
104105

105106
def _vcf_type_to_numpy_type_and_fill_value(
106107
vcf_type: str, category: str, key: str
107-
) -> Tuple[DType, Any]:
108+
) -> Tuple[DTypeLike, Any]:
108109
"""Convert the VCF Type to a NumPy dtype and fill value."""
109110
if vcf_type == "Flag":
110111
return "bool", False

sgkit/stats/ld.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,11 @@
88
import pandas as pd
99
from dask.dataframe import DataFrame
1010
from numba import njit
11+
from numpy.typing import DTypeLike
1112
from xarray import Dataset
1213

1314
from sgkit import variables
14-
from sgkit.typing import ArrayLike, DType
15+
from sgkit.typing import ArrayLike
1516
from sgkit.window import _get_chunked_windows, _sizes_to_start_offsets, has_windows
1617

1718

@@ -205,8 +206,8 @@ def _ld_matrix_jit(
205206
chunk_window_stops: ArrayLike,
206207
abs_chunk_start: int,
207208
chunk_max_window_start: int,
208-
index_dtype: DType,
209-
value_dtype: DType,
209+
index_dtype: DTypeLike,
210+
value_dtype: DTypeLike,
210211
threshold: float,
211212
scores: ArrayLike,
212213
) -> List[Any]: # pragma: no cover
@@ -246,7 +247,7 @@ def _ld_matrix_jit(
246247

247248
if no_threshold or (res >= threshold and np.isfinite(res)):
248249
rows.append(
249-
(index_dtype(index), index_dtype(other), value_dtype(res), cmp)
250+
(index_dtype(index), index_dtype(other), value_dtype(res), cmp) # type: ignore
250251
)
251252

252253
return rows
@@ -258,8 +259,8 @@ def _ld_matrix(
258259
chunk_window_stops: ArrayLike,
259260
abs_chunk_start: int,
260261
chunk_max_window_start: int,
261-
index_dtype: DType,
262-
value_dtype: DType,
262+
index_dtype: DTypeLike,
263+
value_dtype: DTypeLike,
263264
threshold: float = np.nan,
264265
scores: Optional[ArrayLike] = None,
265266
) -> ArrayLike:

sgkit/stats/pca.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,15 @@
44
import numpy as np
55
import xarray as xr
66
from dask_ml.decomposition import TruncatedSVD
7+
from numpy.typing import DTypeLike
78
from sklearn.base import BaseEstimator
89
from sklearn.pipeline import Pipeline
910
from typing_extensions import Literal
1011
from xarray import DataArray, Dataset
1112

1213
from sgkit import variables
1314

14-
from ..typing import ArrayLike, DType, RandomStateType
15+
from ..typing import ArrayLike, RandomStateType
1516
from ..utils import conditional_merge_datasets
1617
from .aggregation import count_call_alleles
1718
from .preprocessing import PattersonScaler
@@ -331,7 +332,7 @@ def _allele_counts(
331332
ds: Dataset,
332333
variable: str,
333334
check_missing: bool = True,
334-
dtype: DType = "float32",
335+
dtype: DTypeLike = "float32",
335336
) -> DataArray:
336337
if variable not in ds:
337338
ds = count_call_alternate_alleles(ds)

sgkit/tests/test_preprocessing.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,19 @@
55
import numpy as np
66
import pytest
77
import xarray as xr
8+
from numpy.typing import DTypeLike
89

910
import sgkit.stats.preprocessing
1011
from sgkit import simulate_genotype_call_dataset
11-
from sgkit.typing import ArrayLike, DType
12+
from sgkit.typing import ArrayLike
1213

1314

1415
def simulate_alternate_allele_counts(
1516
n_variant: int,
1617
n_sample: int,
1718
ploidy: int,
1819
chunks: Any = (10, 10),
19-
dtype: DType = "i",
20+
dtype: DTypeLike = "i",
2021
seed: int = 0,
2122
) -> ArrayLike:
2223
rs = da.random.RandomState(seed)

sgkit/typing.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
from pathlib import Path
2-
from typing import Any, Union
2+
from typing import Union
33

44
import dask.array as da
55
import numpy as np
66

77
ArrayLike = Union[np.ndarray, da.Array]
8-
DType = Any
98
PathType = Union[str, Path]
109
RandomStateType = Union[np.random.RandomState, da.random.RandomState, int]

sgkit/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,16 @@
33

44
import numpy as np
55
from numba import guvectorize
6+
from numpy.typing import DTypeLike
67
from xarray import Dataset
78

89
from . import variables
9-
from .typing import ArrayLike, DType
10+
from .typing import ArrayLike
1011

1112

1213
def check_array_like(
1314
a: Any,
14-
dtype: Union[None, DType, Set[DType]] = None,
15+
dtype: Union[None, DTypeLike, Set[DTypeLike]] = None,
1516
kind: Union[None, str, Set[str]] = None,
1617
ndim: Union[None, int, Set[int]] = None,
1718
) -> None:

sgkit/window.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22

33
import dask.array as da
44
import numpy as np
5+
from numpy.typing import DTypeLike
56
from xarray import Dataset
67

78
from sgkit.utils import conditional_merge_datasets, create_dataset
89
from sgkit.variables import window_contig, window_start, window_stop
910

10-
from .typing import ArrayLike, DType
11+
from .typing import ArrayLike
1112

1213
# Window definition (user code)
1314

@@ -110,7 +111,7 @@ def moving_statistic(
110111
statistic: Callable[..., ArrayLike],
111112
size: int,
112113
step: int,
113-
dtype: DType,
114+
dtype: DTypeLike,
114115
**kwargs: Any,
115116
) -> da.Array:
116117
"""A Dask implementation of scikit-allel's moving_statistic function."""
@@ -135,7 +136,7 @@ def window_statistic(
135136
statistic: Callable[..., ArrayLike],
136137
window_starts: ArrayLike,
137138
window_stops: ArrayLike,
138-
dtype: DType,
139+
dtype: DTypeLike,
139140
chunks: Any = None,
140141
new_axis: Union[None, int, Iterable[int]] = None,
141142
**kwargs: Any,

0 commit comments

Comments
 (0)