1+
2+ description = """
3+ Lightning Data Module for model training
4+ Given bed file, return sequence and chromatin info
5+ """
6+
7+ import torch
8+ import random
9+ import pyfasta
10+ import pyBigWig
11+ import numpy as np
12+ import pandas as pd
13+ import webdataset as wds
14+ from math import sqrt , ceil
15+ from itertools import islice
16+ from torch .utils .data import Dataset , IterableDataset , DataLoader
17+ from pytorch_lightning import LightningDataModule
18+
19+ from seqchromloader import utils
20+
21+ def worker_init_fn (worker_id ):
22+ worker_info = torch .utils .data .get_worker_info ()
23+ dataset = worker_info .dataset
24+ dataset .initialize ()
25+
26+ class SeqChromLoader ():
27+ def __init__ (self , SeqChromDataset ):
28+ self .SeqChromDataset = SeqChromDataset
29+
30+ def __call__ (self , * args , batch_size = 512 , num_workers = 1 , shuffle = False , worker_init_fn = worker_init_fn , ** kwargs ):
31+ return DataLoader (self .SeqChromDataset (* args , ** kwargs ),
32+ batch_size = batch_size ,
33+ num_workers = num_workers ,
34+ shuffle = shuffle ,
35+ worker_init_fn = worker_init_fn )
36+
37+ def seqChromLoaderCurry (SeqChromDataset ):
38+
39+ return SeqChromLoader (SeqChromDataset )
40+
41+ class _SeqChromDatasetByWds (IterableDataset ):
42+ def __init__ (self , wds , seq_transform :list = None , chrom_transform :list = None , target_transform :list = None ):
43+ self .seq_transform = seq_transform
44+ self .chrom_transform = chrom_transform
45+ self .target_transform = target_transform
46+
47+ self .wds = wds
48+
49+ def initialize (self ):
50+ # this function will be called by worker_init_function in DataLoader
51+ pass
52+
53+ def __iter__ (self ):
54+ worker_info = torch .utils .data .get_worker_info ()
55+ if worker_info is None :
56+ ds = wds .DataPipeline (
57+ wds .SimpleShardList (self .wds ),
58+ wds .shuffle (100 , rng = random .Random (1 )),
59+ wds .tarfile_to_samples (),
60+ wds .shuffle (1000 , rng = random .Random (1 )),
61+ wds .decode (),
62+ wds .to_tuple ("seq.npy" , "chrom.npy" , "target.npy" , "label.npy" ),
63+ )
64+ else :
65+ ds = wds .DataPipeline (
66+ wds .SimpleShardList (self .wds ),
67+ wds .shuffle (100 , rng = random .Random (1 )),
68+ wds .split_by_worker ,
69+ wds .tarfile_to_samples (),
70+ wds .shuffle (1000 , rng = random .Random (1 )),
71+ wds .decode (),
72+ wds .to_tuple ("seq.npy" , "chrom.npy" , "target.npy" , "label.npy" ),
73+ )
74+
75+ return iter (ds )
76+
77+ SeqChromDatasetByWds = seqChromLoaderCurry (_SeqChromDatasetByWds )
78+
79+ class _SeqChromDatasetByBed (Dataset ):
80+ def __init__ (self , bed , fasta , bigwig_files , seq_transform :list = None , chrom_transform :list = None , target_transform :list = None ):
81+ self .bed = pd .read_table (bed , header = None , names = ['chrom' , 'start' , 'end' , 'name' , 'score' , 'strand' ])
82+
83+ self .fasta = fasta
84+ self .bigwig_files = bigwig_files
85+ self .seq_transform = [utils .DNA2OneHot ()] + seq_transform # prepend default DNA one hot coding transform
86+ self .chrom_transfrom = chrom_transform
87+ self .target_transform = target_transform
88+
89+ def initialize (self ):
90+ # this function will be called by worker_init_function in DataLoader
91+ self .genome_pyfasta = pyfasta .Fasta (self .config ["train_bichrom" ]["fasta" ])
92+ #self.tfbam = pysam.AlignmentFile(self.config["train_bichrom"]["tf_bam"])
93+ self .bigwigs = [pyBigWig .open (bw ) for bw in self .bigwig_files ]
94+
95+ def __len__ (self ):
96+ return len (self .bed )
97+
98+ def __getitem__ (self , idx ):
99+ entry = self .bed .iloc [idx ,]
100+ # get info in the each entry region
101+ ## sequence
102+ sequence = self .genome_pyfasta [entry .chrom ][int (entry .start ):int (entry .end )]
103+ sequence = self .rev_comp (sequence ) if entry .strand == "-" else sequence
104+ ## chromatin
105+ ms = []
106+ try :
107+ for idx , bigwig in enumerate (self .bigwigs ):
108+ m = (np .nan_to_num (bigwig .values (entry .chrom , entry .start , entry .end ))).astype (np .float32 )
109+ if entry .strand == "-" : m = m [::- 1 ] # reverse if needed
110+ if self .scaler_mean and self .scaler_var :
111+ m = (m - self .scaler_mean [idx ])/ sqrt (self .scaler_var [idx ])
112+ ms .append (m )
113+ except RuntimeError as e :
114+ print (e )
115+ raise Exception (f"Failed to extract chromatin { self .bigwig_files [idx ]} information in region { entry } " )
116+ ms = np .vstack (ms )
117+ ## target: read count in region
118+ #target = self.tfbam.count(entry.chrom, entry.start, entry.end)
119+
120+ # transform
121+ if self .seq_transform :
122+ seq = [t (sequence ) for t in self .seq_transform ]
123+ if self .chrom_transfrom :
124+ ms = [t (ms ) for t in self .chrom_transfrom ]
125+
126+ return seq , ms
127+
128+ def rev_comp (self , inp_str ):
129+ rc_dict = {'A' : 'T' , 'G' : 'C' , 'T' : 'A' , 'C' : 'G' , 'c' : 'g' ,
130+ 'g' : 'c' , 't' : 'a' , 'a' : 't' , 'n' : 'n' , 'N' : 'N' }
131+ outp_str = list ()
132+ for nucl in inp_str :
133+ outp_str .append (rc_dict [nucl ])
134+ return '' .join (outp_str )[::- 1 ]
135+
136+ SeqChromDatasetByBed = seqChromLoaderCurry (_SeqChromDatasetByBed )
137+
138+ def count_lines (fp ):
139+ with open (fp , 'r' ) as f :
140+ for count , line in enumerate (f ):
141+ pass
142+ return count + 1
143+
144+ def _split_by_node (src , global_rank , world_size ):
145+ if world_size > 1 :
146+ for s in islice (src , global_rank , None , world_size ):
147+ yield s
148+ else :
149+ for s in src :
150+ yield s
151+
152+ split_by_node = wds .pipelinefilter (_split_by_node )
153+
154+ def _scale_chrom (sample , scaler_mean , scaler_std ):
155+ # standardize chrom by provided mean and std
156+ seq , chrom , target , label = sample
157+
158+ chrom = np .divide (chrom - scaler_mean , scaler_std , dtype = np .float32 )
159+
160+ return seq , chrom , target , label
161+
162+ scale_chrom = wds .pipelinefilter (_scale_chrom )
163+
164+ def _target_vlog (sample ):
165+ # take log(n+1) on target
166+ seq , chrom , target , label = sample
167+
168+ target = np .log (target + 1 , dtype = np .float32 )
169+
170+ return seq , chrom , target , label
171+
172+ target_vlog = wds .pipelinefilter (_target_vlog )
173+
174+ class SeqChromDataModule (LightningDataModule ):
175+ def __init__ (self , train_wds , val_wds , test_wds , train_dataset_size :int = None , transform :list = None , num_workers = 8 , batch_size = 512 ):
176+ super ().__init__ ()
177+ self .num_workers = num_workers
178+ self .batch_size = batch_size
179+ self .train_dataset_size = train_dataset_size
180+
181+ self .train_wds = train_wds
182+ self .val_wds = val_wds
183+ self .test_wds = test_wds
184+
185+ self .transform = transform
186+
187+ def prepare_data (self ):
188+ pass
189+
190+ def setup (self , stage = None ):
191+ try :
192+ device_id = self .trainer .device_ids [self .trainer .local_rank ]
193+
194+ global_rank = self .trainer .global_rank
195+ world_size = self .trainer .world_size
196+ print (f"device id { device_id } , local rank { self .trainer .local_rank } , global rank { self .trainer .global_rank } in world { world_size } " )
197+ except AttributeError :
198+ print (f"Error when trying to fetch device and rank info" )
199+ print (f"Assume dataset is being setup without a trainer, set device id as 0, global rank as 0, world size as 1" )
200+ device_id = 0
201+ global_rank = 0
202+ world_size = 1
203+
204+ self .batch_size_per_rank = int (self .batch_size / world_size )
205+
206+ if stage in ["fit" , "validate" , "test" ] or stage is None :
207+
208+ self .train_loader = wds .DataPipeline (
209+ wds .SimpleShardList (self .train_wds ),
210+ wds .shuffle (100 , rng = random .Random (1 )),
211+ split_by_node (global_rank , world_size ),
212+ wds .split_by_worker ,
213+ wds .tarfile_to_samples (),
214+ wds .shuffle (1000 , rng = random .Random (1 )),
215+ wds .decode (),
216+ wds .to_tuple ("seq.npy" , "chrom.npy" , "target.npy" , "label.npy" ),
217+ )
218+
219+ self .val_loader = wds .DataPipeline (
220+ wds .SimpleShardList (self .val_wds ),
221+ split_by_node (global_rank , world_size ),
222+ wds .split_by_worker ,
223+ wds .tarfile_to_samples (),
224+ wds .decode (),
225+ wds .to_tuple ("seq.npy" , "chrom.npy" , "target.npy" , "label.npy" ),
226+ )
227+
228+ self .test_loader = wds .DataPipeline (
229+ wds .SimpleShardList (self .test_wds ),
230+ split_by_node (global_rank , world_size ),
231+ wds .split_by_worker ,
232+ wds .tarfile_to_samples (),
233+ wds .decode (),
234+ wds .to_tuple ("seq.npy" , "chrom.npy" , "target.npy" , "label.npy" ),
235+ )
236+
237+ def train_dataloader (self ):
238+ return wds .WebLoader (self .train_loader .repeat (2 ), num_workers = self .num_workers , batch_size = self .batch_size_per_rank ).with_epoch (ceil (self .train_dataset_size / self .batch_size )) # pad the last batch if there is remainder
239+
240+ def val_dataloader (self ):
241+ return wds .WebLoader (self .val_loader , num_workers = self .num_workers , batch_size = self .batch_size_per_rank )
242+
243+ def test_dataloader (self ):
244+ return wds .WebLoader (self .test_loader , num_workers = self .num_workers , batch_size = self .batch_size_per_rank )
0 commit comments