Skip to content

Commit 069aa8a

Browse files
committed
Initial commit
1 parent 0bb5670 commit 069aa8a

10 files changed

Lines changed: 851 additions & 1 deletion

File tree

README.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,8 @@
11
# seqchromloader
2-
Provide ready-to-use dataloader for deep learning models
2+
3+
seqchromloader aims to provide versatile and ready-to-use writer/loader for applying deep learning to bioinformatics study.
4+
5+
Plan to support dataset formats including:
6+
- pytorch dataloader (done)
7+
- webdataset (done)
8+
- tfrecord (x)

seqchromloader/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .loader import SeqChromDatasetByBed, SeqChromDatasetByWds, SeqChromDataModule
2+
from .writer import get_data_webdataset
312 Bytes
Binary file not shown.
8.44 KB
Binary file not shown.
10 KB
Binary file not shown.
6.14 KB
Binary file not shown.

seqchromloader/loader.py

Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
2+
description = """
3+
Lightning Data Module for model training
4+
Given bed file, return sequence and chromatin info
5+
"""
6+
7+
import torch
8+
import random
9+
import pyfasta
10+
import pyBigWig
11+
import numpy as np
12+
import pandas as pd
13+
import webdataset as wds
14+
from math import sqrt, ceil
15+
from itertools import islice
16+
from torch.utils.data import Dataset, IterableDataset, DataLoader
17+
from pytorch_lightning import LightningDataModule
18+
19+
from seqchromloader import utils
20+
21+
def worker_init_fn(worker_id):
22+
worker_info = torch.utils.data.get_worker_info()
23+
dataset = worker_info.dataset
24+
dataset.initialize()
25+
26+
class SeqChromLoader():
27+
def __init__(self, SeqChromDataset):
28+
self.SeqChromDataset = SeqChromDataset
29+
30+
def __call__(self, *args, batch_size=512, num_workers=1, shuffle=False, worker_init_fn=worker_init_fn, **kwargs):
31+
return DataLoader(self.SeqChromDataset(*args, **kwargs),
32+
batch_size=batch_size,
33+
num_workers=num_workers,
34+
shuffle=shuffle,
35+
worker_init_fn=worker_init_fn)
36+
37+
def seqChromLoaderCurry(SeqChromDataset):
38+
39+
return SeqChromLoader(SeqChromDataset)
40+
41+
class _SeqChromDatasetByWds(IterableDataset):
42+
def __init__(self, wds, seq_transform:list=None, chrom_transform:list=None, target_transform:list=None):
43+
self.seq_transform = seq_transform
44+
self.chrom_transform = chrom_transform
45+
self.target_transform = target_transform
46+
47+
self.wds = wds
48+
49+
def initialize(self):
50+
# this function will be called by worker_init_function in DataLoader
51+
pass
52+
53+
def __iter__(self):
54+
worker_info = torch.utils.data.get_worker_info()
55+
if worker_info is None:
56+
ds = wds.DataPipeline(
57+
wds.SimpleShardList(self.wds),
58+
wds.shuffle(100, rng=random.Random(1)),
59+
wds.tarfile_to_samples(),
60+
wds.shuffle(1000, rng=random.Random(1)),
61+
wds.decode(),
62+
wds.to_tuple("seq.npy", "chrom.npy", "target.npy", "label.npy"),
63+
)
64+
else:
65+
ds = wds.DataPipeline(
66+
wds.SimpleShardList(self.wds),
67+
wds.shuffle(100, rng=random.Random(1)),
68+
wds.split_by_worker,
69+
wds.tarfile_to_samples(),
70+
wds.shuffle(1000, rng=random.Random(1)),
71+
wds.decode(),
72+
wds.to_tuple("seq.npy", "chrom.npy", "target.npy", "label.npy"),
73+
)
74+
75+
return iter(ds)
76+
77+
SeqChromDatasetByWds = seqChromLoaderCurry(_SeqChromDatasetByWds)
78+
79+
class _SeqChromDatasetByBed(Dataset):
80+
def __init__(self, bed, fasta, bigwig_files, seq_transform:list=None, chrom_transform:list=None, target_transform:list=None):
81+
self.bed = pd.read_table(bed, header=None, names=['chrom', 'start', 'end', 'name', 'score', 'strand' ])
82+
83+
self.fasta = fasta
84+
self.bigwig_files = bigwig_files
85+
self.seq_transform = [utils.DNA2OneHot()] + seq_transform # prepend default DNA one hot coding transform
86+
self.chrom_transfrom = chrom_transform
87+
self.target_transform = target_transform
88+
89+
def initialize(self):
90+
# this function will be called by worker_init_function in DataLoader
91+
self.genome_pyfasta = pyfasta.Fasta(self.config["train_bichrom"]["fasta"])
92+
#self.tfbam = pysam.AlignmentFile(self.config["train_bichrom"]["tf_bam"])
93+
self.bigwigs = [pyBigWig.open(bw) for bw in self.bigwig_files]
94+
95+
def __len__(self):
96+
return len(self.bed)
97+
98+
def __getitem__(self, idx):
99+
entry = self.bed.iloc[idx,]
100+
# get info in the each entry region
101+
## sequence
102+
sequence = self.genome_pyfasta[entry.chrom][int(entry.start):int(entry.end)]
103+
sequence = self.rev_comp(sequence) if entry.strand=="-" else sequence
104+
## chromatin
105+
ms = []
106+
try:
107+
for idx, bigwig in enumerate(self.bigwigs):
108+
m = (np.nan_to_num(bigwig.values(entry.chrom, entry.start, entry.end))).astype(np.float32)
109+
if entry.strand == "-": m = m[::-1] # reverse if needed
110+
if self.scaler_mean and self.scaler_var:
111+
m = (m - self.scaler_mean[idx])/sqrt(self.scaler_var[idx])
112+
ms.append(m)
113+
except RuntimeError as e:
114+
print(e)
115+
raise Exception(f"Failed to extract chromatin {self.bigwig_files[idx]} information in region {entry}")
116+
ms = np.vstack(ms)
117+
## target: read count in region
118+
#target = self.tfbam.count(entry.chrom, entry.start, entry.end)
119+
120+
# transform
121+
if self.seq_transform:
122+
seq = [t(sequence) for t in self.seq_transform]
123+
if self.chrom_transfrom:
124+
ms = [t(ms) for t in self.chrom_transfrom]
125+
126+
return seq, ms
127+
128+
def rev_comp(self, inp_str):
129+
rc_dict = {'A': 'T', 'G': 'C', 'T': 'A', 'C': 'G', 'c': 'g',
130+
'g': 'c', 't': 'a', 'a': 't', 'n': 'n', 'N': 'N'}
131+
outp_str = list()
132+
for nucl in inp_str:
133+
outp_str.append(rc_dict[nucl])
134+
return ''.join(outp_str)[::-1]
135+
136+
SeqChromDatasetByBed = seqChromLoaderCurry(_SeqChromDatasetByBed)
137+
138+
def count_lines(fp):
139+
with open(fp, 'r') as f:
140+
for count, line in enumerate(f):
141+
pass
142+
return count+1
143+
144+
def _split_by_node(src, global_rank, world_size):
145+
if world_size > 1:
146+
for s in islice(src, global_rank, None, world_size):
147+
yield s
148+
else:
149+
for s in src:
150+
yield s
151+
152+
split_by_node = wds.pipelinefilter(_split_by_node)
153+
154+
def _scale_chrom(sample, scaler_mean, scaler_std):
155+
# standardize chrom by provided mean and std
156+
seq, chrom, target, label = sample
157+
158+
chrom = np.divide(chrom - scaler_mean, scaler_std, dtype=np.float32)
159+
160+
return seq, chrom, target, label
161+
162+
scale_chrom = wds.pipelinefilter(_scale_chrom)
163+
164+
def _target_vlog(sample):
165+
# take log(n+1) on target
166+
seq, chrom, target, label = sample
167+
168+
target = np.log(target + 1, dtype=np.float32)
169+
170+
return seq, chrom, target, label
171+
172+
target_vlog = wds.pipelinefilter(_target_vlog)
173+
174+
class SeqChromDataModule(LightningDataModule):
175+
def __init__(self, train_wds, val_wds, test_wds, train_dataset_size:int=None, transform:list=None, num_workers=8, batch_size=512):
176+
super().__init__()
177+
self.num_workers = num_workers
178+
self.batch_size = batch_size
179+
self.train_dataset_size = train_dataset_size
180+
181+
self.train_wds = train_wds
182+
self.val_wds = val_wds
183+
self.test_wds = test_wds
184+
185+
self.transform = transform
186+
187+
def prepare_data(self):
188+
pass
189+
190+
def setup(self, stage=None):
191+
try:
192+
device_id = self.trainer.device_ids[self.trainer.local_rank]
193+
194+
global_rank = self.trainer.global_rank
195+
world_size = self.trainer.world_size
196+
print(f"device id {device_id}, local rank {self.trainer.local_rank}, global rank {self.trainer.global_rank} in world {world_size}")
197+
except AttributeError:
198+
print(f"Error when trying to fetch device and rank info")
199+
print(f"Assume dataset is being setup without a trainer, set device id as 0, global rank as 0, world size as 1")
200+
device_id = 0
201+
global_rank = 0
202+
world_size = 1
203+
204+
self.batch_size_per_rank = int(self.batch_size/world_size)
205+
206+
if stage in ["fit", "validate", "test"] or stage is None:
207+
208+
self.train_loader = wds.DataPipeline(
209+
wds.SimpleShardList(self.train_wds),
210+
wds.shuffle(100, rng=random.Random(1)),
211+
split_by_node(global_rank, world_size),
212+
wds.split_by_worker,
213+
wds.tarfile_to_samples(),
214+
wds.shuffle(1000, rng=random.Random(1)),
215+
wds.decode(),
216+
wds.to_tuple("seq.npy", "chrom.npy", "target.npy", "label.npy"),
217+
)
218+
219+
self.val_loader = wds.DataPipeline(
220+
wds.SimpleShardList(self.val_wds),
221+
split_by_node(global_rank, world_size),
222+
wds.split_by_worker,
223+
wds.tarfile_to_samples(),
224+
wds.decode(),
225+
wds.to_tuple("seq.npy", "chrom.npy", "target.npy", "label.npy"),
226+
)
227+
228+
self.test_loader = wds.DataPipeline(
229+
wds.SimpleShardList(self.test_wds),
230+
split_by_node(global_rank, world_size),
231+
wds.split_by_worker,
232+
wds.tarfile_to_samples(),
233+
wds.decode(),
234+
wds.to_tuple("seq.npy", "chrom.npy", "target.npy", "label.npy"),
235+
)
236+
237+
def train_dataloader(self):
238+
return wds.WebLoader(self.train_loader.repeat(2), num_workers=self.num_workers, batch_size=self.batch_size_per_rank).with_epoch(ceil(self.train_dataset_size/self.batch_size)) # pad the last batch if there is remainder
239+
240+
def val_dataloader(self):
241+
return wds.WebLoader(self.val_loader, num_workers=self.num_workers, batch_size=self.batch_size_per_rank)
242+
243+
def test_dataloader(self):
244+
return wds.WebLoader(self.test_loader, num_workers=self.num_workers, batch_size=self.batch_size_per_rank)

0 commit comments

Comments
 (0)