diff --git a/cmeutils/gsd_utils.py b/cmeutils/gsd_utils.py index 88893d3..6f4b280 100644 --- a/cmeutils/gsd_utils.py +++ b/cmeutils/gsd_utils.py @@ -380,6 +380,102 @@ def xml_to_gsd(xmlfile, gsdfile): print(f"XML data written to {gsdfile}") +def trim_snapshot_molecules(parent_snapshot, mol_indices): + """Given a snapshot of a system, trim the snapshot to only include + a subset of the molecules. + + Parameters + ---------- + parent_snapshot : gsd.hoomd.Frame + The snapshot to read in. + mol_indices : list of np.ndarray + List of arrays where each array contains the indices + of the particles in a molecule to include. + + Returns + ------- + gsd.hoomd.Frame + The new snapshot with only the specified molecules. + + Notes + ----- + See cmetuils.gsd_utils.get_molecule_cluster for a method to obtain + mol_indices. + + """ + new_snap = gsd.hoomd.Frame() + new_snap.configuration.box = parent_snapshot.configuration.box + new_snap.particles.N = sum(len(i) for i in mol_indices) + + # Write out particle info + for attr in ["position", "mass", "velocity", "orientation", "image", "diameter", "angmom", "typeid"]: + setattr( + new_snap.particles, + attr, + np.concatenate( + list(getattr(parent_snapshot.particles, attr)[i] for i in mol_indices) + ) + ) + new_snap.particles.types = parent_snapshot.particles.types + + particle_index_map = dict() + count = 0 + for indices in mol_indices: + for i in indices: + particle_index_map[i] = count + count += 1 + + # Write out bond info + mol_bond_groups = [] + mol_bond_ids = [] + for count, indices in enumerate(mol_indices): + mask = np.any(np.isin(parent_snapshot.bonds.group, indices.flatten()), axis=1) + parent_mol_bonds = parent_snapshot.bonds.group[np.where(mask)[0]] + parent_mol_bond_typeids = parent_snapshot.bonds.typeid[np.where(mask)[0]] + new_mol_bonds = np.vectorize(particle_index_map.get)(parent_mol_bonds) + mol_bond_groups.append(new_mol_bonds) + mol_bond_ids.append(parent_mol_bond_typeids) + + new_snap.bonds.types = parent_snapshot.bonds.types + new_snap.bonds.group = np.concatenate(mol_bond_groups) + new_snap.bonds.typeid = np.concatenate(mol_bond_ids) + new_snap.bonds.N = sum(len(i) for i in mol_bond_ids) + + # Write out angle info + mol_angle_groups = [] + mol_angle_ids = [] + for count, indices in enumerate(mol_indices): + mask = np.any(np.isin(parent_snapshot.angles.group, indices.flatten()), axis=1) + parent_mol_angles = parent_snapshot.angles.group[np.where(mask)[0]] + parent_mol_angle_typeids = parent_snapshot.angles.typeid[np.where(mask)[0]] + new_mol_angles = np.vectorize(particle_index_map.get)(parent_mol_angles) + mol_angle_groups.append(new_mol_angles) + mol_angle_ids.append(parent_mol_angle_typeids) + + new_snap.angles.types = parent_snapshot.angles.types + new_snap.angles.group = np.concatenate(mol_angle_groups) + new_snap.angles.typeid = np.concatenate(mol_angle_ids) + new_snap.angles.N = sum(len(i) for i in mol_angle_ids) + + # Write out dihedral info + mol_dihedral_groups = [] + mol_dihedral_ids = [] + for count, indices in enumerate(mol_indices): + mask = np.any(np.isin(parent_snapshot.dihedrals.group, indices.flatten()), axis=1) + parent_mol_dihedrals = parent_snapshot.dihedrals.group[np.where(mask)[0]] + parent_mol_dihedral_typeids = parent_snapshot.dihedrals.typeid[np.where(mask)[0]] + new_mol_dihedrals = np.vectorize(particle_index_map.get)(parent_mol_dihedrals) + mol_dihedral_groups.append(new_mol_dihedrals) + mol_dihedral_ids.append(parent_mol_dihedral_typeids) + + new_snap.dihedrals.types = parent_snapshot.dihedrals.types + new_snap.dihedrals.group = np.concatenate(mol_dihedral_groups) + new_snap.dihedrals.typeid = np.concatenate(mol_dihedral_ids) + new_snap.dihedrals.N = sum(len(i) for i in mol_dihedral_ids) + + new_snap.validate() + return new_snap + def identify_snapshot_connections(snapshot): """Identify angle and dihedral connections in a snapshot from bonds.