From 032beb063866fd260a69900c2b246920f4ce7257 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Tue, 5 May 2026 16:53:17 +0200 Subject: [PATCH 1/3] update pubchem for dynamic dataset --- chebai/preprocessing/datasets/pubchem.py | 819 ++++------------------- tests/integration/testPubChemData.py | 18 +- 2 files changed, 116 insertions(+), 721 deletions(-) diff --git a/chebai/preprocessing/datasets/pubchem.py b/chebai/preprocessing/datasets/pubchem.py index 24ba4019..fe903322 100644 --- a/chebai/preprocessing/datasets/pubchem.py +++ b/chebai/preprocessing/datasets/pubchem.py @@ -1,40 +1,30 @@ -__all__ = [ - "PubchemBPE", - "PubChemTokens", - "SWJSelfies", - "SWJPreChem", - "SWJBPE", - "SWJChem", -] - import gzip import os import random import shutil import tempfile -import time from datetime import datetime from typing import Generator, List, Optional, Tuple, Type, Union -import numpy as np import pandas as pd import requests import torch import tqdm -from rdkit import Chem, DataStructs -from rdkit.Chem import AllChem from chebai.preprocessing import reader as dr -from chebai.preprocessing.datasets.base import DataLoader, XYBaseDataModule +from chebai.preprocessing.datasets.base import ( + DataLoader, + XYBaseDataModule, + _DynamicDataset, +) from chebai.preprocessing.datasets.chebi import ( ChEBIOver50, ChEBIOver100, ChEBIOverX, - _ChEBIDataExtractor, ) -class PubChem(XYBaseDataModule): +class PubChem(_DynamicDataset): """ Dataset module for PubChem compounds. """ @@ -45,6 +35,11 @@ class PubChem(XYBaseDataModule): UNLABELED = True READER = dr.ChemDataReader + # Column indices in data.pkl + _ID_IDX: int = 0 + _DATA_REPRESENTATION_IDX: int = 1 + _LABELS_START_IDX: int = 2 + def __init__(self, *args, k: Optional[int] = 100000, **kwargs): """ Args: @@ -67,13 +62,23 @@ def _name(self) -> str: """ return "Pubchem" + @property + def base_dir(self) -> str: + """ + Returns: + str: Base directory for this dataset. + """ + if self._base_dir is not None: + return self._base_dir + return os.path.join("data", self._name) + @property def identifier(self) -> tuple: """ Returns: - tuple: Tuple containing reader name and split label. + tuple: Tuple containing only the reader name (split is encoded in processed_dir_main). """ - return self.reader.name(), self.split_label + return (self.reader.name(),) @property def split_label(self) -> str: @@ -86,6 +91,14 @@ def split_label(self) -> str: else: return "full" + @property + def processed_dir_main(self) -> str: + """ + Returns: + str: Directory where data.pkl and splits.csv are stored (split-specific). + """ + return os.path.join(self.base_dir, "processed", self.split_label) + @property def raw_dir(self) -> str: """ @@ -94,20 +107,41 @@ def raw_dir(self) -> str: """ return os.path.join(self.base_dir, "raw", self.split_label) + @property + def _raw_data_source_path(self) -> str: + """Path to the raw text file used to build data.pkl.""" + return os.path.join(self.raw_dir, "smiles.txt") + @staticmethod - def _load_dict(input_file_path: str) -> Generator[dict, None, None]: + def _parse_raw_line(line: str) -> dict: + """Parse a single tab-separated line from a raw smiles text file.""" + ident, smiles = line.split("\t") + return dict(id=ident.strip(), smiles=smiles.strip()) + + def _load_dict(self, input_file_path: str) -> Generator[dict, None, None]: """ + Load data from data.pkl and yield dicts with features, labels and ident. + Args: - input_file_path (str): Path to the input file. + input_file_path (str): Path to the data.pkl file. Yields: dict: Dictionary containing 'features', 'labels' (None), and 'ident' fields. """ - # pubchem IDs are here - with open(input_file_path, "r") as input_file: - for row in input_file: - ident, smiles = row.split("\t") - yield dict(features=smiles, labels=None, ident=ident) + with open(input_file_path, "rb") as f: + df = pd.read_pickle(f) + for _, row in df.iterrows(): + yield dict(features=row["smiles"], labels=None, ident=str(row["id"])) + + def _download_required_data(self) -> str: + """Download raw data and return the path to the source file.""" + self.download() + return self._raw_data_source_path + + def _graph_to_raw_dataset(self, graph): + raise NotImplementedError( + "PubChem does not use a graph-based data preparation pipeline." + ) def download(self): """ @@ -144,29 +178,6 @@ def download(self): with open(os.path.join(self.raw_dir, "smiles.txt"), "w") as f_out: f_out.writelines([line for _, line in selected_lines]) - def setup_processed(self): - """ - Prepares processed data and saves them as Torch tensors. - """ - from sklearn.model_selection import train_test_split - - filename = os.path.join(self.raw_dir, self.raw_file_names[0]) - print("Load data from file", filename) - data = self._load_data_from_file(filename) - print("Create splits") - train, test = train_test_split( - data, train_size=1 - (self.validation_split + self.test_split) - ) - del data - test, val = train_test_split( - test, train_size=self.test_split / (self.validation_split + self.test_split) - ) - torch.save(train, os.path.join(self.processed_dir, "train.pt")) - torch.save(test, os.path.join(self.processed_dir, "test.pt")) - torch.save(val, os.path.join(self.processed_dir, "validation.pt")) - - self.reader.on_finish() - @property def raw_file_names(self) -> List[str]: """ @@ -175,14 +186,6 @@ def raw_file_names(self) -> List[str]: """ return ["smiles.txt"] - @property - def processed_file_names_dict(self) -> List[str]: - """ - Returns: - List[str]: List of processed data file names. - """ - return {"train": "train.pt", "test": "test.pt", "validation": "validation.pt"} - def _set_processed_data_props(self): """ Self-supervised learning with PubChem does not use this metadata, therefore set them to zero. @@ -201,7 +204,7 @@ def _set_processed_data_props(self): def _perform_data_preparation(self, *args, **kwargs): """ - Checks for raw data and downloads if necessary. + Checks for raw data, downloads if necessary, then builds data.pkl. """ print("Check for raw data in", self.raw_dir) if any( @@ -212,6 +215,45 @@ def _perform_data_preparation(self, *args, **kwargs): self.download() print("Done") + pkl_path = os.path.join( + self.processed_dir_main, self.processed_main_file_names_dict["data"] + ) + if not os.path.isfile(pkl_path): + os.makedirs(self.processed_dir_main, exist_ok=True) + print(f"Building data.pkl from {self._raw_data_source_path}...") + rows = [] + with open(self._raw_data_source_path, "r") as f: + for line in tqdm.tqdm(f): + line = line.rstrip("\n") + if line: + rows.append(self._parse_raw_line(line)) + df = pd.DataFrame(rows, columns=["id", "smiles"]) + pd.to_pickle(df, pkl_path) + print(f"Saved {len(df)} entries to {pkl_path}") + + def _get_data_splits(self) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: + """ + Load encoded data and split into train, validation and test. + """ + from sklearn.model_selection import train_test_split + + filename = self.processed_file_names_dict["data"] + data = self.load_processed_data_from_file(filename) + df = pd.DataFrame(data) + + train, rest = train_test_split( + df, + train_size=1 - (self.validation_split + self.test_split), + random_state=self.dynamic_data_split_seed, + ) + val, test = train_test_split( + rest, + train_size=self.validation_split + / (self.validation_split + self.test_split), + random_state=self.dynamic_data_split_seed, + ) + return train, val, test + class PubChemBatched(PubChem): """Store train data as batches of 10m, validation and test should each be 100k max""" @@ -297,9 +339,11 @@ def setup_processed(self): """ from sklearn.model_selection import train_test_split - filename = os.path.join(self.raw_dir, self.raw_file_names[0]) - print("Load data from file", filename) - data_not_tokenized = [entry for entry in self._load_dict(filename)] + pkl_path = os.path.join( + self.processed_dir_main, self.processed_main_file_names_dict["data"] + ) + print("Load data from file", pkl_path) + data_not_tokenized = list(self._load_dict(pkl_path)) print("Create splits") train, test = train_test_split( data_not_tokenized, test_size=self.test_batch_size + self.val_batch_size @@ -355,497 +399,6 @@ def train_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader ) -class PubChemDissimilar(PubChem): - """ - Subset of PubChem, but choosing the most dissimilar molecules (according to fingerprint) - """ - - def __init__( - self, - *args, - k: Optional[int] = 100000, - n_random_subsets: Optional[int] = 100, - random_size_factor: Optional[int] = 5, - **kwargs, - ): - """ - Args: - k (Optional[int]): Number of entries in this dataset. - n_random_subsets (Optional[int]): Number of subsets of random data to draw most dissimilar molecules from. - random_size_factor (Optional[int]): Size of random subsets in relation to k. - *args: Additional arguments for superclass initialization. - **kwargs: Additional keyword arguments for superclass initialization. - """ - self.n_random_subsets = n_random_subsets - self.random_size_factor = random_size_factor - super(PubChemDissimilar, self).__init__(*args, k=k, **kwargs) - - @property - def _name(self) -> str: - """ - Returns: - str: Name of the dataset. - """ - return "PubchemDissimilar" - - def download(self): - """ - Downloads the PubChemDissimilar dataset. - - If `k` is set to `PubChem.FULL`, downloads the full dataset. - Otherwise, generates random subsets to select the most dissimilar molecules based on fingerprints. - """ - if self._k == PubChem.FULL: - super().download() - else: - # split random subset into n parts, from each part, select the most dissimilar entities - random_dataset = PubChem(k=self._k * self.random_size_factor) - random_dataset.download() - - with open(os.path.join(random_dataset.raw_dir, "smiles.txt"), "r") as f_in: - random_smiles = [ - [x.strip() for x in s.split("\t")] for s in f_in.readlines() - ] - fpgen = AllChem.GetRDKitFPGenerator() - selected_smiles = [] - print("Selecting most dissimilar values from random subsets...") - for i in tqdm.tqdm(range(self.n_random_subsets)): - smiles_i = random_smiles[ - i * len(random_smiles) // self.n_random_subsets : (i + 1) - * len(random_smiles) - // self.n_random_subsets - ] - mols_i = [Chem.MolFromSmiles(smiles) for _, smiles in smiles_i] - fps = [ - fpgen.GetFingerprint(m) if m is not None else m for m in mols_i - ] - nonnull_fps = [fp for fp in fps if fp is not None] - similarity = [] - for i, fp in enumerate(fps): - try: - if fp is not None: - bulk = DataStructs.BulkTanimotoSimilarity( - fp, nonnull_fps - ) - similarity.append(sum(bulk)) - else: - similarity.append(len(smiles_i)) - except Exception as e: - print(i, smiles_i[i]) - print(e.with_traceback(None)) - similarity.append(len(smiles_i)) - - similarity = sorted(zip(smiles_i, similarity), key=lambda x: x[1]) - selected_smiles += list( - list( - zip(*similarity[: len(smiles_i) // self.random_size_factor]) - )[0] - ) - with open(os.path.join(self.raw_dir, "smiles.txt"), "w") as f_out: - f_out.writelines( - "\n".join(["\t".join(smiles) for smiles in selected_smiles]) - ) - - -class PubChemKMeans(PubChem): - """ - Dataset class representing a subset of PubChem dataset clustered using K-Means algorithm. - The idea is to create distinct distributions where pretraining and test sets are formed from dissimilar data. - """ - - def __init__( - self, - *args, - n_clusters: int = 10000, - random_size: int = 1000000, - exclude_data_from: _ChEBIDataExtractor = None, - validation_size_limit: int = 4000, - include_min_n_clusters: int = 100, - **kwargs, - ): - """ - Args: - n_clusters (int): Number of clusters to create using K-Means. - random_size (int): Size of random dataset to download. - exclude_data_from (_ChEBIDataExtractor): Dataset which should not overlap with selected clusters - (remove all clusters that contain data from this dataset). - validation_size_limit (int): Validation set will contain at most this number of instances. - include_min_n_clusters (int): Minimum number of clusters to keep if there are not enough clusters that don't - overlap with the `exclude_data_from` dataset. - - *args: Additional arguments for superclass initialization. - **kwargs: Additional keyword arguments for superclass initialization. - """ - self.n_clusters = int(n_clusters) - self.exclude_data_from = exclude_data_from - self.validation_size_limit = validation_size_limit - self.include_min_n_clusters = include_min_n_clusters - super(PubChemKMeans, self).__init__(*args, k=int(random_size), **kwargs) - self._fingerprints = None - self._cluster_centers = None - self._fingerprints_clustered = None - self._exclusion_data_clustered = None - self._cluster_centers_superclustered = None - - @property - def _name(self) -> str: - """ - Returns: - str: Name of the dataset. - """ - return "PubchemKMeans" - - @property - def split_label(self) -> str: - """ - Returns: - str: Label describing the split based on number of clusters. - """ - if self._k and self._k != self.FULL: - return f"{self.n_clusters}_centers_out_of_{self._k}" - else: - return f"{self.n_clusters}_centers_out_of_full" - - @property - def raw_file_names(self) -> List[str]: - """ - Clusters generated by K-Means, sorted by size (cluster0 is the largest). - cluster0 is the training cluster (will be split into train/val/test in processed, used for pretraining) - Returns: - List[str]: List of raw file names expected in the raw directory. - """ - return ["cluster0.txt", "cluster1.txt", "cluster2.txt"] - - @property - def fingerprints(self) -> pd.DataFrame: - """ - Creates random dataset, sanitises, creates Mol objects, generates fingerprints (RDKit) - Saves `fingerprints_df` to `fingerprints.pkl` - - Returns: - pd.DataFrame: DataFrame containing SMILES and corresponding fingerprints. - """ - if self._fingerprints is None: - fingerprints_path = os.path.join(self.raw_dir, "fingerprints.pkl") - if not os.path.exists(fingerprints_path): - print("No fingerprints found...") - print(f"Loading random dataset (size: {self._k})...") - random_dataset = PubChem(k=self._k) - random_dataset.download() - with open( - os.path.join(random_dataset.raw_dir, "smiles.txt"), "r" - ) as f_in: - random_smiles = [s.split("\t")[1].strip() for s in f_in.readlines()] - fpgen = AllChem.GetRDKitFPGenerator() - print("Converting SMILES to molecules...") - mols = [Chem.MolFromSmiles(s) for s in tqdm.tqdm(random_smiles)] - print("Generating Fingerprints...") - fps = [ - fpgen.GetFingerprint(m) if m is not None else m - for m in tqdm.tqdm(mols) - ] - d = {"smiles": random_smiles, "fps": fps} - fingerprints_df = pd.DataFrame(d, columns=["smiles", "fps"]) - fingerprints_df = fingerprints_df.dropna() - fingerprints_df.to_pickle(open(fingerprints_path, "wb")) - self._fingerprints = fingerprints_df - else: - self._fingerprints = pd.read_pickle(open(fingerprints_path, "rb")) - return self._fingerprints - - def _build_clusters(self) -> tuple[pd.DataFrame, pd.DataFrame]: - """ - Performs K-Means clustering on fingerprints and saves cluster information. - - Returns: - tuple: Tuple containing cluster centers DataFrame and clustered fingerprints DataFrame. - """ - from sklearn.cluster import KMeans - - fingerprints_clustered_path = os.path.join( - self.raw_dir, "fingerprints_clustered.pkl" - ) - cluster_centers_path = os.path.join(self.raw_dir, "cluster_centers.pkl") - print("Starting k-means clustering...") - start_time = time.perf_counter() - kmeans = KMeans(n_clusters=self.n_clusters, random_state=0, n_init="auto") - fps = np.array([list(vec) for vec in self.fingerprints["fps"].tolist()]) - kmeans.fit(fps) - print(f"Finished k-means in {time.perf_counter() - start_time:.2f} seconds") - fingerprints_df = self.fingerprints - fingerprints_df["label"] = kmeans.labels_ - fingerprints_df.to_pickle( - open( - fingerprints_clustered_path, - "wb", - ) - ) - cluster_df = pd.DataFrame( - data={"centers": [center for center in kmeans.cluster_centers_]} - ) - cluster_df.to_pickle( - open( - cluster_centers_path, - "wb", - ) - ) - - return cluster_df, fingerprints_df - - def _exclude_clusters(self, cluster_centers: pd.DataFrame) -> pd.DataFrame: - """ - Excludes clusters based on data from an exclusion dataset (in a training setup, this is the labeled dataset, - usually ChEBI). The goal is to avoid having similar data in the labeled training and the PubChem evaluation. - - Loads data from `exclude_data_from` dataset, generates mols, fingerprints, finds closest cluster centre for - each fingerprint, saves data to `exclusion_data_clustered.pkl`, returns all clusters with no instances from the - exclusion data (or the n clusters with the lowest number of instances if there are less than n clusters with no - instances, n being the minimum number of clusters to include) - - Args: - cluster_centers (pd.DataFrame): DataFrame of cluster centers. - - Returns: - pd.DataFrame: DataFrame of filtered cluster centers. - """ - from scipy import spatial - - exclusion_data_path = os.path.join(self.raw_dir, "exclusion_data_clustered.pkl") - cluster_centers_np = np.array( - [ - [cci for cci in cluster_center] - for cluster_center in cluster_centers["centers"] - ] - ) - if self.exclude_data_from is not None: - if not os.path.exists(exclusion_data_path): - print("Loading data for exclusion of clusters...") - raw_chebi = [] - for filename in self.exclude_data_from.raw_file_names: - raw_chebi.append( - pd.read_pickle( - open( - os.path.join(self.exclude_data_from.raw_dir, filename), - "rb", - ) - ) - ) - raw_chebi = pd.concat(raw_chebi) - raw_chebi_smiles = np.array(raw_chebi["SMILES"]) - fpgen = AllChem.GetRDKitFPGenerator() - print("Converting SMILES to molecules...") - mols = [Chem.MolFromSmiles(s) for s in tqdm.tqdm(raw_chebi_smiles)] - print("Generating Fingerprints...") - chebi_fps = [ - fpgen.GetFingerprint(m) if m is not None else m - for m in tqdm.tqdm(mols) - ] - print("Finding cluster for each instance from exclusion-data") - chebi_fps = np.array([list(fp) for fp in chebi_fps if fp is not None]) - tree = spatial.KDTree(cluster_centers_np) - chebi_clusters = [tree.query(fp)[1] for fp in chebi_fps] - chebi_clusters_df = pd.DataFrame( - {"fp": [fp for fp in chebi_fps], "center_id": chebi_clusters}, - columns=["fp", "center_id"], - ) - chebi_clusters_df.to_pickle(open(exclusion_data_path, "wb")) - else: - chebi_clusters_df = pd.read_pickle(open(exclusion_data_path, "rb")) - # filter pubchem clusters and remove all that contain data from the exclusion set - print("Removing clusters with data from exclusion-set") - counts = chebi_clusters_df["center_id"].value_counts() - cluster_centers["n_chebi_instances"] = counts - cluster_centers["n_chebi_instances"].fillna(0, inplace=True) - cluster_centers.sort_values( - by="n_chebi_instances", ascending=False, inplace=True - ) - zero_centers = cluster_centers[cluster_centers["n_chebi_instances"] == 0] - if len(zero_centers) > self.include_min_n_clusters: - cluster_centers = zero_centers - else: - cluster_centers = cluster_centers[-self.include_min_n_clusters :] - return cluster_centers - - @property - def cluster_centers(self) -> pd.DataFrame: - """ - Loads cluster centers from file if possible, otherwise calls `self._build_clusters()`. - Returns: - pd.DataFrame: DataFrame of cluster centers. - """ - cluster_centers_path = os.path.join(self.raw_dir, "cluster_centers.pkl") - if self._cluster_centers is None: - if os.path.exists(cluster_centers_path): - self._cluster_centers = pd.read_pickle(open(cluster_centers_path, "rb")) - else: - self._cluster_centers = self._build_clusters()[0] - return self._cluster_centers - - @property - def fingerprints_clustered(self) -> pd.DataFrame: - """ - Loads fingerprints with assigned clusters from file if possible, otherwise calls `self._build_clusters()`. - Returns: - pd.DataFrame: DataFrame of clustered fingerprints. - """ - fingerprints_path = os.path.join(self.raw_dir, "fingerprints_clustered.pkl") - if self._fingerprints_clustered is None: - if os.path.exists(fingerprints_path): - self._fingerprints_clustered = pd.read_pickle( - open(fingerprints_path, "rb") - ) - else: - self._fingerprints_clustered = self._build_clusters()[1] - return self._fingerprints_clustered - - @property - def cluster_centers_superclustered(self) -> pd.DataFrame: - """ - Calls `_exclude_clusters()` which removes all clusters that contain data from the exclusion set (usually the - ChEBI, i.e., the labeled dataset). - Runs KMeans with 3 clusters on remaining data, saves cluster centres with assigned supercluster-labels to - `cluster_centers_superclustered.pkl` - Returns: - pd.DataFrame: DataFrame of superclustered cluster centers. - """ - from sklearn.cluster import KMeans - - cluster_centers_path = os.path.join( - self.raw_dir, "cluster_centers_superclustered.pkl" - ) - if self._cluster_centers_superclustered is None: - if not os.path.exists(cluster_centers_path): - clusters_filtered = self._exclude_clusters(self.cluster_centers) - print("Superclustering PubChem clusters") - kmeans = KMeans(n_clusters=3, random_state=0, n_init="auto") - clusters_np = np.array( - [[cci for cci in center] for center in clusters_filtered["centers"]] - ) - kmeans.fit(clusters_np) - clusters_filtered["label"] = kmeans.labels_ - clusters_filtered.to_pickle( - open( - os.path.join( - self.raw_dir, "cluster_centers_superclustered.pkl" - ), - "wb", - ) - ) - self._cluster_centers_superclustered = clusters_filtered - else: - self._cluster_centers_superclustered = pd.read_pickle( - open( - os.path.join( - self.raw_dir, "cluster_centers_superclustered.pkl" - ), - "rb", - ) - ) - return self._cluster_centers_superclustered - - def download(self): - """ - Downloads the PubChemKMeans dataset. This function creates the complete dataset (including train, test, and - validation splits). Most of the steps are hidden in properties (e.g., `self.fingerprints_clustered` triggers - the download of a random dataset, the calculation of fingerprints for it and the KMeans clustering) - The final splits are created by assigning all fingerprints that belong to a cluster of a certain supercluster - to a dataset. This creates 3 datasets (for each of the 3 superclusters), the datasets are saved as validation, - test and train based on their size. The validation set is limited to `self.validation_size_limit` entries. - """ - if self._k == PubChem.FULL: - super().download() - else: - if not all( - os.path.exists(os.path.join(self.raw_dir, file)) - for file in self.raw_file_names - ): - fingerprints = self.fingerprints_clustered - fingerprints["big_cluster_assignment"] = fingerprints["label"].apply( - lambda l_: ( - -1 - if l_ not in self.cluster_centers_superclustered.index - else self.cluster_centers_superclustered.loc[int(l_), "label"] - ) - ) - fp_grouped = fingerprints.groupby("big_cluster_assignment") - splits = [fp_grouped.get_group(g) for g in fp_grouped.groups if g != -1] - splits[0] = splits[0][: self.validation_size_limit] - splits.sort(key=lambda x: len(x)) - for i, name in enumerate(["cluster2", "cluster1", "cluster0"]): - if not os.path.exists(os.path.join(self.raw_dir, f"{name}.txt")): - open(os.path.join(self.raw_dir, f"{name}.txt"), "x").close() - with open(os.path.join(self.raw_dir, f"{name}.txt"), "w") as f: - for id, row in splits[i].itertuples(index=True): - f.writelines(f"{id}\t{row.smiles}\n") - - -class PubChemDissimilarSMILES(PubChemDissimilar): - """ - Subset of PubChem, selecting most dissimilar molecules based on fingerprints. - - Inherits from PubChemDissimilar. - - Attributes: - READER (type): Data reader type for chemical data. - """ - - READER: Type[dr.ChemDataReader] = dr.ChemDataReader - - -class SWJPreChem(PubChem): - """ - Subset of PubChem with unlabeled data, specific to SWJPre. - - Inherits from PubChem. - - Attributes: - UNLABELED (bool): Indicates if the data is unlabeled. - _name (str): Name of the dataset. - """ - - UNLABELED: bool = True - - @property - def _name(self) -> str: - """ - Returns the name of the dataset. - """ - return "SWJpre" - - def download(self): - """ - Raises an exception since required raw files are not found. - """ - raise Exception("Required raw files not found") - - @property - def identifier(self) -> Tuple[str]: - """ - Returns the identifier for the dataset. - """ - return (self.reader.name(),) - - @property - def raw_dir(self) -> str: - """ - Returns the directory path for raw data. - """ - return os.path.join("data", self._name, "raw") - - -class SWJSelfies(SWJPreChem): - """ - Subset of SWJPreChem using SelfiesReader for data reading. - - Inherits from SWJPreChem. - - Attributes: - READER (type): Data reader type for chemical data (SelfiesReader). - """ - - READER: Type[dr.SelfiesReader] = dr.SelfiesReader - - class PubchemChem(PubChem): """ Subset of PubChem using ChemDataReader for data reading. @@ -859,156 +412,6 @@ class PubchemChem(PubChem): READER: Type[dr.ChemDataReader] = dr.ChemDataReader -class PubchemBPE(PubChem): - """ - Subset of PubChem using ChemBPEReader for data reading. - - Inherits from PubChem. - - Attributes: - READER (type): Data reader type for chemical data (ChemBPEReader). - """ - - READER: Type[dr.ChemBPEReader] = dr.ChemBPEReader - - -class SWJChem(SWJPreChem): - """ - Subset of SWJPreChem using ChemDataUnlabeledReader for data reading. - - Inherits from SWJPreChem. - - Attributes: - READER (type): Data reader type for chemical data (ChemDataUnlabeledReader). - """ - - READER: Type[dr.ChemDataUnlabeledReader] = dr.ChemDataUnlabeledReader - - -class SWJBPE(SWJPreChem): - """ - Subset of SWJPreChem using ChemBPEReader for data reading. - - Inherits from SWJPreChem. - - Attributes: - READER (type): Data reader type for chemical data (ChemBPEReader). - """ - - READER: Type[dr.ChemBPEReader] = dr.ChemBPEReader - - -class PubChemTokens(PubChem): - """ - Subset of PubChem using ChemDataReader for data reading. - - Inherits from PubChem. - - Attributes: - READER (type): Data reader type for chemical data (ChemDataReader). - """ - - READER: Type[dr.ChemDataReader] = dr.ChemDataReader - - -class Hazardous(SWJChem): - """ - Subset of SWJChem for hazardous compounds. - - Inherits from SWJChem. - - Attributes: - READER (type): Data reader type for chemical data (ChemDataUnlabeledReader). - """ - - READER: Type[dr.ChemDataUnlabeledReader] = dr.ChemDataUnlabeledReader - - @property - def _name(self) -> str: - """ - Returns the name of the dataset. - """ - return "PubChemHazardous" - - def setup_processed(self): - """ - Sets up the processed data. - """ - filename = os.path.join(self.raw_dir, self.raw_file_names[0]) - print("Load data from file", filename) - data = self._load_data_from_file(filename) - torch.save(data, os.path.join(self.processed_dir, "all.pt")) - - self.reader.on_finish() - - def processed_file_names(self) -> List[str]: - """ - Returns the list of processed file names. - """ - return ["all.pt"] - - def download(self): - """ - Downloads hazardous compound data from PubChem. - """ - # requires the / a hazardous subset from pubchem, e.g. obtained by entering - # "PubChem: PubChem Compound TOC: GHS Classification" in the pubchem search -> download -> csv - csv_path = os.path.join(self.raw_dir, "pubchem_hazardous_compound_list.csv") - compounds = pd.read_csv(csv_path) - smiles_list = [] - for compound in compounds.itertuples(index=False): - if ( - not isinstance(compound.cmpdsynonym, str) - or "CHEBI" not in compound.cmpdsynonym - ): - smiles_list.append(f"{compound.cid}\t{compound.isosmiles}") - with open(os.path.join(self.raw_dir, "smiles.txt"), "w") as f: - f.write("\n".join(smiles_list)) - - -class SWJPreChem(PubChem): - """ - Subset of PubChem specific to SWJpre with unlabeled data. - - Inherits from PubChem. - - Attributes: - UNLABELED (bool): Indicates if the data is unlabeled. - _name (str): Name of the dataset. - """ - - UNLABELED: bool = True - - @property - def _name(self) -> str: - """ - Returns the name of the dataset. - - Returns: - str: Name of the dataset. - """ - return "SWJpre" - - def download(self) -> None: - """ - Raises an exception since required raw files are not found. - - Raises: - Exception: If required raw files are not found. - """ - raise Exception("Required raw files not found") - - @property - def identifier(self) -> Tuple[str]: - """ - Returns the identifier for the dataset. - - Returns: - tuple: A tuple containing the name of the reader. - """ - return (self.reader.name(),) - - class LabeledUnlabeledMixed(XYBaseDataModule): """ Mixed dataset combining labeled and unlabeled data. @@ -1053,12 +456,8 @@ def dataloader(self, kind: str, **kwargs) -> DataLoader: Returns: DataLoader: DataLoader instance. """ - labeled_data = torch.load( - os.path.join(self.labeled.processed_dir, f"{kind}.pt"), weights_only=False - ) - unlabeled_data = torch.load( - os.path.join(self.unlabeled.processed_dir, f"{kind}.pt"), weights_only=False - ) + labeled_data = self.labeled.load_processed_data(kind) + unlabeled_data = self.unlabeled.load_processed_data(kind) if self.data_limit is not None: labeled_data = labeled_data[: self.data_limit] unlabeled_data = unlabeled_data[: self.data_limit] @@ -1173,3 +572,9 @@ class PubChemSELFIES(PubChem): """ READER: Type[dr.SelfiesReader] = dr.SelfiesReader + + +if __name__ == "__main__": + dataset = PubchemChem(k=10000) + dataset.prepare_data() + dataset.setup() diff --git a/tests/integration/testPubChemData.py b/tests/integration/testPubChemData.py index 71591f6e..ea835894 100644 --- a/tests/integration/testPubChemData.py +++ b/tests/integration/testPubChemData.py @@ -1,9 +1,6 @@ -import os import unittest from typing import Dict, List, Tuple -import torch - from chebai.preprocessing.datasets.pubchem import PubChem @@ -34,18 +31,11 @@ def getDataSplitsOverlaps(cls) -> None: """ Get the overlap between data splits based on SMILES features and IDs. """ - processed_path = os.path.join(os.getcwd(), cls.pubChem.processed_dir) - print(f"Checking Data from - {processed_path}") + print(f"Checking Data from - {cls.pubChem.processed_dir}") - train_set = torch.load( - os.path.join(processed_path, "train.pt"), weights_only=False - ) - val_set = torch.load( - os.path.join(processed_path, "validation.pt"), weights_only=False - ) - test_set = torch.load( - os.path.join(processed_path, "test.pt"), weights_only=False - ) + train_set = cls.pubChem.load_processed_data("train") + val_set = cls.pubChem.load_processed_data("validation") + test_set = cls.pubChem.load_processed_data("test") train_smiles, train_smiles_ids = cls.get_features_ids(train_set) val_smiles, val_smiles_ids = cls.get_features_ids(val_set) From 966a483e69910f58200ea3cb6422e8b5747b673d Mon Sep 17 00:00:00 2001 From: sfluegel Date: Tue, 5 May 2026 17:14:58 +0200 Subject: [PATCH 2/3] move classes.txt property (not used by unlabeled datasets like PubChem --- chebai/preprocessing/datasets/base.py | 13 ------------- chebai/preprocessing/datasets/chebi.py | 13 +++++++++++++ 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index 7f49295c..24655b0d 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -1267,16 +1267,3 @@ def processed_file_names_dict(self) -> dict: if self.n_token_limit is not None: return {"data": f"data_maxlen{self.n_token_limit}.pt"} return {"data": "data.pt"} - - @property - def classes_txt_file_path(self) -> str: - """ - Returns the filename for the classes text file. - - Returns: - str: The filename for the classes text file. - """ - # This property also used in following places: - # - chebai/result/prediction.py: to load class names for csv columns names - # - chebai/cli.py: to link this property to `model.init_args.classes_txt_file_path` - return os.path.join(self.processed_dir_main, "classes.txt") diff --git a/chebai/preprocessing/datasets/chebi.py b/chebai/preprocessing/datasets/chebi.py index 698835fd..36e85e6a 100644 --- a/chebai/preprocessing/datasets/chebi.py +++ b/chebai/preprocessing/datasets/chebi.py @@ -519,6 +519,19 @@ def processed_file_names_dict(self) -> dict: } return {"data": f"aug_data_var{self.aug_smiles_variations}.pt"} + @property + def classes_txt_file_path(self) -> str: + """ + Returns the filename for the classes text file. + + Returns: + str: The filename for the classes text file. + """ + # This property also used in following places: + # - chebai/result/prediction.py: to load class names for csv columns names + # - chebai/cli.py: to link this property to `model.init_args.classes_txt_file_path` + return os.path.join(self.processed_dir_main, "classes.txt") + class ChEBIFromList(_ChEBIDataExtractor): """ From 4c0c8e3c0f937d665b4d3279ff89f74e40b59db4 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Tue, 5 May 2026 17:22:33 +0200 Subject: [PATCH 3/3] update configs --- chebai/preprocessing/datasets/pubchem.py | 49 +++++++++------------ configs/data/pubchem/pubchem.yml | 1 + configs/data/pubchem/pubchem_batched.yml | 3 ++ configs/data/pubchem/pubchem_dissimilar.yml | 3 -- 4 files changed, 24 insertions(+), 32 deletions(-) create mode 100644 configs/data/pubchem/pubchem.yml create mode 100644 configs/data/pubchem/pubchem_batched.yml delete mode 100644 configs/data/pubchem/pubchem_dissimilar.yml diff --git a/chebai/preprocessing/datasets/pubchem.py b/chebai/preprocessing/datasets/pubchem.py index fe903322..6df1ee51 100644 --- a/chebai/preprocessing/datasets/pubchem.py +++ b/chebai/preprocessing/datasets/pubchem.py @@ -40,14 +40,14 @@ class PubChem(_DynamicDataset): _DATA_REPRESENTATION_IDX: int = 1 _LABELS_START_IDX: int = 2 - def __init__(self, *args, k: Optional[int] = 100000, **kwargs): + def __init__(self, *args, n_samples: Optional[int] = 100000, **kwargs): """ Args: - k (Optional[int]): Number of samples to use. Set to `PubChem.FULL` for full dataset. + n_samples (Optional[int]): Number of samples to use. Set to `PubChem.FULL` for full dataset. *args: Additional arguments for superclass initialization. **kwargs: Additional keyword arguments for superclass initialization. """ - self._k = k + self._n_samples = n_samples current_year = datetime.today().year current_month = datetime.today().month self.pubchem_url = f"https://ftp.ncbi.nlm.nih.gov/pubchem/Compound/Monthly/{current_year}-{current_month:02d}-01/Extras/CID-SMILES.gz" @@ -86,8 +86,8 @@ def split_label(self) -> str: Returns: str: Label indicating the split of the dataset ('full' or a specific number). """ - if self._k and self._k != self.FULL: - return str(self._k) + if self._n_samples and self._n_samples != self.FULL: + return str(self._n_samples) else: return "full" @@ -148,7 +148,7 @@ def download(self): Downloads PubChem data based on `_k` parameter. """ if not os.path.isfile(os.path.join(self.raw_dir, "smiles.txt")): - if self._k == PubChem.FULL: + if self._n_samples == PubChem.FULL: print("Download from", self.pubchem_url) r = requests.get(self.pubchem_url, allow_redirects=True) with tempfile.NamedTemporaryFile() as tf: @@ -161,13 +161,15 @@ def download(self): ) as f_out: shutil.copyfileobj(f_in, f_out) else: - full_dataset = self.__class__(k=PubChem.FULL) + full_dataset = self.__class__(n_samples=PubChem.FULL) full_dataset.download() with open( os.path.join(full_dataset.raw_dir, "smiles.txt"), "r" ) as f_in: lines = sum(1 for _ in f_in) - selected = frozenset(random.sample(list(range(lines)), k=self._k)) + selected = frozenset( + random.sample(list(range(lines)), k=self._n_samples) + ) f_in.seek(0) selected_lines = list( filter( @@ -264,16 +266,16 @@ def __init__(self, train_batch_size=1_000_000, *args, **kwargs): super(PubChemBatched, self).__init__(*args, **kwargs) self.curr_epoch = 0 self.train_batch_size = train_batch_size - if self._k != self.FULL: + if self._n_samples != self.FULL: self.val_batch_size = ( 100_000 - if self.validation_split * self._k > 100_000 - else int(self.validation_split * self._k) + if self.validation_split * self._n_samples > 100_000 + else int(self.validation_split * self._n_samples) ) self.test_batch_size = ( 100_000 - if self.test_split * self._k > 100_000 - else int(self.test_split * self._k) + if self.test_split * self._n_samples > 100_000 + else int(self.test_split * self._n_samples) ) else: self.val_batch_size = 100_000 @@ -286,7 +288,9 @@ def processed_file_names_dict(self) -> List[str]: List[str]: List of processed data file names. """ train_samples = ( - self._k if self._k != self.FULL else 120_000_000 # estimated PubChem size + self._n_samples + if self._n_samples != self.FULL + else 120_000_000 # estimated PubChem size ) # estimate size train_samples -= self.val_batch_size + self.test_batch_size train_batches = ( @@ -399,19 +403,6 @@ def train_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader ) -class PubchemChem(PubChem): - """ - Subset of PubChem using ChemDataReader for data reading. - - Inherits from PubChem. - - Attributes: - READER (type): Data reader type for chemical data (ChemDataReader). - """ - - READER: Type[dr.ChemDataReader] = dr.ChemDataReader - - class LabeledUnlabeledMixed(XYBaseDataModule): """ Mixed dataset combining labeled and unlabeled data. @@ -511,7 +502,7 @@ def __init__(self, *args, **kwargs): **kwargs: Additional keyword arguments. """ super().__init__( - self.CHEBI_X(*args, **kwargs), PubchemChem(*args, **kwargs), *args, **kwargs + self.CHEBI_X(*args, **kwargs), PubChem(*args, **kwargs), *args, **kwargs ) @property @@ -575,6 +566,6 @@ class PubChemSELFIES(PubChem): if __name__ == "__main__": - dataset = PubchemChem(k=10000) + dataset = PubChem(k=10000) dataset.prepare_data() dataset.setup() diff --git a/configs/data/pubchem/pubchem.yml b/configs/data/pubchem/pubchem.yml new file mode 100644 index 00000000..9fedb097 --- /dev/null +++ b/configs/data/pubchem/pubchem.yml @@ -0,0 +1 @@ +class_path: chebai.preprocessing.datasets.pubchem.PubChem diff --git a/configs/data/pubchem/pubchem_batched.yml b/configs/data/pubchem/pubchem_batched.yml new file mode 100644 index 00000000..59dda20c --- /dev/null +++ b/configs/data/pubchem/pubchem_batched.yml @@ -0,0 +1,3 @@ +class_path: chebai.preprocessing.datasets.pubchem.PubChemBatched +init_args: + n_samples: 0 diff --git a/configs/data/pubchem/pubchem_dissimilar.yml b/configs/data/pubchem/pubchem_dissimilar.yml deleted file mode 100644 index b6dd5522..00000000 --- a/configs/data/pubchem/pubchem_dissimilar.yml +++ /dev/null @@ -1,3 +0,0 @@ -class_path: chebai.preprocessing.datasets.pubchem.PubChemDissimilarSMILES -init_args: - k: 200000