DockFormer / dockformer /data /ligand_features.py
bshor's picture
add code
bca3a49
import os
import numpy as np
import torch
from torch import nn
from rdkit import Chem
from dockformer.data.utils import FeatureTensorDict
from dockformer.utils.consts import POSSIBLE_BOND_TYPES, POSSIBLE_ATOM_TYPES, POSSIBLE_CHARGES, POSSIBLE_CHIRALITIES
def get_atom_features(atom: Chem.Atom):
# TODO: this is temporary, we need to add more features, for example for Zn
if atom.GetSymbol() not in POSSIBLE_ATOM_TYPES:
print(f"********Unknown atom type {atom.GetSymbol()}")
atom_type = POSSIBLE_ATOM_TYPES.index("Ni")
else:
atom_type = POSSIBLE_ATOM_TYPES.index(atom.GetSymbol())
atom_charge = POSSIBLE_CHARGES.index(max(min(atom.GetFormalCharge(), 1), -1))
atom_chirality = POSSIBLE_CHIRALITIES.index(atom.GetChiralTag())
return {"atom_type": atom_type, "atom_charge": atom_charge, "atom_chirality": atom_chirality}
def get_bond_features(bond: Chem.Bond):
bond_type = POSSIBLE_BOND_TYPES.index(bond.GetBondType())
return {"bond_type": bond_type}
def make_ligand_features(ligand: Chem.Mol) -> FeatureTensorDict:
atoms_features = []
atom_idx_to_atom_pos_idx = {}
for atom in ligand.GetAtoms():
atom_idx_to_atom_pos_idx[atom.GetIdx()] = len(atoms_features)
atoms_features.append(get_atom_features(atom))
atom_types = torch.tensor(np.array([atom["atom_type"] for atom in atoms_features], dtype=np.int64))
atom_types_one_hot = nn.functional.one_hot(atom_types, num_classes=len(POSSIBLE_ATOM_TYPES), )
atom_charges = torch.tensor(np.array([atom["atom_charge"] for atom in atoms_features], dtype=np.int64))
atom_charges_one_hot = nn.functional.one_hot(atom_charges, num_classes=len(POSSIBLE_CHARGES))
atom_chiralities = torch.tensor(np.array([atom["atom_chirality"] for atom in atoms_features], dtype=np.int64))
atom_chiralities_one_hot = nn.functional.one_hot(atom_chiralities, num_classes=len(POSSIBLE_CHIRALITIES))
ligand_target_feat = torch.cat([atom_types_one_hot.float(), atom_charges_one_hot.float(),
atom_chiralities_one_hot.float()], dim=1)
# create one-hot matrix encoding for bonds
ligand_bonds_feat = torch.zeros((len(atoms_features), len(atoms_features), len(POSSIBLE_BOND_TYPES)))
ligand_bonds = []
for bond in ligand.GetBonds():
atom1_idx = atom_idx_to_atom_pos_idx[bond.GetBeginAtomIdx()]
atom2_idx = atom_idx_to_atom_pos_idx[bond.GetEndAtomIdx()]
bond_features = get_bond_features(bond)
ligand_bonds.append((atom1_idx, atom2_idx, bond_features["bond_type"]))
ligand_bonds_feat[atom1_idx, atom2_idx, bond_features["bond_type"]] = 1
return {
# These are used for reconstruction at the end of the pipeline
"ligand_atype": atom_types,
"ligand_charge": atom_charges,
"ligand_chirality": atom_chiralities,
"ligand_bonds": torch.tensor(ligand_bonds, dtype=torch.int64),
# these are the actual features
"ligand_target_feat": ligand_target_feat.float(),
"ligand_bonds_feat": ligand_bonds_feat.float(),
}