Skip to content

Commit e55a3f6

Browse files
committed
Modifications for API consistency
- Consistent way to pass transform function in dictionary - Clean the code
1 parent 93f2b1e commit e55a3f6

3 files changed

Lines changed: 194 additions & 234 deletions

File tree

seqchromloader/loader.py

Lines changed: 86 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import logging
88
import torch
99
import random
10+
import pysam
1011
import pyfasta
1112
import pyBigWig
1213
import numpy as np
@@ -28,22 +29,25 @@ class SeqChromLoader():
2829
def __init__(self, SeqChromDataset):
2930
self.SeqChromDataset = SeqChromDataset
3031

31-
def __call__(self, *args, worker_init_fn=worker_init_fn, dataloader_kws:dict=None, **kwargs):
32+
def __call__(self, *args, dataloader_kws:dict=None, **kwargs):
3233
# default dataloader kws
33-
wif = dataloader_kws.pop("worker_init_fn", worker_init_fn) if dataloader_kws is not None else worker_init_fn
34+
if dataloader_kws is not None:
35+
wif = dataloader_kws.pop("worker_init_fn", worker_init_fn)
36+
num_workers = dataloader_kws.pop("num_workers", 1)
37+
else:
38+
wif = worker_init_fn
39+
num_workers = 1
3440

3541
return DataLoader(self.SeqChromDataset(*args, **kwargs),
36-
worker_init_fn=wif, **dataloader_kws)
42+
worker_init_fn=wif, num_workers=num_workers, **dataloader_kws)
3743

3844
def seqChromLoaderCurry(SeqChromDataset):
3945

4046
return SeqChromLoader(SeqChromDataset)
4147

4248
class _SeqChromDatasetByWds(IterableDataset):
43-
def __init__(self, wds, seq_transform:list=None, chrom_transform:list=None, target_transform:list=None):
44-
self.seq_transform = seq_transform
45-
self.chrom_transform = chrom_transform
46-
self.target_transform = target_transform
49+
def __init__(self, wds, transforms:dict=None):
50+
self.transforms = transforms
4751

4852
self.wds = wds
4953

@@ -58,78 +62,71 @@ def __iter__(self):
5862
wds.tarfile_to_samples(),
5963
wds.split_by_worker,
6064
wds.decode(),
61-
wds.to_tuple("seq.npy", "chrom.npy", "target.npy", "label.npy"),
65+
wds.rename(seq="seq.npy",
66+
chrom="chrom.npy",
67+
target="target.npy",
68+
label="label.npy")
6269
]
6370
if worker_info is None:
6471
logging.info("Worker info not found, won't split dataset across subprocesses, are you using custom dataloader?")
6572
logging.info("Ignore the message if you are not using multiprocessing on data loading")
6673
del pipeline[2]
6774

6875
# transform
69-
if self.seq_transform is not None: pipeline.extend(self.seq_transform)
70-
if self.chrom_transform is not None: pipeline.extend(self.chrom_transform)
71-
if self.target_transform is not None: pipeline.extend(self.target_transform)
76+
if self.transforms is not None:
77+
pipeline.append(wds.map_dict(**self.transforms))
7278

79+
pipeline.append(wds.to_tuple("seq", "chrom", "target", "label"))
80+
7381
ds = wds.DataPipeline(*pipeline)
7482

7583
return iter(ds)
7684

7785
SeqChromDatasetByWds = seqChromLoaderCurry(_SeqChromDatasetByWds)
7886

