Spaces:
Sleeping
Sleeping
import os | |
import numpy as np | |
import torch | |
from torch import nn | |
from rdkit import Chem | |
from dockformer.data.utils import FeatureTensorDict | |
from dockformer.utils.consts import POSSIBLE_BOND_TYPES, POSSIBLE_ATOM_TYPES, POSSIBLE_CHARGES, POSSIBLE_CHIRALITIES | |
def get_atom_features(atom: Chem.Atom): | |
# TODO: this is temporary, we need to add more features, for example for Zn | |
if atom.GetSymbol() not in POSSIBLE_ATOM_TYPES: | |
print(f"********Unknown atom type {atom.GetSymbol()}") | |
atom_type = POSSIBLE_ATOM_TYPES.index("Ni") | |
else: | |
atom_type = POSSIBLE_ATOM_TYPES.index(atom.GetSymbol()) | |
atom_charge = POSSIBLE_CHARGES.index(max(min(atom.GetFormalCharge(), 1), -1)) | |
atom_chirality = POSSIBLE_CHIRALITIES.index(atom.GetChiralTag()) | |
return {"atom_type": atom_type, "atom_charge": atom_charge, "atom_chirality": atom_chirality} | |
def get_bond_features(bond: Chem.Bond): | |
bond_type = POSSIBLE_BOND_TYPES.index(bond.GetBondType()) | |
return {"bond_type": bond_type} | |
def make_ligand_features(ligand: Chem.Mol) -> FeatureTensorDict: | |
atoms_features = [] | |
atom_idx_to_atom_pos_idx = {} | |
for atom in ligand.GetAtoms(): | |
atom_idx_to_atom_pos_idx[atom.GetIdx()] = len(atoms_features) | |
atoms_features.append(get_atom_features(atom)) | |
atom_types = torch.tensor(np.array([atom["atom_type"] for atom in atoms_features], dtype=np.int64)) | |
atom_types_one_hot = nn.functional.one_hot(atom_types, num_classes=len(POSSIBLE_ATOM_TYPES), ) | |
atom_charges = torch.tensor(np.array([atom["atom_charge"] for atom in atoms_features], dtype=np.int64)) | |
atom_charges_one_hot = nn.functional.one_hot(atom_charges, num_classes=len(POSSIBLE_CHARGES)) | |
atom_chiralities = torch.tensor(np.array([atom["atom_chirality"] for atom in atoms_features], dtype=np.int64)) | |
atom_chiralities_one_hot = nn.functional.one_hot(atom_chiralities, num_classes=len(POSSIBLE_CHIRALITIES)) | |
ligand_target_feat = torch.cat([atom_types_one_hot.float(), atom_charges_one_hot.float(), | |
atom_chiralities_one_hot.float()], dim=1) | |
# create one-hot matrix encoding for bonds | |
ligand_bonds_feat = torch.zeros((len(atoms_features), len(atoms_features), len(POSSIBLE_BOND_TYPES))) | |
ligand_bonds = [] | |
for bond in ligand.GetBonds(): | |
atom1_idx = atom_idx_to_atom_pos_idx[bond.GetBeginAtomIdx()] | |
atom2_idx = atom_idx_to_atom_pos_idx[bond.GetEndAtomIdx()] | |
bond_features = get_bond_features(bond) | |
ligand_bonds.append((atom1_idx, atom2_idx, bond_features["bond_type"])) | |
ligand_bonds_feat[atom1_idx, atom2_idx, bond_features["bond_type"]] = 1 | |
return { | |
# These are used for reconstruction at the end of the pipeline | |
"ligand_atype": atom_types, | |
"ligand_charge": atom_charges, | |
"ligand_chirality": atom_chiralities, | |
"ligand_bonds": torch.tensor(ligand_bonds, dtype=torch.int64), | |
# these are the actual features | |
"ligand_target_feat": ligand_target_feat.float(), | |
"ligand_bonds_feat": ligand_bonds_feat.float(), | |
} | |