12
12
from anndata import AnnData
13
13
from cycler import Cycler
14
14
from matplotlib .axes import Axes
15
- from pandas .api .types import is_categorical_dtype
15
+ from pandas .api .types import is_categorical_dtype , is_numeric_dtype
16
16
from scipy .sparse import issparse
17
17
from matplotlib import pyplot as pl
18
18
from matplotlib import rcParams
@@ -1792,7 +1792,9 @@ def _prepare_dataframe(
1792
1792
use_raw
1793
1793
Whether to use `raw` attribute of `adata`. Defaults to `True` if `.raw` is present.
1794
1794
log
1795
- Use the log of the values
1795
+ Use the log of the values.
1796
+ layer
1797
+ AnnData layer to use. Takes precedence over `use_raw`
1796
1798
num_categories
1797
1799
Only used if groupby observation is not categorical. This value
1798
1800
determines the number of groups into which the groupby observation
@@ -1804,90 +1806,72 @@ def _prepare_dataframe(
1804
1806
-------
1805
1807
Tuple of `pandas.DataFrame` and list of categories.
1806
1808
"""
1807
- from scipy .sparse import issparse
1808
1809
1809
1810
sanitize_anndata (adata )
1810
1811
use_raw = _check_use_raw (adata , use_raw )
1812
+ if layer is not None :
1813
+ use_raw = False
1811
1814
if isinstance (var_names , str ):
1812
1815
var_names = [var_names ]
1813
1816
1817
+ groupby_index = None
1814
1818
if groupby is not None :
1815
1819
if isinstance (groupby , str ):
1816
1820
# if not a list, turn into a list
1817
1821
groupby = [groupby ]
1818
1822
for group in groupby :
1819
- if group not in adata .obs_keys ():
1823
+ if group not in list (adata .obs_keys ()) + [adata .obs .index .name ]:
1824
+ if adata .obs .index .name is not None :
1825
+ msg = f' or index name "{ adata .obs .index .name } "'
1826
+ else :
1827
+ msg = ''
1820
1828
raise ValueError (
1821
1829
'groupby has to be a valid observation. '
1822
- f'Given { group } , is not in observations: { adata .obs_keys ()} '
1830
+ f'Given { group } , is not in observations: { adata .obs_keys ()} ' + msg
1823
1831
)
1824
-
1825
- if gene_symbols is not None and gene_symbols in adata .var .columns :
1826
- # translate gene_symbols to var_names
1827
- # slow method but gives a meaningful error if no gene symbol is found:
1828
- translated_var_names = []
1829
- # if we're using raw to plot, we should also do gene symbol translations
1830
- # using raw
1831
- if use_raw :
1832
- adata_or_raw = adata .raw
1833
- else :
1834
- adata_or_raw = adata
1835
- for symbol in var_names :
1836
- if symbol not in adata_or_raw .var [gene_symbols ].values :
1837
- logg .error (
1838
- f"Gene symbol { symbol !r} not found in given "
1839
- f"gene_symbols column: { gene_symbols !r} "
1832
+ if group in adata .obs .keys () and group == adata .obs .index .name :
1833
+ raise ValueError (
1834
+ f'Given group { group } is both and index and a column level, '
1835
+ 'which is ambiguous.'
1840
1836
)
1841
- return
1842
- translated_var_names .append (
1843
- adata_or_raw .var [adata_or_raw .var [gene_symbols ] == symbol ].index [0 ]
1844
- )
1845
- symbols = var_names
1846
- var_names = translated_var_names
1847
- if layer is not None :
1848
- if layer not in adata .layers .keys ():
1849
- raise KeyError (
1850
- f'Selected layer: { layer } is not in the layers list. '
1851
- f'The list of valid layers is: { adata .layers .keys ()} '
1852
- )
1853
- matrix = adata [:, var_names ].layers [layer ]
1854
- elif use_raw :
1855
- matrix = adata .raw [:, var_names ].X
1856
- else :
1857
- matrix = adata [:, var_names ].X
1837
+ if group == adata .obs .index .name :
1838
+ groupby_index = group
1839
+ if groupby_index is not None :
1840
+ # obs_tidy contains adata.obs.index
1841
+ # and does not need to be given
1842
+ groupby = groupby .copy () # copy to not modify user passed parameter
1843
+ groupby .remove (groupby_index )
1844
+ keys = list (groupby ) + list (np .unique (var_names ))
1845
+ obs_tidy = get .obs_df (
1846
+ adata , keys = keys , layer = layer , use_raw = use_raw , gene_symbols = gene_symbols
1847
+ )
1848
+ assert np .all (np .array (keys ) == np .array (obs_tidy .columns ))
1858
1849
1859
- if issparse ( matrix ) :
1860
- matrix = matrix . toarray ()
1861
- if log :
1862
- matrix = np . log1p ( matrix )
1850
+ if groupby_index is not None :
1851
+ # reset index to treat all columns the same way.
1852
+ obs_tidy . reset_index ( inplace = True )
1853
+ groupby . append ( groupby_index )
1863
1854
1864
- obs_tidy = pd .DataFrame (matrix , columns = var_names )
1865
1855
if groupby is None :
1866
- groupby = ''
1867
1856
categorical = pd .Series (np .repeat ('' , len (obs_tidy ))).astype ('category' )
1857
+ elif len (groupby ) == 1 and is_numeric_dtype (obs_tidy [groupby [0 ]]):
1858
+ # if the groupby column is not categorical, turn it into one
1859
+ # by subdividing into `num_categories` categories
1860
+ categorical = pd .cut (obs_tidy [groupby [0 ]], num_categories )
1861
+ elif len (groupby ) == 1 :
1862
+ categorical = obs_tidy [groupby [0 ]].astype ('category' )
1863
+ categorical .name = groupby [0 ]
1868
1864
else :
1869
- if len (groupby ) == 1 and not is_categorical_dtype (adata .obs [groupby [0 ]]):
1870
- # if the groupby column is not categorical, turn it into one
1871
- # by subdividing into `num_categories` categories
1872
- categorical = pd .cut (adata .obs [groupby [0 ]], num_categories )
1873
- else :
1874
- categorical = adata .obs [groupby [0 ]]
1875
- if len (groupby ) > 1 :
1876
- for group in groupby [1 :]:
1877
- # create new category by merging the given groupby categories
1878
- categorical = (
1879
- categorical .astype (str ) + "_" + adata .obs [group ].astype (str )
1880
- ).astype ('category' )
1881
- categorical .name = "_" .join (groupby )
1882
- obs_tidy .set_index (categorical , inplace = True )
1883
- if gene_symbols is not None :
1884
- # translate the column names to the symbol names
1885
- obs_tidy .rename (
1886
- columns = {var_names [x ]: symbols [x ] for x in range (len (var_names ))},
1887
- inplace = True ,
1888
- )
1865
+ # join the groupby values using "_" to make a new 'category'
1866
+ categorical = obs_tidy [groupby ].agg ('_' .join , axis = 1 ).astype ('category' )
1867
+ categorical .name = "_" .join (groupby )
1868
+
1869
+ obs_tidy = obs_tidy [var_names ].set_index (categorical )
1889
1870
categories = obs_tidy .index .categories
1890
1871
1872
+ if log :
1873
+ obs_tidy = np .log1p (obs_tidy )
1874
+
1891
1875
return categories , obs_tidy
1892
1876
1893
1877
0 commit comments