7987
class _SeqChromDatasetByBed(Dataset):
80-
def __init__(self, bed, fasta, bigwig_files, seq_transform:list=None, chrom_transform:list=None, target_transform:list=None):
81-
self.bed = pd.read_table(bed, header=None, names=['chrom', 'start', 'end', 'name', 'score', 'strand' ])
82-
83-
self.fasta = fasta
84-
self.bigwig_files = bigwig_files
85-
self.seq_transform = [utils.DNA2OneHot()] + seq_transform # prepend default DNA one hot coding transform
86-
self.chrom_transfrom = chrom_transform
87-
self.target_transform = target_transform
88+
def __init__(self, bed, genome_fasta, bigwig_filelist:list, target_bam=None, transforms:dict=None, initialize_first=False):
89+
self.bed = pd.read_table(bed, header=None, names=['chrom', 'start', 'end', 'label', 'score', 'strand' ])
90+
91+
self.genome_fasta = genome_fasta
92+
self.genome_pyfasta = None
93+
self.bigwig_filelist = bigwig_filelist
94+
self.bigwigs = None
95+
self.target_bam = target_bam
96+
self.target_pysam = None
97+
98+
self.transforms = transforms
99+
100+
if initialize_first: self.initialize()
88101

89102
def initialize(self):
90103
# this function will be called by worker_init_function in DataLoader
91-
self.genome_pyfasta = pyfasta.Fasta(self.config["train_bichrom"]["fasta"])
92-
#self.tfbam = pysam.AlignmentFile(self.config["train_bichrom"]["tf_bam"])
93-
self.bigwigs = [pyBigWig.open(bw) for bw in self.bigwig_files]
104+
self.genome_pyfasta = pyfasta.Fasta(self.genome_fasta)
105+
self.bigwigs = [pyBigWig.open(bw) for bw in self.bigwig_filelist]
106+
if self.target_bam is not None:
107+
self.target_pysam = pysam.AlignmentFile(self.target_bam)
94108

95109
def __len__(self):
96110
return len(self.bed)
97111

98112
def __getitem__(self, idx):
99-
entry = self.bed.iloc[idx,]
100-
# get info in the each entry region
101-
## sequence
102-
sequence = self.genome_pyfasta[entry.chrom][int(entry.start):int(entry.end)]
103-
sequence = self.rev_comp(sequence) if entry.strand=="-" else sequence
104-
## chromatin
105-
ms = []
113+
item = self.bed.iloc[idx,]
106114
try:
107-
for idx, bigwig in enumerate(self.bigwigs):
108-
m = (np.nan_to_num(bigwig.values(entry.chrom, entry.start, entry.end))).astype(np.float32)
109-
if entry.strand == "-": m = m[::-1] # reverse if needed
110-
ms.append(m)
111-
except RuntimeError as e:
112-
print(e)
113-
raise Exception(f"Failed to extract chromatin {self.bigwig_files[idx]} information in region {entry}")
114-
ms = np.vstack(ms)
115-
## target: read count in region
116-
#target = self.tfbam.count(entry.chrom, entry.start, entry.end)
117-
118-
# transform
119-
if self.seq_transform:
120-
seq = [t(sequence) for t in self.seq_transform]
121-
if self.chrom_transfrom:
122-
ms = [t(ms) for t in self.chrom_transfrom]
123-
124-
return seq, ms
115+
feature = utils.extract_info(
116+
item.chrom,
117+
item.start,
118+
item.end,
119+
item.label,
120+
genome_pyfasta=self.genome_pyfasta,
121+
bigwigs=self.bigwigs,
122+
target_bam=self.target_pysam,
123+
strand=item.strand,
124+
transforms=self.transforms
125+
)
126+
except utils.BigWigInaccessible as e:
127+
raise e
125128

126-
def rev_comp(self, inp_str):
127-
rc_dict = {'A': 'T', 'G': 'C', 'T': 'A', 'C': 'G', 'c': 'g',
128-
'g': 'c', 't': 'a', 'a': 't', 'n': 'n', 'N': 'N'}
129-
outp_str = list()
130-
for nucl in inp_str:
131-
outp_str.append(rc_dict[nucl])
132-
return ''.join(outp_str)[::-1]
129+
return feature['seq'], feature['chrom'], feature['target'], feature['label']
133130

134131
SeqChromDatasetByBed = seqChromLoaderCurry(_SeqChromDatasetByBed)
135132

@@ -170,17 +167,18 @@ def _target_vlog(sample):
170167
target_vlog = wds.pipelinefilter(_target_vlog)
171168

