Skip to content

sq.pp.filter_cells for SpatialData #1011

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
12 changes: 12 additions & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,18 @@ Plotting
pl.extract
pl.var_by_distance

Preprocessing
~~~~~~~~~~~~~

.. module:: squidpy.pp
.. currentmodule:: squidpy

.. autosummary::
:toctree: api

pp.filter_cells


Reading
~~~~~~~

Expand Down
2 changes: 1 addition & 1 deletion docs/notebooks
2 changes: 1 addition & 1 deletion src/squidpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from importlib import metadata
from importlib.metadata import PackageMetadata

from squidpy import datasets, gr, im, pl, read, tl
from squidpy import datasets, gr, im, pl, pp, read, tl

try:
md: PackageMetadata = metadata.metadata(__name__)
Expand Down
5 changes: 5 additions & 0 deletions src/squidpy/pp/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Basic pre-processing functions adapted from scanpy."""

from __future__ import annotations

from squidpy.pp._simple import filter_cells
127 changes: 127 additions & 0 deletions src/squidpy/pp/_simple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
from __future__ import annotations

import anndata as ad
import numpy as np
import pandas as pd
import scanpy as sc
import spatialdata as sd
import xarray as xr
from spatialdata import subset_sdata_by_table_mask
from spatialdata._logging import logger as logg
from spatialdata.models import Labels2DModel, PointsModel, ShapesModel, get_model
from spatialdata.transformations import get_transformation
from xarray import DataTree


def filter_cells(
data: sd.SpatialData,
tables: list[str] | str | None = None,
min_counts: int | None = None,
min_genes: int | None = None,
max_counts: int | None = None,
max_genes: int | None = None,
inplace: bool = True,
filter_labels: bool = True,
) -> sd.SpatialData | None:
"""\
Squidpy's implementation of :func:`scanpy.pp.filter_cells` for :class:`anndata.AnnData` and :class:`spatialdata.SpatialData` objects.
For :class:`spatialdata.SpatialData` objects, this function filters the following elements:


- labels: filtered based on the values of the images which are assumed to be the instance_id.
- points: filtered based on the index which is assumed to be the instance_id.
- shapes: filtered based on the instance_id column.


See :func:`scanpy.pp.filter_cells` for more details regarding the filtering
behavior.

Parameters
----------
data
:class:`spatialdata.SpatialData` object.
tables
If :class:`spatialdata.SpatialData` object, the tables to filter. If `None`, all tables are filtered.
min_counts
Minimum number of counts required for a cell to pass filtering.
min_genes
Minimum number of genes expressed required for a cell to pass filtering.
max_counts
Maximum number of counts required for a cell to pass filtering.
max_genes
Maximum number of genes expressed required for a cell to pass filtering.
inplace
Perform computation inplace or return result.
filter_labels
Whether to filter labels. If `True`, then labels are filtered based on the instance_id column.

