Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
gwirn committed Dec 6, 2024
1 parent 192d3bf commit 2d61fcd
Showing 1 changed file with 91 additions and 58 deletions.
149 changes: 91 additions & 58 deletions src/molearn/analysis/analyser.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,44 +265,59 @@ def get_bondlengths(self, key):
"""
# Get the atomic indices to calculate different types of bond lengths
if set(['CA','C', 'N','CB']).issubset(set(self.atoms)):
indices = {'N-Ca':[], 'Ca-C':[], 'C-N':[], 'CA-CB':[]}
elif set(['CA','C', 'N']).issubset(set(self.atoms)):
indices = {'N-Ca':[], 'Ca-C':[], 'C-N':[]}
if set(["CA", "C", "N", "CB"]).issubset(set(self.atoms)):
indices = {"N-Ca": [], "Ca-C": [], "C-N": [], "CA-CB": []}
elif set(["CA", "C", "N"]).issubset(set(self.atoms)):
indices = {"N-Ca": [], "Ca-C": [], "C-N": []}
else:
raise ValueError(f"Selected atoms should contain at least N, CA, and C.")
raise ValueError("Selected atoms should contain at least N, CA, and C.")

mol_df = self.mol.data
for resid in mol_df.resid.unique():
resname = mol_df[mol_df['resid'] == resid].resname.unique()[0]

N_id = mol_df[(mol_df['resid'] == resid) & (mol_df['name'] == 'N')].index[0]
CA_id = mol_df[(mol_df['resid'] == resid) & (mol_df['name'] == 'CA')].index[0]
C_id = mol_df[(mol_df['resid'] == resid) & (mol_df['name'] == 'C')].index[0]
indices['N-Ca'].append((N_id, CA_id))
indices['Ca-C'].append((CA_id, C_id))
if resname != 'GLY' and 'CB' in self.atoms:
CB_id = mol_df[(mol_df['resid'] == resid) & (mol_df['name'] == 'CB')].index[0]
indices['Ca-Cb'].append((CA_id, CB_id))
resname = mol_df[mol_df["resid"] == resid].resname.unique()[0]

N_id = mol_df[(mol_df["resid"] == resid) & (mol_df["name"] == "N")].index[0]
CA_id = mol_df[(mol_df["resid"] == resid) & (mol_df["name"] == "CA")].index[
0
]
C_id = mol_df[(mol_df["resid"] == resid) & (mol_df["name"] == "C")].index[0]
indices["N-Ca"].append((N_id, CA_id))
indices["Ca-C"].append((CA_id, C_id))
if resname != "GLY" and "CB" in self.atoms:
CB_id = mol_df[
(mol_df["resid"] == resid) & (mol_df["name"] == "CB")
].index[0]
indices["Ca-Cb"].append((CA_id, CB_id))

if resid != len(mol_df.resid.unique()):
next_N_id = mol_df[(mol_df['resid'] == (resid+1)) & (mol_df['name'] == 'N')].index[0]
indices['C-N'].append((C_id, next_N_id))

next_N_id = mol_df[
(mol_df["resid"] == (resid + 1)) & (mol_df["name"] == "N")
].index[0]
indices["C-N"].append((C_id, next_N_id))

# Look for the key in self._datasets and self._encoded
if key in self._datasets.keys():
dataset = self.get_dataset(key)*self.stdval + self.meanval
decoded = self.get_decoded(key)*self.stdval + self.meanval
dataset_bondlen = {k: MolearnAnalysis._bond_lengths(dataset, v) for k, v in indices.items()}
decoded_bondlen = {k: MolearnAnalysis._bond_lengths(decoded, v) for k, v in indices.items()}
return dict(dataset_bondlen=dataset_bondlen, decoded_bondlen=decoded_bondlen)
dataset = self.get_dataset(key) * self.stdval + self.meanval
decoded = self.get_decoded(key) * self.stdval + self.meanval
dataset_bondlen = {
k: MolearnAnalysis._bond_lengths(dataset, v) for k, v in indices.items()
}
decoded_bondlen = {
k: MolearnAnalysis._bond_lengths(decoded, v) for k, v in indices.items()
}
return dict(
dataset_bondlen=dataset_bondlen, decoded_bondlen=decoded_bondlen
)
elif key in self._encoded.keys():
decoded = self.get_decoded(key)*self.stdval + self.meanval
decoded_bondlen = {k: MolearnAnalysis._bond_lengths(decoded, v) for k, v in indices.items()}
decoded = self.get_decoded(key) * self.stdval + self.meanval
decoded_bondlen = {
k: MolearnAnalysis._bond_lengths(decoded, v) for k, v in indices.items()
}
return dict(decoded_bondlen=decoded_bondlen)
else:
raise ValueError(f"Key {key} not found in _datasets or _encoded. Please load the dataset or setup a grid first.")

raise ValueError(
f"Key {key} not found in _datasets or _encoded. Please load the dataset or setup a grid first."
)

def setup_grid(self, samples=64, bounds_from=None, bounds=None, padding=0.1):
"""
Expand Down Expand Up @@ -521,9 +536,11 @@ def _ca_chirality(N, CA, C, CB):

