Skip to content

Commit

Permalink
Add filtering fingerprint histogram by elements
Browse files Browse the repository at this point in the history
  • Loading branch information
peterspackman committed Feb 26, 2024
1 parent 4b2a117 commit 18ab2f7
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 5 deletions.
2 changes: 1 addition & 1 deletion CITATION.cff
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ authors:
given-names: Peter R.
orcid: https://orcid.org/0000-0002-6532-8571
title: "Chmpy: A python library for computational chemistry"
version: v1.1.3
version: v1.1.4
date-released: 2024-02-24
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]
name = "chmpy"
readme = "README.md"
version = "1.1.3"
version = "1.1.4"
requires-python = ">=3.9"
description = "Molecules, crystals, promolecule and Hirshfeld surfaces using python."
authors = [
Expand Down
15 changes: 15 additions & 0 deletions src/chmpy/crystal/crystal.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,16 @@ def _nearest_molecule_idx(vertices, el, pos):
u, idxs = np.unique(l, return_inverse=True)
return np.arange(len(u), dtype=np.uint8)[idxs]

def _nearest_atom_idx(vertices, el, pos):
from scipy.sparse.csgraph import connected_components
import pandas as pd
from time import time

t1 = time()
tree = KDTree(pos)
d, idxs = tree.query(vertices, k=1)
t2 = time()
return idxs

class Crystal:
"""
Expand Down Expand Up @@ -1016,6 +1026,9 @@ def stockholder_weight_isosurfaces(self, kind="mol", **kwargs) -> List[Trimesh]:
meshes = []
extra_props = {}
isos = []
def nearest_atomic_number(pos, n_e, n_p):
return np.array(n_e[_nearest_atom_idx(pos, n_e, n_p)], dtype=np.uint8)

if kind == "atom":
for surrounds in self.atomic_surroundings(radius=radius):
n = surrounds["centre"]["element"]
Expand All @@ -1037,6 +1050,8 @@ def stockholder_weight_isosurfaces(self, kind="mol", **kwargs) -> List[Trimesh]:
extra_props["fragment_patch"] = lambda x: _nearest_molecule_idx(
x, n_e, n_p
)
extra_props["nearest_atom_external"] = lambda x: nearest_atomic_number(x, n_e, n_p)
extra_props["nearest_atom_internal"] = lambda x: nearest_atomic_number(x, mol.atomic_numbers, mol.positions)
s = StockholderWeight.from_arrays(
mol.atomic_numbers, mol.positions, n_e, n_p
)
Expand Down
47 changes: 44 additions & 3 deletions src/chmpy/crystal/fingerprint.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,63 @@
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import numpy as np

def filtered_histogram(mesh, internal, external, bins=200, xrange=None, yrange=None):
di = mesh.vertex_attributes["d_i"]
de = mesh.vertex_attributes["d_e"]
if xrange is None:
xrange = np.min(di), np.max(di)
if yrange is None:
yrange = np.min(de), np.max(de)
di_atom = mesh.vertex_attributes["nearest_atom_internal"]
de_atom = mesh.vertex_attributes["nearest_atom_external"]
mask = (de_atom == external) & (di_atom == internal)
return np.histogram2d(di[mask], de[mask], bins=bins, range=(xrange, yrange))

def fingerprint_histogram(mesh, bins=200, xrange=(0.5, 2.5), yrange=(0.5, 2.5)):
def fingerprint_histogram(mesh, bins=200, xrange=None, yrange=None):
di = mesh.vertex_attributes["d_i"]
de = mesh.vertex_attributes["d_e"]
if xrange is None:
xrange = np.min(di), np.max(di)
if yrange is None:
yrange = np.min(de), np.max(de)
return np.histogram2d(di, de, bins=bins, range=(xrange, yrange))


def plot_fingerprint_histogram(hist, ax=None, filename=None):
def plot_fingerprint_histogram(hist, ax=None, filename=None, cmap="coolwarm",
xlim=(0.5, 2.5), ylim=(0.5, 2.5)):
if ax is None:
fig, ax = plt.subplots()
H1, xedges, yedges = hist
X, Y = np.meshgrid(xedges, yedges)
H1[H1 == 0] = np.nan
ax.pcolormesh(X, Y, H1, cmap='coolwarm')
ax.pcolormesh(X, Y, H1, cmap=cmap)
ax.set_xlabel(r'$d_i$')
ax.set_ylabel(r'$d_e$')
ax.set_xlim(*xlim)
ax.set_ylim(*ylim)

if filename is not None:
plt.savefig(filename, dpi=300, bbox_inches="tight")

def plot_filtered_histogram(hist_filtered, hist, ax=None, filename=None, cmap="coolwarm",
xlim=(0.5, 2.5), ylim=(0.5, 2.5)):

if ax is None:
fig, ax = plt.subplots()
fig.set_size_inches(4, 4)
H1, xedges1, yedges1 = hist
H2, xedges2, yedges2 = hist_filtered
X1, Y1 = np.meshgrid(xedges1, yedges1)
H1_binary = np.where(H1 > 0, 1, np.nan)
H2[H2 == 0] = np.nan
ax.pcolormesh(X1, Y1, H1_binary, cmap='Greys_r', alpha=0.15)
ax.pcolormesh(X1, Y1, H2, cmap=cmap)
ax.set_xlabel(r'$d_i$')
ax.set_ylabel(r'$d_e$')
ax.set_xlim(*xlim)
ax.set_ylim(*ylim)

if filename is not None:
plt.savefig(filename, dpi=300, bbox_inches="tight")

0 comments on commit 18ab2f7

Please sign in to comment.