Skip to content

Commit 2a3ea0e

Browse files
committed
update plotting methods
1 parent 2237090 commit 2a3ea0e

File tree

1 file changed

+44
-12
lines changed

1 file changed

+44
-12
lines changed

snp2cell/snp2cell_class.py

Lines changed: 44 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1218,7 +1218,7 @@ def plot_group_summary(
12181218
self,
12191219
score_key: str = "score",
12201220
plt_df: Optional[pd.DataFrame] = None,
1221-
topn: Optional[int] = None,
1221+
topn: Optional[int] = 20,
12221222
errorbar: Literal["pi", "std", "ci", "se"] = "ci",
12231223
row_pattern: str = ".*DE_(?P<rowname>.+?)__",
12241224
figsize: Tuple[int, int] = (7, 5),
@@ -1230,8 +1230,14 @@ def plot_group_summary(
12301230
12311231
There are three ways to select scores for plotting:
12321232
1. Set `plt_df` to a data frame with scores. This will plot all scores in the data frame.
1233-
2. Set `score_key` to a key of a score to plot. This will plot all combinations of this score with other scores (`^min.*{score_key}.*zscore_mad$`).
1233+
2. Set `score_key` to a key of a score to plot. This will plot all combinations of this score with other scores.
12341234
`**kwargs` will be passed to `get_scores(**kwargs)`. If `query` is not in `kwargs`, it will be set to `f"{score_key}__pval < 0.05"`.
1235+
If `regex` is not in `kwargs`, it will be set to `f"^min.*{score_key}.*zscore_mad$`.
1236+
1237+
E.g. `regex="^min.*{score_key}.*zscore_mad$"` means that only combinations of score `score_key` with any other scores will be plotted.
1238+
If this score has been combined with DE scores, for example, this will be the combinations for all cell types.
1239+
Here `min(...,...)__zscore_mad` means the combined score is the minimum of the two scores, normalized as a robust z-score.
1240+
The complicated column names are simplified by extracting the cell type names with `row_pattern`, to use for the x-axis of the plot.
12351241
12361242
Parameters
12371243
----------
@@ -1241,7 +1247,7 @@ def plot_group_summary(
12411247
plt_df : Optional[pd.DataFrame], optional
12421248
Data frame with scores to plot (as retrieved by `snp2cell.get_scores()`), by default None. If not set, scores in the object will be plotted.
12431249
topn : Optional[int], optional
1244-
Number of top scores to plot, by default None.
1250+
Number of top scores to plot, by default 20. If None, all scores will be plotted.
12451251
errorbar : Literal["pi", "std", "ci", "se"], optional
12461252
Type of error bar to plot, by default "ci".
12471253
Options are: "pi" (percentile interval), "std" (standard deviation), "ci" (confidence interval, 95%), "se" (standard error).
@@ -1261,7 +1267,8 @@ def plot_group_summary(
12611267
"""
12621268
if plt_df is None:
12631269
if score_key:
1264-
kwargs["regex"] = f"^min.*{score_key}.*zscore_mad$"
1270+
if "regex" not in kwargs:
1271+
kwargs["regex"] = f"^min.*{score_key}.*zscore_mad$"
12651272
if "query" not in kwargs:
12661273
kwargs["query"] = f"{score_key}__pval < 0.05"
12671274
plt_df = self.get_scores(**kwargs)
@@ -1296,7 +1303,9 @@ def plot_group_heatmap(
12961303
self,
12971304
score_key: str = "score",
12981305
plt_df: Optional[pd.DataFrame] = None,
1299-
topn: int = 5,
1306+
genes_per_score: int = 5,
1307+
n_col: Optional[int] = 30,
1308+
asinh_transform: bool = False,
13001309
row_pattern: str = ".*DE_(?P<rowname>.+?)__",
13011310
figsize: Tuple[int, int] = (7, 7),
13021311
dendrogram_ratio: Tuple[float, float] = (0.1, 0.1),
@@ -1307,8 +1316,14 @@ def plot_group_heatmap(
13071316
13081317
There are two ways to select scores for plotting:
13091318
1. Set `plt_df` to a data frame with scores. This will plot all scores in the data frame.
1310-
2. Set `score_key` to a key of a score to plot. This will plot all combinations of this score with other scores (`^min.*{score_key}.*zscore_mad$`).
1311-
`**kwargs` will be passed to `get_scores(**kwargs)`. If `query` is not in `kwargs`, it will be set to `f"{score_key}__pval < 0.05"`.
1319+
2. Set `score_key` to a key of a score to plot. This will plot all combinations of this score with other scores.
1320+
`**kwargs` will be passed to `get_scores(**kwargs)`. If `query` is not in `kwargs`, it will be set to `f"~index.str.startswith('chr') and {score_key}__pval < 0.05"`.
1321+
If `regex` is not in `kwargs`, it will be set to `f"^min.*{score_key}.*zscore_mad$`.
1322+
1323+
E.g. `regex="^min.*{score_key}.*zscore_mad$"` means that only combinations of score `score_key` with any other scores will be plotted.
1324+
If this score has been combined with DE scores, for example, this will be the combinations for all cell types.
1325+
Here `min(...,...)__zscore_mad` means the combined score is the minimum of the two scores, normalized as a robust z-score.
1326+
The complicated column names are simplified by extracting the cell type names with `row_pattern`, to use for the x-axis of the plot.
13121327
13131328
Parameters
13141329
----------
@@ -1317,8 +1332,12 @@ def plot_group_heatmap(
13171332
plt_df : Optional[pd.DataFrame], optional
13181333
Data frame with scores to plot (as retrieved by `snp2cell.get_scores()`), by default None.
13191334
If not set, scores in the object will be plotted.
1320-
topn : int, optional
1321-
Number of top scores to plot, by default 5.
1335+
genes_per_score : int
1336+
Number of top genes to plot per score, by default 5.
1337+
n_col : int, optional
1338+
Number of top scores / columns to plot, by default 30.
1339+
asinh_transform : bool
1340+
Whether to apply an arcsinh transformation to the scores to reduce the effect of outliers, by default False.
13221341
row_pattern : str, optional
13231342
Regex for extracting names for plotting from the score names, by default ".*DE_(?P<rowname>.+?)__".
13241343
figsize : Tuple[int, int], optional
@@ -1334,19 +1353,32 @@ def plot_group_heatmap(
13341353
"""
13351354
if plt_df is None:
13361355
if score_key:
1337-
kwargs["regex"] = f"^min.*{score_key}.*zscore_mad$"
1356+
if "regex" not in kwargs:
1357+
kwargs["regex"] = f"^min.*{score_key}.*zscore_mad$"
13381358
if "query" not in kwargs:
1339-
kwargs["query"] = f"{score_key}__pval < 0.05"
1359+
kwargs["query"] = (
1360+
f"~index.str.startswith('chr') and {score_key}__pval < 0.05"
1361+
)
13401362
plt_df = self.get_scores(**kwargs)
13411363

13421364
if row_pattern:
13431365
plt_df = plt_df.rename(
13441366
columns=lambda c: self.rename_column(c, row_pattern=row_pattern)
13451367
)
13461368

1369+
if n_col is not None:
1370+
plt_df = plt_df.loc[
1371+
:, plt_df.mean(axis=0).nlargest(min(n_col, plt_df.shape[1])).index
1372+
]
1373+
1374+
if asinh_transform:
1375+
plt_df = np.arcsinh(plt_df)
1376+
13471377
rows = []
13481378
for c in plt_df:
1349-
rows.extend(plt_df.sort_values(c, ascending=False)[:topn].index.tolist())
1379+
rows.extend(
1380+
plt_df.sort_values(c, ascending=False)[:genes_per_score].index.tolist()
1381+
)
13501382
plt_df = plt_df.loc[list(set(rows)), :]
13511383

13521384
# optimal leaf ordering for cols

0 commit comments

Comments
 (0)