"""
Plotting functions for TrackCell package.
This module provides functions for visualizing spatial transcriptomics data,
including cell polygon visualization.
"""
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize, ListedColormap
from typing import Optional, Union, List, Dict
import warnings
try:
import geopandas as gpd
from shapely import wkt
HAS_GEOPANDAS = True
except ImportError:
HAS_GEOPANDAS = False
warnings.warn("geopandas and shapely are required for spatial_cell function")
def _process_background_image(spatial_info, img_key, data_coords_range=None):
"""
Process background image similar to scanpy's approach.
Parameters
----------
spatial_info : dict
Spatial information dictionary from adata.uns['spatial'][library_id]
img_key : str or None
Key for the image to use
data_coords_range : tuple, optional
Tuple of (x_min, y_min, x_max, y_max) for data coordinate range
Returns
-------
img : numpy.ndarray or None
Background image array
img_extent : list or None
Image extent [left, right, bottom, top] in data coordinates
"""
if img_key is None:
img_key = "hires" if "hires" in spatial_info.get("images", {}) else None
if not img_key or "images" not in spatial_info or img_key not in spatial_info["images"]:
return None, None
img = spatial_info["images"][img_key]
scalefactors = spatial_info.get("scalefactors", {})
# Get scale factor for the image (scanpy way)
# For hires: tissue_hires_scalef (~0.05-0.2)
# For lowres: tissue_lowres_scalef (~0.03)
# For fullres: scale_factor = 1.0
scale_key = f"tissue_{img_key}_scalef"
scale_factor = scalefactors.get(scale_key, 1.0)
# Calculate image extent in data coordinates
# Image shape is (height, width) in pixels
img_height, img_width = img.shape[:2]
if scale_factor < 1.0:
# Image is downscaled, need to scale up the extent
# The image pixels represent a smaller region in full-res coordinates
img_extent = [0, img_width / scale_factor, img_height / scale_factor, 0]
else:
# Full resolution image
img_extent = [0, img_width, img_height, 0]
return img, img_extent
[docs]
def spatial_cell(
adata,
color: Optional[Union[str, List[str]]] = None,
groups: Optional[List[str]] = None,
groupby: Optional[str] = None,
library_id: Optional[str] = None,
size: float = 1.0,
figsize: Optional[tuple] = None,
cmap: str = "viridis",
palette: Optional[Union[dict, list, np.ndarray]] = None,
vmin: Optional[float] = None,
vmax: Optional[float] = None,
img_key: Optional[str] = None,
basis: str = "spatial",
edges_width: float = 0.5,
edges_color: str = "black",
alpha: float = 0.8,
alpha_img: float = 0.5,
show: bool = True,
ax: Optional[plt.Axes] = None,
legend: bool = True,
xlabel: Optional[str] = "spatial 1",
ylabel: Optional[str] = "spatial 2",
show_ticks: bool = False,
**kwargs
):
"""
Plot spatial transcriptomics data with cell polygons instead of points.
This function visualizes cells as polygons (from cell segmentation) rather than
simple points, providing a more accurate representation of cell boundaries.
Uses GeoDataFrame.plot() for efficient rendering and automatic legend generation.
Parameters
----------
adata : AnnData
Annotated data object with spatial information and cell geometries.
color : str or list of str, optional
Keys for observation/categorical or continuous variables to color cells.
Can be a single key or a list of keys for multiple plots.
Can be:
- A column name in `adata.obs` (metadata)
- A gene name in `adata.var_names` (gene expression)
- None: Only display the H&E background image without cell polygons.
When None, axis ticks and labels are automatically shown.
groups : list of str, optional
Subset of groups to plot. If None, plots all groups.
Requires either `color` to be a categorical column in `adata.obs` or `groupby` to be specified.
groupby : str, optional
Column name in `adata.obs` to use for filtering with `groups` parameter.
If None and `groups` is specified, will use `color` if it's a categorical column in `adata.obs`.
This is useful when `color` is a continuous variable (e.g., gene expression) but you want
to filter by a categorical column (e.g., 'classification').
library_id : str, optional
Key in adata.uns["spatial"] containing spatial information.
If None, uses the first available library_id (similar to sc.pl.spatial).
Default is None, which will auto-select the first library_id.
size : float, default 1.0
Size scaling factor for cells (not used for polygons, kept for API compatibility).
This parameter is currently not implemented for polygon-based visualization.
figsize : tuple, optional
Figure size (width, height) in inches.
cmap : str, default "viridis"
Colormap for continuous values.
vmin : float, optional
Minimum value for colormap normalization. If None, uses the minimum value in the data.
vmax : float, optional
Maximum value for colormap normalization. If None, uses the maximum value in the data.
palette : dict, list, or array, optional
Color palette for categorical variables. Can be:
- A dictionary mapping category names to colors (e.g., {'A': 'red', 'B': 'blue'})
- A list/array of colors that will be assigned to categories in sorted order
(e.g., ['red', 'blue', 'green'] will assign colors to categories alphabetically)
img_key : str, optional
Key in adata.uns["spatial"][library_id]["images"] for background image.
If None, uses "hires" if available.
basis : str, default "spatial"
Key in adata.obsm containing spatial coordinates (for fallback).
edges_width : float, default 0.5
Width of cell polygon edges.
edges_color : str, default "black"
Color of cell polygon edges.
alpha : float, default 0.8
Transparency of cell polygons.
alpha_img : float, default 0.5
Transparency of background image.
show : bool, default True
Whether to display the plot.
ax : matplotlib.Axes, optional
Axes object to plot on. If None, creates a new figure.
legend : bool, default True
Whether to show legend for categorical values or colorbar for continuous values.
xlabel : str, optional, default "spatial 1"
Label for the x-axis. Set to None to hide the label.
ylabel : str, optional, default "spatial 2"
Label for the y-axis. Set to None to hide the label.
show_ticks : bool, default False
Whether to show axis ticks and tick labels. If False, ticks are hidden.
Note: When `color=None`, ticks are automatically shown regardless of this setting.
**kwargs
Additional arguments passed to GeoDataFrame.plot().
Returns
-------
matplotlib.Axes or list of matplotlib.Axes
Axes object(s) containing the plot.
Examples
--------
>>> import trackcell as tcl
>>> adata = tcl.io.read_hd_cellseg("path/to/data", sample="sample1")
>>> # Plot by metadata (categorical)
>>> tcl.pl.spatial_cell(adata, color="classification")
>>> # Plot by metadata (continuous)
>>> tcl.pl.spatial_cell(adata, color="Cluster-2_dist", cmap="Reds")
>>> # Plot by gene expression
>>> tcl.pl.spatial_cell(adata, color="PDPN", cmap="viridis")
>>> # Plot with groups filter
>>> tcl.pl.spatial_cell(adata, color="classification", groups=["Cluster-1", "Cluster-2"])
>>> # Plot only H&E image (no cell polygons)
>>> tcl.pl.spatial_cell(adata, color=None)
"""
if not HAS_GEOPANDAS:
raise ImportError("geopandas and shapely are required for spatial_cell function")
# Check if geometries are available
if "spatial" not in adata.uns:
raise ValueError("`adata.uns['spatial']` is required but missing.")
# Auto-select library_id if not provided (similar to sc.pl.spatial)
if library_id is None:
available_library_ids = list(adata.uns["spatial"].keys())
if len(available_library_ids) == 0:
raise ValueError("No library_id found in `adata.uns['spatial']`.")
library_id = available_library_ids[0]
if len(available_library_ids) > 1:
warnings.warn(
f"Multiple library_ids found: {available_library_ids}. "
f"Using '{library_id}'. Specify `library_id` explicitly to use a different one."
)
if library_id not in adata.uns["spatial"]:
raise ValueError(
f"`library_id` '{library_id}' not found in `adata.uns['spatial']`. "
f"Available library_ids: {list(adata.uns['spatial'].keys())}"
)
spatial_info = adata.uns["spatial"][library_id]
# Get geometries - try GeoDataFrame first, then fallback to WKT strings in obs
geometries = None
use_wkt = False
if "geometries" in spatial_info:
geometries = spatial_info["geometries"]
elif "geometry" in adata.obs.columns:
# Fallback: use WKT strings from obs
use_wkt = True
warnings.warn(
"Using geometry from adata.obs['geometry'] (WKT strings). "
"For better performance, use adata.uns['spatial']['geometries'] (GeoDataFrame)."
)
else:
raise ValueError(
f"Cell geometries not found. Expected either:\n"
f" - adata.uns['spatial']['{library_id}']['geometries'] (GeoDataFrame), or\n"
f" - adata.obs['geometry'] (WKT strings).\n"
"Please use `tcl.io.read_hd_cellseg()` to load data with geometries."
)
# Handle color parameter (can be single string or list)
if color is None:
colors_to_plot = [None]
elif isinstance(color, str):
colors_to_plot = [color]
else:
colors_to_plot = color
# Validate all color keys upfront (before creating figures)
# This avoids repeated checks later and provides early error feedback
for color_key in colors_to_plot:
if color_key is not None:
# Check if color_key exists in obs.columns or var_names
if color_key not in adata.obs.columns and color_key not in adata.var_names:
# Check if it's in layers (not supported yet)
if hasattr(adata, 'layers') and color_key in adata.layers:
raise ValueError(
f"`color` key '{color_key}' found in `adata.layers`, but gene expression "
f"from layers is not yet supported. Please use gene names from `adata.var_names` "
f"or metadata from `adata.obs.columns`."
)
else:
raise ValueError(
f"`color` key '{color_key}' not found in `adata.obs.columns` or `adata.var_names`. "
f"Available obs keys: {list(adata.obs.columns[:10])}... "
f"Available var names (genes): {list(adata.var_names[:10])}..."
)
# Create figure if needed
if ax is None:
if figsize is None:
if len(colors_to_plot) > 1:
figsize = (5 * len(colors_to_plot), 5)
else:
figsize = (10, 10)
if len(colors_to_plot) > 1:
fig, axes = plt.subplots(1, len(colors_to_plot), figsize=figsize, sharex=True, sharey=True)
if len(colors_to_plot) == 1:
axes = [axes]
else:
fig, axes = plt.subplots(1, 1, figsize=figsize)
axes = [axes]
else:
fig = ax.figure
axes = [ax]
if len(colors_to_plot) > 1:
warnings.warn("Multiple colors specified but single ax provided. Only first color will be plotted.")
colors_to_plot = [colors_to_plot[0]]
axes_list = []
for idx, color_key in enumerate(colors_to_plot):
if idx < len(axes):
current_ax = axes[idx]
else:
# Should not happen, but handle gracefully
continue
# Filter cells if groups is specified
if groups is not None:
# Determine which column to use for filtering
filter_column = None
if groupby is not None:
# Use explicitly specified groupby column
if groupby not in adata.obs.columns:
raise ValueError(
f"`groupby` column '{groupby}' not found in `adata.obs.columns`. "
f"Available columns: {list(adata.obs.columns[:10])}..."
)
filter_column = groupby
elif color_key is not None and color_key in adata.obs.columns:
# Use color_key if it's a categorical column in obs
filter_column = color_key
else:
# Cannot determine filter column
raise ValueError(
"`groups` parameter requires either:\n"
" - `color` to be a categorical column in `adata.obs`, or\n"
" - `groupby` to specify the column name for filtering.\n"
f"Current `color` value: {color_key}"
)
# Apply groups filter
mask = adata.obs[filter_column].isin(groups)
else:
mask = pd.Series(True, index=adata.obs_names)
# Get cell indices to plot
cells_to_plot = adata.obs_names[mask]
if len(cells_to_plot) == 0:
warnings.warn("No cells to plot after filtering.")
axes_list.append(current_ax)
continue
# Calculate coordinate range from actual data to be plotted
# This ensures image extent matches the data range, especially for subset data
coords_list = []
for cell_id in cells_to_plot:
if use_wkt:
if cell_id in adata.obs.index and pd.notna(adata.obs.loc[cell_id, "geometry"]):
try:
geom = wkt.loads(adata.obs.loc[cell_id, "geometry"])
if hasattr(geom, 'bounds'):
coords_list.append(geom.bounds) # (minx, miny, maxx, maxy)
except Exception:
continue
else:
if cell_id in geometries.index:
geom = geometries.loc[cell_id, "geometry"]
if geom is not None and hasattr(geom, 'bounds'):
coords_list.append(geom.bounds)
if coords_list:
# Calculate overall bounds from all geometries
all_bounds = np.array(coords_list)
x_min = all_bounds[:, 0].min()
y_min = all_bounds[:, 1].min()
x_max = all_bounds[:, 2].max()
y_max = all_bounds[:, 3].max()
else:
# Fallback to spatial coordinates if available
if basis in adata.obsm and len(adata.obsm[basis]) > 0:
spatial_coords = adata.obsm[basis][mask]
if len(spatial_coords) > 0:
x_min, y_min = spatial_coords.min(axis=0)
x_max, y_max = spatial_coords.max(axis=0)
else:
x_min = y_min = 0
x_max = y_max = 1
else:
x_min = y_min = 0
x_max = y_max = 1
data_coords_range = (x_min, y_min, x_max, y_max)
# Process and draw background image (scanpy way)
img, img_extent = _process_background_image(spatial_info, img_key, data_coords_range)
if img is not None and img_extent is not None:
# If color is None, use full opacity for the image
img_alpha = 1.0 if color_key is None else alpha_img
current_ax.imshow(img, extent=img_extent, origin='upper', alpha=img_alpha)
# If color is None, skip geometry plotting and only show HE image
if color_key is None:
# Set axis limits based on image extent or data coordinates
if img_extent is not None:
current_ax.set_xlim(img_extent[0], img_extent[1])
current_ax.set_ylim(img_extent[2], img_extent[3])
else:
# Fallback to data coordinates
current_ax.set_xlim(x_min, x_max)
current_ax.set_ylim(y_max, y_min) # Inverted for y-axis
# Set axis properties
current_ax.set_aspect('equal')
current_ax.invert_yaxis() # Match image coordinates
# Set axis labels
if xlabel is not None:
current_ax.set_xlabel(xlabel)
if ylabel is not None:
current_ax.set_ylabel(ylabel)
# Always show ticks when color is None
current_ax.tick_params(axis='both', which='major', labelsize=10)
axes_list.append(current_ax)
continue # Skip to next color or finish
# Create temporary GeoDataFrame for plotting
# This combines geometry and color data in one structure
if use_wkt:
# Convert WKT strings to geometries
geom_list = []
valid_cells = []
for cell_id in cells_to_plot:
if cell_id in adata.obs.index and pd.notna(adata.obs.loc[cell_id, "geometry"]):
try:
geom = wkt.loads(adata.obs.loc[cell_id, "geometry"])
geom_list.append(geom)
valid_cells.append(cell_id)
except Exception:
continue
temp_geometries = gpd.GeoSeries(geom_list, index=valid_cells)
else:
# Use geometries from GeoDataFrame
# Filter to only include cells that exist in geometries index
cells_in_geometries = cells_to_plot[cells_to_plot.isin(geometries.index)]
if len(cells_in_geometries) == 0:
warnings.warn(
"No cells found in geometries index. "
"Geometries may not be synchronized after subset. "
"Consider using tcl.io.sync_geometries_after_subset() after subsetting."
)
axes_list.append(current_ax)
continue
# Get geometries for cells that exist in the geometries index
temp_geometries_raw = geometries.loc[cells_in_geometries, "geometry"]
# Filter out invalid geometries (None, NaN, or invalid geometry objects)
valid_cells = []
for cell_id in cells_in_geometries:
if cell_id not in temp_geometries_raw.index:
continue
geom = temp_geometries_raw.loc[cell_id]
# Check if geometry is valid
if geom is None or pd.isna(geom):
continue
if not hasattr(geom, 'bounds'):
continue
# Check if geometry is valid shapely object
if hasattr(geom, 'is_valid') and not geom.is_valid:
continue
# Verify bounds are finite
try:
bounds = geom.bounds
if not all(np.isfinite(bounds)):
continue
# Check bounds are valid (max > min)
if bounds[2] <= bounds[0] or bounds[3] <= bounds[1]:
continue
except Exception:
continue
valid_cells.append(cell_id)
if len(valid_cells) == 0:
warnings.warn(
"No valid geometries found after filtering invalid geometries. "
"This may be due to geometries not being synchronized after subset. "
"Consider using tcl.io.sync_geometries_after_subset() after subsetting."
)
axes_list.append(current_ax)
continue
temp_geometries = temp_geometries_raw.loc[valid_cells]
if len(temp_geometries) == 0:
warnings.warn("No valid geometries found for plotting.")
axes_list.append(current_ax)
continue
# Additional validation: check if total_bounds would be valid
try:
test_gdf = gpd.GeoDataFrame(geometry=temp_geometries)
bounds = test_gdf.total_bounds
if not all(np.isfinite(bounds)) or bounds[2] <= bounds[0] or bounds[3] <= bounds[1]:
warnings.warn(
f"Invalid geometry bounds detected. This may cause plotting errors. "
f"Bounds: {bounds}. Consider using sync_geometries_after_subset() after subsetting."
)
except Exception as e:
warnings.warn(
f"Could not validate geometry bounds: {e}. "
f"This may cause plotting errors. Consider using sync_geometries_after_subset() after subsetting."
)
# Create GeoDataFrame with color data
# color_key has already been validated upfront, so we can safely proceed
if color_key is None:
# No coloring, use default
temp_gdf = gpd.GeoDataFrame(
geometry=temp_geometries,
index=valid_cells
)
plot_column = None
elif color_key in adata.obs.columns:
# Get color data from obs (metadata)
color_data = adata.obs.loc[valid_cells, color_key]
temp_gdf = gpd.GeoDataFrame(
{color_key: color_data},
geometry=temp_geometries,
index=valid_cells
)
plot_column = color_key
else:
# color_key must be in var_names (already validated upfront)
# Get color data from gene expression (var)
gene_idx = adata.var_names.get_loc(color_key)
# Get expression values for valid cells
# Handle both sparse and dense matrices
if hasattr(adata.X, 'toarray'):
# Sparse matrix
expression_values = adata.X[adata.obs_names.get_indexer(valid_cells), gene_idx].toarray().flatten()
else:
# Dense matrix
expression_values = adata.X[adata.obs_names.get_indexer(valid_cells), gene_idx]
# Create Series with valid cells as index
color_data = pd.Series(expression_values, index=valid_cells, name=color_key)
temp_gdf = gpd.GeoDataFrame(
{color_key: color_data},
geometry=temp_geometries,
index=valid_cells
)
plot_column = color_key
# Validate and fix total_bounds after creating temp_gdf
# This is critical for subset data where geometries may not be properly synchronized
max_retries = 2
retry_count = 0
while retry_count <= max_retries:
try:
# Validate total_bounds
bounds = temp_gdf.total_bounds
if not all(np.isfinite(bounds)) or bounds[2] <= bounds[0] or bounds[3] <= bounds[1]:
# Invalid bounds detected - filter out problematic geometries
if retry_count == 0:
warnings.warn(
f"Invalid geometry bounds detected (bounds: {bounds}). "
f"Filtering out problematic geometries. "
f"Consider using tcl.io.sync_geometries_after_subset() after subsetting."
)
# Filter geometries with valid bounds
valid_mask = pd.Series(True, index=temp_gdf.index)
for idx in temp_gdf.index:
try:
geom = temp_gdf.loc[idx, 'geometry']
if geom is None or pd.isna(geom):
valid_mask.loc[idx] = False
continue
geom_bounds = geom.bounds
if not all(np.isfinite(geom_bounds)) or geom_bounds[2] <= geom_bounds[0] or geom_bounds[3] <= geom_bounds[1]:
valid_mask.loc[idx] = False
except Exception:
valid_mask.loc[idx] = False
temp_gdf = temp_gdf[valid_mask]
if len(temp_gdf) == 0:
warnings.warn("No valid geometries remaining after filtering. Skipping plot.")
axes_list.append(current_ax)
break # Exit the retry loop and continue to next color
# Update valid_cells to match filtered temp_gdf
valid_cells = temp_gdf.index.tolist()
# Re-validate bounds
bounds = temp_gdf.total_bounds
if not all(np.isfinite(bounds)) or bounds[2] <= bounds[0] or bounds[3] <= bounds[1]:
# Still invalid after filtering - will use aspect='equal' as fallback
if retry_count == 0:
warnings.warn(
f"Still invalid bounds after filtering (bounds: {bounds}). "
f"Will use aspect='equal' to avoid errors."
)
retry_count += 1
continue
else:
# Bounds are now valid
break
else:
# Bounds are valid
break
except Exception as e:
if retry_count == 0:
warnings.warn(
f"Error validating geometry bounds: {e}. "
f"Will use aspect='equal' to avoid errors."
)
retry_count += 1
continue
# If we still have invalid bounds after retries, we'll set aspect='equal' later
use_equal_aspect = False
if retry_count > max_retries:
try:
bounds = temp_gdf.total_bounds
if not all(np.isfinite(bounds)) or bounds[2] <= bounds[0] or bounds[3] <= bounds[1]:
use_equal_aspect = True
except Exception:
use_equal_aspect = True
# Determine if continuous or categorical immediately after creating temp_gdf
# This allows all subsequent logic to use is_categorical directly
is_categorical = False
use_custom_palette = False
custom_cmap = None
categories = None
color_list_for_legend = None # Store color list for legend creation
if color_key is not None:
# plot_column is guaranteed to be color_key at this point (already validated)
is_categorical = not pd.api.types.is_numeric_dtype(temp_gdf[plot_column])
# Process categorical data (palette, categories, etc.)
if color_key is not None and is_categorical:
# Get categories (prioritize Categorical's original order)
if pd.api.types.is_categorical_dtype(temp_gdf[plot_column]):
# Categorical type: use original order
categories = temp_gdf[plot_column].cat.categories.tolist()
else:
# Regular column: use sorted unique values
categories = sorted(temp_gdf[plot_column].dropna().unique())
color_list = []
if palette is not None:
use_custom_palette = True
# Convert palette to color list (in categories order)
if isinstance(palette, dict):
# Dictionary: map categories to colors
# Check for missing categories and warn
missing_cats = [cat for cat in categories if cat not in palette]
if missing_cats:
warnings.warn(
f"Palette dictionary is missing colors for {len(missing_cats)} categories: "
f"{missing_cats[:5]}{'...' if len(missing_cats) > 5 else ''}. "
f"Using 'gray' as default color for missing categories."
)
color_list = [palette.get(cat, 'gray') for cat in categories]
elif isinstance(palette, (list, np.ndarray)):
# List/array: assign colors sequentially
palette_array = np.asarray(palette)
if len(palette_array) < len(categories):
warnings.warn(
f"Palette has {len(palette_array)} colors but there are "
f"{len(categories)} categories. Colors will be cycled."
)
color_list = [
palette_array[i % len(palette_array)]
for i in range(len(categories))
]
else:
raise ValueError(f"Unsupported palette type: {type(palette)}")
# Create custom colormap from color list
custom_cmap = ListedColormap(color_list)
color_list_for_legend = color_list # Store for legend
# Convert column to Categorical type with specified categories order
# This ensures GeoPandas uses the correct order for color mapping
# GeoPandas will automatically use cat.categories for color assignment
# Prepare plot arguments for GeoDataFrame.plot()
plot_kwargs = {
'ax': current_ax,
'edgecolor': edges_color,
'linewidth': edges_width,
'alpha': alpha,
**kwargs
}
if color_key is not None:
# plot_column is guaranteed to be color_key at this point (already validated)
plot_kwargs['column'] = plot_column
if is_categorical:
# Categorical values
plot_kwargs['legend'] = False # Disable automatic legend, we'll add it manually
# Only set categorical=True if column is not already Categorical type
# GeoPandas can automatically detect Categorical columns, so we don't need
# to set categorical=True for Categorical columns (as shown in the example)
#if not pd.api.types.is_categorical_dtype(temp_gdf[plot_column]):
# plot_kwargs['categorical'] = True
# If using custom palette, use custom colormap
# GeoPandas will automatically use the Categorical column's cat.categories
# for color assignment, so we don't need to set categories parameter
# Reference: https://geopandas.org/en/stable/docs/reference/api/geopandas.GeoDataFrame.plot.html
if use_custom_palette:
# Use custom colormap
# The column is already converted to Categorical with correct order
# GeoPandas will use cat.categories automatically
plot_kwargs['cmap'] = custom_cmap
# else: use default GeoPandas categorical plotting with column
# Set aspect='equal' if bounds are invalid to avoid calculation errors
if use_equal_aspect:
plot_kwargs['aspect'] = 'equal'
# Plot using GeoDataFrame.plot() - uses column + categorical + cmap
# Use try-except to catch aspect calculation errors and retry with aspect='equal'
try:
temp_gdf.plot(**plot_kwargs)
except ValueError as e:
if "aspect must be finite and positive" in str(e):
# Retry with explicit aspect='equal'
plot_kwargs['aspect'] = 'equal'
warnings.warn(
f"Aspect calculation failed. Using aspect='equal' instead. "
f"Consider using tcl.io.sync_geometries_after_subset() after subsetting."
)
temp_gdf.plot(**plot_kwargs)
else:
raise
# Create legend manually from categories
if legend:
from matplotlib.patches import Patch
if use_custom_palette:
# Use colors from stored color list (same order as categories)
legend_elements = [
Patch(facecolor=color_list_for_legend[i], label=str(cat))
for i, cat in enumerate(categories)
]
else:
# Generate colors matching GeoPandas default
n_cats = len(categories)
default_cmap = plt.get_cmap('tab20' if n_cats <= 20 else 'tab20b')
legend_elements = [
Patch(facecolor=default_cmap(i / n_cats), label=str(cat))
for i, cat in enumerate(categories)
]
if legend_elements:
current_ax.legend(handles=legend_elements,
bbox_to_anchor=(1.05, 1),
loc='upper left',
frameon=True)
else:
# Continuous values
plot_kwargs['cmap'] = cmap
# Use GeoPandas automatic colorbar for continuous values
plot_kwargs['legend'] = legend
# Handle vmin and vmax for continuous values
# If not provided, GeoPandas will use data min/max automatically
if vmin is not None:
plot_kwargs['vmin'] = vmin
if vmax is not None:
plot_kwargs['vmax'] = vmax
# Set aspect='equal' if bounds are invalid to avoid calculation errors
if use_equal_aspect:
plot_kwargs['aspect'] = 'equal'
# Plot using GeoDataFrame.plot()
# GeoPandas will automatically create colorbar if legend=True
# Use try-except to catch aspect calculation errors and retry with aspect='equal'
try:
temp_gdf.plot(**plot_kwargs)
except ValueError as e:
if "aspect must be finite and positive" in str(e):
# Retry with explicit aspect='equal'
plot_kwargs['aspect'] = 'equal'
warnings.warn(
f"Aspect calculation failed. Using aspect='equal' instead. "
f"Consider using tcl.io.sync_geometries_after_subset() after subsetting."
)
temp_gdf.plot(**plot_kwargs)
else:
raise
else:
# No coloring - only show HE image (background image), no cell polygons
# This is useful for viewing just the tissue image with coordinates
pass # Skip geometry plotting, only background image will be shown
# Set axis properties
current_ax.set_aspect('equal')
current_ax.invert_yaxis() # Match image coordinates
# Set axis limits based on actual data range
# Add small padding (5% of range) for better visualization
x_range = x_max - x_min
y_range = y_max - y_min
x_padding = x_range * 0.05 if x_range > 0 else 1
y_padding = y_range * 0.05 if y_range > 0 else 1
current_ax.set_xlim(x_min - x_padding, x_max + x_padding)
current_ax.set_ylim(y_max + y_padding, y_min - y_padding) # Inverted for y-axis
# Set axis labels
if xlabel is not None:
current_ax.set_xlabel(xlabel)
if ylabel is not None:
current_ax.set_ylabel(ylabel)
# Control ticks visibility
# If color is None, always show ticks to display coordinates
if color_key is None:
# When color=None, show ticks and labels by default
current_ax.tick_params(axis='both', which='major', labelsize=10)
elif not show_ticks:
current_ax.set_xticks([])
current_ax.set_yticks([])
if color_key:
current_ax.set_title(color_key)
axes_list.append(current_ax)
if show:
# Adjust layout to make room for legend/colorbar on the right
# This is similar to scanpy's approach
if ax is None: # Only adjust if we created the figure
fig.tight_layout(rect=[0, 0, 0.95, 1]) # Leave 5% space on the right
else:
fig.tight_layout()
plt.show()
if len(axes_list) == 1:
return axes_list[0]
else:
return axes_list
[docs]
def mark_region(
ax: plt.Axes,
xlim: Optional[tuple] = None,
ylim: Optional[tuple] = None,
edges_color: str = 'red',
edges_width: float = 1.0
):
"""
Mark a rectangular region on a spatial plot by drawing a rectangle.
This function draws a rectangle on the given axes to highlight a specific
spatial region. It can be used with any spatial plot.
Parameters
----------
ax : matplotlib.axes.Axes
Axes object to draw the rectangle on.
xlim : tuple, optional
Tuple of (x_min, x_max) to define the x-range of the region.
If None, uses the current x-axis limits.
ylim : tuple, optional
Tuple of (y_min, y_max) to define the y-range of the region.
If None, uses the current y-axis limits.
edges_color : str, default 'red'
Color of the rectangle edges.
edges_width : float, default 1.0
Width of the rectangle edges.
Returns
-------
matplotlib.patches.Rectangle
The rectangle patch object that was added to the axes.
Examples
--------
>>> import trackcell as tcl
>>> import matplotlib.pyplot as plt
>>>
>>> # Plot with spatial_cell and mark a region
>>> fig, ax = plt.subplots(figsize=(10, 10))
>>> tcl.pl.spatial_cell(adata, color="CellType", ax=ax)
>>> tcl.pl.mark_region(ax, xlim=(54500, 56000), ylim=(15000, 16000))
>>>
>>> # Mark a region on any plot
>>> fig, ax = plt.subplots(figsize=(10, 10))
>>> # ... create your plot on ax ...
>>> tcl.pl.mark_region(ax, xlim=(54500, 56000), ylim=(15000, 16000),
... edges_color='blue', edges_width=2.0)
"""
from matplotlib.patches import Rectangle
# Get current axis limits if xlim/ylim are not provided
if xlim is None:
xlim = ax.get_xlim()
if ylim is None:
ylim = ax.get_ylim()
x_min, x_max = xlim
y_min, y_max = ylim
# Create rectangle
rect = Rectangle(
(x_min, y_min),
x_max - x_min,
y_max - y_min,
linewidth=edges_width,
edgecolor=edges_color,
facecolor='none'
)
# Add rectangle to axes
ax.add_patch(rect)
return rect