From 644c1508f84b36fea19aedcbeb6a8329aad19d2b Mon Sep 17 00:00:00 2001 From: rzhu Date: Fri, 29 Nov 2024 11:12:48 +0000 Subject: [PATCH] fix static methods; add scan bond length --- src/molearn/analysis/analyser.py | 59 +++++++++++++++++++++++++++++--- 1 file changed, 55 insertions(+), 4 deletions(-) diff --git a/src/molearn/analysis/analyser.py b/src/molearn/analysis/analyser.py index e2f7e51..6754982 100644 --- a/src/molearn/analysis/analyser.py +++ b/src/molearn/analysis/analyser.py @@ -234,10 +234,10 @@ def get_dope(self, key, refine=True, **kwargs): dataset = self.get_dataset(key) decoded = self.get_decoded(key) - dope_dataset = self.get_all_dope_score(dataset, refine=refine, **kwargs) - dope_decoded = self.get_all_dope_score(decoded, refine=refine, **kwargs) + dataset_dope = self.get_all_dope_score(dataset, refine=refine, **kwargs) + decoded_dope = self.get_all_dope_score(decoded, refine=refine, **kwargs) - return dict(dataset_dope=dope_dataset, decoded_dope=dope_decoded) + return dict(dataset_dope=dataset_dope, decoded_dope=decoded_dope) def get_ramachandran(self, key): """ @@ -259,6 +259,51 @@ def get_ramachandran(self, key): ) return ramachandran + def get_bondlengths(self, key): + """ + Get backbone bond lengths of a dataset and its decoded counterpart. + + """ + # Get the atomic indices to calculate different types of bond lengths + if set(['CA','C', 'N','CB']).issubset(set(self.atoms)): + indices = {'N-Ca':[], 'Ca-C':[], 'C-N':[], 'CA-CB':[]} + elif set(['CA','C', 'N']).issubset(set(self.atoms)): + indices = {'N-Ca':[], 'Ca-C':[], 'C-N':[]} + else: + raise ValueError(f"Selected atoms should contain at least N, CA, and C.") + + mol_df = self.mol.data + for resid in mol_df.resid.unique(): + resname = mol_df[mol_df['resid'] == resid].resname.unique()[0] + + N_id = mol_df[(mol_df['resid'] == resid) & (mol_df['name'] == 'N')].index[0] + CA_id = mol_df[(mol_df['resid'] == resid) & (mol_df['name'] == 'CA')].index[0] + C_id = mol_df[(mol_df['resid'] == resid) & (mol_df['name'] == 'C')].index[0] + indices['N-Ca'].append((N_id, CA_id)) + indices['Ca-C'].append((CA_id, C_id)) + if resname != 'GLY' and 'CB' in self.atoms: + CB_id = mol_df[(mol_df['resid'] == resid) & (mol_df['name'] == 'CB')].index[0] + indices['Ca-Cb'].append((CA_id, CB_id)) + + if resid != len(mol_df.resid.unique()): + next_N_id = mol_df[(mol_df['resid'] == (resid+1)) & (mol_df['name'] == 'N')].index[0] + indices['C-N'].append((C_id, next_N_id)) + + # Look for the key in self._datasets and self._encoded + if key in self._datasets.keys(): + dataset = self.get_dataset(key)*self.stdval + self.meanval + decoded = self.get_decoded(key)*self.stdval + self.meanval + dataset_bondlen = {k: MolearnAnalysis._bond_lengths(dataset, v) for k, v in indices.items()} + decoded_bondlen = {k: MolearnAnalysis._bond_lengths(decoded, v) for k, v in indices.items()} + return dict(dataset_bondlen=dataset_bondlen, decoded_bondlen=decoded_bondlen) + elif key in self._encoded.keys(): + decoded = self.get_decoded(key)*self.stdval + self.meanval + decoded_bondlen = {k: MolearnAnalysis._bond_lengths(decoded, v) for k, v in indices.items()} + return dict(decoded_bondlen=decoded_bondlen) + else: + raise ValueError(f"Key {key} not found in _datasets or _encoded. Please load the dataset or setup a grid first.") + + def setup_grid(self, samples=64, bounds_from=None, bounds=None, padding=0.1): """ Define a NxN point grid regularly sampling the latent space. @@ -455,6 +500,7 @@ def _dope_score(self, frame, refine=True, **kwargs): return self.dope_score_class.get_score(f * self.stdval, refine=refine, **kwargs) + @staticmethod def _ca_chirality(N, CA, C, CB): """ Calculate chirality of Cα atom in a protein residue. @@ -473,6 +519,11 @@ def _ca_chirality(N, CA, C, CB): # L if dot_product > 0 else D return dot_product + @staticmethod + def _bond_lengths(crds, indices): + bond_lengths = [np.linalg.norm(crds[:,:,i[0]] - crds[:,:,i[1]], axis=1) for i in indices] + return np.array(bond_lengths) + def get_all_ramachandran_score(self, tensor): """ Calculate Ramachandran score of an ensemble of atomic conrdinates. @@ -606,7 +657,7 @@ def scan_ca_chirality(self): # Get atom indices indices = dict() - for resid in range(len(mol_df.resid.unique())): + for resid in mol_df.resid.unique(): resname = mol_df[mol_df['resid'] == resid].resname.unique()[0] if not resname == 'GLY': N_id = mol_df[(mol_df['resid'] == resid) & (mol_df['name'] == 'N')].index[0]