Returns
-------
If `inplace` then returns `None`, otherwise returns the filtered :class:`spatialdata.SpatialData` object.
"""
if not isinstance(data, sd.SpatialData):
raise ValueError(
f"Expected `SpatialData`, found `{type(data)}` instead. Perhaps you want to use `scanpy.pp.filter_cells` instead."
)

return _filter_cells_spatialdata(data, tables, min_counts, min_genes, max_counts, max_genes, inplace, filter_labels)


def _filter_cells_spatialdata(
data: sd.SpatialData,
tables: list[str] | str | None = None,
min_counts: int | None = None,
min_genes: int | None = None,
max_counts: int | None = None,
max_genes: int | None = None,
inplace: bool = True,
filter_labels: bool = True,
) -> sd.SpatialData | None:
if isinstance(tables, str):
tables = [tables]
elif tables is None:
tables = list(data.tables.keys())

if len(tables) == 0:
raise ValueError("Expected at least one table to be filtered, found `0`")

if not all(t in data.tables for t in tables):
raise ValueError(f"Expected all tables to be in `{data.tables.keys()}`.")

for t in tables:
if "spatialdata_attrs" not in data.tables[t].uns:
raise ValueError(f"Table `{t}` does not have 'spatialdata_attrs' to indicate what it annotates.")

if not inplace:
logg.warning(
"Creating a deepcopy of the SpatialData object, depending on the size of the object this can take a while."
)
data_out = sd.deepcopy(data)
else:
data_out = data

for t in tables:
table_old = data_out.tables[t]
mask_filtered, _ = sc.pp.filter_cells(
table_old,
min_counts=min_counts,
min_genes=min_genes,
max_counts=max_counts,
max_genes=max_genes,
inplace=False,
)
if mask_filtered.sum() == 0:
raise ValueError(f"Filter results in empty table when filtering table `{t}`.")
sdata_filtered = subset_sdata_by_table_mask(sdata=data_out, table_name=t, mask=mask_filtered)
data_out.tables[t] = sdata_filtered.tables[t]
for k in list(sdata_filtered.points.keys()):
data_out.points[k] = sdata_filtered.points[k]
for k in list(sdata_filtered.shapes.keys()):
data_out.shapes[k] = sdata_filtered.shapes[k]
if filter_labels:
for k in list(sdata_filtered.labels.keys()):
data_out.labels[k] = sdata_filtered.labels[k]

if inplace:
return None
return data_out
60 changes: 60 additions & 0 deletions tests/preprocessing/test_simple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from __future__ import annotations

import anndata as ad
import numpy as np
import pytest
import scanpy as sc
from spatialdata.datasets import blobs_annotating_element

import squidpy as sq


def _make_sdata(name: str, num_counts: int, count_value: int):
assert num_counts <= 5, "num_counts must be less than 5"
sdata_temp = blobs_annotating_element(name)
m, _ = sdata_temp.tables["table"].shape
n = m
X = np.zeros((m, n))
# random choice of row
row_indices = np.random.choice(m, num_counts, replace=False)
col_indices = np.random.choice(n, num_counts, replace=False)
X[row_indices, col_indices] = count_value

sdata_temp.tables["table"] = ad.AnnData(
X=X,
obs=sdata_temp.tables["table"].obs,
var={"gene": ["gene" for _ in range(n)]},
uns=sdata_temp.tables["table"].uns,
)
return sdata_temp


@pytest.mark.parametrize("name", ["blobs_labels", "blobs_circles", "blobs_points", "blobs_multiscale_labels"])
def test_filter_cells(name: str):
filtered_cells = 3
sdata = _make_sdata(name, num_counts=filtered_cells, count_value=100)
num_cells = sdata.tables["table"].shape[0]
adata_copy = sdata.tables["table"].copy()
sc.pp.filter_cells(adata_copy, max_counts=50, inplace=True)
sq.pp.filter_cells(sdata, max_counts=50, inplace=True, filter_labels=True)

assert np.all(sdata.tables["table"].X == adata_copy.X), "Filtered cells are not the same as scanpy"
assert np.all(sdata.tables["table"].obs["instance_id"] == adata_copy.obs["instance_id"]), (
"Filtered cells are not the same as scanpy"
)
assert sdata.tables["table"].shape[0] == (num_cells - filtered_cells), (
f"Expected {num_cells - filtered_cells} cells, got {sdata.tables['table'].shape[0]}"
)

if name == "blobs_labels":
unique_labels = np.unique(adata_copy.obs["instance_id"])
unique_labels_sdata = np.unique(sdata.labels["blobs_labels"].data.compute())
assert set(unique_labels) == set(unique_labels_sdata).difference([0]), (
f"Filtered labels {unique_labels} are not the same as scanpy {unique_labels_sdata}"
)


def test_filter_cells_empty_fail():
sdata = _make_sdata("blobs_labels", num_counts=5, count_value=200)
with pytest.raises(ValueError, match="Filter results in empty table when filtering table `table`."):
sq.pp.filter_cells(sdata, max_counts=100, inplace=True)
2 changes: 1 addition & 1 deletion tests/utils/test_parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def func(request) -> Callable:
# in case of failure.


@pytest.mark.timeout(30)
@pytest.mark.timeout(50)
@pytest.mark.parametrize(
"backend",
[
Expand Down
Loading