Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions src/spatialdata_plot/pl/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from scanpy._settings import settings as sc_settings
from spatialdata import get_extent, get_values, join_spatialelement_table
from spatialdata.models import PointsModel, ShapesModel, get_table_keys
from spatialdata.transformations import get_transformation, set_transformation
from spatialdata.transformations import set_transformation
from spatialdata.transformations.transformations import Identity
from xarray import DataTree

Expand All @@ -44,8 +44,6 @@
_get_colors_for_categorical_obs,
_get_extent_and_range_for_datashader_canvas,
_get_linear_colormap,
_get_transformation_matrix_for_datashader,
_hex_no_alpha,
_is_coercable_to_float,
_map_color_seg,
_maybe_set_colors,
Expand Down Expand Up @@ -186,10 +184,9 @@ def _render_shapes(
sdata_filt.shapes[element].loc[is_point, "geometry"] = _geometry[is_point].buffer(scale.to_numpy())

# apply transformations to the individual points
element_trans = get_transformation(sdata_filt.shapes[element], to_coordinate_system=coordinate_system)
tm = _get_transformation_matrix_for_datashader(element_trans)
tm = trans.get_matrix()
transformed_element = sdata_filt.shapes[element].transform(
lambda x: (np.hstack([x, np.ones((x.shape[0], 1))]) @ tm)[:, :2]
lambda x: (np.hstack([x, np.ones((x.shape[0], 1))]) @ tm.T)[:, :2]
)
transformed_element = ShapesModel.parse(
gpd.GeoDataFrame(
Expand Down
140 changes: 1 addition & 139 deletions src/spatialdata_plot/pl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import matplotlib.transforms as mtransforms
import numpy as np
import numpy.ma as ma
import numpy.typing as npt
import pandas as pd
import shapely
import spatialdata as sd
Expand Down Expand Up @@ -65,11 +64,8 @@
from spatialdata._core.query.relational_query import _locate_value
from spatialdata._types import ArrayLike
from spatialdata.models import Image2DModel, Labels2DModel, SpatialElement

# from spatialdata.transformations.transformations import Scale
from spatialdata.transformations import Affine, Identity, MapAxis, Scale, Translation
from spatialdata.transformations import Sequence as SDSequence
from spatialdata.transformations.operations import get_transformation
from spatialdata.transformations.transformations import Scale
from xarray import DataArray, DataTree

from spatialdata_plot._logging import logger
Expand Down Expand Up @@ -2379,137 +2375,3 @@ def _prepare_transformation(
trans_data = trans + ax.transData if ax is not None else None

return trans, trans_data


def _get_datashader_trans_matrix_of_single_element(
trans: Identity | Scale | Affine | MapAxis | Translation,
) -> npt.NDArray[Any]:
flip_matrix = np.array([[1, 0, 0], [0, -1, 0], [0, 0, 1]])
tm: npt.NDArray[Any] = trans.to_affine_matrix(("x", "y"), ("x", "y"))

if isinstance(trans, Identity):
return np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
if isinstance(trans, (Scale | Affine)):
# idea: "flip the y-axis", apply transformation, flip back
flip_and_transform: npt.NDArray[Any] = flip_matrix @ tm @ flip_matrix
return flip_and_transform
if isinstance(trans, MapAxis):
# no flipping needed
return tm
# for a Translation, we need the transposed transformation matrix
tm_T = tm.T
assert isinstance(tm_T, np.ndarray)
return tm_T


def _get_transformation_matrix_for_datashader(
trans: Scale | Identity | Affine | MapAxis | Translation | SDSequence,
) -> npt.NDArray[Any]:
"""Get the affine matrix needed to transform shapes for rendering with datashader."""
if isinstance(trans, SDSequence):
tm = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
for x in trans.transformations:
tm = tm @ _get_datashader_trans_matrix_of_single_element(x)
return tm
return _get_datashader_trans_matrix_of_single_element(trans)


def _datashader_map_aggregate_to_color(
agg: DataArray,
cmap: str | list[str] | ListedColormap,
color_key: None | list[str] = None,
min_alpha: float = 40,
span: None | list[float] = None,
clip: bool = True,
) -> ds.tf.Image | np.ndarray[Any, np.dtype[np.uint8]]:
"""ds.tf.shade() part, ensuring correct clipping behavior.

If necessary (norm.clip=False), split shading in 3 parts and in the end, stack results.
This ensures the correct clipping behavior, because else datashader would always automatically clip.
"""
if not clip and isinstance(cmap, Colormap) and span is not None:
# in case we use datashader together with a Normalize object where clip=False
# why we need this is documented in https://github.com/scverse/spatialdata-plot/issues/372
agg_in = agg.where((agg >= span[0]) & (agg <= span[1]))
img_in = ds.tf.shade(
agg_in,
cmap=cmap,
span=(span[0], span[1]),
how="linear",
color_key=color_key,
min_alpha=min_alpha,
)

agg_under = agg.where(agg < span[0])
img_under = ds.tf.shade(
agg_under,
cmap=[to_hex(cmap.get_under())[:7]],
min_alpha=min_alpha,
color_key=color_key,
)

agg_over = agg.where(agg > span[1])
img_over = ds.tf.shade(
agg_over,
cmap=[to_hex(cmap.get_over())[:7]],
min_alpha=min_alpha,
color_key=color_key,
)

# stack the 3 arrays manually: go from under, through in to over and always overlay the values where alpha=0
stack = img_under.to_numpy().base
if stack is None:
stack = img_in.to_numpy().base
else:
stack[stack[:, :, 3] == 0] = img_in.to_numpy().base[stack[:, :, 3] == 0]
img_over = img_over.to_numpy().base
if img_over is not None:
stack[stack[:, :, 3] == 0] = img_over[stack[:, :, 3] == 0]
return stack

return ds.tf.shade(
agg,
cmap=cmap,
color_key=color_key,
min_alpha=min_alpha,
span=span,
how="linear",
)


def _hex_no_alpha(hex: str) -> str:
"""
Return a hex color string without an alpha component.

Parameters
----------
hex : str
The input hex color string. Must be in one of the following formats:
- "#RRGGBB": a hex color without an alpha channel.
- "#RRGGBBAA": a hex color with an alpha channel that will be removed.

Returns
-------
str
The hex color string in "#RRGGBB" format.
"""
if not isinstance(hex, str):
raise TypeError("Input must be a string")
if not hex.startswith("#"):
raise ValueError("Invalid hex color: must start with '#'")

hex_digits = hex[1:]
length = len(hex_digits)

if length == 6:
if not all(c in "0123456789abcdefABCDEF" for c in hex_digits):
raise ValueError("Invalid hex color: contains non-hex characters")
return hex # Already in #RRGGBB format.

if length == 8:
if not all(c in "0123456789abcdefABCDEF" for c in hex_digits):
raise ValueError("Invalid hex color: contains non-hex characters")
# Return only the first 6 characters, stripping the alpha.
return "#" + hex_digits[:6]

raise ValueError("Invalid hex color length: must be either '#RRGGBB' or '#RRGGBBAA'")
Loading