172169
class SeqChromDataModule(LightningDataModule):
173-
def __init__(self, train_wds, val_wds, test_wds, train_dataset_size:int=None, transform:list=None, num_workers=8, batch_size=512):
170+
def __init__(self, train_wds, val_wds, test_wds, train_dataset_size:int=None, transforms:dict=None, num_workers=1, batch_size=512, patch_last=True):
174171
super().__init__()
175172
self.num_workers = num_workers
176173
self.batch_size = batch_size
177174
self.train_dataset_size = train_dataset_size
175+
self.patch_last = patch_last
178176

179177
self.train_wds = train_wds
180178
self.val_wds = val_wds
181179
self.test_wds = test_wds
182180

183-
self.transform = transform
181+
self.transforms = transforms
184182

185183
def prepare_data(self):
186184
pass
@@ -210,7 +208,10 @@ def setup(self, stage=None):
210208
wds.tarfile_to_samples(),
211209
wds.shuffle(1000, rng=random.Random(1)),
212210
wds.decode(),
213-
wds.to_tuple("seq.npy", "chrom.npy", "target.npy", "label.npy"),
211+
wds.rename(seq="seq.npy",
212+
chrom="chrom.npy",
213+
target="target.npy",
214+
label="label.npy")
214215
]
215216

216217
val_pipeline = [
@@ -219,7 +220,10 @@ def setup(self, stage=None):
219220
wds.split_by_worker,
220221
wds.tarfile_to_samples(),
221222
wds.decode(),
222-
wds.to_tuple("seq.npy", "chrom.npy", "target.npy", "label.npy"),
223+
wds.rename(seq="seq.npy",
224+
chrom="chrom.npy",
225+
target="target.npy",
226+
label="label.npy")
223227
]
224228

225229
test_pipeline = [
@@ -228,28 +232,37 @@ def setup(self, stage=None):
228232
wds.split_by_worker,
229233
wds.tarfile_to_samples(),
230234
wds.decode(),
231-
wds.to_tuple("seq.npy", "chrom.npy", "target.npy", "label.npy"),
235+
wds.rename(seq="seq.npy",
236+
chrom="chrom.npy",
237+
target="target.npy",
238+
label="label.npy")
232239
]
233240

234-
if self.transform is not None:
235-
train_pipeline.extend(self.transform)
236-
val_pipeline.extend(self.transform)
237-
test_pipeline.extend(self.transform)
241+
if self.transforms is not None:
242+
train_pipeline.append(wds.map_dict(**self.transforms))
243+
val_pipeline.append(wds.map_dict(**self.transforms))
244+
test_pipeline.append(wds.map_dict(**self.transforms))
238245

239-
self.train_loader = wds.DataPipeline(
240-
*train_pipeline
241-
)
246+
self.train_loader = wds.DataPipeline([
247+
*train_pipeline,
248+
wds.to_tuple("seq", "chrom", "target", "label")
249+
])
242250

243-
self.val_loader = wds.DataPipeline(
244-
*val_pipeline
245-
)
251+
self.val_loader = wds.DataPipeline([
252+
*val_pipeline,
253+
wds.to_tuple("seq", "chrom", "target", "label"),
254+
])
246255

247-
self.test_loader = wds.DataPipeline(
248-
*test_pipeline
249-
)
256+
self.test_loader = wds.DataPipeline([
257+
*test_pipeline,
258+
wds.to_tuple("seq", "chrom", "target", "label"),
259+
])
250260

251261
def train_dataloader(self):
252-
return wds.WebLoader(self.train_loader.repeat(2), num_workers=self.num_workers, batch_size=self.batch_size_per_rank).with_epoch(ceil(self.train_dataset_size/self.batch_size)) # pad the last batch if there is remainder
262+
if self.patch_last:
263+
return wds.WebLoader(self.train_loader.repeat(2), num_workers=self.num_workers, batch_size=self.batch_size_per_rank).with_epoch(ceil(self.train_dataset_size/self.batch_size)) # pad the last batch if there is remainder
264+
else:
265+
return wds.WebLoader(self.train_loader, num_workers=self.num_workers, batch_size=self.batch_size_per_rank)
253266

