Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions CELLULAR_CL/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def train(adata,
target_key: str,
batch_key: str,
latent_dim: int=100,
HVG: bool=True,
HVGs: int=2000,
model_path: str="trained_models/",
train_classifier: bool=False,
Expand Down Expand Up @@ -60,7 +61,10 @@ def train(adata,

latent_dim (int, optional)
Dimension of latent space produced by CELLULAR. Default is 100.


HVG (bool, optional)
A boolean for whether to filter for highly variable genes or not

HVGs (int, optional)
Number of highly variable genes (HVGs) to select as input to CELLULAR. Default is 2000.

Expand Down Expand Up @@ -171,7 +175,11 @@ def train(adata,
-------
None
"""


# If not using HVGs set of HVGs to the number of genes
if not HVG:
HVGs = adata.n_vars

# Raise error if the number of HVGs is not possible to achieve
if adata.n_vars < HVGs:
raise ValueError('Number of genes in adata is less than number of HVGs specified to be used.')
Expand All @@ -183,7 +191,7 @@ def train(adata,
# Initiate training class
train_env = trainer_fun.train_module(data_path=adata,
save_model_path=model_path,
HVG=True,
HVG=HVG,
HVGs=HVGs,
target_key=target_key,
batch_keys=[batch_key],
Expand Down Expand Up @@ -684,4 +692,4 @@ def rep_seed(seed: int = 42):
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.benchmark = False
18 changes: 10 additions & 8 deletions CELLULAR_CL/functions/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,15 +164,15 @@ def cell_type_centroid_distances(self, n_components: int=100):
# Step 2: Calculate centroids for each cell type cluster of each batch effect
centroids = {}
for batch_effect in adata.obs[self.batch_keys[0]].unique():
for cell_type in adata.obs['cell_type'].unique():
mask = (adata.obs[self.batch_keys[0]] == batch_effect) & (adata.obs['cell_type'] == cell_type)
for cell_type in adata.obs[self.target_key].unique():
mask = (adata.obs[self.batch_keys[0]] == batch_effect) & (adata.obs[self.target_key] == cell_type)
centroid = np.mean(adata_pca[mask], axis=0)
centroids[(batch_effect, cell_type)] = centroid

# Step 3: Calculate the average centroid distance between all batch effects
average_distance_matrix = np.zeros((len(adata.obs['cell_type'].unique()), len(adata.obs['cell_type'].unique())))
for i, cell_type_i in enumerate(adata.obs['cell_type'].unique()):
for j, cell_type_j in enumerate(adata.obs['cell_type'].unique()):
average_distance_matrix = np.zeros((len(adata.obs[self.target_key].unique()), len(adata.obs[self.target_key].unique())))
for i, cell_type_i in enumerate(adata.obs[self.target_key].unique()):
for j, cell_type_j in enumerate(adata.obs[self.target_key].unique()):
distances = []
for batch_effect in adata.obs[self.batch_keys[0]].unique():
centroid_i = torch.tensor(centroids[(batch_effect, cell_type_i)], dtype=torch.float32, requires_grad=False)
Expand All @@ -188,7 +188,7 @@ def cell_type_centroid_distances(self, n_components: int=100):
average_distance_matrix[i, j] = average_distance

# Convert average_distance_matrix into a DataFrame
average_distance_df = pd.DataFrame(average_distance_matrix, index=self.label_encoder.transform(adata.obs['cell_type'].unique()), columns=self.label_encoder.transform(adata.obs['cell_type'].unique()))
average_distance_df = pd.DataFrame(average_distance_matrix, index=self.label_encoder.transform(adata.obs[self.target_key].unique()), columns=self.label_encoder.transform(adata.obs[self.target_key].unique()))

# Replace NaN values with 0
average_distance_df = average_distance_df.fillna(0)
Expand Down Expand Up @@ -1021,8 +1021,9 @@ def train_model(self,

# Update learning rate
lr_scheduler.step()
except:
except Exception as e:
print(f"**Training forced to finish early due to error during training**")
print(e)

print()
print(f"**Finished training**")
Expand Down Expand Up @@ -1151,7 +1152,8 @@ def rep_seed(seed):
batch_keys=self.batch_keys,
temperature=init_temperature,
min_temperature=min_temperature,
max_temperature=max_temperature)
max_temperature=max_temperature,
device = device)

# Define Adam optimer
optimizer = optim.Adam([{'params': model_step_1.parameters(), 'lr': init_lr}, {'params': loss_module.parameters(), 'lr': init_lr}], weight_decay=5e-5)
Expand Down
6 changes: 3 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ joblib==1.3.2
jupyter_client==8.6.0
jupyter_core==5.7.1
kiwisolver==1.4.5
llvmlite==0.42.0
llvmlite==0.43.0
Mako==1.3.2
MarkupSafe==2.1.5
matplotlib==3.8.3
Expand All @@ -34,7 +34,7 @@ mpmath==1.3.0
natsort==8.4.0
nest-asyncio==1.6.0
networkx==3.2.1
numba==0.59.0
numba==0.60.0
numpy==1.26.4
optuna==3.5.0
packaging==23.2
Expand All @@ -53,7 +53,7 @@ python-dateutil==2.9.0.post0
pytz==2024.1
PyYAML==6.0.1
pyzmq==25.1.2
scanpy==1.9.8
scanpy==1.10.0
scikit-learn==1.4.1.post1
scipy==1.12.0
seaborn==0.13.2
Expand Down