Skip to content

Commit 22ce611

Browse files
authored
Merge pull request #1 from domonik/main
Bug fixes and multicore support for generate_kmer_features.py
2 parents 98cef8b + a0a014c commit 22ce611

File tree

2 files changed

+131
-29
lines changed

2 files changed

+131
-29
lines changed

conda-environment.yml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
name:
2+
BrainDead
3+
channels:
4+
- conda-forge
5+
- defaults
6+
- bioconda
7+
dependencies:
8+
- viennarna=2.4.18
9+
- intarna
10+
- biopython
11+
- pandas
12+
- scikit-learn
13+

src/generate_kmer_features.py

Lines changed: 118 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,17 @@
77
import re
88
import argparse
99
import os.path
10+
from multiprocessing import Pool
11+
from tempfile import TemporaryDirectory
12+
import pickle
13+
from typing import List
1014
import sys
1115
BINDIR = os.path.dirname(os.path.realpath(__file__))
1216
def find_kmer_hits(sequence, kmer):
1317
return [m.start() for m in re.finditer('(?='+kmer+')', sequence)] # re with look-ahead for overlaps
1418

1519
def call_command (cmd):
16-
p = subprocess.Popen(cmd,shell=True,stdin=None, stdout=PIPE)
20+
p = subprocess.Popen(cmd,shell=True,stdin=None, stdout=PIPE, stderr=PIPE)
1721
(result, error) = p.communicate()
1822
if error:
1923
raise RuntimeError("Error in calling cmd or perl script\ncmd:{}\nstdout:{}\nstderr:{}".format(cmd, result, error))
@@ -63,6 +67,50 @@ def is_valid_file(file_name):
6367
return os.path.abspath(file_name)
6468
else:
6569
raise FileNotFoundError(os.path.abspath(file_name))
70+
71+
def multicore_wrapper(seq_record, args):
72+
out_csv_str = seq_record.id
73+
print(seq_record.id)
74+
75+
seq_subopt, seq_intarna = get_subopt_intarna_strs(str(seq_record.seq),
76+
minE_subopt=args.minE_subopt,
77+
minE_intarna=args.minE_intarna)
78+
for kmer in kmers_list:
79+
hseq, hsubopt, hintarna, hsubopt_intarna = find_hits_all(str(seq_record.seq),
80+
seq_subopt,
81+
seq_intarna,
82+
kmer)
83+
cseq, csubopt, cintarna, csubopt_intarna = len(hseq), len(
84+
hsubopt), len(hintarna), len(hsubopt_intarna)
85+
array_features = []
86+
if "a" in args.feature_context.lower():
87+
array_features.append(cseq)
88+
if "s" in args.feature_context.lower():
89+
array_features.append(csubopt)
90+
if "h" in args.feature_context.lower():
91+
array_features.append(cintarna)
92+
if "u" in args.feature_context.lower():
93+
array_features.append(csubopt_intarna)
94+
95+
if args.report_counts is True:
96+
out_csv_str += ''.join([',{}'.format(f) for f in array_features])
97+
else:
98+
binary_hits = ['0' if c == 0 else '1' for c in array_features]
99+
out_csv_str += "," + ','.join(binary_hits)
100+
return out_csv_str
101+
102+
103+
def write_pickled_output(files: List[str], outfile: str, csv_header: str):
104+
with open(outfile, "w") as of:
105+
of.write(csv_header)
106+
for file in files:
107+
with open(file, "rb") as handle:
108+
data = pickle.load(handle)
109+
of.write("\n".join(data) + "\n")
110+
del data
111+
112+
113+
66114

67115
if __name__ == '__main__':
68116

@@ -72,6 +120,8 @@ def is_valid_file(file_name):
72120

73121
parser.add_argument('--kmers', required=True, type=str, help='List of kmers as a comma separated string e.g. \"AGG,GA,GG\"')
74122
parser.add_argument('--fasta', required=True, type=is_valid_file, help='Sequences to extract features from as a FASTA file')
123+
parser.add_argument('--threads', type=int, default=1, help='Number of threads used for processing (default: 1) (WARNING: threads > 1 will impair stdout prints')
124+
parser.add_argument('--batchsize', type=int, default=10000, help='If the number of processed fasta sequences is greater than batch size batch processing will be applied. This will lower memory consumption (default: 10000)')
75125
parser.add_argument('--report-counts', action='store_true', help='Whether to report counts as integer, default is binary nohit(0)-hit(1)'),
76126
parser.add_argument('--out-csv', type=str, default='stdout', help='CSV File name to write counts, pass "stdout" for stdout ')
77127
parser.add_argument('--minE-subopt', default=-3, type=int, help='Minimum free energy of the position on RNAsubopt result')
@@ -97,33 +147,72 @@ def is_valid_file(file_name):
97147
out_csv_str += ",{}_free".format(kmer)
98148

