Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions config/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ default: `50000`
Useful for low resource utilization. This will ensure all data is stored in multiple chunks of almost `sample_chunksize` samples. This does not hamper any logic in algorithms but simply ensures that the entire dataset is never loaded all at once on the RAM.
`null` value will disregard this optimization.

**num_workers** {int}: `int | null`
default: `1`
This param uses multiple workers in parallel to speed up the data writing to disk. Please use this
with careful consideration of the number of cores available in the device. *Note that this doesn't increase memory usage of pipeline*. Ideal increment found at `num_workers = 3`.

**train_val_test** {dict}:
This section splits the data using the mentioned splitting technique mentioned in `splitter_config` & required params like `split_ratio` and `stratify` options. Example below.

Expand Down
2 changes: 2 additions & 0 deletions config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ experiment:
# DATA CONFIG.
data:
sample_chunksize: 20000
num_workers: 1

train_val_test:
full_datapath: '/path/to/anndata.h5ad'
Expand Down Expand Up @@ -42,6 +43,7 @@ feature_selection:

# score_matrix: '/path/to/matrix'
feature_subsetsize: 5000
num_workers: 1

model:
name: SequentialModel
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
anndata==0.10.9
isort==5.13.2
loky==3.4.1
memory-profiler==0.61.0
pillow==10.4.0
pre_commit==4.0.1
Expand Down
11 changes: 8 additions & 3 deletions scalr/data/preprocess/_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,19 @@ def fit(
"""
pass

def process_data(self, full_data: Union[AnnData, AnnCollection],
sample_chunksize: int, dirpath: str):
def process_data(self,
full_data: Union[AnnData, AnnCollection],
sample_chunksize: int,
dirpath: str,
num_workers: int = 1):
"""A function to process the entire data chunkwise and write the processed data
to disk.

Args:
full_data (Union[AnnData, AnnCollection]): Full data for transformation.
sample_chunksize (int): Number of samples in one chunk.
dirpath (str): Path to write the data to.
num_workers (int): number of jobs to run in parallel for data writing.
"""
if not sample_chunksize:
# TODO
Expand All @@ -68,7 +72,8 @@ def process_data(self, full_data: Union[AnnData, AnnCollection],
write_chunkwise_data(full_data,
sample_chunksize,
dirpath,
transform=self.transform)
transform=self.transform,
num_workers=num_workers)


def build_preprocessor(
Expand Down
5 changes: 0 additions & 5 deletions scalr/data/preprocess/sample_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import numpy as np

from scalr.data.preprocess import PreprocessorBase
from scalr.utils import EventLogger


class SampleNorm(PreprocessorBase):
Expand All @@ -20,8 +19,6 @@ def __init__(self, scaling_factor: float = 1.0):

self.scaling_factor = scaling_factor

self.event_logger = EventLogger('Sample norm normalization')

def transform(self, data: np.ndarray) -> np.ndarray:
"""A function to transform provided input data.

Expand All @@ -31,8 +28,6 @@ def transform(self, data: np.ndarray) -> np.ndarray:
Returns:
np.ndarray: Processed data.
"""
self.event_logger.info('\Transforming data using sample norm.')

data *= (self.scaling_factor / (data.sum(axis=1).reshape(len(data), 1)))
return data

Expand Down
18 changes: 0 additions & 18 deletions scalr/data/preprocess/standard_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import numpy as np

from scalr.data.preprocess import PreprocessorBase
from scalr.utils import EventLogger


class StandardScaler(PreprocessorBase):
Expand All @@ -28,8 +27,6 @@ def __init__(self, with_mean: bool = True, with_std: bool = True):
self.train_mean = None
self.train_std = None

self.event_logger = EventLogger('Standard Scaler Normalization')

def transform(self, data: np.ndarray) -> np.ndarray:
"""A function to transform provided input data.

Expand All @@ -39,9 +36,6 @@ def transform(self, data: np.ndarray) -> np.ndarray:
Returns:
np.ndarray: processed data
"""
self.event_logger.info(
'\Transforming data using standard scaler object')

if not self.with_mean:
train_mean = np.zeros((1, data.shape[1]))
else:
Expand All @@ -58,9 +52,6 @@ def fit(self, data: Union[AnnData, AnnCollection],

"""

self.event_logger.info('\n\nStarting standardscaler normalization')
self.event_logger.info('\nFitting standard scaler object on train data')

self.calculate_mean(data, sample_chunksize)
self.calculate_std(data, sample_chunksize)

Expand All @@ -76,7 +67,6 @@ def calculate_mean(self, data: Union[AnnData, AnnCollection],
Nothing, stores mean per feature of the train data.
"""

self.event_logger.info('Calculating mean of data...')
train_sum = np.zeros(data.shape[1]).reshape(1, -1)

# Iterate through batches of data to get mean statistics
Expand All @@ -85,11 +75,6 @@ def calculate_mean(self, data: Union[AnnData, AnnCollection],
sample_chunksize].X.sum(axis=0)
self.train_mean = train_sum / data.shape[0]

if not self.with_mean:
self.event_logger.info(
'`train_mean` will be set to zero during `transform()`, as `with_mean` is set to False!'
)

def calculate_std(self, data: Union[AnnData, AnnCollection],
sample_chunksize: int) -> None:
"""A function to calculate standard deviation for each feature in the train data.
Expand All @@ -104,7 +89,6 @@ def calculate_std(self, data: Union[AnnData, AnnCollection],

# Getting standard deviation of entire train data per feature.
if self.with_std:
self.event_logger.info('Calculating standard deviation of data...')
self.train_std = np.zeros(data.shape[1]).reshape(1, -1)
# Iterate through batches of data to get std statistics
for i in range(int(np.ceil(data.shape[0] / sample_chunksize))):
Expand All @@ -119,8 +103,6 @@ def calculate_std(self, data: Union[AnnData, AnnCollection],
self.train_std[self.train_std == 0] = 1
else:
# If `with_std` is False, set train_std to 1.
self.event_logger.info(
'Setting `train_std` to be 1, as `with_std` is set to False!')
self.train_std = np.ones((1, data.shape[1]))

@classmethod
Expand Down
20 changes: 12 additions & 8 deletions scalr/data/split/_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,27 +92,31 @@ def check_splits(self, datapath: str, data_splits: dict, target: str):
self.event_logger.info(
f'{metadata[target].iloc[test_inds].value_counts()}\n')

def write_splits(self, full_data: Union[AnnData, AnnCollection],
data_split_indices: dict, sample_chunksize: int,
dirpath: int):
def write_splits(self,
full_data: Union[AnnData, AnnCollection],
data_split_indices: dict,
sample_chunksize: int,
dirpath: int,
num_workers: int = None):
"""THis function writes the train validation and test splits to the disk.

Args:
full_data (Union[AnnData, AnnCollection]): Full data to be split.
data_split_indices (dict): Indices of each split.
sample_chunksize (int): Number of samples to be written in one file.
dirpath (int): Path to write data into.

Returns:
dict: Path of each split.
num_workers (int): number of jobs to run in parallel for data writing.
"""

for split in data_split_indices.keys():
if sample_chunksize:
split_dirpath = path.join(dirpath, split)
os.makedirs(split_dirpath, exist_ok=True)
write_chunkwise_data(full_data, sample_chunksize, split_dirpath,
data_split_indices[split])
write_chunkwise_data(full_data,
sample_chunksize,
split_dirpath,
data_split_indices[split],
num_workers=num_workers)
else:
filepath = path.join(dirpath, f'{split}.h5ad')
write_data(full_data[data_split_indices[split]].to_memory(),
Expand Down
7 changes: 5 additions & 2 deletions scalr/data_ingestion_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __init__(self, data_config: dict, dirpath: str = '.'):
self.data_config = deepcopy(data_config)
self.target = self.data_config.get('target')
self.sample_chunksize = self.data_config.get('sample_chunksize')
self.num_workers = self.data_config.get('num_workers', 1)

# Make some necessary checks and logs.
if not self.target:
Expand Down Expand Up @@ -99,7 +100,8 @@ def generate_train_val_test_split(self):

splitter.write_splits(self.full_data, train_val_test_split_indices,
self.sample_chunksize,
train_val_test_split_dirpath)
train_val_test_split_dirpath,
self.num_workers)

# Garbage collection
del self.full_data
Expand Down Expand Up @@ -146,7 +148,8 @@ def preprocess_data(self):
for split in ['train', 'val', 'test']:
split_data = read_data(path.join(datapath, split))
preprocessor.process_data(split_data, self.sample_chunksize,
path.join(processed_datapath, split))
path.join(processed_datapath, split),
self.num_workers)

datapath = processed_datapath

Expand Down
100 changes: 86 additions & 14 deletions scalr/feature/feature_subsetting.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,15 @@

from anndata import AnnData
from anndata.experimental import AnnCollection
from joblib import delayed
from joblib import Parallel
from torch import nn

from scalr.model_training_pipeline import ModelTrainingPipeline
from scalr.utils import EventLogger
from scalr.utils import FlowLogger
from scalr.utils import read_data
from scalr.utils import write_chunkwise_data


class FeatureSubsetting:
Expand All @@ -30,7 +34,9 @@ def __init__(self,
target: str,
mappings: dict,
dirpath: str = None,
device: str = 'cpu'):
device: str = 'cpu',
num_workers: int = 1,
sample_chunksize: int = None):
"""Initialize required parameters for feature subset training.

Args:
Expand All @@ -43,9 +49,12 @@ def __init__(self,
mappings (dict): mapping of target to labels.
dirpath (str, optional): Dirpath to store chunked model weights. Defaults to None.
device (str, optional): Device to train models on. Defaults to 'cpu'.
num_workers (int, optional): Number of parallel processes to launch to train multiple
feature subsets simultaneously. Defaults to using single
process.
sample_chunksize (int, optional): Chunks of samples to be loaded in memory at once.
Required when `num_workers` > 1.
"""
self.event_logger = EventLogger('FeatureSubsetting')

self.feature_subsetsize = feature_subsetsize
self.chunk_model_config = chunk_model_config
self.chunk_model_train_config = chunk_model_train_config
Expand All @@ -55,30 +64,83 @@ def __init__(self,
self.mappings = mappings
self.dirpath = dirpath
self.device = device
self.num_workers = num_workers if num_workers else 1
self.sample_chunksize = sample_chunksize

self.total_features = len(self.train_data.var_names)

# Note that EventLogger does not work with parallel training
# You may use tensorboard logging to track model training logs
if self.num_workers == 1:
self.event_logger = EventLogger('FeatureSubsetting')

def write_feature_subsetted_data(self):
"""Write chunks of feature-subsetted data, to enable parallel training of models
using different chunks of data."""
if self.num_workers == 1:
return

self.feature_chunked_data_dirpath = path.join(self.dirpath,
'chunked_data')
os.makedirs(self.feature_chunked_data_dirpath, exist_ok=True)

i = 0
for start in range(0, self.total_features, self.feature_subsetsize):

feature_subset_inds = list(
range(start,
min(start + self.feature_subsetsize,
self.total_features)))

write_chunkwise_data(self.train_data,
self.sample_chunksize,
path.join(self.feature_chunked_data_dirpath,
'train', str(i)),
feature_inds=feature_subset_inds,
num_workers=self.num_workers)

write_chunkwise_data(self.val_data,
self.sample_chunksize,
path.join(self.feature_chunked_data_dirpath,
'val', str(i)),
feature_inds=feature_subset_inds,
num_workers=self.num_workers)

i += 1

del self.train_data
del self.val_data

def train_chunked_models(self) -> list[nn.Module]:
"""Trains a model for each subset data.

Returns:
list[nn.Module]: List of models for each subset.
"""
self.event_logger.info('Feature subset models training')
models = []
if self.num_workers == 1:
self.event_logger.info('Feature subset models training')

chunked_models_dirpath = path.join(self.dirpath, 'chunked_models')
os.makedirs(chunked_models_dirpath, exist_ok=True)

i = 0
for start in range(0, len(self.train_data.var_names),
self.feature_subsetsize):
self.event_logger.info(f'\nChunk {i}')
def train_chunked_model(i, start):
if self.num_workers == 1:
self.event_logger.info(f'\nChunk {i}')

chunk_dirpath = path.join(chunked_models_dirpath, str(i))
os.makedirs(chunk_dirpath, exist_ok=True)
i += 1

train_features_subset = self.train_data[:, start:start +
if self.num_workers > 1:
train_features_subset = read_data(
path.join(self.feature_chunked_data_dirpath, 'train',
str(i)))
val_features_subset = read_data(
path.join(self.feature_chunked_data_dirpath, 'val', str(i)))
else:
train_features_subset = self.train_data[:, start:start +
self.feature_subsetsize]
val_features_subset = self.val_data[:, start:start +
self.feature_subsetsize]
val_features_subset = self.val_data[:, start:start +
self.feature_subsetsize]

chunk_model_config = deepcopy(self.chunk_model_config)

Expand All @@ -89,14 +151,24 @@ def train_chunked_models(self) -> list[nn.Module]:
model_trainer.set_data_and_targets(train_features_subset,
val_features_subset, self.target,
self.mappings)

model_trainer.build_model_training_artifacts()
best_model = model_trainer.train()

self.chunk_model_config, self.chunk_model_train_config = model_trainer.get_updated_config(
)

models.append(best_model)
return i, best_model

parallel = Parallel(n_jobs=self.num_workers)
models = parallel(
delayed(train_chunked_model)(i, start) for i, (start) in enumerate(
range(0, self.total_features, self.feature_subsetsize)))

# parallel loop returns all models with the chunk number, which is used to sort models in order
# model[1] returns only the model, without the chunk number
models = sorted(models)
models = [model[1] for model in models]
return models

def get_updated_configs(self):
Expand Down
Loading
Loading