diff --git a/CELLULAR_CL/__init__.py b/CELLULAR_CL/__init__.py index eb46287..b4cfebf 100644 --- a/CELLULAR_CL/__init__.py +++ b/CELLULAR_CL/__init__.py @@ -15,6 +15,7 @@ def train(adata, target_key: str, batch_key: str, latent_dim: int=100, + HVG: bool=True, HVGs: int=2000, model_path: str="trained_models/", train_classifier: bool=False, @@ -60,7 +61,10 @@ def train(adata, latent_dim (int, optional) Dimension of latent space produced by CELLULAR. Default is 100. - + + HVG (bool, optional) + A boolean for whether to filter for highly variable genes or not + HVGs (int, optional) Number of highly variable genes (HVGs) to select as input to CELLULAR. Default is 2000. @@ -171,7 +175,11 @@ def train(adata, ------- None """ - + + # If not using HVGs set of HVGs to the number of genes + if not HVG: + HVGs = adata.n_vars + # Raise error if the number of HVGs is not possible to achieve if adata.n_vars < HVGs: raise ValueError('Number of genes in adata is less than number of HVGs specified to be used.') @@ -183,7 +191,7 @@ def train(adata, # Initiate training class train_env = trainer_fun.train_module(data_path=adata, save_model_path=model_path, - HVG=True, + HVG=HVG, HVGs=HVGs, target_key=target_key, batch_keys=[batch_key], @@ -684,4 +692,4 @@ def rep_seed(seed: int = 42): np.random.seed(seed) random.seed(seed) torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False \ No newline at end of file + torch.backends.cudnn.benchmark = False diff --git a/CELLULAR_CL/functions/train.py b/CELLULAR_CL/functions/train.py index 50825fc..df2838c 100644 --- a/CELLULAR_CL/functions/train.py +++ b/CELLULAR_CL/functions/train.py @@ -164,15 +164,15 @@ def cell_type_centroid_distances(self, n_components: int=100): # Step 2: Calculate centroids for each cell type cluster of each batch effect centroids = {} for batch_effect in adata.obs[self.batch_keys[0]].unique(): - for cell_type in adata.obs['cell_type'].unique(): - mask = (adata.obs[self.batch_keys[0]] == batch_effect) & (adata.obs['cell_type'] == cell_type) + for cell_type in adata.obs[self.target_key].unique(): + mask = (adata.obs[self.batch_keys[0]] == batch_effect) & (adata.obs[self.target_key] == cell_type) centroid = np.mean(adata_pca[mask], axis=0) centroids[(batch_effect, cell_type)] = centroid # Step 3: Calculate the average centroid distance between all batch effects - average_distance_matrix = np.zeros((len(adata.obs['cell_type'].unique()), len(adata.obs['cell_type'].unique()))) - for i, cell_type_i in enumerate(adata.obs['cell_type'].unique()): - for j, cell_type_j in enumerate(adata.obs['cell_type'].unique()): + average_distance_matrix = np.zeros((len(adata.obs[self.target_key].unique()), len(adata.obs[self.target_key].unique()))) + for i, cell_type_i in enumerate(adata.obs[self.target_key].unique()): + for j, cell_type_j in enumerate(adata.obs[self.target_key].unique()): distances = [] for batch_effect in adata.obs[self.batch_keys[0]].unique(): centroid_i = torch.tensor(centroids[(batch_effect, cell_type_i)], dtype=torch.float32, requires_grad=False) @@ -188,7 +188,7 @@ def cell_type_centroid_distances(self, n_components: int=100): average_distance_matrix[i, j] = average_distance # Convert average_distance_matrix into a DataFrame - average_distance_df = pd.DataFrame(average_distance_matrix, index=self.label_encoder.transform(adata.obs['cell_type'].unique()), columns=self.label_encoder.transform(adata.obs['cell_type'].unique())) + average_distance_df = pd.DataFrame(average_distance_matrix, index=self.label_encoder.transform(adata.obs[self.target_key].unique()), columns=self.label_encoder.transform(adata.obs[self.target_key].unique())) # Replace NaN values with 0 average_distance_df = average_distance_df.fillna(0) @@ -1021,8 +1021,9 @@ def train_model(self, # Update learning rate lr_scheduler.step() - except: + except Exception as e: print(f"**Training forced to finish early due to error during training**") + print(e) print() print(f"**Finished training**") @@ -1151,7 +1152,8 @@ def rep_seed(seed): batch_keys=self.batch_keys, temperature=init_temperature, min_temperature=min_temperature, - max_temperature=max_temperature) + max_temperature=max_temperature, + device = device) # Define Adam optimer optimizer = optim.Adam([{'params': model_step_1.parameters(), 'lr': init_lr}, {'params': loss_module.parameters(), 'lr': init_lr}], weight_decay=5e-5) diff --git a/requirements.txt b/requirements.txt index baf66f2..5a15fcb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -25,7 +25,7 @@ joblib==1.3.2 jupyter_client==8.6.0 jupyter_core==5.7.1 kiwisolver==1.4.5 -llvmlite==0.42.0 +llvmlite==0.43.0 Mako==1.3.2 MarkupSafe==2.1.5 matplotlib==3.8.3 @@ -34,7 +34,7 @@ mpmath==1.3.0 natsort==8.4.0 nest-asyncio==1.6.0 networkx==3.2.1 -numba==0.59.0 +numba==0.60.0 numpy==1.26.4 optuna==3.5.0 packaging==23.2 @@ -53,7 +53,7 @@ python-dateutil==2.9.0.post0 pytz==2024.1 PyYAML==6.0.1 pyzmq==25.1.2 -scanpy==1.9.8 +scanpy==1.10.0 scikit-learn==1.4.1.post1 scipy==1.12.0 seaborn==0.13.2