@staticmethod
def _bond_lengths(crds, indices):
bond_lengths = [np.linalg.norm(crds[:,:,i[0]] - crds[:,:,i[1]], axis=1) for i in indices]
bond_lengths = [
np.linalg.norm(crds[:, :, i[0]] - crds[:, :, i[1]], axis=1) for i in indices
]
return np.array(bond_lengths)

def get_all_ramachandran_score(self, tensor):
"""
Calculate Ramachandran score of an ensemble of atomic conrdinates.
Expand Down Expand Up @@ -642,10 +659,10 @@ def scan_ca_chirality(self):
:return: x-axis values
:return: y-axis values
"""
assert (
set(['CA','C','N','CB']).issubset(set(self.atoms))
), "Atom selection shoud at least include CA, C, N, and CB"
assert set(["CA", "C", "N", "CB"]).issubset(
set(self.atoms)
), "Atom selection shoud at least include CA, C, N, and CB"

key = "Chirality"
if key not in self.surfaces:
assert (
Expand All @@ -655,32 +672,41 @@ def scan_ca_chirality(self):

mol_df = self.mol.data

# Get atom indices
# Get atom indices
indices = dict()
for resid in mol_df.resid.unique():
resname = mol_df[mol_df['resid'] == resid].resname.unique()[0]
if not resname == 'GLY':
N_id = mol_df[(mol_df['resid'] == resid) & (mol_df['name'] == 'N')].index[0]
C_id = mol_df[(mol_df['resid'] == resid) & (mol_df['name'] == 'C')].index[0]
CA_id = mol_df[(mol_df['resid'] == resid) & (mol_df['name'] == 'CA')].index[0]
CB_id = mol_df[(mol_df['resid'] == resid) & (mol_df['name'] == 'CB')].index[0]
resname = mol_df[mol_df["resid"] == resid].resname.unique()[0]
if not resname == "GLY":
N_id = mol_df[
(mol_df["resid"] == resid) & (mol_df["name"] == "N")
].index[0]
C_id = mol_df[
(mol_df["resid"] == resid) & (mol_df["name"] == "C")
].index[0]
CA_id = mol_df[
(mol_df["resid"] == resid) & (mol_df["name"] == "CA")
].index[0]
CB_id = mol_df[
(mol_df["resid"] == resid) & (mol_df["name"] == "CB")
].index[0]
indices[resname + str(resid)] = (N_id, CA_id, C_id, CB_id)

results = []
for j in decoded:
s = (j.view(1, 3, -1).permute(0, 2, 1) * self.stdval).numpy()
inversions = {}
for k, v in indices.items():
dot_prod = self._ca_chirality(s[0,v[0],:],
s[0,v[1],:],
s[0,v[2],:],
s[0,v[3],:])
dot_prod = self._ca_chirality(
s[0, v[0], :], s[0, v[1], :], s[0, v[2], :], s[0, v[3], :]
)
if dot_prod < 0:
inversions[k] = dot_prod
results.append(len(inversions))
self.surfaces[key] = np.array(results).reshape(self.n_samples, self.n_samples)
self.surfaces[key] = np.array(results).reshape(
self.n_samples, self.n_samples
)

return self.surfaces[key], self.xvals, self.yvals
return self.surfaces[key], self.xvals, self.yvals

def scan_custom(self, fct, params, key):
"""
Expand All @@ -702,7 +728,12 @@ def scan_custom(self, fct, params, key):

return self.surfaces[key], self.xvals, self.yvals

def _relax(self, pdb_file: Union[str, Path], out_path: Union[str, Path], maxIterations: int = 1000) -> None:
def _relax(
self,
pdb_file: Union[str, Path],
out_path: Union[str, Path],
maxIterations: int = 1000,
) -> None:
"""
Model the sidechains and relax generated structure
Expand All @@ -720,14 +751,14 @@ def _relax(self, pdb_file: Union[str, Path], out_path: Union[str, Path], maxIter
modelled_file = out_path + os.sep + (pdb_file.stem + "_modelled.pdb")
try:
env = Environ()
env.libs.topology.read(file='$(LIB)/top_heav.lib')
env.libs.parameters.read(file='$(LIB)/par.lib')
env.libs.topology.read(file="$(LIB)/top_heav.lib")
env.libs.parameters.read(file="$(LIB)/par.lib")

mdl = complete_pdb(env, str(pdb_file))
mdl.write(str(modelled_file))
pdb_file = modelled_file
except Exception as e:
print(f'Failed to model {pdb_file}\n{e}')
print(f"Failed to model {pdb_file}\n{e}")
try:
relaxed_file = out_path + os.sep + (pdb_file.stem + "_relaxed.pdb")
# Read pdb
Expand All @@ -737,7 +768,9 @@ def _relax(self, pdb_file: Union[str, Path], out_path: Union[str, Path], maxIter
modeller = Modeller(pdb.topology, pdb.positions)
modeller.addHydrogens(forcefield)

system = forcefield.createSystem(modeller.topology, nonbondedMethod=NoCutoff)
system = forcefield.createSystem(
modeller.topology, nonbondedMethod=NoCutoff
)
integrator = VerletIntegrator(0.001 * picoseconds)
simulation = Simulation(modeller.topology, system, integrator)
simulation.context.setPositions(modeller.positions)
Expand All @@ -747,8 +780,8 @@ def _relax(self, pdb_file: Union[str, Path], out_path: Union[str, Path], maxIter
# Write energy minimized file
PDBFile.writeFile(simulation.topology, positions, open(relaxed_file, "w+"))
except Exception as e:
print(f'Failed to relax {pdb_file}\n{e}')
print(f"Failed to relax {pdb_file}\n{e}")

def _pdb_file(
self,
prot_coords: np.ndarray[tuple[int, int], np.dtype[np.float64]],
Expand Down Expand Up @@ -807,14 +840,14 @@ def generate(
gen_prot_coords = s * self.stdval + self.meanval
# create pdb file
if pdb_path is not None:
for i, coord in enumerate(tqdm(gen_prot_coords, desc="Generating pdb files")):
for i, coord in enumerate(
tqdm(gen_prot_coords, desc="Generating pdb files")
):
struct_path = os.path.join(pdb_path, f"s{i}.pdb")
self._pdb_file(coord, struct_path)
# relax and save as new file
if relax:
self._relax(
struct_path, pdb_path, maxIterations=1000
)
self._relax(struct_path, pdb_path, maxIterations=1000)

return gen_prot_coords

Expand Down

0 comments on commit 2d61fcd

Please sign in to comment.