import logging
from collections import namedtuple, defaultdict
from copy import deepcopy
from dataclasses import dataclass, field
from typing import List, Tuple
import numpy as np
import networkx as nx
import parmed as pm
from IPython.display import display, SVG
from rdkit import Chem
from rdkit.Chem import AllChem, Draw, rdFMCS, rdCoordGen
from rdkit.Chem.Draw import rdMolDraw2D
from rdkit.Chem.Draw import IPythonConsole
IPythonConsole.molSize = (900, 900) # Change image size
IPythonConsole.ipython_useSVG = True # Change output to SVG
from transformato.system import SystemStructure
from transformato.helper_functions import (
calculate_order_of_LJ_mutations_asfe,
cycle_checks_nx,
cycle_checks,
exclude_Hs_from_mutations,
change_route_cycles,
)
logger = logging.getLogger(__name__)
logging.getLogger().setLevel(logging.DEBUG)
def _flattened(list_of_lists: list) -> list:
return [item for sublist in list_of_lists for item in sublist]
def _performe_linear_charge_scaling(
nr_of_steps: int,
intermediate_factory,
mutation,
):
for lambda_value in np.linspace(1, 0, nr_of_steps + 1)[1:]:
print("####################")
print(
f"Coulomb scaling in step: {intermediate_factory.current_step} with lamb: {lambda_value}"
)
print("####################")
intermediate_factory.write_state(
mutation_conf=mutation,
lambda_value_electrostatic=lambda_value,
)
def _performe_linear_cc_scaling(
nr_of_steps: int,
intermediate_factory,
mutation,
) -> int:
for lambda_value in np.linspace(1, 0, nr_of_steps + 1)[1:]:
print("####################")
print(
f"Perform parameter scaling on cc in step: {intermediate_factory.current_step} with lamb: {lambda_value}"
)
print("####################")
intermediate_factory.write_state(
mutation_conf=mutation,
common_core_transformation=lambda_value,
)
def perform_mutations(
configuration: dict,
i,
mutation_list: list,
list_of_heavy_atoms_to_be_mutated: list = [],
nr_of_mutation_steps_charge: int = 5,
nr_of_mutation_steps_lj_of_hydrogens: int = 1,
nr_of_mutation_steps_lj_of_heavy_atoms: int = 1,
nr_of_mutation_steps_cc: int = 5,
endstate_correction: bool = False,
):
"""Performs the mutations necessary to mutate the physical endstate to the defined common core.
Args:
configuration (dict): A configuration dictionary.
i ([type]): IntermediateState instance
mutation_list (list): list of mutation objects
list_of_heavy_atoms_to_be_mutated (list, optional): A list of atom indices that define the order in which the vdw parameters of the heavy atoms are turned off. Defaults to [].
nr_of_mutation_steps_charge (int, optional): Nr of steps to turne of the charges. Defaults to 5.
nr_of_mutation_steps_lj_of_hydrogens (int, optional): Nr of steps to turne of lj of hydrogens. Only needed for systems with many hydrogens in dummy region
nr_of_mutation_steps_lj_of_heavy_atoms (int, optional): Nr of steps to turne of the lj of heavy atoms
nr_of_mutation_steps_cc (int, optional): Nr of steps to interpolate between the common core parameters. Defaults to 5.
Returns:
list: list of directories with the parameter and topology files
"""
from transformato.utils import map_lj_mutations_to_atom_idx
######################################
# write endpoint mutation
######################################
print("####################")
print(f"Physical endstate in step: 1")
print("####################")
i.write_state(mutation_conf=[])
######################################
# turn off electrostatics
######################################
m = mutation_list["charge"]
# turn off charges
# if number of charge mutation steps are defined in config file overwrite default or passed value
try:
nr_of_mutation_steps_charge = configuration["system"][i.system.structure][
"mutation"
]["steps_charge"]
print("Using number of steps for charge mutattions as defined in config file")
except KeyError:
pass
_performe_linear_charge_scaling(
nr_of_steps=nr_of_mutation_steps_charge,
intermediate_factory=i,
mutation=m,
)
######################################
# turn off LJ
######################################
######################################
# Turn off hydrogens
if nr_of_mutation_steps_lj_of_hydrogens == 1:
if mutation_list["hydrogen-lj"]:
print("####################")
print(f"Hydrogen vdW scaling in step: {i.current_step} with lamb: {0.0}")
print("####################")
i.write_state(
mutation_conf=mutation_list["hydrogen-lj"],
lambda_value_vdw=0.0,
)
else:
# Scaling lj-parameters in multiple steps
if mutation_list["hydrogen-lj"]:
for lambda_value in np.linspace(
0.75, 0, nr_of_mutation_steps_lj_of_hydrogens + 1
):
print("####################")
print(
f"Hydrogen vdW scaling in step: {i.current_step} with lamb: {lambda_value}"
)
print("####################")
i.write_state(
mutation_conf=mutation_list["hydrogen-lj"],
lambda_value_vdw=lambda_value,
)
######################################
# turn off lj of heavy atoms
# take the order from either config file, passed to this function or the default ordering
try:
list_of_heavy_atoms_to_be_mutated = configuration["system"][i.system.structure][
"mutation"
]["heavy_atoms"]
print("Using ordering of LJ mutations as defined in config file.")
except KeyError:
if not list_of_heavy_atoms_to_be_mutated:
# Use the ordering provided by _calculate_order_of_LJ_mutations
list_of_heavy_atoms_to_be_mutated = [
lj.vdw_atom_idx[0] for lj in (mutation_list["lj"])
]
print("Using calculated ordering of LJ mutations.")
else:
print("Using passed ordering of LJ mutations.")
mapping_of_atom_idx_to_mutation = map_lj_mutations_to_atom_idx(mutation_list["lj"])
for heavy_atoms_to_turn_off_in_a_single_step in list_of_heavy_atoms_to_be_mutated:
logger.info(
f"turning off lj of heavy atom: {heavy_atoms_to_turn_off_in_a_single_step}"
)
try: # heavy_atoms_to_turn_off_in_a_single_step can be a tuple or an integer
mutations = [
mapping_of_atom_idx_to_mutation[heavy_atom_idx]
for heavy_atom_idx in heavy_atoms_to_turn_off_in_a_single_step
]
except TypeError:
mutations = [
mapping_of_atom_idx_to_mutation[
heavy_atoms_to_turn_off_in_a_single_step
]
]
# only used in asfe to ensure that last atom is
# turned off in two steps
if (
heavy_atoms_to_turn_off_in_a_single_step
== list_of_heavy_atoms_to_be_mutated[-1]
and configuration["simulation"]["free-energy-type"] == "asfe"
):
for lambda_value in np.linspace(
0.5, 0, nr_of_mutation_steps_lj_of_heavy_atoms + 1
):
print("####################")
print(
f"Turn off last heavy atom vdW parameter in: {i.current_step} on atoms: {heavy_atoms_to_turn_off_in_a_single_step} with lambda {lambda_value}"
)
print("####################")
i.write_state(
mutation_conf=mutations,
lambda_value_vdw=lambda_value,
)
elif nr_of_mutation_steps_lj_of_heavy_atoms == 1:
print("####################")
print(
f"Turn off heavy atom vdW parameter in: {i.current_step} on atoms: {heavy_atoms_to_turn_off_in_a_single_step}"
)
print("####################")
i.write_state(
mutation_conf=mutations,
lambda_value_vdw=0.0,
)
else:
for lambda_value in np.linspace(
0.75, 0, nr_of_mutation_steps_lj_of_heavy_atoms + 1
):
print("####################")
print(
f"Turn off heavy atom vdW parameter in: {i.current_step} on atoms: {heavy_atoms_to_turn_off_in_a_single_step} with lambda {lambda_value}"
)
print("####################")
i.write_state(
mutation_conf=mutations,
lambda_value_vdw=lambda_value,
)
######################################
# generate terminal LJ
######################################
if not configuration["simulation"]["free-energy-type"] == "asfe":
print("####################")
print(
f"Generate terminal LJ particle in step: {i.current_step} on atoms: {[v.vdw_atom_idx for v in mutation_list['default-lj']]}"
)
print("####################")
i.write_state(
mutation_conf=mutation_list["default-lj"],
lambda_value_vdw=0.0,
)
######################################
# mutate common core
######################################
if mutation_list["transform"]:
try:
nr_of_mutation_steps_cc = configuration["system"][i.system.structure][
"mutation"
]["steps_common_core"]
except KeyError:
nr_of_mutation_steps_cc = nr_of_mutation_steps_cc
# change bonded parameters on common core
_performe_linear_cc_scaling(
nr_of_steps=nr_of_mutation_steps_cc,
intermediate_factory=i,
mutation=mutation_list["transform"],
)
if endstate_correction:
i.endstate_correction()
[docs]
@dataclass
class DummyRegion:
mol_name: str
match_termin_real_and_dummy_atoms: dict
connected_dummy_regions: list
tlc: str
lj_default: list
def return_connecting_real_atom(self, dummy_atoms: list):
for real_atom in self.match_termin_real_and_dummy_atoms:
for dummy_atom in self.match_termin_real_and_dummy_atoms[real_atom]:
if dummy_atom in dummy_atoms:
logger.debug(f"Connecting real atom: {real_atom}")
return real_atom
logger.critical("No connecting real atom was found!")
return None
[docs]
@dataclass
class MutationDefinition:
atoms_to_be_mutated: List[int]
common_core: List[int]
dummy_region: DummyRegion
vdw_atom_idx: List[int] = field(default_factory=list)
steric_mutation_to_default: bool = False
def print_details(self):
print("####################")
print(f"Atoms to be mutated: {self.atoms_to_be_mutated}")
print(f"Mutated on common core: {self.common_core}")
if self.vdw_atom_idx:
print(f"VDW atoms to be decoupled: {self.vdw_atom_idx}")
[docs]
class ProposeMutationRoute(object):
[docs]
def __init__(
self,
s1: SystemStructure,
s2: SystemStructure = None,
):
"""
A class that proposes the mutation route between two molecules with a
common core (same atom types) based on two mols and generates the mutation
objects to perform the mutation on the psf objects.
Parameters
----------
mol1: Chem.Mol
mol2: Chem.Mol
"""
try:
mol1_name: str = "m1"
mol2_name: str = "m2"
self.system: dict = {"system1": s1, "system2": s2}
self.mols: dict = {mol1_name: s1.mol, mol2_name: s2.mol}
self.graphs: dict = {mol1_name: s1.graph, mol2_name: s2.graph}
# psfs for reference of only ligand
self.psfs: dict = {
mol1_name: s1.psfs["waterbox"][f":{s1.tlc}"],
mol2_name: s2.psfs["waterbox"][f":{s2.tlc}"],
}
self.psf1: pm.charmm.CharmmPsfFile = s1.psfs
self.psf2: pm.charmm.CharmmPsfFile = s2.psfs
self._substructure_match: dict = {mol1_name: [], mol2_name: []}
self.removed_indeces: dict = {mol1_name: [], mol2_name: []}
self.added_indeces: dict = {mol1_name: [], mol2_name: []}
self.s1_tlc = s1.tlc
self.s2_tlc = s2.tlc
self.terminal_real_atom_cc1: list = []
self.terminal_real_atom_cc2: list = []
self.terminal_dummy_atom_cc1: list = []
self.terminal_dummy_atom_cc2: list = []
self.bondCompare = rdFMCS.BondCompare.CompareAny
self.atomCompare = rdFMCS.AtomCompare.CompareElements
self.maximizeBonds: bool = True
self.matchValences: bool = False
self.completeRingsOnly: bool = False
self.ringMatchesRingOnly: bool = True
self.dummy_region_cc1: DummyRegion
self.dummy_region_cc2: DummyRegion
self.asfe: bool = False
# self._check_cgenff_versions()
except:
logger.info(
"Only information about one structure, assume an ASFE simulation is requested"
)
mol1_name: str = "m1"
self.system: dict = {"system1": s1}
self.mols: dict = {mol1_name: s1.mol}
self.graphs: dict = {mol1_name: s1.graph}
# psfs for reference of only ligand
self.psfs: dict = {s1.psfs["waterbox"][f":{s1.tlc}"]}
self.psf1: pm.charmm.CharmmPsfFile = s1.psfs
self._substructure_match: dict = {mol1_name: []}
self.removed_indeces: dict = {mol1_name: []}
self.added_indeces: dict = {mol1_name: []}
self.s1_tlc = s1.tlc
self.asfe: bool = True
self.dummy_region_cc1: DummyRegion
self.drude: bool = False
for atom in self.psf1["waterbox"].view[f":{s1.tlc}"].atoms:
if atom.type.startswith("DR"):
self.drude = True
logger.info("Assuming Drude particles are present")
def _check_cgenff_versions(self):
cgenff_sys1 = self.system["system1"].cgenff_version
cgenff_sys2 = self.system["system2"].cgenff_version
if cgenff_sys1 == cgenff_sys2:
pass
else:
raise RuntimeError(
f"CGenFF compatibility error. CGenFF: {cgenff_sys1} and CGenFF: {cgenff_sys2} are combined."
)
[docs]
def _match_terminal_real_and_dummy_atoms_for_mol1(self):
"""
Matches the terminal real and dummy atoms and returns a dict with real atom idx as key and a set of dummy atoms that connect
to this real atom as a set
"""
return self._match_terminal_real_and_dummy_atoms(
self.mols["m1"], self.terminal_real_atom_cc1, self.terminal_dummy_atom_cc1
)
[docs]
def _match_terminal_real_and_dummy_atoms_for_mol2(self) -> dict:
"""
Matches the terminal real and dummy atoms and returns a dict with real atom idx as key and a set of dummy atoms that connect
to this real atom as a set
"""
return self._match_terminal_real_and_dummy_atoms(
self.mols["m2"], self.terminal_real_atom_cc2, self.terminal_dummy_atom_cc2
)
[docs]
@staticmethod
def _match_terminal_real_and_dummy_atoms(
mol, real_atoms_cc: list, dummy_atoms_cc: list
) -> dict:
"""
Matches the terminal real and dummy atoms and returns a dict with real atom idx as key and a set of dummy atoms that connect
to this real atom as a set
Parameters
----------
mol : [Chem.Mol]
The mol object with the real and dummy atoms
real_atoms_cc : list
list of real atom idx
dummy_atoms_cc : list
list of dummy atom idx
Returns
-------
[type]
[description]
"""
from collections import defaultdict
real_atom_match_dummy_atom = defaultdict(set)
for real_atom_idx in real_atoms_cc:
real_atom = mol.GetAtomWithIdx(real_atom_idx)
real_neighbors = [x.GetIdx() for x in real_atom.GetNeighbors()]
for dummy_atoms_idx in dummy_atoms_cc:
if dummy_atoms_idx in real_neighbors:
real_atom_match_dummy_atom[real_atom_idx].add(dummy_atoms_idx)
return real_atom_match_dummy_atom
def _set_common_core_parameters(self):
# find terminal atoms
(
self.terminal_dummy_atom_cc1,
self.terminal_real_atom_cc1,
) = self._find_terminal_atom(self.get_common_core_idx_mol1(), self.mols["m1"])
(
self.terminal_dummy_atom_cc2,
self.terminal_real_atom_cc2,
) = self._find_terminal_atom(self.get_common_core_idx_mol2(), self.mols["m2"])
# match terminal real atoms between cc1 and cc2 that connect dummy atoms
cc_idx_mol1 = self.get_common_core_idx_mol1()
cc_idx_mol2 = self.get_common_core_idx_mol2()
matching_terminal_atoms_between_cc = list()
for cc1_idx, cc2_idx in zip(cc_idx_mol1, cc_idx_mol2):
if (
cc1_idx in self.terminal_real_atom_cc1
and cc2_idx in self.terminal_real_atom_cc2
):
logger.info(
f"Dummy regions connect on the same terminal atoms. cc1: {cc1_idx} : cc2: {cc2_idx}"
)
matching_terminal_atoms_between_cc.append((cc1_idx, cc2_idx))
elif (
cc1_idx in self.terminal_real_atom_cc1
and cc2_idx not in self.terminal_real_atom_cc2
) or (
cc1_idx not in self.terminal_real_atom_cc1
and cc2_idx in self.terminal_real_atom_cc2
):
logger.info(
f"Single dummy region connects on terminal atom. cc1: {cc1_idx} : cc2: {cc2_idx}"
)
matching_terminal_atoms_between_cc.append((cc1_idx, cc2_idx))
else:
pass
if not matching_terminal_atoms_between_cc:
raise RuntimeError(
"No terminal real atoms were matched between the common cores. Aborting."
)
self.matching_terminal_atoms_between_cc = matching_terminal_atoms_between_cc
def _match_terminal_dummy_atoms_between_common_cores(
self,
match_terminal_atoms_cc1: dict,
match_terminal_atoms_cc2: dict,
) -> Tuple[list, list]:
cc1_idx = self._substructure_match["m1"]
cc2_idx = self._substructure_match["m2"]
lj_default_cc1 = []
lj_default_cc2 = []
# iterate through the common core substracter (the order represents the matched atoms)
for idx1, idx2 in zip(cc1_idx, cc2_idx):
# if both atoms are terminal atoms connected dummy regions can be identified
if (
idx1 in match_terminal_atoms_cc1.keys()
and idx2 in match_terminal_atoms_cc2.keys()
):
connected_dummy_cc1 = list(match_terminal_atoms_cc1[idx1])
connected_dummy_cc2 = list(match_terminal_atoms_cc2[idx2])
if len(connected_dummy_cc1) == 1 and len(connected_dummy_cc2) == 1:
pass
# multiple, possible dummy regions
elif len(connected_dummy_cc1) > 1 or len(connected_dummy_cc2) > 1:
logger.critical("There is a dual junction. Be careful.")
# NOTE: For now we are just taking the non hydrogen atom
for atom_idx in connected_dummy_cc1:
if self.mols["m1"].GetAtomWithIdx(atom_idx).GetSymbol() != "H":
connected_dummy_cc1 = [atom_idx]
break
for atom_idx in connected_dummy_cc2:
if self.mols["m2"].GetAtomWithIdx(atom_idx).GetSymbol() != "H":
connected_dummy_cc2 = [atom_idx]
break
# hydrogen mutates to dummy atom (but not a LJ particle)
elif len(connected_dummy_cc1) == 0 or len(connected_dummy_cc2) == 0:
logger.debug("Hydrogen to dummy mutation")
raise NotImplementedError()
lj_default_cc1.append(connected_dummy_cc1[0])
lj_default_cc2.append(connected_dummy_cc2[0])
return (lj_default_cc1, lj_default_cc2)
[docs]
@staticmethod
def _calculate_order_of_LJ_mutations(
connected_dummy_regions: list,
match_terminal_atoms: dict,
G: nx.Graph,
cyclecheck=True,
ordercycles=True,
exclude_Hs=False,
) -> list:
"""
bfs/djikstra-algorithm applied to calculate the ordere for turning of the LJ interactions of the heavy atoms
-----
most functions for theses options are imported from the helper_functions file
cyclecheck: updates weights according to cycle participation (should always be set to True)
ordercheck: if there is no possibility to decide between two nodes - i.e. the weight would be the exactly the same - weight updating according to preferential removal decides that the node in which neighbourhood nodes already have been removed is removed next
exclude_Hs: if True, hydrogens are removed before the mutation algorithm is applied - necessary for usual Transformato workflow
"""
if exclude_Hs == True:
connected_dummy_regions, G = exclude_Hs_from_mutations(
connected_dummy_regions, G
)
ordered_LJ_mutations = []
for real_atom in match_terminal_atoms:
for dummy_atom in match_terminal_atoms[real_atom]:
for connected_dummy_region in connected_dummy_regions:
# stop at connected dummy region with specific dummy_atom in it
if dummy_atom not in connected_dummy_region:
continue
G_dummy = G.copy()
# delete all nodes not in dummy region
remove_nodes = [
node for node in G.nodes() if node not in connected_dummy_region
]
for remove_node in remove_nodes:
G_dummy.remove_node(remove_node)
# root is the dummy atom that connects the real region with the dummy region
root = dummy_atom
# process cycles
if cyclecheck == True and ordercycles == False:
G_dummy = cycle_checks_nx(G_dummy)
# process cycles and correct order (according to 'preferential removal')
if cyclecheck == True and ordercycles == True:
cycledict, degreedict = cycle_checks(G_dummy)
# dijkstra
ssource = nx.single_source_dijkstra(
G_dummy, source=root, weight="weight"
)
# result of dijkstra algorithm is sorted
sortedssource = {
k: v
for k, v in sorted(ssource[0].items(), key=lambda item: item[1])
}
# get keys of sorted dict
sortedssource_edges = sortedssource.keys()
sortedssource_edges_list = list(sortedssource_edges)
# sorted list contains the mutation route
nodes = sortedssource_edges_list
# order has to be reversed - the most distant atom is the first to be removed
nodes.reverse()
# sort nodes according to degree, cycle participation and removal order
if cyclecheck == True and ordercycles == True:
nodes = change_route_cycles(
nodes, cycledict, degreedict, sortedssource, G
)
print("Final mutation route:")
print(nodes)
ordered_LJ_mutations.append(nodes)
return ordered_LJ_mutations
[docs]
def _check_for_lp(
self,
odered_connected_dummy_regions_cc_with_lp: list,
psf: pm.charmm.CharmmPsfFile,
tlc: str,
name: str,
) -> list:
"""
With the help of parmed this function will look in the ordered_connected_dummy_regions list if
there is a atom which has lonepairs. It will check wheather the lp belongs to the common core or
to the dummy region and assign it into the sorted list accordingly.
"""
flat_ordered_connected_dummy_regions = [
item
for sublist in odered_connected_dummy_regions_cc_with_lp
for item in sublist
]
lp_dict_dummy_region = defaultdict(list)
lp_dict_common_core = defaultdict(list)
for atom in psf.view[f":{tlc}"].atoms:
if atom.name.find("LP") == False:
if atom.frame_type.atom1.idx in flat_ordered_connected_dummy_regions:
lp_dict_dummy_region[atom.frame_type.atom1.idx].append(atom.idx)
elif (
atom.frame_type.atom1.idx not in lp_dict_common_core
and name == "m1"
):
logger.info(f"Adding atom {atom.idx} to the common core of mol1")
self.add_idx_to_common_core_of_mol1([atom.idx])
elif (
atom.frame_type.atom1.idx not in lp_dict_common_core
and name == "m2"
):
logger.info(f"Adding atom {atom.idx} to the common core of mol1")
self.add_idx_to_common_core_of_mol2([atom.idx])
if lp_dict_dummy_region:
for i in odered_connected_dummy_regions_cc_with_lp:
lp_to_insert = []
for atom in i:
if atom in lp_dict_dummy_region.keys():
lp_to_insert.extend(lp_dict_dummy_region[atom])
for lp_num in reversed(lp_to_insert):
i.insert(0, lp_num)
logger.debug(
f"Orderd connected dummy atoms containing the lp {odered_connected_dummy_regions_cc_with_lp}"
)
return odered_connected_dummy_regions_cc_with_lp
[docs]
def get_idx_of_all_atoms(
self,
mol1_name: str,
):
"""
Iterates over all atoms of the molecule and saves them as a list
----------
mol1_name: str
"""
s1 = []
for atom in self.psf1["waterbox"][f":{self.s1_tlc}"].atoms:
s1.append(atom.idx)
self._substructure_match[mol1_name] = list(s1)
[docs]
def propose_common_core(self):
"""
Searches for the common core using the rdkit module, in case of asfe only a list of
atoms of the ligand is created
"""
if self.asfe:
self.get_idx_of_all_atoms("m1")
else:
# System for RBFE/RSFE contains two mols
mcs = self._find_mcs("m1", "m2")
return mcs
[docs]
def finish_common_core(
self,
connected_dummy_regions_cc1: list = [],
connected_dummy_regions_cc2: list = [],
odered_connected_dummy_regions_cc1: list = [],
odered_connected_dummy_regions_cc2: list = [],
):
"""
The dummy region is created and the final atoms connected to the CC are collected. It is possible
to define a dummy region on its own or to change the ordering how the lj parameters of the
heavy atoms in the dummy region are turned off
---------
connected_dummy_regions_cc1: list = []
connected_dummy_regions_cc2: list = []
odered_connected_dummy_regions_cc1: list = []
odered_connected_dummy_regions_cc2: list = []
"""
if not self.asfe:
# set the teriminal real/dummy atom indices
self._set_common_core_parameters()
# match the real/dummy atoms
match_terminal_atoms_cc1 = (
self._match_terminal_real_and_dummy_atoms_for_mol1()
)
match_terminal_atoms_cc2 = (
self._match_terminal_real_and_dummy_atoms_for_mol2()
)
logger.info("Find connected dummy regions")
# define connected dummy regions
if not connected_dummy_regions_cc1:
connected_dummy_regions_cc1 = self._find_connected_dummy_regions(
mol_name="m1",
)
if not connected_dummy_regions_cc2:
connected_dummy_regions_cc2 = self._find_connected_dummy_regions(
mol_name="m2",
)
logger.debug(
f"connected dummy regions for mol1: {connected_dummy_regions_cc1}"
)
logger.debug(
f"connected dummy regions for mol2: {connected_dummy_regions_cc2}"
)
# calculate the ordering or LJ mutations
if not odered_connected_dummy_regions_cc1:
odered_connected_dummy_regions_cc1 = (
self._calculate_order_of_LJ_mutations(
connected_dummy_regions_cc1,
match_terminal_atoms_cc1,
self.graphs["m1"].copy(),
)
)
if not odered_connected_dummy_regions_cc2:
odered_connected_dummy_regions_cc2 = (
self._calculate_order_of_LJ_mutations(
connected_dummy_regions_cc2,
match_terminal_atoms_cc2,
self.graphs["m2"].copy(),
)
)
logger.info(
f"sorted connected dummy regions for mol1: {odered_connected_dummy_regions_cc1}"
)
logger.info(
f"sorted connected dummy regions for mol2: {odered_connected_dummy_regions_cc2}"
)
if odered_connected_dummy_regions_cc1:
odered_connected_dummy_regions_cc1 = self._check_for_lp(
odered_connected_dummy_regions_cc1,
self.psf1["waterbox"],
self.s1_tlc,
"m1",
)
if odered_connected_dummy_regions_cc2:
odered_connected_dummy_regions_cc2 = self._check_for_lp(
odered_connected_dummy_regions_cc2,
self.psf2["waterbox"],
self.s2_tlc,
"m2",
)
# find the atoms from dummy_region in s1 that needs to become lj default
(
lj_default_cc1,
lj_default_cc2,
) = self._match_terminal_dummy_atoms_between_common_cores(
match_terminal_atoms_cc1, match_terminal_atoms_cc2
)
self.dummy_region_cc1 = DummyRegion(
mol_name="m1",
tlc=self.s1_tlc,
match_termin_real_and_dummy_atoms=match_terminal_atoms_cc1,
connected_dummy_regions=odered_connected_dummy_regions_cc1,
lj_default=lj_default_cc1,
)
self.dummy_region_cc2 = DummyRegion(
mol_name="m2",
tlc=self.s2_tlc,
match_termin_real_and_dummy_atoms=match_terminal_atoms_cc2,
connected_dummy_regions=odered_connected_dummy_regions_cc2,
lj_default=lj_default_cc2,
)
# generate charge compmensated psfs
psf1, psf2 = self._prepare_cc_for_charge_transfer()
self.charge_compensated_ligand1_psf = psf1
self.charge_compensated_ligand2_psf = psf2
else:
# all atoms should become dummy atoms in the end
central_atoms = nx.center(self.graphs["m1"])
if not self.drude:
# Assure, that the central atom is no hydrogen
for atom in self.psf1["waterbox"][f":{self.s1_tlc}"].atoms:
if atom.idx in central_atoms:
if atom.name.startswith("H") == True:
raise RuntimeError(
f"One of the central atoms seems to be a hydrogen atom"
)
# calculate the ordering or LJ mutations
if not odered_connected_dummy_regions_cc1:
odered_connected_dummy_regions_cc1 = (
calculate_order_of_LJ_mutations_asfe(
central_atoms,
self.graphs["m1"].copy(),
)
)
if odered_connected_dummy_regions_cc1:
odered_connected_dummy_regions_cc1 = self._check_for_lp(
odered_connected_dummy_regions_cc1,
self.psf1["waterbox"],
self.s1_tlc,
"m1",
)
else:
odered_connected_dummy_regions_cc1 = []
for atom in self.psf1["vacuum"].view[f":{self.s1_tlc}"]:
if atom.type != "DRUD":
if not atom.type.startswith("L"):
# if not atom.type.startswith("H"):
odered_connected_dummy_regions_cc1.append(atom.idx)
odered_connected_dummy_regions_cc1 = list(
[odered_connected_dummy_regions_cc1]
)
logger.info(
f"Using the order as given in the psf file {odered_connected_dummy_regions_cc1}"
)
self.dummy_region_cc1 = DummyRegion(
mol_name="m1",
tlc=self.s1_tlc,
match_termin_real_and_dummy_atoms=[],
connected_dummy_regions=odered_connected_dummy_regions_cc1,
lj_default=[],
)
def calculate_common_core(self):
self.propose_common_core()
self.finish_common_core()
def _prepare_cc_for_charge_transfer(self):
# we have to run the same charge mutation that will be run on cc2 to get the
# charge distribution AFTER the full mutation
# make a copy of the full psf
m2_psf = self.psfs["m2"][:, :, :]
m1_psf = self.psfs["m1"][:, :, :]
charge_transformed_psfs = []
for psf, tlc, cc_idx, dummy_region in zip(
[m1_psf, m2_psf],
[self.s1_tlc, self.s2_tlc],
[self.get_common_core_idx_mol1(), self.get_common_core_idx_mol2()],
[self.dummy_region_cc1, self.dummy_region_cc2],
):
## We need this for point mutations, because if we give a resid, the mol here
## consists only of on residue which resid is always 1
try:
int(tlc)
tlc = "1"
except ValueError:
tlc = tlc
# set `initial_charge` parameter for Mutation
for atom in psf.view[f":{tlc}"].atoms:
# charge, epsilon and rmin are directly modiefied
atom.initial_charge = atom.charge
offset = min([atom.idx for atom in psf.view[f":{tlc}"].atoms])
# getting copy of the atoms
atoms_to_be_mutated = []
for atom in psf.view[f":{tlc}"].atoms:
idx = atom.idx - offset
if idx not in cc_idx:
atoms_to_be_mutated.append(idx)
logger.debug("############################")
logger.debug("Preparing cc2 for charge transfer")
logger.debug(
f"Atoms for which charge is set to zero: {atoms_to_be_mutated}"
)
logger.debug("############################")
m = Mutation(
atoms_to_be_mutated=atoms_to_be_mutated, dummy_region=dummy_region
)
m.mutate(psf, lambda_value_electrostatic=0.0)
charge_transformed_psfs.append(psf)
return charge_transformed_psfs[0], charge_transformed_psfs[1]
def remove_idx_from_common_core_of_mol1(self, idx_list: list):
for idx in idx_list:
self._remove_idx_from_common_core("m1", idx)
def remove_idx_from_common_core_of_mol2(self, idx_list: list):
for idx in idx_list:
self._remove_idx_from_common_core("m2", idx)
def _remove_idx_from_common_core(self, name: str, idx: int):
if idx in self.added_indeces[name] or idx in self._get_common_core(name):
if idx in self.removed_indeces[name]:
print(f"Idx: {idx} already removed from common core.")
return
self.removed_indeces[name].append(idx)
else:
print(f"Idx: {idx} not in common core.")
[docs]
def add_idx_to_common_core_of_mol1(self, idx_list: list):
"""Adds a list of atoms to the common core of molecule 1
.. caution::
Be aware of the ordering! Atom idx need to be added to match the ordering of the atom idx of common core 2
Args:
idx_list: Array of atom idxs to add
"""
for idx in idx_list:
self._add_common_core_atom("m1", idx)
logger.warning(
f"ATTENTION: Be aware of the ordering! Atom idx need to be added to match the ordering of the atom idx of common core 2"
)
logger.info(
f"Atom idx of the new common core: {self.get_common_core_idx_mol1()}"
)
[docs]
def add_idx_to_common_core_of_mol2(self, idx_list: list):
"""Adds a list of atoms to the common core of molecule 1
.. caution::
Be aware of the ordering! Atom idx need to be added to match the ordering of the atom idx of common core 2
Args:
idx_list: Array of atom idxs to add
"""
for idx in idx_list:
self._add_common_core_atom("m2", idx)
logger.warning(
f"ATTENTION: Be aware of the ordering! Atom idx need to be added to match the ordering of the atom idx of common core 1"
)
logger.info(
f" Atom idx of the new common core: {self.get_common_core_idx_mol2()}"
)
def _add_common_core_atom(self, name: str, idx: int):
if idx in self.added_indeces[name] or idx in self._get_common_core(name):
print(f"Idx: {idx} already in common core.")
return
self.added_indeces[name].append(idx)
def get_idx_not_in_common_core_for_mol1(self) -> list:
return self._get_idx_not_in_common_core_for_mol("m1")
def get_idx_not_in_common_core_for_mol2(self) -> list:
return self._get_idx_not_in_common_core_for_mol("m2")
def _get_idx_not_in_common_core_for_mol(self, mol_name: str) -> list:
dummy_list_mol = [
atom.GetIdx()
for atom in self.mols[mol_name].GetAtoms()
if atom.GetIdx() not in self._get_common_core(mol_name)
]
return dummy_list_mol
[docs]
def get_common_core_idx_mol1(self) -> list:
"""
Returns the common core of mol1.
"""
return self._get_common_core("m1")
[docs]
def get_common_core_idx_mol2(self) -> list:
"""
Returns the common core of mol2.
"""
return self._get_common_core("m2")
[docs]
def _get_common_core(self, name: str) -> list:
"""
Helper Function - should not be called directly.
Returns the common core.
"""
keep_idx = []
# BEWARE: the ordering is important - don't cast set!
for idx in self._substructure_match[name] + self.added_indeces[name]:
if idx not in self.removed_indeces[name]:
keep_idx.append(idx)
return keep_idx
[docs]
def _find_mcs(
self,
mol1_name: str,
mol2_name: str,
iterate_over_matches: bool = False,
max_matches: int = 10,
):
"""
A class that proposes the mutation route between two molecules with a
common core (same atom types) based on two mols and generates the mutation
objects to perform the mutation on the psf objects.
Parameters
----------
mol1_name: str
mol2_name: str
"""
logger.info("MCS starting ...")
logger.debug(f"bondCompare: {self.bondCompare}")
logger.debug(f"atomCompare: {self.atomCompare}")
logger.debug(f"maximizeBonds: {self.maximizeBonds}")
logger.debug(f"matchValences: {self.matchValences} ")
logger.debug(f"ringMatchesRingOnly: {self.ringMatchesRingOnly} ")
logger.debug(f"completeRingsOnly: {self.completeRingsOnly} ")
m1, m2 = [deepcopy(self.mols[mol1_name]), deepcopy(self.mols[mol2_name])]
# second copy of mols - to use as representation with removed hydrogens
remmol1 = deepcopy(m1)
remmol2 = deepcopy(m2)
# removal of hydrogens - if not removed, common core for molecule + hydrogens is computed!
remmol1 = Chem.rdmolops.RemoveAllHs(remmol1)
remmol2 = Chem.rdmolops.RemoveAllHs(remmol2)
# remmols contains both molecules with removed hydrogens
remmols = [remmol1, remmol2]
for m in [m1, m2]:
logger.debug("Mol in SMILES format: {}.".format(Chem.MolToSmiles(m, True)))
# make copy of mols
changed_mols = [Chem.Mol(x) for x in [m1, m2]]
# find substructure match (ignore bond order but enforce element matching)
# findmcs-function is called for mol-objects with removed hydrogens
# original Transformato-parameters (yield bad / for Transformato not usable results for molecules with cyclic structures, e.g., ccores between 2-CPI and 7-CPI)
# especially because completeRingsOnly is set to False
"""
mcs = rdFMCS.FindMCS(
#changed_mols,
remmols,
bondCompare=self.bondCompare,
timeout=120,
atomCompare=self.atomCompare,
maximizeBonds=self.maximizeBonds,
matchValences=self.matchValences,
completeRingsOnly=self.completeRingsOnly,
ringMatchesRingOnly=self.ringMatchesRingOnly,
)
"""
# find_mcs-function from tf_routes:
# yields more reasonable common cores (e.g. for 2-CPI/7-CPI )
# in particular, completeRingsOnly=True is important
mcs = rdFMCS.FindMCS(
remmols,
timeout=120,
ringMatchesRingOnly=True,
completeRingsOnly=True,
ringCompare=Chem.rdFMCS.RingCompare.StrictRingFusion,
bondCompare=rdFMCS.BondCompare.CompareAny,
matchValences=False,
)
logger.debug("Substructure match: {}".format(mcs.smartsString))
# convert from SMARTS
mcsp = Chem.MolFromSmarts(mcs.smartsString, False)
# iterate_over_matches == False: the common core atoms for a single stubstructure match are determined
# possibly a different match yields a bigger ccore - i.e. a ccore with more hydrogens (neopentane - methane)
if iterate_over_matches == False:
s1 = m1.GetSubstructMatch(mcsp)
logger.debug("Substructere match idx: {}".format(s1))
self._show_common_core(
m1, self.get_common_core_idx_mol1(), show_atom_type=False, internal=True
)
s2 = m2.GetSubstructMatch(mcsp)
logger.debug("Substructere match idx: {}".format(s2))
self._show_common_core(
m2, self.get_common_core_idx_mol2(), show_atom_type=False, internal=True
)
# new code: add hydrogens to both common-core-on-molecule-projections
# set with all common core atom indices for both molecules
hit_ats1_compl = list(s1)
hit_ats2_compl = list(s2)
# check for each common core atom whether hydrogen atoms are in its neighbourhood
# s1/s2 contain the mapping of the common core (without hydrogens) to both molecules
# iterating over all mapped atoms, the number of hydrogens attached to the common core atom is determined
# the minimum number (i.e. if the atom of molecule 1 has one hydrogen bond, the atom of molecule 2 zero hydrogen bonds, it is zero) gives the number of hydrogen atoms to add to the common core
for indexpos, indexnr in enumerate(s1):
# get mapped atoms
atom1 = m1.GetAtomWithIdx(s1[indexpos])
atom2 = m2.GetAtomWithIdx(s2[indexpos])
# determine number of hydrogens in the neighbourhood of the atom from molecule1
h_atoms1 = 0
for x in atom1.GetNeighbors():
if x.GetSymbol() == "H":
h_atoms1 = h_atoms1 + 1
# determine number of hydrogens in the neighbourhood of the atom from molecule2
h_atoms2 = 0
for x in atom2.GetNeighbors():
if x.GetSymbol() == "H":
h_atoms2 = h_atoms2 + 1
# find minimum number of hydrogens
min_h_atoms = min(h_atoms1, h_atoms2)
# add minimum number of hydrogens to the ccore for molecule1
h_atoms1 = 0
for x in atom1.GetNeighbors():
if x.GetSymbol() == "H" and h_atoms1 < min_h_atoms:
hit_ats1_compl.append(x.GetIdx())
h_atoms1 = h_atoms1 + 1
# add minimum number of hydrogens to the ccore for molecule2
h_atoms2 = 0
for x in atom2.GetNeighbors():
if x.GetSymbol() == "H" and h_atoms2 < min_h_atoms:
hit_ats2_compl.append(x.GetIdx())
h_atoms2 = h_atoms2 + 1
# create new tuple of common core atom indices with additional hydrogens (molecule 1)
hit_ats1 = tuple(hit_ats1_compl)
# create new tuple of common core atom indices with additional hydrogens (molecule 2)
hit_ats2 = tuple(hit_ats2_compl)
self._substructure_match[mol1_name] = list(hit_ats1)
self._substructure_match[mol2_name] = list(hit_ats2)
# self._substructure_match[mol1_name] = list(s1)
# self._substructure_match[mol2_name] = list(s2)
return mcs
# iterate_over_matches == True: it is iterated over all pairs of substructure matches
# the substructure matches with the biggest emering common cores are finally chosen
# the common cores for different substructure match pairs contain the same heavy atoms, but differ in the number of hydrogens, i.e. the finally chosen matches have the common cores with most hydrogens
else:
s1s = m1.GetSubstructMatches(mcsp, maxMatches=max_matches)
logger.debug("Substructere match idx: {}".format(s1s))
self._show_common_core(
m1, self.get_common_core_idx_mol1(), show_atom_type=False, internal=True
)
s2s = m2.GetSubstructMatches(mcsp, maxMatches=max_matches)
logger.debug("Substructere match idx: {}".format(s2s))
self._show_common_core(
m2, self.get_common_core_idx_mol2(), show_atom_type=False, internal=True
)
curr_size_of_ccores = 0
for s1 in s1s:
for s2 in s2s:
# new code: add hydrogens to both common-core-on-molecule-projections
# set with all common core atom indices for both molecules
hit_ats1_compl = list(s1)
hit_ats2_compl = list(s2)
# check for each common core atom whether hydrogen atoms are in its neighbourhood
# s1/s2 contain the mapping of the common core (without hydrogens) to both molecules
# iterating over all mapped atoms, the number of hydrogens attached to the common core atom is determined
# the minimum number (i.e. if the atom of molecule 1 has one hydrogen bond, the atom of molecule 2 zero hydrogen bonds, it is zero) gives the number of hydrogen atoms to add to the common core
for indexpos, indexnr in enumerate(s1):
# get mapped atoms
atom1 = m1.GetAtomWithIdx(s1[indexpos])
atom2 = m2.GetAtomWithIdx(s2[indexpos])
# determine number of hydrogens in the neighbourhood of the atom from molecule1
h_atoms1 = 0
for x in atom1.GetNeighbors():
if x.GetSymbol() == "H":
h_atoms1 = h_atoms1 + 1
# determine number of hydrogens in the neighbourhood of the atom from molecule2
h_atoms2 = 0
for x in atom2.GetNeighbors():
if x.GetSymbol() == "H":
h_atoms2 = h_atoms2 + 1
# find minimum number of hydrogens
min_h_atoms = min(h_atoms1, h_atoms2)
# add minimum number of hydrogens to the ccore for molecule1
h_atoms1 = 0
for x in atom1.GetNeighbors():
if x.GetSymbol() == "H" and h_atoms1 < min_h_atoms:
hit_ats1_compl.append(x.GetIdx())
h_atoms1 = h_atoms1 + 1
# add minimum number of hydrogens to the ccore for molecule2
h_atoms2 = 0
for x in atom2.GetNeighbors():
if x.GetSymbol() == "H" and h_atoms2 < min_h_atoms:
hit_ats2_compl.append(x.GetIdx())
h_atoms2 = h_atoms2 + 1
# count whether the new common cores are bigger (i.e. contain more hydrogens) than the previous common cores
# if this is the case, the current substructure matches are chosen
if len(hit_ats1_compl) > curr_size_of_ccores:
curr_size_of_ccores = len(hit_ats1_compl)
hit_ats1_compl_final = hit_ats1_compl
hit_ats2_compl_final = hit_ats2_compl
# create new tuple of common core atom indices with additional hydrogens (molecule 1)
hit_ats1 = tuple(hit_ats1_compl_final)
# create new tuple of common core atom indices with additional hydrogens (molecule 2)
hit_ats2 = tuple(hit_ats2_compl_final)
self._substructure_match[mol1_name] = list(hit_ats1)
self._substructure_match[mol2_name] = list(hit_ats2)
# self._substructure_match[mol1_name] = list(s1)
# self._substructure_match[mol2_name] = list(s2)
return mcs
def _return_atom_idx_from_bond_idx(self, mol: Chem.Mol, bond_idx: int):
return (
mol.GetBondWithIdx(bond_idx).GetBeginAtomIdx(),
mol.GetBondWithIdx(bond_idx).GetEndAtomIdx(),
)
def _find_connected_dummy_regions(self, mol_name: str) -> List[set]:
sub = self._get_common_core(mol_name)
#############################
# start
#############################
mol = self.mols[mol_name]
G = self.graphs[mol_name].copy()
# find all dummy atoms
list_of_dummy_atoms_idx = [
atom.GetIdx() for atom in mol.GetAtoms() if atom.GetIdx() not in sub
]
nr_of_dummy_atoms = len(list_of_dummy_atoms_idx) + 1
list_of_real_atoms_idx = [
atom.GetIdx() for atom in mol.GetAtoms() if atom.GetIdx() in sub
]
# remove real atoms from graph to obtain multiple connected compounds
for real_atom_idx in list_of_real_atoms_idx:
G.remove_node(real_atom_idx)
# find these connected compounds
from networkx.algorithms.components import connected_components
unique_subgraphs = [
c for c in sorted(nx.connected_components(G), key=len, reverse=True)
]
return unique_subgraphs
[docs]
def show_common_core_on_mol1(self, show_atom_types: bool = False):
"""
Shows common core on mol1
"""
return self._show_common_core(
self.mols["m1"],
self.get_common_core_idx_mol1(),
show_atom_types,
internal=False,
)
[docs]
def show_common_core_on_mol2(self, show_atom_types: bool = False):
"""
Shows common core on mol2
"""
return self._show_common_core(
self.mols["m2"],
self.get_common_core_idx_mol2(),
show_atom_types,
internal=False,
)
[docs]
def _show_common_core(
self, mol, highlight: list, show_atom_type: bool, internal: bool
):
"""
Helper function - do not call directly.
Show common core.
"""
# https://rdkit.blogspot.com/2015/02/new-drawing-code.html
mol = deepcopy(mol)
drawer = rdMolDraw2D.MolDraw2DSVG(500, 500)
drawer.SetFontSize(6)
opts = drawer.drawOptions()
if show_atom_type:
for i in mol.GetAtoms():
opts.atomLabels[i.GetIdx()] = (
str(i.GetProp("atom_index")) + ":" + i.GetProp("atom_type")
)
elif mol.GetNumAtoms() < 30:
for i in mol.GetAtoms():
opts.atomLabels[i.GetIdx()] = (
str(i.GetProp("atom_index")) + ":" + i.GetProp("atom_name")
)
rdCoordGen.AddCoords(mol) # Create Cordinates
drawer.DrawMolecule(mol, highlightAtoms=highlight)
drawer.FinishDrawing()
svg = drawer.GetDrawingText().replace("svg:", "")
if internal:
display(SVG(svg))
return svg
[docs]
def generate_mutations_to_common_core_for_mol1(self) -> dict:
"""
Generates the mutation route to the common fore for mol1.
----------
mutations: list
list of mutations
"""
m = self._mutate_to_common_core(
self.dummy_region_cc1, self.get_common_core_idx_mol1(), mol_name="m1"
)
if not self.asfe:
m["transform"] = self._transform_common_core()
return m
[docs]
def generate_mutations_to_common_core_for_mol2(self) -> dict:
"""
Generates the mutation route to the common fore for mol2.
Returns
----------
mutations: list
list of mutations
"""
if not self.terminal_real_atom_cc1:
raise RuntimeError("First generate the MCS")
m = self._mutate_to_common_core(
self.dummy_region_cc2, self.get_common_core_idx_mol2(), mol_name="m2"
)
return m
[docs]
@staticmethod
def _find_terminal_atom(cc_idx: list, mol: Chem.Mol) -> Tuple[list, list]:
"""
Find atoms that connect the molecule to the common core.
Args:
cc_idx (list): common core index atoms
mol ([type]): rdkit mol object
"""
terminal_dummy_atoms = []
terminal_real_atoms = []
for atom in mol.GetAtoms():
idx = atom.GetIdx()
if idx not in cc_idx:
neighbors = [x.GetIdx() for x in atom.GetNeighbors()]
if any([n in cc_idx for n in neighbors]):
terminal_dummy_atoms.append(idx)
if idx in cc_idx:
neighbors = [x.GetIdx() for x in atom.GetNeighbors()]
if any([n not in cc_idx for n in neighbors]):
terminal_real_atoms.append(idx)
logger.info(f"Terminal dummy atoms: {str(list(set(terminal_dummy_atoms)))}")
logger.info(f"Terminal real atoms: {str(list(set(terminal_real_atoms)))}")
return (list(set(terminal_dummy_atoms)), list(set(terminal_real_atoms)))
[docs]
def _mutate_to_common_core(
self, dummy_region: DummyRegion, cc_idx: list, mol_name: str
) -> dict:
"""
Helper function - do not call directly.
Generates the mutation route to the common fore for mol.
"""
mutations = defaultdict(list)
## We need this for point mutations, because if we give a resid, the mol here
## consists only of on residue which resid is always 1
try:
int(self.s1_tlc)
tlc = "1"
except ValueError:
tlc = self.s1_tlc
if self.asfe:
psf = self.psf1["waterbox"]
cc_idx = [] # no CC in ASFE
list_termin_dummy_atoms = []
else:
# copy of the currently used psf
psf = self.psfs[f"{mol_name}"][:, :, :]
# only necessary for relative binding/solvation free energies
# get the atom that connects the common core to the dummy regiom
match_termin_real_and_dummy_atoms = (
dummy_region.match_termin_real_and_dummy_atoms
)
# get the terminal dummy atoms
list_termin_dummy_atoms = []
for m in match_termin_real_and_dummy_atoms.values():
list_termin_dummy_atoms.extend(list(m))
logger.info(f"Terminal dummy atoms: {list_termin_dummy_atoms}")
if mol_name == "m2":
## We need this for point mutations, because if we give a resid, the mol here
## consists only of on residue which resid is always 1
try:
int(self.s2_tlc)
tlc = "1"
except ValueError:
tlc = self.s2_tlc
# iterate through atoms and select atoms that need to be mutated
atoms_to_be_mutated = []
hydrogens = []
for atom in psf.view[f":{tlc}"].atoms:
# idx = atom.idx - self.offset
idx = atom.idx
if idx not in cc_idx:
if atom.name.find("H") == False and idx not in list_termin_dummy_atoms:
hydrogens.append(idx)
atoms_to_be_mutated.append(idx)
logger.info(
"Will be decoupled: Idx:{} Element:{}".format(idx, atom.name)
)
if atoms_to_be_mutated:
############################################
############################################
# charge mutation
############################################
############################################
m = MutationDefinition(
atoms_to_be_mutated=atoms_to_be_mutated,
common_core=cc_idx,
dummy_region=dummy_region,
vdw_atom_idx=[],
steric_mutation_to_default=False,
)
mutations["charge"].append(m)
############################################
############################################
# LJ mutation
############################################
############################################
# start with mutation of LJ of hydrogens
# Only take hydrogens that are not terminal hydrogens
if hydrogens:
m = MutationDefinition(
atoms_to_be_mutated=atoms_to_be_mutated,
common_core=cc_idx,
dummy_region=dummy_region,
vdw_atom_idx=hydrogens,
steric_mutation_to_default=False,
)
mutations["hydrogen-lj"].append(m)
for region in dummy_region.connected_dummy_regions:
for atom_idx in region:
if (
atom_idx in list_termin_dummy_atoms
and atom_idx in dummy_region.lj_default
):
# test if atom is a terminal atom and there is a corresponding atom on the other cc
# in this case the atom needs to become a default lj particle
m = MutationDefinition(
atoms_to_be_mutated=atoms_to_be_mutated,
common_core=cc_idx,
dummy_region=dummy_region,
vdw_atom_idx=[atom_idx],
steric_mutation_to_default=True,
)
mutations["default-lj"].append(m)
elif (
atom_idx in hydrogens
or psf[atom_idx].type.startswith("LP")
or psf[atom_idx].type.startswith("DRUD")
):
# already mutated
continue
else:
# normal lj mutation
m = MutationDefinition(
atoms_to_be_mutated=atoms_to_be_mutated,
common_core=cc_idx,
dummy_region=dummy_region,
vdw_atom_idx=[atom_idx],
steric_mutation_to_default=False,
)
mutations["lj"].append(m)
else:
logger.critical("No atoms will be decoupled.")
mutations = defaultdict()
return mutations
[docs]
class Mutation(object):
[docs]
def __init__(self, atoms_to_be_mutated: list, dummy_region: DummyRegion):
assert type(atoms_to_be_mutated) == list
self.atoms_to_be_mutated = atoms_to_be_mutated
self.dummy_region = dummy_region
self.tlc = dummy_region.tlc
def _mutate_charge(
self, psf: pm.charmm.CharmmPsfFile, lambda_value: float, offset: int
):
total_charge = float(
round(
sum([atom.initial_charge for atom in psf.view[f":{self.tlc}"].atoms]), 4
),
)
# scale the charge of all atoms
print(f"Scaling charge on: {self.atoms_to_be_mutated}")
for idx in self.atoms_to_be_mutated:
odx = idx + offset
atom = psf[odx]
logger.debug(f"Scale charge on {atom}")
logger.debug(f"Scaling charge with: {lambda_value}")
logger.debug(f"Old charge: {atom.charge}")
atom.charge = atom.initial_charge * lambda_value
logger.debug(f"New charge: {atom.charge}")
# in the end, we save the topology for amber (lig_in_env.parm7) using parmed
# pm.save_parm(), this requires all changes applied via an action tool
if type(psf) == pm.amber.AmberParm:
pm.tools.actions.change(
psf,
"CHARGE",
f":{self.tlc}@{idx+1}",
atom.initial_charge * lambda_value,
).execute()
# check to avoid compensating charges when doing asfe
if (
lambda_value != 1
and len(self.dummy_region.match_termin_real_and_dummy_atoms) != 0
):
# compensate for the total change in charge the terminal atom
self._compensate_charge(psf, total_charge, offset)
[docs]
def _mutate_vdw(
self,
psf: pm.charmm.CharmmPsfFile,
lambda_value: float,
vdw_atom_idx: List[int],
offset: int,
to_default: bool,
):
"""
This is used to scale the LJ parameters of the DDD and DDX atoms to zero in phase II and III
"""
if not set(vdw_atom_idx).issubset(set(self.atoms_to_be_mutated)):
raise RuntimeError(
f"Specified atom {vdw_atom_idx} is not in atom_idx list {self.atoms_to_be_mutated}. Aborting."
)
logger.info(f"Acting on atoms: {vdw_atom_idx}")
offset = min([a.idx for a in psf.view[f":{self.tlc.upper()}"].atoms])
for i in vdw_atom_idx:
atom = psf[i + offset]
if to_default:
logger.info("Mutate to default")
atom_type_suffix = "DDX"
atom.rmin = 1.5
atom.epsilon = -0.15
# do this only when using GAFF
if type(psf) == pm.amber.AmberParm:
assert psf[f":{self.tlc}@{atom.idx+1}"].atoms[0].name == atom.name
pm.tools.actions.changeLJSingleType(
psf,
f":{self.tlc}@{atom.idx+1}",
1.5,
0.15, ### ATTENTION: This should be -0.15 but somehow GAFF does not like negative values
).execute()
else:
logger.info("Mutate to dummy")
atom_type_suffix = f"DDD"
self._scale_epsilon_and_rmin(atom, lambda_value, psf, self.tlc)
# NOTEthere is always a type change
self._modify_type(atom, psf, atom_type_suffix, self.tlc)
[docs]
def mutate(
self,
psf: pm.charmm.CharmmPsfFile,
lambda_value_electrostatic: float = 1.0,
lambda_value_vdw: float = 1.0,
vdw_atom_idx: List[int] = [],
steric_mutation_to_default: bool = False,
):
"""Performs the mutation"""
if lambda_value_electrostatic < 0.0 or lambda_value_electrostatic > 1.0:
raise RuntimeError("Lambda value for LJ needs to be between 0.0 and 1.0.")
if lambda_value_vdw < 0.0 or lambda_value_vdw > 1.0:
raise RuntimeError("Lambda value for vdw needs to be between 0.0 and 1.0.")
logger.debug(f"LJ scaling factor: {lambda_value_electrostatic}")
logger.debug(f"VDW scaling factor: {lambda_value_vdw}")
try:
offset = min([a.idx for a in psf.view[f":{self.tlc.upper()}"].atoms])
### This give a ValueErrror for point mutation, where a resid is specified
### but here we have only one ligand or the residue, which should be mutated left
except ValueError:
offset = min([a.idx for a in psf.view[f":1"].atoms])
if lambda_value_electrostatic < 1.0:
self._mutate_charge(psf, lambda_value_electrostatic, offset)
if lambda_value_vdw < 1.0:
self._mutate_vdw(
psf, lambda_value_vdw, vdw_atom_idx, offset, steric_mutation_to_default
)
[docs]
def _compensate_charge(
self, psf: pm.charmm.CharmmPsfFile, total_charge: int, offset: int
):
"""
_compensate_charge This function compensates the charge changes of a dummy region on the terminal real atom
that connects the specific dummy group to the real region.
Parameters
----------
psf : pm.charmm.CharmmPsfFile
[description]
total_charge : int
[description]
offset : int
[description]
Raises
------
RuntimeError
[description]
"""
# get dummy retions
connected_dummy_regions = self.dummy_region.connected_dummy_regions
logger.debug(f"Compensating charge ...")
# save the atoms that are used for charge compenstation. This is done because if two regions
# use the same atom, a special handling needs to be invoced
compensating_on_this_real_atom = []
# check for each dummy region how much charge has changed and compensate on atom that connects
# the real region with specific dummy regions
for dummy_idx in connected_dummy_regions:
logger.debug(f"Dummy idx region: {dummy_idx}")
connecting_real_atom_for_this_dummy_region = (
self.dummy_region.return_connecting_real_atom(dummy_idx)
)
logger.debug(
f"Connecting atom: {connecting_real_atom_for_this_dummy_region}"
)
if connecting_real_atom_for_this_dummy_region == None:
raise RuntimeError(
"Something went wrong with the charge compensation. Aborting."
)
charge_acceptor = psf[connecting_real_atom_for_this_dummy_region + offset]
charge_to_compenstate_for_region = 0.0
for atom_idx in dummy_idx:
charge_to_compenstate_for_region += (
psf[atom_idx + offset].initial_charge
- psf[atom_idx + offset].charge
)
logger.debug(f"Charge to compensate: {charge_to_compenstate_for_region}")
# adding charge difference to initial charge on real terminal atom
if (
connecting_real_atom_for_this_dummy_region
in compensating_on_this_real_atom
):
charge_acceptor.charge = (
charge_acceptor.charge + charge_to_compenstate_for_region
)
else:
charge_acceptor.charge = (
charge_acceptor.initial_charge + charge_to_compenstate_for_region
)
compensating_on_this_real_atom.append(
connecting_real_atom_for_this_dummy_region
)
#### check if rest charge is missing
new_charge = sum(
[atom.charge for atom in psf.view[f":{self.tlc.upper()}"].atoms]
)
if not (np.isclose(round(new_charge, 3), round(total_charge, 3), rtol=1e-4)):
raise RuntimeError(
f"Charge compensation failed. Introducing non integer total charge: {new_charge}. Target total charge: {total_charge}."
)
[docs]
@staticmethod
def _scale_epsilon_and_rmin(atom, lambda_value, psf, tlc):
"""
This scales the LJ interactions (epsilon and rmin) from non-interacting DDD atom (no charge)
to 'real' dummy atom (no LJ!), typically this is performed in one step, but to be sure
we offer scalling possibility here as well
"""
logger.debug(atom)
logger.debug(atom.initial_epsilon)
logger.debug(atom.initial_rmin)
atom.epsilon = atom.initial_epsilon * lambda_value
atom.rmin = atom.initial_rmin * lambda_value
### do this only when using GAFF
if type(psf) == pm.amber.AmberParm:
# Quick check, if selected atom via AMBER mask is the same as the atom
# we want to modify
assert psf[f":{tlc}@{atom.idx+1}"].atoms[0].type == atom.type
pm.tools.actions.addLJType(
psf,
f":{tlc}@{atom.idx+1}",
radius=atom.initial_rmin * lambda_value,
epsilon=atom.initial_epsilon * lambda_value,
).execute()
@staticmethod
def _modify_type(atom, psf, atom_type_suffix, tlc):
if hasattr(atom, "initial_type"):
# only change parameters
pass
else:
atom.initial_type = atom.type
if atom_type_suffix == "DDD":
psf.number_of_dummys += 1
new_type = f"{atom_type_suffix}{psf.number_of_dummys}"
elif atom_type_suffix == "DDX":
psf.mutations_to_default += 1
new_type = f"{atom_type_suffix}{psf.mutations_to_default}"
atom.type = new_type
if type(psf) == pm.amber.AmberParm:
pm.tools.actions.change(
psf,
"AMBER_ATOM_TYPE",
f":{tlc}@{atom.idx+1}",
new_type,
).execute()