Skip to content

Commit cf17b14

Browse files
authored
fix: reduce memory usage of em step
- Use two passes to build matrix, then calculate coverage after EM. This avoid storing reduced representations of all alignments in memory. - Consolidate BAM parsing into common interface.
1 parent 2340210 commit cf17b14

File tree

14 files changed

+1815
-1149
lines changed

14 files changed

+1815
-1149
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ authors = [
1818

1919
[dependency-groups]
2020
dev = [
21-
"maturin>=1.9.2",
21+
"maturin[patchelf]>=1.9.2",
2222
"pysam>=0.22.0",
2323
"pytest>=8.4.1,<9.0.0",
2424
"pytest-asyncio<=1.1.0,<2.0.0",

python/workflow_pathoscope/rust.pyi

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,8 @@ class PathoscopeResults:
2424
coverage: dict[str, list[int]]
2525

2626
def run_expectation_maximization(
27-
alignment_path: str,
27+
bam_path: str,
2828
p_score_cutoff: float,
29-
ref_lengths: dict[str, int],
3029
) -> PathoscopeResults:
3130
"""Run Pathoscope expectation maximization algorithm using Rust on SAM/BAM files."""
3231

@@ -45,6 +44,5 @@ def find_candidate_otus_from_bytes(
4544
def calculate_coverage_from_em_results(
4645
alignment_path: str,
4746
p_score_cutoff: float,
48-
ref_lengths: dict[str, int],
4947
) -> dict[str, list[int]]:
5048
"""Calculate coverage directly from EM results and alignment data."""

python/workflow_pathoscope/utils.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -140,20 +140,17 @@ def write_report(
140140

141141

142142
def run_pathoscope(
143-
alignment_path: Path,
143+
bam_path: Path,
144144
p_score_cutoff: float,
145-
ref_lengths: dict[str, int],
146145
):
147146
"""Run Pathoscope on an alignment file.
148147
149148
Returns PathoscopeResults containing EM results and coverage data.
150149
151150
:param alignment_path: The path to the SAM or BAM file.
152151
:param p_score_cutoff: The minimum allowed ``p_score`` for an alignment.
153-
:param ref_lengths: Dictionary mapping reference IDs to their lengths.
154152
"""
155153
return run_expectation_maximization(
156-
str(alignment_path),
154+
str(bam_path),
157155
p_score_cutoff,
158-
ref_lengths,
159156
)

rustfmt.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
max_width = 88

src/candidates.rs

Lines changed: 61 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
1-
use std::collections::HashSet;
21
use log::info;
3-
use pyo3::prelude::*;
42
use pyo3::exceptions::PyIOError;
3+
use pyo3::prelude::*;
4+
use std::collections::HashSet;
55

66
const AS_TAG_PREFIX: &str = "AS:i:";
77

88
/// Extract AS:i alignment score from SAM optional fields
9-
///
9+
///
1010
/// # Arguments
1111
/// * `fields` - SAM fields starting from the optional fields (field 11+)
12-
///
12+
///
1313
/// # Returns
1414
/// Option containing the AS:i score as f64, None if not found or invalid
1515
fn extract_as_score(fields: &[&str]) -> Option<f64> {
@@ -24,14 +24,14 @@ fn extract_as_score(fields: &[&str]) -> Option<f64> {
2424
}
2525

2626
/// Parse a single SAM line and extract candidate OTU information
27-
///
27+
///
2828
/// This function processes one SAM line and determines if the read meets the score cutoff.
2929
/// Used for testing and by the streaming functions.
3030
///
3131
/// # Arguments
3232
/// * `line` - A SAM format line as string
3333
/// * `p_score_cutoff` - Minimum score threshold (AS:i score + read length)
34-
///
34+
///
3535
/// # Returns
3636
/// Option containing the reference name if the read meets the cutoff, None otherwise
3737
pub fn parse_sam_line(line: &str, p_score_cutoff: f64) -> Option<String> {
@@ -42,14 +42,14 @@ pub fn parse_sam_line(line: &str, p_score_cutoff: f64) -> Option<String> {
4242

4343
// Parse SAM line - tab-separated format
4444
let fields: Vec<&str> = line.split('\t').collect();
45-
45+
4646
// SAM format requires at least 11 fields
4747
if fields.len() < 11 {
4848
return None;
4949
}
5050

5151
// Extract key fields:
52-
// 1: FLAG
52+
// 1: FLAG
5353
// 2: RNAME (reference name)
5454
// 9: SEQ (read sequence)
5555
let flag: u16 = fields[1].parse().unwrap_or(4); // Default to unmapped if parse fails
@@ -75,19 +75,18 @@ pub fn parse_sam_line(line: &str, p_score_cutoff: f64) -> Option<String> {
7575
None
7676
}
7777

78-
7978
/// Extract candidate OTU reference IDs by running bowtie2 directly with streaming
80-
///
79+
///
8180
/// This function spawns a bowtie2 process directly from Rust and streams its output
8281
/// to avoid memory issues with large SAM files. It processes SAM lines as they arrive
8382
/// and returns only the unique reference IDs that meet the score cutoff.
84-
///
83+
///
8584
/// # Arguments
8685
/// * `bowtie_index_path` - Path to the bowtie2 index
8786
/// * `read_paths` - List of paths to the input read files
8887
/// * `proc` - Number of processor threads for bowtie2
8988
/// * `p_score_cutoff` - Minimum score threshold (AS:i score + read length)
90-
///
89+
///
9190
/// # Returns
9291
/// Set of reference IDs that have reads meeting the score cutoff
9392
pub fn find_candidate_otus_with_bowtie2(
@@ -97,52 +96,57 @@ pub fn find_candidate_otus_with_bowtie2(
9796
proc: i32,
9897
p_score_cutoff: f64,
9998
) -> PyResult<HashSet<String>> {
100-
use std::process::{Command, Stdio};
10199
use std::io::{BufRead, BufReader};
102-
103-
info!("running bowtie2: index={}, reads={:?}, cutoff={}",
104-
bowtie_index_path, read_paths, p_score_cutoff);
100+
use std::process::{Command, Stdio};
101+
102+
info!(
103+
"running bowtie2: index={}, reads={:?}, cutoff={}",
104+
bowtie_index_path, read_paths, p_score_cutoff
105+
);
105106
py.allow_threads(|| {
106107
let mut cmd = Command::new("bowtie2");
107-
cmd.arg("-p").arg(proc.to_string())
108-
.arg("--local")
109-
.arg("--no-unal")
110-
.arg("--score-min").arg("L,20,1.0")
111-
.arg("-N").arg("0")
112-
.arg("-L").arg("15")
113-
.arg("-x").arg(bowtie_index_path)
114-
.arg("-U").arg(read_paths.join(","))
115-
.stdout(Stdio::piped())
116-
.stderr(Stdio::piped());
117-
108+
cmd.arg("-p")
109+
.arg(proc.to_string())
110+
.arg("--local")
111+
.arg("--no-unal")
112+
.arg("--score-min")
113+
.arg("L,20,1.0")
114+
.arg("-N")
115+
.arg("0")
116+
.arg("-L")
117+
.arg("15")
118+
.arg("-x")
119+
.arg(bowtie_index_path)
120+
.arg("-U")
121+
.arg(read_paths.join(","))
122+
.stdout(Stdio::piped())
123+
.stderr(Stdio::piped());
124+
118125
info!("spawning bowtie2 process");
119-
let mut child = cmd.spawn()
120-
.map_err(|e| PyErr::new::<PyIOError, _>(format!("Failed to spawn bowtie2: {}", e)))?;
121-
126+
let mut child = cmd.spawn()?;
127+
122128
let stdout = child.stdout.take().unwrap();
123129
let reader = BufReader::new(stdout);
124-
130+
125131
let mut candidate_otus = HashSet::new();
126132
let mut line_count = 0u64;
127133
let mut passing_count = 0u64;
128-
134+
129135
for line_result in reader.lines() {
130-
let line = line_result
131-
.map_err(|e| PyErr::new::<PyIOError, _>(format!("Error reading bowtie2 output: {}", e)))?;
132-
136+
let line = line_result?;
137+
133138
line_count += 1;
134-
139+
135140
// Use the extracted SAM parsing function
136141
if let Some(ref_name) = parse_sam_line(&line, p_score_cutoff) {
137142
candidate_otus.insert(ref_name);
138143
passing_count += 1;
139144
}
140145
}
141-
146+
142147
// Wait for bowtie2 to finish and check exit status
143-
let status = child.wait()
144-
.map_err(|e| PyErr::new::<PyIOError, _>(format!("Error waiting for bowtie2: {}", e)))?;
145-
148+
let status = child.wait()?;
149+
146150
if !status.success() {
147151
// Read stderr for error details
148152
let stderr_output = if let Some(mut stderr) = child.stderr.take() {
@@ -152,22 +156,25 @@ pub fn find_candidate_otus_with_bowtie2(
152156
} else {
153157
"Unknown error".to_string()
154158
};
155-
159+
156160
return Err(PyErr::new::<PyIOError, _>(format!(
157-
"bowtie2 failed with exit code {:?}: {}",
158-
status.code(),
161+
"bowtie2 failed with exit code {:?}: {}",
162+
status.code(),
159163
stderr_output
160164
)));
161165
}
162-
163-
info!("processed {} sam lines, {} passed cutoff, found {} unique otus",
164-
line_count, passing_count, candidate_otus.len());
165-
166+
167+
info!(
168+
"processed {} sam lines, {} passed cutoff, found {} unique otus",
169+
line_count,
170+
passing_count,
171+
candidate_otus.len()
172+
);
173+
166174
Ok(candidate_otus)
167175
})
168176
}
169177

170-
171178
#[cfg(test)]
172179
mod tests {
173180
use super::*;
@@ -176,7 +183,7 @@ mod tests {
176183
fn test_parse_sam_line_basic() {
177184
let line = "read1\t0\tref1\t100\t255\t50M\t*\t0\t0\tAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\t*\tAS:i:45";
178185
let result = parse_sam_line(line, 0.01);
179-
186+
180187
// AS:i:45 + seq_len(50) = 95.0, should pass cutoff of 0.01
181188
assert_eq!(result, Some("ref1".to_string()));
182189
}
@@ -185,7 +192,7 @@ mod tests {
185192
fn test_parse_sam_line_below_cutoff() {
186193
let line = "read1\t0\tref1\t100\t255\t50M\t*\t0\t0\tAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\t*\tAS:i:45";
187194
let result = parse_sam_line(line, 100.0);
188-
195+
189196
// AS:i:45 + seq_len(50) = 95.0, should not pass cutoff of 100.0
190197
assert_eq!(result, None);
191198
}
@@ -194,7 +201,7 @@ mod tests {
194201
fn test_parse_sam_line_unmapped() {
195202
let line = "read1\t4\t*\t0\t0\t*\t*\t0\t0\tAAAAA\t*";
196203
let result = parse_sam_line(line, 0.01);
197-
204+
198205
// Unmapped read (flag & 4 != 0), should return None
199206
assert_eq!(result, None);
200207
}
@@ -203,7 +210,7 @@ mod tests {
203210
fn test_parse_sam_line_no_as_score() {
204211
let line = "read1\t0\tref1\t100\t255\t50M\t*\t0\t0\tAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\t*";
205212
let result = parse_sam_line(line, 0.01);
206-
213+
207214
// No AS:i score, should return None
208215
assert_eq!(result, None);
209216
}
@@ -212,10 +219,8 @@ mod tests {
212219
fn test_parse_sam_line_header() {
213220
let line = "@HD\tVN:1.0\tSO:unsorted";
214221
let result = parse_sam_line(line, 0.01);
215-
222+
216223
// Header line, should return None
217224
assert_eq!(result, None);
218225
}
219-
220-
221-
}
226+
}

0 commit comments

Comments
 (0)