99149
out_csv_str += '\n'
100-
for r in SeqIO.parse(args.fasta, format='fasta'):
101-
print(r.id)
102-
out_csv_str += r.id
103-
seq_subopt, seq_intarna = get_subopt_intarna_strs(str(r.seq), minE_subopt=args.minE_subopt, minE_intarna=args.minE_intarna)
104-
for kmer in kmers_list:
105-
hseq, hsubopt, hintarna, hsubopt_intarna = find_hits_all(str(r.seq),seq_subopt,seq_intarna, kmer)
106-
print(kmer, hseq, hsubopt, hintarna, hsubopt_intarna)
107-
cseq, csubopt, cintarna, csubopt_intarna = len(hseq), len(hsubopt), len(hintarna), len(hsubopt_intarna)
108-
array_features = []
109-
if "a" in args.feature_context.lower():
110-
array_features.append(cseq)
111-
if "s" in args.feature_context.lower():
112-
array_features.append(csubopt)
113-
if "h" in args.feature_context.lower():
114-
array_features.append(cintarna)
115-
if "u" in args.feature_context.lower():
116-
array_features.append(csubopt_intarna)
117-
118-
if args.report_counts is True:
119-
out_csv_str += ''.join([',{}'.format(f) for f in array_features])
120-
else:
121-
binary_hits = ['0' if c==0 else '1' for c in array_features]
122-
out_csv_str += ","+','.join(binary_hits)
123-
out_csv_str += '\n'
150+
if args.threads == 1:
151+
for r in SeqIO.parse(args.fasta, format='fasta'):
152+
print(r.id)
153+
out_csv_str += r.id
154+
seq_subopt, seq_intarna = get_subopt_intarna_strs(str(r.seq), minE_subopt=args.minE_subopt, minE_intarna=args.minE_intarna)
155+
for kmer in kmers_list:
156+
hseq, hsubopt, hintarna, hsubopt_intarna = find_hits_all(str(r.seq),seq_subopt,seq_intarna, kmer)
157+
print(kmer, hseq, hsubopt, hintarna, hsubopt_intarna)
158+
cseq, csubopt, cintarna, csubopt_intarna = len(hseq), len(hsubopt), len(hintarna), len(hsubopt_intarna)
159+
array_features = []
160+
if "a" in args.feature_context.lower():
161+
array_features.append(cseq)
162+
if "s" in args.feature_context.lower():
163+
array_features.append(csubopt)
164+
if "h" in args.feature_context.lower():
165+
array_features.append(cintarna)
166+
if "u" in args.feature_context.lower():
167+
array_features.append(csubopt_intarna)
124168

125-
if args.out_csv == "stdout":
126-
print(out_csv_str)
169+
if args.report_counts is True:
170+
out_csv_str += ''.join([',{}'.format(f) for f in array_features])
171+
else:
172+
binary_hits = ['0' if c==0 else '1' for c in array_features]
173+
out_csv_str += ","+','.join(binary_hits)
174+
out_csv_str += '\n'
175+
176+
if args.out_csv == "stdout":
177+
print(out_csv_str)
178+
else:
179+
with open(args.out_csv, 'w') as outfile:
180+
outfile.write(out_csv_str)
127181
else:
128-
with open(args.out_csv, 'w') as outfile:
129-
outfile.write(out_csv_str)
182+
183+
calls = []
184+
for seq_record in SeqIO.parse(args.fasta, format='fasta'):
185+
calls.append((seq_record, args))
186+
187+
if args.batchsize < len(calls):
188+
tmp_dir = TemporaryDirectory(prefix="BrainDead")
189+
files = []
190+
batch_calls = [calls[x:x+args.batchsize] for x in range(0, len(calls), args.batchsize)]
191+
for x, batch in enumerate(batch_calls):
192+
with Pool(processes=args.threads) as pool:
193+
outstrings = pool.starmap(multicore_wrapper, batch)
194+
file = os.path.join(tmp_dir.name, f"batch_{x}.pckl")
195+
files.append(file)
196+
with open(file, "wb") as handle:
197+
pickle.dump(outstrings, handle)
198+
write_pickled_output(files=files,
199+
outfile=args.out_csv,
200+
csv_header=out_csv_str)
201+
202+
203+
else:
204+
with Pool(processes=args.threads) as pool:
205+
outstrings = pool.starmap(multicore_wrapper, calls)
206+
207+
out_csv_str += "\n".join(outstrings) + "\n"
208+
209+
if args.out_csv == "stdout":
210+
print(out_csv_str)
211+
else:
212+
with open(args.out_csv, 'w') as outfile:
213+
outfile.write(out_csv_str)
214+
215+
216+
217+
218+

0 commit comments

Comments
 (0)