Skip to content

[BUG] _Two differences between Unimol2 source code and unimol_tools package resulting in different embedding generating #371

@Du-JP

Description

@Du-JP

Describe the bug

First Difference:

There are different coding strategies for Unimolv1 and Unimolv2.
3D conformer normalizing is conducted for Unimolv1.

    # coordinates normalize & padding
    src_coord = coordinates - coordinates.mean(axis=0)
    src_coord = np.concatenate([np.zeros((1, 3)), src_coord, np.zeros((1, 3))], axis=0)
    # distance matrix
    src_distance = distance_matrix(src_coord, src_coord)
    # edge type
    src_edge_type = src_tokens.reshape(-1, 1) * len(dictionary) + src_tokens.reshape(
        1, -1
    )

In Uni-Mol.unimol_tools.unimol_tools.data.conformer.coords2unimol

3D conformer normalizing is not conducted for Unimolv2.

    # cropping atoms and coordinates
    if len(atoms) > max_atoms:
        mask = np.zeros(len(atoms), dtype=bool)
        mask[:max_atoms] = True
        np.random.shuffle(mask)  # shuffle the mask
        atoms = atoms[mask]
        coordinates = coordinates[mask]
    else:
        mask = np.ones(len(atoms), dtype=bool)
    # tokens padding
    src_tokens = [AllChem.GetPeriodicTable().GetAtomicNumber(item) for item in atoms]
    src_coord = coordinates
    #

In Uni-Mol.unimol_tools.unimol_tools.data.conformer.mol2unimolv2

However, we can choose whether to normalize 3D conformer in Unimolv2 source code, which is default to TRUE.

dataset = CroppingDataset(
            dataset, self.seed, "atoms", "coordinates", "coordinates_2d", self.args.max_atoms
        )
dataset = NormalizeDataset(dataset, "coordinates", "coordinates_2d", normalize_coord=True)

In Uni-Mol.unimol2.unimol2.tasks.unimol_finetune, Uni-Mol.unimol2.unimol2.data.normalize_dataset.NormalizeDataset.cached_item


Second Difference:

Drop feat prob which is used in input generating also differs between Unimolv2 source code and unimol_tools package.
Drop feat prob is default to 0 in unimol_tools package.

    node_attr, edge_index, edge_attr = get_graph(mol)
    feat = get_graph_features(edge_attr, edge_index, node_attr, drop_feat=0, mask=mask)
    feat['src_tokens'] = src_tokens
    feat['src_coord'] = src_coord
    return feat

In Uni-Mol.unimol_tools.unimol_tools.data.conformer.mol2unimolv2

While Drop feat prob can be set manly, it is default to two numbers in unimolv2 source code.

class MoleculeFeatureDataset(BaseWrapperDataset):
    def __init__(self, dataset, smi_key='smi', drop_feat_prob=0.5, seed=None):
        self.dataset = dataset
        self.smi_key = smi_key
        self.drop_feat_prob = drop_feat_prob
        self.seed = seed
        self.set_epoch(None)

In Uni-Mol.unimol2.unimol2.data.molecule_dataset.MoleculeFeatureDataset.init

class Unimol2FeatureDataset(BaseWrapperDataset):
    def __init__(
        self,
        smi_dataset: torch.utils.data.Dataset,
        token_dataset: torch.utils.data.Dataset,
        src_pos_dataset: torch.utils.data.Dataset,
        src_2d_pos_dataset: torch.utils.data.Dataset,
        pad_idx: int,
        mask_idx: int,
        mask_token_prob: float = 0.15,

        mask_pos_prob: float = 1.0,
        noise: float = 1.0,
        noise_type: str = "uniform",
        drop_feat_prob: float = 1.0,
        use_2d_pos: float = 0.5,
        seed: int = 1,
    ):

In Uni-Mol.unimol2.unimol2.data.unimol2_dataset.Unimol2FeatureDataset.init

It may even be processed randomly in the middle, resulting in a 0 or 1 drop feat, which will be used later to generate the input.

    with data_utils.numpy_seed(self.seed, epoch, idx):
            data['drop_feat'] = np.random.rand() < self.drop_feat_prob
        return data

In Uni-Mol.unimol2.unimol2.data.molecule_dataset.MoleculeFeatureDataset.cached_item

Uni-Mol Version

Uni-Mol2, Uni-Mol Tools

Expected behavior

Generating molecular embeddings separately in the source code and the tools package should yield the same molecular embeddings.

Image
generated by source code

Image
generated by unimol_tools package

After unifying all differences and fixing all bugs, it is possible to obtain embeddings for the above two images, and the embeddings of the two images are basically the same.
PS: The source code generates with padding, so the length is greater than that generated by the Unimol tools package.

To Reproduce

seq1:CCC(C1=CC=CC=C1)C1=C(O)C2=C(OC1=O)C=CC=C2
seq2:[H][C@]12SCC(C[N+]3=CC=CC=C3)=C(N1C(=O)[C@H]2NC(=O)CC1=CC=CS1)C([O-])=O

Environment

No response

Additional Context

Before solving my bug, you should solving #364 firstly.
In addition, you need to ensure that the 3D conformation of the source code input is generated according to the following method. The 3D conformation of the source code cannot be generated online.

def generate_3d_conformer(mol, seed=23, mode='fast'):
    # 获取原子数用于后续校验或兜底
    num_atoms = mol.GetNumAtoms()
    if num_atoms == 0:
        return np.zeros((0, 3), dtype=np.float32)

    try:
        # 1. 尝试 Embedding (生成 3D 初始构象)
        # randomSeed=-1 表示随机,否则固定种子
        res = AllChem.EmbedMolecule(mol, randomSeed=seed)

        if res == 0:
            # 1.1 Embedding 成功,尝试 MMFF 优化
            try:
                # 优先使用 MMFF 优化(精度通常比 UFF 高)
                AllChem.MMFFOptimizeMolecule(mol)
            except:
                # 某些分子无法使用 MMFF(如含有特殊金属原子),则跳过优化
                pass
            
            # 获取坐标
            coordinates = mol.GetConformer().GetPositions().astype(np.float32)

        # 1.2 Embedding 失败 (-1),且为 heavy 模式,尝试增加次数重试
        elif res == -1 and mode == 'heavy':
            AllChem.EmbedMolecule(mol, maxAttempts=5000, randomSeed=seed)
            try:
                AllChem.MMFFOptimizeMolecule(mol)
                coordinates = mol.GetConformer().GetPositions().astype(np.float32)
            except:
                # 如果重试后优化仍失败,或者根本没有生成构象,退回到 2D
                AllChem.Compute2DCoords(mol)
                coordinates = mol.GetConformer().GetPositions().astype(np.float32)
        
        # 1.3 Embedding 失败且非 heavy 模式,直接退回到 2D
        else:
            AllChem.Compute2DCoords(mol)
            coordinates = mol.GetConformer().GetPositions().astype(np.float32)

    except Exception as e:
        # 2. 兜底逻辑:如果上述过程报错,返回全 0 坐标防止程序崩溃
        print(f"Failed to generate conformer ({e}), replace with zeros.")
        coordinates = np.zeros((num_atoms, 3), dtype=np.float32)

    return coordinates

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions