diff --git a/src/molearn/analysis/analyser.py b/src/molearn/analysis/analyser.py index e7d9e1b..c15e7ac 100644 --- a/src/molearn/analysis/analyser.py +++ b/src/molearn/analysis/analyser.py @@ -515,6 +515,25 @@ def _dope_score(self, frame, refine=True, **kwargs): return self.dope_score_class.get_score(f * self.stdval, refine=refine, **kwargs) + def _chirality_whole( + self, + n: np.ndarray[tuple[int,], np.dtype[np.float64]], + ca: np.ndarray[tuple[int,], np.dtype[np.float64]], + c: np.ndarray[tuple[int,], np.dtype[np.float64]], + cb: np.ndarray[tuple[int,], np.dtype[np.float64]], + ): + """ + check chirality for a set o amino acid + """ + ca_n = n - ca + ca_c = c - ca + cb_ca = cb - ca + normal = np.cross(ca_n, ca_c) + dot = np.einsum("ij,ij->i", normal, cb_ca) + # same but more calculations + # dot = np.diagonal(np.matmul(normal, cb_ca.T)) + return dot + @staticmethod def _ca_chirality(N, CA, C, CB): """ @@ -691,17 +710,19 @@ def scan_ca_chirality(self): ].index[0] indices[resname + str(resid)] = (N_id, CA_id, C_id, CB_id) + idx = np.asarray(list(indices.values())) results = [] for j in decoded: - s = (j.view(1, 3, -1).permute(0, 2, 1) * self.stdval).numpy() - inversions = {} - for k, v in indices.items(): - dot_prod = self._ca_chirality( - s[0, v[0], :], s[0, v[1], :], s[0, v[2], :], s[0, v[3], :] - ) - if dot_prod < 0: - inversions[k] = dot_prod - results.append(len(inversions)) + s = (j.view(1, 3, -1).permute(0, 2, 1) * self.stdval).numpy().squeeze() + chir_test = self._chirality_whole( + s[idx[:, 0], :], + s[idx[:, 1], :], + s[idx[:, 2], :], + s[idx[:, 3], :], + ) + wrong_chir = chir_test < 0 + results.append(wrong_chir.sum()) + results = np.asarray(results) self.surfaces[key] = np.array(results).reshape( self.n_samples, self.n_samples )