Skip to content

Commit aa3c052

Browse files
committed
add tests
1 parent 3764cb5 commit aa3c052

File tree

4 files changed

+280
-17
lines changed

4 files changed

+280
-17
lines changed

snp2cell/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from snp2cell.snp2cell_class import SNP2CELL, NCPU
1+
from snp2cell.snp2cell_class import SNP2CELL, SUFFIX, NCPU
22
from snp2cell import util
33
from snp2cell import cli
44
from snp2cell import recipes

snp2cell/snp2cell_class.py

Lines changed: 79 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,11 @@ class SUFFIX(Enum):
3434

3535

3636
class SNP2CELL:
37-
def __init__(self, path: Optional[Union[str, os.PathLike]] = None, seed: Optional[int] = RANDOM_SEED) -> None:
37+
def __init__(
38+
self,
39+
path: Optional[Union[str, os.PathLike]] = None,
40+
seed: Optional[int] = RANDOM_SEED,
41+
) -> None:
3842
"""
3943
Initialize the SNP2CELL object.
4044
@@ -127,7 +131,7 @@ def _add_de_groups(self, groupby: str, groups: List[str]) -> None:
127131
shared_keys = set(groups) & set(v)
128132
if shared_keys:
129133
raise ValueError(f"Groups {shared_keys} already exist in {k}")
130-
134+
131135
self.de_groups[groupby] = groups
132136

133137
def _scale_score(
@@ -212,7 +216,7 @@ def _get_perturbed_stats(self, score_key: str, suffix: SUFFIX) -> pd.DataFrame:
212216
"""
213217
if suffix not in [s.value for s in SUFFIX]:
214218
raise ValueError(
215-
f"Invalid suffix. Must be one of {[s.value for s in SUFFIX]}."
219+
f"Invalid suffix. Must be one of {[s.value for s in SUFFIX]}. Got '{suffix}'."
216220
)
217221
score = self.scores_rand[score_key]
218222
if suffix == SUFFIX.ZSCORE:
@@ -407,14 +411,16 @@ def add_grn_from_pandas(self, adjacency_df: pd.DataFrame) -> None:
407411
"""
408412
raise NotImplementedError("This method is not yet implemented.")
409413

410-
def add_grn_from_networkx(self, nx_grn: nx.Graph, overwrite: bool = False) -> None:
414+
def add_grn_from_networkx(
415+
self, nx_grn: Union[nx.Graph, str, Path], overwrite: bool = False
416+
) -> None:
411417
"""
412418
Add GRN from networkx object to snp2cell object.
413419
414420
Parameters
415421
----------
416-
nx_grn : nx.Graph
417-
Networkx object.
422+
nx_grn : Union[nx.Graph, str, Path]
423+
Networkx object or path to a pickled networkx object.
418424
overwrite : bool, optional
419425
Whether to overwrite existing networkx object, by default False.
420426
@@ -423,7 +429,7 @@ def add_grn_from_networkx(self, nx_grn: nx.Graph, overwrite: bool = False) -> No
423429
IndexError
424430
If existing scores are found and overwrite is False.
425431
"""
426-
if self.scores and not overwrite:
432+
if self.scores is not None and not overwrite:
427433
raise IndexError(
428434
"existing scores found, set overwrite=True to discard them."
429435
)
@@ -506,7 +512,12 @@ def add_score(
506512
)
507513
self.scores_prop[score_key] = self.scores_prop.index.map(p_scr_dct) # type: ignore
508514
if statistics:
509-
self.rand_sim(score_key=score_key, num_cores=num_cores, n=num_rand, reset_seed=reset_seed)
515+
self.rand_sim(
516+
score_key=score_key,
517+
num_cores=num_cores,
518+
n=num_rand,
519+
reset_seed=reset_seed,
520+
)
510521
self.add_score_statistics(score_keys=score_key)
511522
self._defrag_pandas()
512523

@@ -528,7 +539,7 @@ def propagate_score(self, score_key: str = "score") -> Tuple[str, Dict[str, floa
528539
scr_dct = self.scores[score_key].to_dict() # type: ignore
529540
p_scr_dct = self._prop_scr(scr_dct)
530541
return score_key, p_scr_dct
531-
542+
532543
@add_logger()
533544
def propagate_scores(
534545
self,
@@ -992,33 +1003,79 @@ def remove_scores(self, which: str = "propagated", **kwargs: Any) -> None:
9921003
"""
9931004
Delete selected scores from the object.
9941005
1006+
By default, this will delete all scores of the selected type.
1007+
Set `**kwargs` to select specific columns to delete.
1008+
9951009
Parameters
9961010
----------
9971011
which : str, optional
998-
Type of scores to retrieve. Can be "original" (before propagation), "propagated" (after propagation) or "perturbed" (random permutations), by default "propagated".
1012+
Type of scores to delete. Can be "original" (before propagation), "propagated" (after propagation), "perturbed" (random permutations) or "all" (all scores), by default "propagated".
9991013
kwargs : Any
1000-
Options passed to `pd.filter(**kwargs)` for selecting columns to DROP.
1014+
Options passed to `pd.filter(**kwargs)` for selecting columns to DROP. If not set, all columns will be dropped.
1015+
Set `items=[]` to drop columns by name, `like=""` to drop columns by partial name, `regex=""` to drop columns by regex.
10011016
"""
10021017
self._check_init()
1003-
if which == "perturbed":
1004-
self.scores_rand = {}
1005-
elif which == "all":
1018+
if which == "all":
10061019
self._init_scores()
1020+
elif which == "perturbed":
1021+
if kwargs:
1022+
# remove columns from random / perturbed scores
1023+
keys_to_remove = self.scores_rand.keys()
1024+
if "items" in kwargs:
1025+
keys_to_remove = kwargs["items"]
1026+
elif "like" in kwargs:
1027+
keys_to_remove = [
1028+
k for k in self.scores_rand if kwargs["like"] in k
1029+
]
1030+
elif "regex" in kwargs:
1031+
keys_to_remove = [
1032+
k for k in self.scores_rand if re.search(kwargs["regex"], k)
1033+
]
1034+
for key in keys_to_remove:
1035+
self.scores_rand.pop(key, None)
1036+
else:
1037+
self.scores_rand = {}
10071038
elif which == "propagated":
10081039
if kwargs:
1040+
# remove columns from propagated scores
10091041
cols = self.scores_prop.filter(**kwargs).columns # type: ignore
10101042
self.scores_prop = self.scores_prop.drop(columns=cols) # type: ignore
1043+
1044+
# also remove columns from random / perturbed scores
1045+
for key in cols:
1046+
if key in self.scores_rand:
1047+
self.scores_rand.pop(key, None)
10111048
else:
10121049
self.scores_prop = pd.DataFrame(index=list(self.grn.nodes)) # type: ignore
1050+
self.scores_rand = {}
10131051
elif which == "original":
10141052
if kwargs:
1053+
# remove columns from original scores
10151054
cols = self.scores.filter(**kwargs).columns # type: ignore
10161055
self.scores = self.scores.drop(columns=cols) # type: ignore
10171056
for k in self.de_groups:
10181057
self.de_groups[k] = [i for i in self.de_groups[k] if i not in cols]
1058+
1059+
# also remove columns from propagated scores
1060+
self.scores_prop = self.scores_prop.drop(columns=cols, errors="ignore") # type: ignore
1061+
stat_cols = [
1062+
col
1063+
for col in self.scores_prop.columns
1064+
if any(col.startswith(f"{c}__") for c in cols)
1065+
]
1066+
self.scores_prop = self.scores_prop.drop(columns=stat_cols, errors="ignore") # type: ignore
1067+
1068+
# also remove columns from random / perturbed scores
1069+
for key in cols:
1070+
if key in self.scores_rand:
1071+
self.scores_rand.pop(key, None)
10191072
else:
10201073
self.scores = pd.DataFrame(index=list(self.grn.nodes)) # type: ignore
10211074
self.de_groups = {}
1075+
self.scores_prop = pd.DataFrame(index=list(self.grn.nodes)) # type: ignore
1076+
self.scores_rand = {}
1077+
else:
1078+
raise ValueError(f"unknown score type: {which}")
10221079

10231080
def get_components(self, sel_nodes: List[str]) -> Tuple[nx.Graph, List[set]]:
10241081
"""
@@ -1215,5 +1272,11 @@ def plot_group_heatmap(
12151272
sns.heatmap(plt_df, cmap="mako", yticklabels=False)
12161273

12171274

1218-
SNP2CELL._get_perturbed_stats.__doc__ = SNP2CELL._get_perturbed_stats.__doc__.format(_SUFFIX_=str([e.value for e in SUFFIX]))
1219-
SNP2CELL.adata_combine_de_scores.__doc__ = SNP2CELL.adata_combine_de_scores.__doc__.format(_SUFFIX_=str([e.value for e in SUFFIX]))
1275+
SNP2CELL._get_perturbed_stats.__doc__ = SNP2CELL._get_perturbed_stats.__doc__.format(
1276+
_SUFFIX_=str([e.value for e in SUFFIX])
1277+
)
1278+
SNP2CELL.adata_combine_de_scores.__doc__ = (
1279+
SNP2CELL.adata_combine_de_scores.__doc__.format(
1280+
_SUFFIX_=str([e.value for e in SUFFIX])
1281+
)
1282+
)

tests/conftest.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,11 @@
33
import pandas as pd
44
import networkx as nx
55
import scanpy as sc
6+
from snp2cell.snp2cell_class import SNP2CELL
67

8+
@pytest.fixture
9+
def snp2cell_instance():
10+
return SNP2CELL(seed=42)
711

812
@pytest.fixture(scope="session")
913
def fake_grn():

tests/test_snp2cell.py

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
import pytest
2+
import networkx as nx
3+
import pandas as pd
4+
import numpy as np
5+
import snp2cell
6+
7+
snp2cell.util.set_num_cpu(1)
8+
9+
10+
def test_initialization_with_path(snp2cell_instance, tmp_path):
11+
# Create a temporary file to simulate the path
12+
path = tmp_path / "test_data.pkl"
13+
snp2cell_instance.save_data(path=str(path))
14+
15+
s2c = snp2cell.SNP2CELL(path=str(path), seed=42)
16+
assert s2c is not None, "snp2cell object was not created"
17+
assert s2c.grn is None, "GRN should be None"
18+
assert s2c.adata is None, "AnnData should be None"
19+
assert s2c.scores is None, "Scores should be None"
20+
21+
22+
def test_init_scores(snp2cell_instance):
23+
G = nx.Graph()
24+
G.add_edges_from([(1, 2), (2, 3)])
25+
snp2cell_instance._set_grn(G)
26+
snp2cell_instance._init_scores()
27+
assert snp2cell_instance.scores is not None, "Scores should be initialized"
28+
assert (
29+
snp2cell_instance.scores_prop is not None
30+
), "Propagated scores should be initialized"
31+
assert snp2cell_instance.scores_rand == {}, "Random scores should be initialized"
32+
assert snp2cell_instance.de_groups == {}, "DE groups should be initialized"
33+
34+
35+
def test_set_grn(snp2cell_instance):
36+
G = nx.Graph()
37+
G.add_edges_from([(1, 2), (2, 3)])
38+
snp2cell_instance._set_grn(G)
39+
assert snp2cell_instance.grn is not None, "GRN should be set"
40+
assert list(snp2cell_instance.grn.edges) == [
41+
(1, 2),
42+
(2, 3),
43+
], "GRN edges should match"
44+
45+
46+
def test_add_de_groups(snp2cell_instance):
47+
snp2cell_instance._add_de_groups("group1", ["A", "B"])
48+
assert "group1" in snp2cell_instance.de_groups, "Group1 should be added"
49+
assert snp2cell_instance.de_groups["group1"] == [
50+
"A",
51+
"B",
52+
], "Group1 values should match"
53+
54+
with pytest.raises(ValueError):
55+
snp2cell_instance._add_de_groups("group1", ["C"])
56+
57+
snp2cell_instance._add_de_groups("group2", ["C"])
58+
with pytest.raises(ValueError):
59+
snp2cell_instance._add_de_groups("group2", ["A"])
60+
61+
62+
def test_get_perturbed_stats(snp2cell_instance):
63+
snp2cell_instance.scores_rand["test_key"] = pd.DataFrame(np.random.randn(10, 3))
64+
65+
for suffix in snp2cell.SUFFIX:
66+
result = snp2cell_instance._get_perturbed_stats("test_key", suffix.value)
67+
assert isinstance(result, pd.DataFrame), "Result should be a DataFrame"
68+
69+
70+
def test_robust_z_score():
71+
series = pd.Series([1, 2, 3, 4, 5])
72+
result = snp2cell.SNP2CELL._robust_z_score(series)
73+
assert isinstance(result, pd.Series), "Result should be a Series"
74+
assert len(result) == 5, "Result length should match input length"
75+
76+
77+
def test_get_scores(snp2cell_instance):
78+
# Add some scores to the instance
79+
snp2cell_instance.add_grn_from_networkx(nx.from_edgelist([(1, 2), (2, 3)]))
80+
snp2cell_instance.scores = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]})
81+
snp2cell_instance.scores_prop = pd.DataFrame({"A": [7, 8, 9], "B": [10, 11, 12]})
82+
snp2cell_instance.scores_rand = {"test_key": pd.DataFrame(np.random.randn(10, 3))}
83+
84+
# Test retrieving original and propagated scores
85+
for which in ["original", "propagated"]:
86+
scores = snp2cell_instance.get_scores(which=which)
87+
assert scores is not None, "Scores should be retrieved"
88+
assert isinstance(scores, pd.DataFrame), "Scores should be a DataFrame"
89+
assert "A" in scores.columns, "Scores should have column 'A'"
90+
assert "B" in scores.columns, "Scores should have column 'B'"
91+
92+
# Test retrieving perturbed scores
93+
scores = snp2cell_instance.get_scores(which="perturbed")
94+
assert scores is not None, "Scores should be retrieved"
95+
assert isinstance(scores, dict), "Scores should be a dictionary"
96+
assert "test_key" in scores, "Scores should have key 'test_key'"
97+
assert isinstance(scores["test_key"], pd.DataFrame), "Scores should be a DataFrame"
98+
99+
# Test retrieving with query
100+
scores = snp2cell_instance.get_scores(which="propagated", query="A > 7")
101+
assert len(scores) == 2, "Query should filter the DataFrame"
102+
103+
# Test retrieving with sort_key
104+
scores = snp2cell_instance.get_scores(which="propagated", sort_key="A")
105+
assert scores.iloc[0]["A"] == 9, "Scores should be sorted in descending order"
106+
107+
108+
def test_remove_scores(snp2cell_instance):
109+
snp2cell_instance.add_grn_from_networkx(nx.from_edgelist([(1, 2), (2, 3)]))
110+
111+
# Add some scores to the instance
112+
snp2cell_instance.scores = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]})
113+
snp2cell_instance.scores_prop = pd.DataFrame(
114+
{"A": [7, 8, 9], "A__pval": [7, 8, 9], "B": [10, 11, 12]}
115+
)
116+
snp2cell_instance.scores_rand = {"A": pd.DataFrame(np.random.randn(10, 3))}
117+
118+
# Test removing non-existing scores (should not raise an error)
119+
snp2cell_instance.remove_scores(which="original", items=["C"])
120+
assert snp2cell_instance.scores is not None, "Original scores should not be removed"
121+
assert (
122+
snp2cell_instance.scores.shape[1] == 2
123+
), "Original scores should not be removed"
124+
assert (
125+
snp2cell_instance.scores_prop is not None
126+
), "Propagated scores should not be removed"
127+
assert (
128+
snp2cell_instance.scores_prop.shape[1] == 3
129+
), "Propagated scores should not be removed"
130+
assert "A" in snp2cell_instance.scores_rand, "Random scores should not be removed"
131+
132+
# Test removing original scores
133+
snp2cell_instance.remove_scores(which="original", items=["A"])
134+
assert (
135+
"A" not in snp2cell_instance.scores.columns
136+
), "Original scores should be removed"
137+
assert (
138+
"A" not in snp2cell_instance.scores_prop.columns
139+
), "Propagated scores should also be removed"
140+
assert (
141+
"A__pval" not in snp2cell_instance.scores_prop.columns
142+
), "Corresponding statistics should also be removed"
143+
assert (
144+
"A" not in snp2cell_instance.scores_rand
145+
), "Random scores should also be removed"
146+
147+
# Add scores to the instance
148+
snp2cell_instance.scores = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]})
149+
snp2cell_instance.scores_prop = pd.DataFrame(
150+
{"A": [7, 8, 9], "A__pval": [7, 8, 9], "B": [10, 11, 12]}
151+
)
152+
snp2cell_instance.scores_rand = {"A": pd.DataFrame(np.random.randn(10, 3))}
153+
154+
# Test removing propagated scores
155+
snp2cell_instance.remove_scores(which="propagated", items=["A"])
156+
assert (
157+
"A" in snp2cell_instance.scores.columns
158+
), "Original scores should not be removed"
159+
assert (
160+
"A" not in snp2cell_instance.scores_prop.columns
161+
), "Propagated scores should be removed"
162+
assert (
163+
"A" not in snp2cell_instance.scores_rand
164+
), "Random scores should also be removed"
165+
166+
# Add scores to the instance
167+
snp2cell_instance.scores = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]})
168+
snp2cell_instance.scores_prop = pd.DataFrame(
169+
{"A": [7, 8, 9], "A__pval": [7, 8, 9], "B": [10, 11, 12]}
170+
)
171+
snp2cell_instance.scores_rand = {"A": pd.DataFrame(np.random.randn(10, 3))}
172+
173+
# Test removing random scores
174+
snp2cell_instance.remove_scores(which="perturbed", items=["A"])
175+
assert (
176+
"A" in snp2cell_instance.scores.columns
177+
), "Original scores should not be removed"
178+
assert (
179+
"A" in snp2cell_instance.scores_prop.columns
180+
), "Propagated scores should not be removed"
181+
assert "A" not in snp2cell_instance.scores_rand, "Random scores should be removed"
182+
183+
# Test removing all propagated scores
184+
snp2cell_instance.remove_scores(which="propagated")
185+
assert (
186+
snp2cell_instance.scores_prop.shape[1] == 0
187+
), "All propagated scores should be removed"
188+
assert (
189+
len(snp2cell_instance.scores_rand) == 0
190+
), "All random scores should also be removed"
191+
192+
# Test removing all original scores
193+
snp2cell_instance.remove_scores(which="original")
194+
assert (
195+
snp2cell_instance.scores.shape[1] == 0
196+
), "All original scores should be removed"

0 commit comments

Comments
 (0)