From 28662aae92f3b67beaf9db60f8577b0b53d54d68 Mon Sep 17 00:00:00 2001 From: Surdeg190 Date: Tue, 18 Mar 2025 11:10:28 +0100 Subject: [PATCH 1/7] changed hard coded variable 'cell_type' to self.target_key in cell_type_centroid_distances --- CELLULAR_CL/functions/train.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/CELLULAR_CL/functions/train.py b/CELLULAR_CL/functions/train.py index 50825fc..cb4b980 100644 --- a/CELLULAR_CL/functions/train.py +++ b/CELLULAR_CL/functions/train.py @@ -164,15 +164,15 @@ def cell_type_centroid_distances(self, n_components: int=100): # Step 2: Calculate centroids for each cell type cluster of each batch effect centroids = {} for batch_effect in adata.obs[self.batch_keys[0]].unique(): - for cell_type in adata.obs['cell_type'].unique(): - mask = (adata.obs[self.batch_keys[0]] == batch_effect) & (adata.obs['cell_type'] == cell_type) + for cell_type in adata.obs[self.target_key].unique(): + mask = (adata.obs[self.batch_keys[0]] == batch_effect) & (adata.obs[self.target_key] == cell_type) centroid = np.mean(adata_pca[mask], axis=0) centroids[(batch_effect, cell_type)] = centroid # Step 3: Calculate the average centroid distance between all batch effects - average_distance_matrix = np.zeros((len(adata.obs['cell_type'].unique()), len(adata.obs['cell_type'].unique()))) - for i, cell_type_i in enumerate(adata.obs['cell_type'].unique()): - for j, cell_type_j in enumerate(adata.obs['cell_type'].unique()): + average_distance_matrix = np.zeros((len(adata.obs[self.target_key].unique()), len(adata.obs[self.target_key].unique()))) + for i, cell_type_i in enumerate(adata.obs[self.target_key].unique()): + for j, cell_type_j in enumerate(adata.obs[self.target_key].unique()): distances = [] for batch_effect in adata.obs[self.batch_keys[0]].unique(): centroid_i = torch.tensor(centroids[(batch_effect, cell_type_i)], dtype=torch.float32, requires_grad=False) @@ -188,7 +188,7 @@ def cell_type_centroid_distances(self, n_components: int=100): average_distance_matrix[i, j] = average_distance # Convert average_distance_matrix into a DataFrame - average_distance_df = pd.DataFrame(average_distance_matrix, index=self.label_encoder.transform(adata.obs['cell_type'].unique()), columns=self.label_encoder.transform(adata.obs['cell_type'].unique())) + average_distance_df = pd.DataFrame(average_distance_matrix, index=self.label_encoder.transform(adata.obs[self.target_key].unique()), columns=self.label_encoder.transform(adata.obs[self.target_key].unique())) # Replace NaN values with 0 average_distance_df = average_distance_df.fillna(0) From 8490d2274c29803b0715f8585587a30fd4168ff6 Mon Sep 17 00:00:00 2001 From: Surdeg190 Date: Tue, 18 Mar 2025 11:13:25 +0100 Subject: [PATCH 2/7] added device when initializing loss module to enable runnning on CPU --- CELLULAR_CL/functions/train.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/CELLULAR_CL/functions/train.py b/CELLULAR_CL/functions/train.py index cb4b980..8a0d481 100644 --- a/CELLULAR_CL/functions/train.py +++ b/CELLULAR_CL/functions/train.py @@ -1151,7 +1151,8 @@ def rep_seed(seed): batch_keys=self.batch_keys, temperature=init_temperature, min_temperature=min_temperature, - max_temperature=max_temperature) + max_temperature=max_temperature, + device = device) # Define Adam optimer optimizer = optim.Adam([{'params': model_step_1.parameters(), 'lr': init_lr}, {'params': loss_module.parameters(), 'lr': init_lr}], weight_decay=5e-5) From 7492cf214f4e12b47d58f205113faf2d7a9caff9 Mon Sep 17 00:00:00 2001 From: Surdeg190 Date: Tue, 22 Apr 2025 13:23:43 +0200 Subject: [PATCH 3/7] printing except --- CELLULAR_CL/functions/train.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/CELLULAR_CL/functions/train.py b/CELLULAR_CL/functions/train.py index 8a0d481..df2838c 100644 --- a/CELLULAR_CL/functions/train.py +++ b/CELLULAR_CL/functions/train.py @@ -1021,8 +1021,9 @@ def train_model(self, # Update learning rate lr_scheduler.step() - except: + except Exception as e: print(f"**Training forced to finish early due to error during training**") + print(e) print() print(f"**Finished training**") From 58ffc5f0c2b6155830081554cff3ef764365c841 Mon Sep 17 00:00:00 2001 From: Surdeg190 Date: Tue, 22 Apr 2025 13:33:54 +0200 Subject: [PATCH 4/7] updated dependencies --- requirements.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index baf66f2..7ba0c1f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -34,7 +34,7 @@ mpmath==1.3.0 natsort==8.4.0 nest-asyncio==1.6.0 networkx==3.2.1 -numba==0.59.0 +numba==0.60.0 numpy==1.26.4 optuna==3.5.0 packaging==23.2 @@ -53,7 +53,7 @@ python-dateutil==2.9.0.post0 pytz==2024.1 PyYAML==6.0.1 pyzmq==25.1.2 -scanpy==1.9.8 +scanpy==1.10.0 scikit-learn==1.4.1.post1 scipy==1.12.0 seaborn==0.13.2 From 5ed34e3e80d38f05f82c2c3d50ec447a8a252d97 Mon Sep 17 00:00:00 2001 From: Surdeg190 Date: Tue, 22 Apr 2025 13:35:36 +0200 Subject: [PATCH 5/7] updated dependencies --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 7ba0c1f..967d24e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -25,7 +25,7 @@ joblib==1.3.2 jupyter_client==8.6.0 jupyter_core==5.7.1 kiwisolver==1.4.5 -llvmlite==0.42.0 +llvmlite==0.43.0dev0 Mako==1.3.2 MarkupSafe==2.1.5 matplotlib==3.8.3 From a22b0f3c48fbe04d3dc8dfa6643ba39c353f7a5d Mon Sep 17 00:00:00 2001 From: Surdeg190 Date: Tue, 22 Apr 2025 13:36:38 +0200 Subject: [PATCH 6/7] updated dependencies --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 967d24e..5a15fcb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -25,7 +25,7 @@ joblib==1.3.2 jupyter_client==8.6.0 jupyter_core==5.7.1 kiwisolver==1.4.5 -llvmlite==0.43.0dev0 +llvmlite==0.43.0 Mako==1.3.2 MarkupSafe==2.1.5 matplotlib==3.8.3 From 2c785de2c4256366b648bdce1e66ab4dcd3f6b86 Mon Sep 17 00:00:00 2001 From: Marcus Johansson <119714470+Surdeg190@users.noreply.github.com> Date: Wed, 23 Apr 2025 10:33:39 +0200 Subject: [PATCH 7/7] Added option to not use HVGs Set HVGs to adata.n_vars if HVG = False --- CELLULAR_CL/__init__.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/CELLULAR_CL/__init__.py b/CELLULAR_CL/__init__.py index eb46287..b4cfebf 100644 --- a/CELLULAR_CL/__init__.py +++ b/CELLULAR_CL/__init__.py @@ -15,6 +15,7 @@ def train(adata, target_key: str, batch_key: str, latent_dim: int=100, + HVG: bool=True, HVGs: int=2000, model_path: str="trained_models/", train_classifier: bool=False, @@ -60,7 +61,10 @@ def train(adata, latent_dim (int, optional) Dimension of latent space produced by CELLULAR. Default is 100. - + + HVG (bool, optional) + A boolean for whether to filter for highly variable genes or not + HVGs (int, optional) Number of highly variable genes (HVGs) to select as input to CELLULAR. Default is 2000. @@ -171,7 +175,11 @@ def train(adata, ------- None """ - + + # If not using HVGs set of HVGs to the number of genes + if not HVG: + HVGs = adata.n_vars + # Raise error if the number of HVGs is not possible to achieve if adata.n_vars < HVGs: raise ValueError('Number of genes in adata is less than number of HVGs specified to be used.') @@ -183,7 +191,7 @@ def train(adata, # Initiate training class train_env = trainer_fun.train_module(data_path=adata, save_model_path=model_path, - HVG=True, + HVG=HVG, HVGs=HVGs, target_key=target_key, batch_keys=[batch_key], @@ -684,4 +692,4 @@ def rep_seed(seed: int = 42): np.random.seed(seed) random.seed(seed) torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False \ No newline at end of file + torch.backends.cudnn.benchmark = False