diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index 21ed8226..e806bc8e 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -19,7 +19,7 @@ from scanpy._settings import settings as sc_settings from spatialdata import get_extent, get_values, join_spatialelement_table from spatialdata.models import PointsModel, ShapesModel, get_table_keys -from spatialdata.transformations import get_transformation, set_transformation +from spatialdata.transformations import set_transformation from spatialdata.transformations.transformations import Identity from xarray import DataTree @@ -44,8 +44,6 @@ _get_colors_for_categorical_obs, _get_extent_and_range_for_datashader_canvas, _get_linear_colormap, - _get_transformation_matrix_for_datashader, - _hex_no_alpha, _is_coercable_to_float, _map_color_seg, _maybe_set_colors, @@ -186,10 +184,9 @@ def _render_shapes( sdata_filt.shapes[element].loc[is_point, "geometry"] = _geometry[is_point].buffer(scale.to_numpy()) # apply transformations to the individual points - element_trans = get_transformation(sdata_filt.shapes[element], to_coordinate_system=coordinate_system) - tm = _get_transformation_matrix_for_datashader(element_trans) + tm = trans.get_matrix() transformed_element = sdata_filt.shapes[element].transform( - lambda x: (np.hstack([x, np.ones((x.shape[0], 1))]) @ tm)[:, :2] + lambda x: (np.hstack([x, np.ones((x.shape[0], 1))]) @ tm.T)[:, :2] ) transformed_element = ShapesModel.parse( gpd.GeoDataFrame( diff --git a/src/spatialdata_plot/pl/utils.py b/src/spatialdata_plot/pl/utils.py index 574ca56b..496db9f2 100644 --- a/src/spatialdata_plot/pl/utils.py +++ b/src/spatialdata_plot/pl/utils.py @@ -19,7 +19,6 @@ import matplotlib.transforms as mtransforms import numpy as np import numpy.ma as ma -import numpy.typing as npt import pandas as pd import shapely import spatialdata as sd @@ -65,11 +64,8 @@ from spatialdata._core.query.relational_query import _locate_value from spatialdata._types import ArrayLike from spatialdata.models import Image2DModel, Labels2DModel, SpatialElement - -# from spatialdata.transformations.transformations import Scale -from spatialdata.transformations import Affine, Identity, MapAxis, Scale, Translation -from spatialdata.transformations import Sequence as SDSequence from spatialdata.transformations.operations import get_transformation +from spatialdata.transformations.transformations import Scale from xarray import DataArray, DataTree from spatialdata_plot._logging import logger @@ -2379,137 +2375,3 @@ def _prepare_transformation( trans_data = trans + ax.transData if ax is not None else None return trans, trans_data - - -def _get_datashader_trans_matrix_of_single_element( - trans: Identity | Scale | Affine | MapAxis | Translation, -) -> npt.NDArray[Any]: - flip_matrix = np.array([[1, 0, 0], [0, -1, 0], [0, 0, 1]]) - tm: npt.NDArray[Any] = trans.to_affine_matrix(("x", "y"), ("x", "y")) - - if isinstance(trans, Identity): - return np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) - if isinstance(trans, (Scale | Affine)): - # idea: "flip the y-axis", apply transformation, flip back - flip_and_transform: npt.NDArray[Any] = flip_matrix @ tm @ flip_matrix - return flip_and_transform - if isinstance(trans, MapAxis): - # no flipping needed - return tm - # for a Translation, we need the transposed transformation matrix - tm_T = tm.T - assert isinstance(tm_T, np.ndarray) - return tm_T - - -def _get_transformation_matrix_for_datashader( - trans: Scale | Identity | Affine | MapAxis | Translation | SDSequence, -) -> npt.NDArray[Any]: - """Get the affine matrix needed to transform shapes for rendering with datashader.""" - if isinstance(trans, SDSequence): - tm = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) - for x in trans.transformations: - tm = tm @ _get_datashader_trans_matrix_of_single_element(x) - return tm - return _get_datashader_trans_matrix_of_single_element(trans) - - -def _datashader_map_aggregate_to_color( - agg: DataArray, - cmap: str | list[str] | ListedColormap, - color_key: None | list[str] = None, - min_alpha: float = 40, - span: None | list[float] = None, - clip: bool = True, -) -> ds.tf.Image | np.ndarray[Any, np.dtype[np.uint8]]: - """ds.tf.shade() part, ensuring correct clipping behavior. - - If necessary (norm.clip=False), split shading in 3 parts and in the end, stack results. - This ensures the correct clipping behavior, because else datashader would always automatically clip. - """ - if not clip and isinstance(cmap, Colormap) and span is not None: - # in case we use datashader together with a Normalize object where clip=False - # why we need this is documented in https://github.com/scverse/spatialdata-plot/issues/372 - agg_in = agg.where((agg >= span[0]) & (agg <= span[1])) - img_in = ds.tf.shade( - agg_in, - cmap=cmap, - span=(span[0], span[1]), - how="linear", - color_key=color_key, - min_alpha=min_alpha, - ) - - agg_under = agg.where(agg < span[0]) - img_under = ds.tf.shade( - agg_under, - cmap=[to_hex(cmap.get_under())[:7]], - min_alpha=min_alpha, - color_key=color_key, - ) - - agg_over = agg.where(agg > span[1]) - img_over = ds.tf.shade( - agg_over, - cmap=[to_hex(cmap.get_over())[:7]], - min_alpha=min_alpha, - color_key=color_key, - ) - - # stack the 3 arrays manually: go from under, through in to over and always overlay the values where alpha=0 - stack = img_under.to_numpy().base - if stack is None: - stack = img_in.to_numpy().base - else: - stack[stack[:, :, 3] == 0] = img_in.to_numpy().base[stack[:, :, 3] == 0] - img_over = img_over.to_numpy().base - if img_over is not None: - stack[stack[:, :, 3] == 0] = img_over[stack[:, :, 3] == 0] - return stack - - return ds.tf.shade( - agg, - cmap=cmap, - color_key=color_key, - min_alpha=min_alpha, - span=span, - how="linear", - ) - - -def _hex_no_alpha(hex: str) -> str: - """ - Return a hex color string without an alpha component. - - Parameters - ---------- - hex : str - The input hex color string. Must be in one of the following formats: - - "#RRGGBB": a hex color without an alpha channel. - - "#RRGGBBAA": a hex color with an alpha channel that will be removed. - - Returns - ------- - str - The hex color string in "#RRGGBB" format. - """ - if not isinstance(hex, str): - raise TypeError("Input must be a string") - if not hex.startswith("#"): - raise ValueError("Invalid hex color: must start with '#'") - - hex_digits = hex[1:] - length = len(hex_digits) - - if length == 6: - if not all(c in "0123456789abcdefABCDEF" for c in hex_digits): - raise ValueError("Invalid hex color: contains non-hex characters") - return hex # Already in #RRGGBB format. - - if length == 8: - if not all(c in "0123456789abcdefABCDEF" for c in hex_digits): - raise ValueError("Invalid hex color: contains non-hex characters") - # Return only the first 6 characters, stripping the alpha. - return "#" + hex_digits[:6] - - raise ValueError("Invalid hex color length: must be either '#RRGGBB' or '#RRGGBBAA'")