77import logging
88import torch
99import random
10+ import pysam
1011import pyfasta
1112import pyBigWig
1213import numpy as np
@@ -28,22 +29,25 @@ class SeqChromLoader():
2829 def __init__ (self , SeqChromDataset ):
2930 self .SeqChromDataset = SeqChromDataset
3031
31- def __call__ (self , * args , worker_init_fn = worker_init_fn , dataloader_kws :dict = None , ** kwargs ):
32+ def __call__ (self , * args , dataloader_kws :dict = None , ** kwargs ):
3233 # default dataloader kws
33- wif = dataloader_kws .pop ("worker_init_fn" , worker_init_fn ) if dataloader_kws is not None else worker_init_fn
34+ if dataloader_kws is not None :
35+ wif = dataloader_kws .pop ("worker_init_fn" , worker_init_fn )
36+ num_workers = dataloader_kws .pop ("num_workers" , 1 )
37+ else :
38+ wif = worker_init_fn
39+ num_workers = 1
3440
3541 return DataLoader (self .SeqChromDataset (* args , ** kwargs ),
36- worker_init_fn = wif , ** dataloader_kws )
42+ worker_init_fn = wif , num_workers = num_workers , ** dataloader_kws )
3743
3844def seqChromLoaderCurry (SeqChromDataset ):
3945
4046 return SeqChromLoader (SeqChromDataset )
4147
4248class _SeqChromDatasetByWds (IterableDataset ):
43- def __init__ (self , wds , seq_transform :list = None , chrom_transform :list = None , target_transform :list = None ):
44- self .seq_transform = seq_transform
45- self .chrom_transform = chrom_transform
46- self .target_transform = target_transform
49+ def __init__ (self , wds , transforms :dict = None ):
50+ self .transforms = transforms
4751
4852 self .wds = wds
4953
@@ -58,78 +62,71 @@ def __iter__(self):
5862 wds .tarfile_to_samples (),
5963 wds .split_by_worker ,
6064 wds .decode (),
61- wds .to_tuple ("seq.npy" , "chrom.npy" , "target.npy" , "label.npy" ),
65+ wds .rename (seq = "seq.npy" ,
66+ chrom = "chrom.npy" ,
67+ target = "target.npy" ,
68+ label = "label.npy" )
6269 ]
6370 if worker_info is None :
6471 logging .info ("Worker info not found, won't split dataset across subprocesses, are you using custom dataloader?" )
6572 logging .info ("Ignore the message if you are not using multiprocessing on data loading" )
6673 del pipeline [2 ]
6774
6875 # transform
69- if self .seq_transform is not None : pipeline .extend (self .seq_transform )
70- if self .chrom_transform is not None : pipeline .extend (self .chrom_transform )
71- if self .target_transform is not None : pipeline .extend (self .target_transform )
76+ if self .transforms is not None :
77+ pipeline .append (wds .map_dict (** self .transforms ))
7278
79+ pipeline .append (wds .to_tuple ("seq" , "chrom" , "target" , "label" ))
80+
7381 ds = wds .DataPipeline (* pipeline )
7482
7583 return iter (ds )
7684
7785SeqChromDatasetByWds = seqChromLoaderCurry (_SeqChromDatasetByWds )
7886
7987class _SeqChromDatasetByBed (Dataset ):
80- def __init__ (self , bed , fasta , bigwig_files , seq_transform :list = None , chrom_transform :list = None , target_transform :list = None ):
81- self .bed = pd .read_table (bed , header = None , names = ['chrom' , 'start' , 'end' , 'name' , 'score' , 'strand' ])
82-
83- self .fasta = fasta
84- self .bigwig_files = bigwig_files
85- self .seq_transform = [utils .DNA2OneHot ()] + seq_transform # prepend default DNA one hot coding transform
86- self .chrom_transfrom = chrom_transform
87- self .target_transform = target_transform
88+ def __init__ (self , bed , genome_fasta , bigwig_filelist :list , target_bam = None , transforms :dict = None , initialize_first = False ):
89+ self .bed = pd .read_table (bed , header = None , names = ['chrom' , 'start' , 'end' , 'label' , 'score' , 'strand' ])
90+
91+ self .genome_fasta = genome_fasta
92+ self .genome_pyfasta = None
93+ self .bigwig_filelist = bigwig_filelist
94+ self .bigwigs = None
95+ self .target_bam = target_bam
96+ self .target_pysam = None
97+
98+ self .transforms = transforms
99+
100+ if initialize_first : self .initialize ()
88101
89102 def initialize (self ):
90103 # this function will be called by worker_init_function in DataLoader
91- self .genome_pyfasta = pyfasta .Fasta (self .config ["train_bichrom" ]["fasta" ])
92- #self.tfbam = pysam.AlignmentFile(self.config["train_bichrom"]["tf_bam"])
93- self .bigwigs = [pyBigWig .open (bw ) for bw in self .bigwig_files ]
104+ self .genome_pyfasta = pyfasta .Fasta (self .genome_fasta )
105+ self .bigwigs = [pyBigWig .open (bw ) for bw in self .bigwig_filelist ]
106+ if self .target_bam is not None :
107+ self .target_pysam = pysam .AlignmentFile (self .target_bam )
94108
95109 def __len__ (self ):
96110 return len (self .bed )
97111
98112 def __getitem__ (self , idx ):
99- entry = self .bed .iloc [idx ,]
100- # get info in the each entry region
101- ## sequence
102- sequence = self .genome_pyfasta [entry .chrom ][int (entry .start ):int (entry .end )]
103- sequence = self .rev_comp (sequence ) if entry .strand == "-" else sequence
104- ## chromatin
105- ms = []
113+ item = self .bed .iloc [idx ,]
106114 try :
107- for idx , bigwig in enumerate (self .bigwigs ):
108- m = (np .nan_to_num (bigwig .values (entry .chrom , entry .start , entry .end ))).astype (np .float32 )
109- if entry .strand == "-" : m = m [::- 1 ] # reverse if needed
110- ms .append (m )
111- except RuntimeError as e :
112- print (e )
113- raise Exception (f"Failed to extract chromatin { self .bigwig_files [idx ]} information in region { entry } " )
114- ms = np .vstack (ms )
115- ## target: read count in region
116- #target = self.tfbam.count(entry.chrom, entry.start, entry.end)
117-
118- # transform
119- if self .seq_transform :
120- seq = [t (sequence ) for t in self .seq_transform ]
121- if self .chrom_transfrom :
122- ms = [t (ms ) for t in self .chrom_transfrom ]
123-
124- return seq , ms
115+ feature = utils .extract_info (
116+ item .chrom ,
117+ item .start ,
118+ item .end ,
119+ item .label ,
120+ genome_pyfasta = self .genome_pyfasta ,
121+ bigwigs = self .bigwigs ,
122+ target_bam = self .target_pysam ,
123+ strand = item .strand ,
124+ transforms = self .transforms
125+ )
126+ except utils .BigWigInaccessible as e :
127+ raise e
125128
126- def rev_comp (self , inp_str ):
127- rc_dict = {'A' : 'T' , 'G' : 'C' , 'T' : 'A' , 'C' : 'G' , 'c' : 'g' ,
128- 'g' : 'c' , 't' : 'a' , 'a' : 't' , 'n' : 'n' , 'N' : 'N' }
129- outp_str = list ()
130- for nucl in inp_str :
131- outp_str .append (rc_dict [nucl ])
132- return '' .join (outp_str )[::- 1 ]
129+ return feature ['seq' ], feature ['chrom' ], feature ['target' ], feature ['label' ]
133130
134131SeqChromDatasetByBed = seqChromLoaderCurry (_SeqChromDatasetByBed )
135132
@@ -170,17 +167,18 @@ def _target_vlog(sample):
170167target_vlog = wds .pipelinefilter (_target_vlog )
171168
172169class SeqChromDataModule (LightningDataModule ):
173- def __init__ (self , train_wds , val_wds , test_wds , train_dataset_size :int = None , transform : list = None , num_workers = 8 , batch_size = 512 ):
170+ def __init__ (self , train_wds , val_wds , test_wds , train_dataset_size :int = None , transforms : dict = None , num_workers = 1 , batch_size = 512 , patch_last = True ):
174171 super ().__init__ ()
175172 self .num_workers = num_workers
176173 self .batch_size = batch_size
177174 self .train_dataset_size = train_dataset_size
175+ self .patch_last = patch_last
178176
179177 self .train_wds = train_wds
180178 self .val_wds = val_wds
181179 self .test_wds = test_wds
182180
183- self .transform = transform
181+ self .transforms = transforms
184182
185183 def prepare_data (self ):
186184 pass
@@ -210,7 +208,10 @@ def setup(self, stage=None):
210208 wds .tarfile_to_samples (),
211209 wds .shuffle (1000 , rng = random .Random (1 )),
212210 wds .decode (),
213- wds .to_tuple ("seq.npy" , "chrom.npy" , "target.npy" , "label.npy" ),
211+ wds .rename (seq = "seq.npy" ,
212+ chrom = "chrom.npy" ,
213+ target = "target.npy" ,
214+ label = "label.npy" )
214215 ]
215216
216217 val_pipeline = [
@@ -219,7 +220,10 @@ def setup(self, stage=None):
219220 wds .split_by_worker ,
220221 wds .tarfile_to_samples (),
221222 wds .decode (),
222- wds .to_tuple ("seq.npy" , "chrom.npy" , "target.npy" , "label.npy" ),
223+ wds .rename (seq = "seq.npy" ,
224+ chrom = "chrom.npy" ,
225+ target = "target.npy" ,
226+ label = "label.npy" )
223227 ]
224228
225229 test_pipeline = [
@@ -228,28 +232,37 @@ def setup(self, stage=None):
228232 wds .split_by_worker ,
229233 wds .tarfile_to_samples (),
230234 wds .decode (),
231- wds .to_tuple ("seq.npy" , "chrom.npy" , "target.npy" , "label.npy" ),
235+ wds .rename (seq = "seq.npy" ,
236+ chrom = "chrom.npy" ,
237+ target = "target.npy" ,
238+ label = "label.npy" )
232239 ]
233240
234- if self .transform is not None :
235- train_pipeline .extend ( self .transform )
236- val_pipeline .extend ( self .transform )
237- test_pipeline .extend ( self .transform )
241+ if self .transforms is not None :
242+ train_pipeline .append ( wds . map_dict ( ** self .transforms ) )
243+ val_pipeline .append ( wds . map_dict ( ** self .transforms ) )
244+ test_pipeline .append ( wds . map_dict ( ** self .transforms ) )
238245
239- self .train_loader = wds .DataPipeline (
240- * train_pipeline
241- )
246+ self .train_loader = wds .DataPipeline ([
247+ * train_pipeline ,
248+ wds .to_tuple ("seq" , "chrom" , "target" , "label" )
249+ ])
242250
243- self .val_loader = wds .DataPipeline (
244- * val_pipeline
245- )
251+ self .val_loader = wds .DataPipeline ([
252+ * val_pipeline ,
253+ wds .to_tuple ("seq" , "chrom" , "target" , "label" ),
254+ ])
246255
247- self .test_loader = wds .DataPipeline (
248- * test_pipeline
249- )
256+ self .test_loader = wds .DataPipeline ([
257+ * test_pipeline ,
258+ wds .to_tuple ("seq" , "chrom" , "target" , "label" ),
259+ ])
250260
251261 def train_dataloader (self ):
252- return wds .WebLoader (self .train_loader .repeat (2 ), num_workers = self .num_workers , batch_size = self .batch_size_per_rank ).with_epoch (ceil (self .train_dataset_size / self .batch_size )) # pad the last batch if there is remainder
262+ if self .patch_last :
263+ return wds .WebLoader (self .train_loader .repeat (2 ), num_workers = self .num_workers , batch_size = self .batch_size_per_rank ).with_epoch (ceil (self .train_dataset_size / self .batch_size )) # pad the last batch if there is remainder
264+ else :
265+ return wds .WebLoader (self .train_loader , num_workers = self .num_workers , batch_size = self .batch_size_per_rank )
253266
254267 def val_dataloader (self ):
255268 return wds .WebLoader (self .val_loader , num_workers = self .num_workers , batch_size = self .batch_size_per_rank )
0 commit comments