254267
def val_dataloader(self):
255268
return wds.WebLoader(self.val_loader, num_workers=self.num_workers, batch_size=self.batch_size_per_rank)

seqchromloader/utils.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
"""
55
import pandas as pd
66
import numpy as np
7+
import logging
78
from multiprocessing import Pool
89
from pybedtools import Interval, BedTool
910

@@ -267,13 +268,83 @@ def __call__(self, dnaSeq):
267268
continue
268269
return seqMatrix
269270

271+
def dna2OneHot(dnaSeq):
272+
DNA2Index = {
273+
"A": 0,
274+
"C": 1,
275+
"G": 2,
276+
"T": 3,
277+
}
278+
279+
seqLen = len(dnaSeq)
280+
# initialize the matrix as 4 x len(dnaSeq)
281+
seqMatrix = np.zeros((4, len(dnaSeq)), dtype=np.float32)
282+
# change the value to matrix
283+
dnaSeq = dnaSeq.upper()
284+
for j in range(0, seqLen):
285+
if dnaSeq[j] == "N": continue
286+
try:
287+
seqMatrix[DNA2Index[dnaSeq[j]], j] = 1
288+
except KeyError as e:
289+
print(f"Keyerror happened at position {j}: {dnaSeq[j]}, legal keys are: [A, C, G, T, N]")
290+
continue
291+
return seqMatrix
292+
270293
def rev_comp(inp_str):
271294
rc_dict = {'A': 'T', 'G': 'C', 'T': 'A', 'C': 'G', 'c': 'g',
272295
'g': 'c', 't': 'a', 'a': 't', 'n': 'n', 'N': 'N'}
273296
outp_str = list()
274297
for nucl in inp_str:
275298
outp_str.append(rc_dict[nucl])
276-
return ''.join(outp_str)[::-1]
299+
return ''.join(outp_str)[::-1]
300+
301+
class BigWigInaccessible(Exception):
302+
def __init__(self, chrom, start, end, *args):
303+
super.__init__(*args)
304+
self.chrom = chrom
305+
self.start = start
306+
self.end = end
307+
308+
def __str__(self) -> str:
309+
return f'Chromatin Info Inaccessible in region {self.chrom}:{self.start}-{self.end}'
310+
311+
def extract_info(chrom, start, end, label, genome_pyfasta, bigwigs, target_bam, strand="+", transforms:dict=None):
312+
seq = genome_pyfasta[chrom][int(start):int(end)]
313+
if strand=="-":
314+
seq = rev_comp(seq)
315+
seq_array = dna2OneHot(seq)
316+
317+
#chromatin track
318+
chroms_array = []
319+
try:
320+
for idx, bigwig in enumerate(bigwigs):
321+
c = (np.nan_to_num(bigwig.values(chrom, start, end))).astype(np.float32)
322+
if strand=="-":
323+
c = c[::-1]
324+
chroms_array.append(c)
325+
except RuntimeError as e:
326+
logging.warning(e)
327+
logging.warning(f"RuntimeError happened when accessing {chrom}:{start}-{end}, it's probably due to at least one chromatin track bigwig doesn't have information in this region")
328+
raise BigWigInaccessible(chrom, start, end)
329+
chroms_array = np.vstack(chroms_array) # create the chromatin track array, shape (num_tracks, length)
330+
# label
331+
label_array = np.array(label, dtype=np.int32)[np.newaxis]
332+
# counts
333+
target_array = target_bam.count(chrom, start, end) if target_bam is not None else np.nan
334+
target_array = np.array(target_array, dtype=np.float32)[np.newaxis]
335+
336+
feature = {
337+
'seq': seq_array,
338+
'chrom': chroms_array,
339+
'target': target_array,
340+
'label': label_array
341+
}
342+
343+
if transforms is not None:
344+
for k,t in transforms.items():
345+
feature[k] = t(feature[k])
346+
347+
return feature
277348

278349
if __name__ == "__main__":
279350
chip_seq_coordinates = load_chipseq_data("Bichrom/sample_data/Ascl1.peaks", genome_sizes_file="Bichrom/sample_data/mm10.info",

0 commit comments

Comments
 (0)