Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/squidpy/_constants/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,6 @@ class TenxVersions(str, ModeEnum):
class NicheDefinitions(ModeEnum):
NEIGHBORHOOD = "neighborhood"
UTAG = "utag"
CELLCHARTER = "cellcharter"
CELLCHARTER = "cellcharter_simple"
SPOT = "spot"
BANKSY = "banksy"
28 changes: 15 additions & 13 deletions src/squidpy/gr/_niche.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
@inject_docs(fla=NicheDefinitions)
def calculate_niche(
data: AnnData | SpatialData,
flavor: Literal["neighborhood", "utag", "cellcharter"],
flavor: Literal["neighborhood", "utag", "cellcharter_simple"],
library_key: str | None = None,
table_key: str | None = None,
mask: pd.core.series.Series = None,
Expand Down Expand Up @@ -58,7 +58,7 @@ def calculate_niche(
Method to use for niche calculation. Available options are:
- `{fla.NEIGHBORHOOD.s!r}` - cluster the neighborhood profile.
- `{fla.UTAG.s!r}` - use utag algorithm (matrix multiplication).
- `{fla.CELLCHARTER.s!r}` - cluster adjacency matrix with Gaussian Mixture Model (GMM) using CellCharter's approach.
- `{fla.CELLCHARTER.s!r}` - a simplified version of CellCharter's approach, using PCA instead of scVI for dimensionality reduction.
%(library_key)s
If provided, niches will be calculated separately for each unique value in this column.
Each niche will be prefixed with the library identifier.
Expand Down Expand Up @@ -108,7 +108,7 @@ def calculate_niche(
If 'False', return a new AnnData object with the niche labels.
"""

if flavor == "cellcharter" and aggregation is None:
if flavor == "cellcharter_simple" and aggregation is None:
aggregation = "mean"

_validate_niche_args(
Expand All @@ -134,7 +134,7 @@ def calculate_niche(
resolutions = [0.5]

if distance is None:
distance = 1
distance = 3 if flavor == "cellcharter_simple" else 1

if isinstance(data, SpatialData):
orig_adata = data.tables[table_key]
Expand Down Expand Up @@ -187,7 +187,7 @@ def calculate_niche(
mask=lib_mask,
groups=groups,
n_neighbors=n_neighbors,
resolutions=None if flavor == "cellcharter" else resolutions,
resolutions=None if flavor == "cellcharter_simple" else resolutions,
min_niche_size=min_niche_size,
scale=scale,
abs_nhood=abs_nhood,
Expand Down Expand Up @@ -258,7 +258,7 @@ def _get_result_columns(

library_str = f"_{library_key}" if library_key is not None else ""

if flavor == "cellcharter":
if flavor == "cellcharter_simple":
base_column = "cellcharter_niche"
if library_key is None:
return [base_column]
Expand Down Expand Up @@ -311,7 +311,7 @@ def _calculate_niches(
)
elif flavor == "utag":
_get_utag_niches(adata, n_neighbors, resolutions, spatial_connectivities_key)
elif flavor == "cellcharter":
elif flavor == "cellcharter_simple":
assert isinstance(aggregation, str) # for mypy
assert isinstance(n_components, int) # for mypy
_get_cellcharter_niches(
Expand Down Expand Up @@ -667,7 +667,7 @@ def _jensen_shannon_divergence(adata: AnnData, niche_key: str, library_key: str)

def _validate_niche_args(
data: AnnData | SpatialData,
flavor: Literal["neighborhood", "utag", "cellcharter"],
flavor: Literal["neighborhood", "utag", "cellcharter_simple"],
library_key: str | None,
table_key: str | None,
groups: str | None,
Expand Down Expand Up @@ -697,8 +697,10 @@ def _validate_niche_args(
if not isinstance(data, AnnData | SpatialData):
raise TypeError(f"'data' must be an AnnData or SpatialData object, got {type(data).__name__}")

if flavor not in ["neighborhood", "utag", "cellcharter"]:
raise ValueError(f"Invalid flavor '{flavor}'. Please choose one of 'neighborhood', 'utag', 'cellcharter'.")
if flavor not in ["neighborhood", "utag", "cellcharter_simple"]:
raise ValueError(
f"Invalid flavor '{flavor}'. Please choose one of 'neighborhood', 'utag', 'cellcharter_simple'."
)

if library_key is not None:
if not isinstance(library_key, str):
Expand Down Expand Up @@ -760,7 +762,7 @@ def _validate_niche_args(
"random_state",
],
},
"cellcharter": {
"cellcharter_simple": {
"required": ["distance", "aggregation", "n_components", "random_state"],
"optional": [],
"unused": [
Expand Down Expand Up @@ -809,7 +811,7 @@ def _validate_niche_args(
if distance is not None and isinstance(distance, int) and distance < 1:
raise ValueError(f"'distance' must be at least 1, got {distance}")

elif flavor == "cellcharter":
elif flavor == "cellcharter_simple":
if distance is not None and not isinstance(distance, int):
raise TypeError(f"'distance' must be an integer, got {type(distance).__name__}")
if distance is not None and distance < 1:
Expand Down Expand Up @@ -843,7 +845,7 @@ def _check_unnecessary_args(flavor: str, param_dict: dict[str, Any], param_specs
Parameters
----------
flavor
The flavor being used ('neighborhood', 'utag', or 'cellcharter')
The flavor being used ('neighborhood', 'utag', or 'cellcharter_simple')
param_dict
Dictionary of parameter names to their values
param_specs
Expand Down
Loading