Skip to content

Commit fad82a0

Browse files
Backport PR #1583: Allow plots to use adata.obs index as groupby (#1617)
Co-authored-by: Fidel Ramirez <[email protected]>
1 parent e1ed250 commit fad82a0

File tree

6 files changed

+250
-85
lines changed

6 files changed

+250
-85
lines changed

scanpy/get.py

Lines changed: 55 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33

44
import numpy as np
55
import pandas as pd
6-
from scipy.sparse import spmatrix, issparse
6+
from scipy.sparse import spmatrix
77

88
from anndata import AnnData
9+
import warnings
910

1011
# --------------------------------------------------------------------------------
1112
# Plotting data helpers
@@ -166,11 +167,46 @@ def obs_df(
166167
var_names = []
167168
var_symbol = []
168169
not_found = []
169-
for key in keys:
170+
171+
# check that adata.obs does not contain duplicated columns
172+
# if duplicated columns names are present, they will
173+
# be further duplicated when selecting them.
174+
if not adata.obs.columns.is_unique:
175+
dup_obs = adata.obs.columns[adata.obs.columns.duplicated()].tolist()
176+
raise ValueError(
177+
"adata.obs contains duplicated columns. Please rename or remove "
178+
"these columns first.\n`"
179+
f"Duplicated columns {dup_obs}"
180+
)
181+
182+
# check that adata.var does not contain duplicated indices
183+
# If duplicated indices are present the selection of var by numeric
184+
# index
185+
if not adata.var_names.is_unique:
186+
raise ValueError(
187+
"adata.var contains duplicated var names\n"
188+
"Please rename these var names first for example using "
189+
"`adata.var_names_make_unique()`"
190+
)
191+
# use only unique keys, otherwise duplicated keys will
192+
# further duplicate when reordering the keys later in the function
193+
for key in np.unique(keys):
170194
if key in adata.obs.columns:
171195
obs_names.append(key)
196+
if key in gene_names.index:
197+
raise KeyError(
198+
f'The key `{key}` is found in both adata.obs and adata.var_names.'
199+
)
172200
elif key in gene_names.index:
173-
var_names.append(gene_names[key])
201+
val = gene_names[key]
202+
if isinstance(val, pd.Series):
203+
# while var_names must be unique, adata.var[gene_symbols] does not
204+
# It's still ambiguous to refer to a duplicated entry though.
205+
assert gene_symbols is not None
206+
raise KeyError(
207+
f"Found duplicate entries for '{key}' in adata.var['{gene_symbols}']."
208+
)
209+
var_names.append(val)
174210
var_symbol.append(key)
175211
else:
176212
not_found.append(key)
@@ -216,13 +252,16 @@ def obs_df(
216252

217253
if issparse(matrix):
218254
matrix = matrix.toarray()
219-
df = df.join(pd.DataFrame(matrix, columns=var_symbol, index=adata.obs.index))
255+
df = pd.concat(
256+
[df, pd.DataFrame(matrix, columns=var_symbol, index=adata.obs.index)],
257+
axis=1,
258+
)
220259

221260
# add obs values
222261
if len(obs_names) > 0:
223-
df = df.join(adata.obs[obs_names])
262+
df = pd.concat([df, adata.obs[obs_names]], axis=1)
224263

225-
# reorder columns to given order
264+
# reorder columns to given order (including duplicates keys if present)
226265
df = df[keys]
227266
for k, idx in obsm_keys:
228267
added_k = f"{k}-{idx}"
@@ -233,6 +272,7 @@ def obs_df(
233272
df[added_k] = np.ravel(val[:, idx].toarray())
234273
elif isinstance(val, pd.DataFrame):
235274
df[added_k] = val.loc[:, idx]
275+
236276
return df
237277

238278

@@ -266,7 +306,10 @@ def var_df(
266306
obs_names = []
267307
var_names = []
268308
not_found = []
269-
for key in keys:
309+
310+
# use only unique keys, otherwise duplicated keys will
311+
# further duplicate when reordering the keys later in the function
312+
for key in np.unique(keys):
270313
if key in adata.obs_names:
271314
obs_names.append(key)
272315
elif key in adata.var.columns:
@@ -298,11 +341,14 @@ def var_df(
298341
if issparse(matrix):
299342
matrix = matrix.toarray()
300343

301-
df = df.join(pd.DataFrame(matrix.T, columns=obs_names, index=adata.var.index))
344+
df = pd.concat(
345+
[df, pd.DataFrame(matrix.T, columns=obs_names, index=adata.var.index)],
346+
axis=1,
347+
)
302348

303349
# add obs values
304350
if len(var_names) > 0:
305-
df = df.join(adata.var[var_names])
351+
df = pd.concat([df, adata.var[var_names]], axis=1)
306352

307353
# reorder columns to given order
308354
df = df[keys]

scanpy/plotting/_anndata.py

Lines changed: 48 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from anndata import AnnData
1313
from cycler import Cycler
1414
from matplotlib.axes import Axes
15-
from pandas.api.types import is_categorical_dtype
15+
from pandas.api.types import is_categorical_dtype, is_numeric_dtype
1616
from scipy.sparse import issparse
1717
from matplotlib import pyplot as pl
1818
from matplotlib import rcParams
@@ -1792,7 +1792,9 @@ def _prepare_dataframe(
17921792
use_raw
17931793
Whether to use `raw` attribute of `adata`. Defaults to `True` if `.raw` is present.
17941794
log
1795-
Use the log of the values
1795+
Use the log of the values.
1796+
layer
1797+
AnnData layer to use. Takes precedence over `use_raw`
17961798
num_categories
17971799
Only used if groupby observation is not categorical. This value
17981800
determines the number of groups into which the groupby observation
@@ -1804,90 +1806,72 @@ def _prepare_dataframe(
18041806
-------
18051807
Tuple of `pandas.DataFrame` and list of categories.
18061808
"""
1807-
from scipy.sparse import issparse
18081809

18091810
sanitize_anndata(adata)
18101811
use_raw = _check_use_raw(adata, use_raw)
1812+
if layer is not None:
1813+
use_raw = False
18111814
if isinstance(var_names, str):
18121815
var_names = [var_names]
18131816

1817+
groupby_index = None
18141818
if groupby is not None:
18151819
if isinstance(groupby, str):
18161820
# if not a list, turn into a list
18171821
groupby = [groupby]
18181822
for group in groupby:
1819-
if group not in adata.obs_keys():
1823+
if group not in list(adata.obs_keys()) + [adata.obs.index.name]:
1824+
if adata.obs.index.name is not None:
1825+
msg = f' or index name "{adata.obs.index.name}"'
1826+
else:
1827+
msg = ''
18201828
raise ValueError(
18211829
'groupby has to be a valid observation. '
1822-
f'Given {group}, is not in observations: {adata.obs_keys()}'
1830+
f'Given {group}, is not in observations: {adata.obs_keys()}' + msg
18231831
)
1824-
1825-
if gene_symbols is not None and gene_symbols in adata.var.columns:
1826-
# translate gene_symbols to var_names
1827-
# slow method but gives a meaningful error if no gene symbol is found:
1828-
translated_var_names = []
1829-
# if we're using raw to plot, we should also do gene symbol translations
1830-
# using raw
1831-
if use_raw:
1832-
adata_or_raw = adata.raw
1833-
else:
1834-
adata_or_raw = adata
1835-
for symbol in var_names:
1836-
if symbol not in adata_or_raw.var[gene_symbols].values:
1837-
logg.error(
1838-
f"Gene symbol {symbol!r} not found in given "
1839-
f"gene_symbols column: {gene_symbols!r}"
1832+
if group in adata.obs.keys() and group == adata.obs.index.name:
1833+
raise ValueError(
1834+
f'Given group {group} is both and index and a column level, '
1835+
'which is ambiguous.'
18401836
)
1841-
return
1842-
translated_var_names.append(
1843-
adata_or_raw.var[adata_or_raw.var[gene_symbols] == symbol].index[0]
1844-
)
1845-
symbols = var_names
1846-
var_names = translated_var_names
1847-
if layer is not None:
1848-
if layer not in adata.layers.keys():
1849-
raise KeyError(
1850-
f'Selected layer: {layer} is not in the layers list. '
1851-
f'The list of valid layers is: {adata.layers.keys()}'
1852-
)
1853-
matrix = adata[:, var_names].layers[layer]
1854-
elif use_raw:
1855-
matrix = adata.raw[:, var_names].X
1856-
else:
1857-
matrix = adata[:, var_names].X
1837+
if group == adata.obs.index.name:
1838+
groupby_index = group
1839+
if groupby_index is not None:
1840+
# obs_tidy contains adata.obs.index
1841+
# and does not need to be given
1842+
groupby = groupby.copy() # copy to not modify user passed parameter
1843+
groupby.remove(groupby_index)
1844+
keys = list(groupby) + list(np.unique(var_names))
1845+
obs_tidy = get.obs_df(
1846+
adata, keys=keys, layer=layer, use_raw=use_raw, gene_symbols=gene_symbols
1847+
)
1848+
assert np.all(np.array(keys) == np.array(obs_tidy.columns))
18581849

1859-
if issparse(matrix):
1860-
matrix = matrix.toarray()
1861-
if log:
1862-
matrix = np.log1p(matrix)
1850+
if groupby_index is not None:
1851+
# reset index to treat all columns the same way.
1852+
obs_tidy.reset_index(inplace=True)
1853+
groupby.append(groupby_index)
18631854

1864-
obs_tidy = pd.DataFrame(matrix, columns=var_names)
18651855
if groupby is None:
1866-
groupby = ''
18671856
categorical = pd.Series(np.repeat('', len(obs_tidy))).astype('category')
1857+
elif len(groupby) == 1 and is_numeric_dtype(obs_tidy[groupby[0]]):
1858+
# if the groupby column is not categorical, turn it into one
1859+
# by subdividing into `num_categories` categories
1860+
categorical = pd.cut(obs_tidy[groupby[0]], num_categories)
1861+
elif len(groupby) == 1:
1862+
categorical = obs_tidy[groupby[0]].astype('category')
1863+
categorical.name = groupby[0]
18681864
else:
1869-
if len(groupby) == 1 and not is_categorical_dtype(adata.obs[groupby[0]]):
1870-
# if the groupby column is not categorical, turn it into one
1871-
# by subdividing into `num_categories` categories
1872-
categorical = pd.cut(adata.obs[groupby[0]], num_categories)
1873-
else:
1874-
categorical = adata.obs[groupby[0]]
1875-
if len(groupby) > 1:
1876-
for group in groupby[1:]:
1877-
# create new category by merging the given groupby categories
1878-
categorical = (
1879-
categorical.astype(str) + "_" + adata.obs[group].astype(str)
1880-
).astype('category')
1881-
categorical.name = "_".join(groupby)
1882-
obs_tidy.set_index(categorical, inplace=True)
1883-
if gene_symbols is not None:
1884-
# translate the column names to the symbol names
1885-
obs_tidy.rename(
1886-
columns={var_names[x]: symbols[x] for x in range(len(var_names))},
1887-
inplace=True,
1888-
)
1865+
# join the groupby values using "_" to make a new 'category'
1866+
categorical = obs_tidy[groupby].agg('_'.join, axis=1).astype('category')
1867+
categorical.name = "_".join(groupby)
1868+
1869+
obs_tidy = obs_tidy[var_names].set_index(categorical)
18891870
categories = obs_tidy.index.categories
18901871

1872+
if log:
1873+
obs_tidy = np.log1p(obs_tidy)
1874+
18911875
return categories, obs_tidy
18921876

18931877

scanpy/plotting/_baseplot_class.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from matplotlib.axes import Axes
1111
from matplotlib import pyplot as pl
1212
from matplotlib import gridspec
13+
from warnings import warn
1314

1415
from .. import logging as logg
1516
from .._compat import Literal
@@ -66,6 +67,8 @@ class BasePlot(object):
6667
DEFAULT_LEGENDS_WIDTH = 1.5
6768
DEFAULT_COLOR_LEGEND_TITLE = 'Expression\nlevel in group'
6869

70+
MAX_NUM_CATEGORIES = 500 # maximum number of categories allowed to be plotted
71+
6972
def __init__(
7073
self,
7174
adata: AnnData,
@@ -109,6 +112,11 @@ def __init__(
109112
layer=layer,
110113
gene_symbols=gene_symbols,
111114
)
115+
if len(self.categories) > self.MAX_NUM_CATEGORIES:
116+
warn(
117+
f"Over {self.MAX_NUM_CATEGORIES} categories found. "
118+
"Plot would be very large."
119+
)
112120

113121
if categories_order is not None:
114122
if set(self.obs_tidy.index.categories) != set(categories_order):
16.3 KB
Loading

0 commit comments

Comments
 (0)