From 2d61fcddc4f745a20e3db1a3dc13a11b50355964 Mon Sep 17 00:00:00 2001 From: gwirn <71886945+gwirn@users.noreply.github.com> Date: Fri, 6 Dec 2024 13:20:03 +0100 Subject: [PATCH] formatting --- src/molearn/analysis/analyser.py | 149 +++++++++++++++++++------------ 1 file changed, 91 insertions(+), 58 deletions(-) diff --git a/src/molearn/analysis/analyser.py b/src/molearn/analysis/analyser.py index d096b3d..e7d9e1b 100644 --- a/src/molearn/analysis/analyser.py +++ b/src/molearn/analysis/analyser.py @@ -265,44 +265,59 @@ def get_bondlengths(self, key): """ # Get the atomic indices to calculate different types of bond lengths - if set(['CA','C', 'N','CB']).issubset(set(self.atoms)): - indices = {'N-Ca':[], 'Ca-C':[], 'C-N':[], 'CA-CB':[]} - elif set(['CA','C', 'N']).issubset(set(self.atoms)): - indices = {'N-Ca':[], 'Ca-C':[], 'C-N':[]} + if set(["CA", "C", "N", "CB"]).issubset(set(self.atoms)): + indices = {"N-Ca": [], "Ca-C": [], "C-N": [], "CA-CB": []} + elif set(["CA", "C", "N"]).issubset(set(self.atoms)): + indices = {"N-Ca": [], "Ca-C": [], "C-N": []} else: - raise ValueError(f"Selected atoms should contain at least N, CA, and C.") + raise ValueError("Selected atoms should contain at least N, CA, and C.") mol_df = self.mol.data for resid in mol_df.resid.unique(): - resname = mol_df[mol_df['resid'] == resid].resname.unique()[0] - - N_id = mol_df[(mol_df['resid'] == resid) & (mol_df['name'] == 'N')].index[0] - CA_id = mol_df[(mol_df['resid'] == resid) & (mol_df['name'] == 'CA')].index[0] - C_id = mol_df[(mol_df['resid'] == resid) & (mol_df['name'] == 'C')].index[0] - indices['N-Ca'].append((N_id, CA_id)) - indices['Ca-C'].append((CA_id, C_id)) - if resname != 'GLY' and 'CB' in self.atoms: - CB_id = mol_df[(mol_df['resid'] == resid) & (mol_df['name'] == 'CB')].index[0] - indices['Ca-Cb'].append((CA_id, CB_id)) + resname = mol_df[mol_df["resid"] == resid].resname.unique()[0] + + N_id = mol_df[(mol_df["resid"] == resid) & (mol_df["name"] == "N")].index[0] + CA_id = mol_df[(mol_df["resid"] == resid) & (mol_df["name"] == "CA")].index[ + 0 + ] + C_id = mol_df[(mol_df["resid"] == resid) & (mol_df["name"] == "C")].index[0] + indices["N-Ca"].append((N_id, CA_id)) + indices["Ca-C"].append((CA_id, C_id)) + if resname != "GLY" and "CB" in self.atoms: + CB_id = mol_df[ + (mol_df["resid"] == resid) & (mol_df["name"] == "CB") + ].index[0] + indices["Ca-Cb"].append((CA_id, CB_id)) if resid != len(mol_df.resid.unique()): - next_N_id = mol_df[(mol_df['resid'] == (resid+1)) & (mol_df['name'] == 'N')].index[0] - indices['C-N'].append((C_id, next_N_id)) - + next_N_id = mol_df[ + (mol_df["resid"] == (resid + 1)) & (mol_df["name"] == "N") + ].index[0] + indices["C-N"].append((C_id, next_N_id)) + # Look for the key in self._datasets and self._encoded if key in self._datasets.keys(): - dataset = self.get_dataset(key)*self.stdval + self.meanval - decoded = self.get_decoded(key)*self.stdval + self.meanval - dataset_bondlen = {k: MolearnAnalysis._bond_lengths(dataset, v) for k, v in indices.items()} - decoded_bondlen = {k: MolearnAnalysis._bond_lengths(decoded, v) for k, v in indices.items()} - return dict(dataset_bondlen=dataset_bondlen, decoded_bondlen=decoded_bondlen) + dataset = self.get_dataset(key) * self.stdval + self.meanval + decoded = self.get_decoded(key) * self.stdval + self.meanval + dataset_bondlen = { + k: MolearnAnalysis._bond_lengths(dataset, v) for k, v in indices.items() + } + decoded_bondlen = { + k: MolearnAnalysis._bond_lengths(decoded, v) for k, v in indices.items() + } + return dict( + dataset_bondlen=dataset_bondlen, decoded_bondlen=decoded_bondlen + ) elif key in self._encoded.keys(): - decoded = self.get_decoded(key)*self.stdval + self.meanval - decoded_bondlen = {k: MolearnAnalysis._bond_lengths(decoded, v) for k, v in indices.items()} + decoded = self.get_decoded(key) * self.stdval + self.meanval + decoded_bondlen = { + k: MolearnAnalysis._bond_lengths(decoded, v) for k, v in indices.items() + } return dict(decoded_bondlen=decoded_bondlen) else: - raise ValueError(f"Key {key} not found in _datasets or _encoded. Please load the dataset or setup a grid first.") - + raise ValueError( + f"Key {key} not found in _datasets or _encoded. Please load the dataset or setup a grid first." + ) def setup_grid(self, samples=64, bounds_from=None, bounds=None, padding=0.1): """ @@ -521,9 +536,11 @@ def _ca_chirality(N, CA, C, CB): @staticmethod def _bond_lengths(crds, indices): - bond_lengths = [np.linalg.norm(crds[:,:,i[0]] - crds[:,:,i[1]], axis=1) for i in indices] + bond_lengths = [ + np.linalg.norm(crds[:, :, i[0]] - crds[:, :, i[1]], axis=1) for i in indices + ] return np.array(bond_lengths) - + def get_all_ramachandran_score(self, tensor): """ Calculate Ramachandran score of an ensemble of atomic conrdinates. @@ -642,10 +659,10 @@ def scan_ca_chirality(self): :return: x-axis values :return: y-axis values """ - assert ( - set(['CA','C','N','CB']).issubset(set(self.atoms)) - ), "Atom selection shoud at least include CA, C, N, and CB" - + assert set(["CA", "C", "N", "CB"]).issubset( + set(self.atoms) + ), "Atom selection shoud at least include CA, C, N, and CB" + key = "Chirality" if key not in self.surfaces: assert ( @@ -655,32 +672,41 @@ def scan_ca_chirality(self): mol_df = self.mol.data - # Get atom indices + # Get atom indices indices = dict() for resid in mol_df.resid.unique(): - resname = mol_df[mol_df['resid'] == resid].resname.unique()[0] - if not resname == 'GLY': - N_id = mol_df[(mol_df['resid'] == resid) & (mol_df['name'] == 'N')].index[0] - C_id = mol_df[(mol_df['resid'] == resid) & (mol_df['name'] == 'C')].index[0] - CA_id = mol_df[(mol_df['resid'] == resid) & (mol_df['name'] == 'CA')].index[0] - CB_id = mol_df[(mol_df['resid'] == resid) & (mol_df['name'] == 'CB')].index[0] + resname = mol_df[mol_df["resid"] == resid].resname.unique()[0] + if not resname == "GLY": + N_id = mol_df[ + (mol_df["resid"] == resid) & (mol_df["name"] == "N") + ].index[0] + C_id = mol_df[ + (mol_df["resid"] == resid) & (mol_df["name"] == "C") + ].index[0] + CA_id = mol_df[ + (mol_df["resid"] == resid) & (mol_df["name"] == "CA") + ].index[0] + CB_id = mol_df[ + (mol_df["resid"] == resid) & (mol_df["name"] == "CB") + ].index[0] indices[resname + str(resid)] = (N_id, CA_id, C_id, CB_id) - + results = [] for j in decoded: s = (j.view(1, 3, -1).permute(0, 2, 1) * self.stdval).numpy() inversions = {} for k, v in indices.items(): - dot_prod = self._ca_chirality(s[0,v[0],:], - s[0,v[1],:], - s[0,v[2],:], - s[0,v[3],:]) + dot_prod = self._ca_chirality( + s[0, v[0], :], s[0, v[1], :], s[0, v[2], :], s[0, v[3], :] + ) if dot_prod < 0: inversions[k] = dot_prod results.append(len(inversions)) - self.surfaces[key] = np.array(results).reshape(self.n_samples, self.n_samples) + self.surfaces[key] = np.array(results).reshape( + self.n_samples, self.n_samples + ) - return self.surfaces[key], self.xvals, self.yvals + return self.surfaces[key], self.xvals, self.yvals def scan_custom(self, fct, params, key): """ @@ -702,7 +728,12 @@ def scan_custom(self, fct, params, key): return self.surfaces[key], self.xvals, self.yvals - def _relax(self, pdb_file: Union[str, Path], out_path: Union[str, Path], maxIterations: int = 1000) -> None: + def _relax( + self, + pdb_file: Union[str, Path], + out_path: Union[str, Path], + maxIterations: int = 1000, + ) -> None: """ Model the sidechains and relax generated structure @@ -720,14 +751,14 @@ def _relax(self, pdb_file: Union[str, Path], out_path: Union[str, Path], maxIter modelled_file = out_path + os.sep + (pdb_file.stem + "_modelled.pdb") try: env = Environ() - env.libs.topology.read(file='$(LIB)/top_heav.lib') - env.libs.parameters.read(file='$(LIB)/par.lib') + env.libs.topology.read(file="$(LIB)/top_heav.lib") + env.libs.parameters.read(file="$(LIB)/par.lib") mdl = complete_pdb(env, str(pdb_file)) mdl.write(str(modelled_file)) pdb_file = modelled_file except Exception as e: - print(f'Failed to model {pdb_file}\n{e}') + print(f"Failed to model {pdb_file}\n{e}") try: relaxed_file = out_path + os.sep + (pdb_file.stem + "_relaxed.pdb") # Read pdb @@ -737,7 +768,9 @@ def _relax(self, pdb_file: Union[str, Path], out_path: Union[str, Path], maxIter modeller = Modeller(pdb.topology, pdb.positions) modeller.addHydrogens(forcefield) - system = forcefield.createSystem(modeller.topology, nonbondedMethod=NoCutoff) + system = forcefield.createSystem( + modeller.topology, nonbondedMethod=NoCutoff + ) integrator = VerletIntegrator(0.001 * picoseconds) simulation = Simulation(modeller.topology, system, integrator) simulation.context.setPositions(modeller.positions) @@ -747,8 +780,8 @@ def _relax(self, pdb_file: Union[str, Path], out_path: Union[str, Path], maxIter # Write energy minimized file PDBFile.writeFile(simulation.topology, positions, open(relaxed_file, "w+")) except Exception as e: - print(f'Failed to relax {pdb_file}\n{e}') - + print(f"Failed to relax {pdb_file}\n{e}") + def _pdb_file( self, prot_coords: np.ndarray[tuple[int, int], np.dtype[np.float64]], @@ -807,14 +840,14 @@ def generate( gen_prot_coords = s * self.stdval + self.meanval # create pdb file if pdb_path is not None: - for i, coord in enumerate(tqdm(gen_prot_coords, desc="Generating pdb files")): + for i, coord in enumerate( + tqdm(gen_prot_coords, desc="Generating pdb files") + ): struct_path = os.path.join(pdb_path, f"s{i}.pdb") self._pdb_file(coord, struct_path) # relax and save as new file if relax: - self._relax( - struct_path, pdb_path, maxIterations=1000 - ) + self._relax(struct_path, pdb_path, maxIterations=1000) return gen_prot_coords