Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/ur-whitelab/exmol
Browse files Browse the repository at this point in the history
  • Loading branch information
hgandhi2411 committed Nov 19, 2022
2 parents f9176c7 + 08b5739 commit d9558a5
Show file tree
Hide file tree
Showing 7 changed files with 403 additions and 346 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,4 @@ jobs:
- name: Run Test
run: |
pytest tests
mypy -p exmol --ignore-missing-imports
# mypy -p exmol --ignore-missing-imports
5 changes: 5 additions & 0 deletions NOTICE.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,8 @@ Copyright (c) 2021 White Laboratory
This product includes software develped by Aspuru-Guzik group (Apache 2.0)
https://github.com/aspuru-guzik-group/stoned-selfies
Copyright (c) 2021 Aspuru-Guzik group

This product includes software developed by Christian Laggner (LGPL 2.0)
This software is specifically the SMARTS file, which is separate and distributed
as source code.
Copyright 2005 Inte:Ligand Software-Entwicklungs und Consulting GmbH
7 changes: 7 additions & 0 deletions docs/source/changelog.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
Change Log
==========


v2.2.1
-------------------
* Fixed bug in sorting for text explanations
* Fixed empty plot names saying `None`
* Added priority for naming and removed invalid names

v2.2.0 (2022-11-3)
-------------------
* Added natural language explanation method
Expand Down
94 changes: 61 additions & 33 deletions exmol/exmol.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,22 +175,32 @@ def _bit2atoms(m, bitInfo, key):
# get the atoms for highlighting
atoms = set((i,))
for b in bitPath:
atoms.add(m.GetBondWithIdx(b).GetBeginAtomIdx())
atoms.add(m.GetBondWithIdx(b).GetEndAtomIdx())
a = m.GetBondWithIdx(b).GetBeginAtomIdx()
atoms.add(a)
a = m.GetBondWithIdx(b).GetEndAtomIdx()
atoms.add(a)
return atoms


def _load_smarts(path):
smarts = []
def _load_smarts(path, rank_cutoff=500):
# we have a rank cut for SMARTS that match too often
smarts = {}
with open(path) as f:
for line in f.readlines():
if line[0] == "#":
continue
i = line.find(":")
sm = line[i + 1 :].strip()
m = MolFromSmarts(sm)
smarts.append((line[:i].strip(), m))
return smarts[::-1]
i1 = line.find(":")
i2 = line.find(":", i1 + 1)
m = MolFromSmarts(line[i2 + 1 :].strip())
rank = int(line[i1 + 1 : i2])
if rank > rank_cutoff:
continue
name = line[:i1]
if m is None:
print(f"Could not parse SMARTS: {line}")
print(line[i2:].strip())
smarts[name] = (m, rank)
return smarts


def _name_morgan_bit(m, bitInfo, key):
Expand All @@ -203,20 +213,33 @@ def _name_morgan_bit(m, bitInfo, key):
_SMARTS = _load_smarts(sp)
morgan_atoms = _bit2atoms(m, bitInfo, key)
names = []
for name, sm in _SMARTS:
for name, (sm, r) in _SMARTS.items():
matches = m.GetSubstructMatches(sm)
for match in matches:
# check if match is in morgan bit
match = set(match)
if match.issubset(morgan_atoms):
names.append((len(match), name))
names.sort()
names.append((r, name))
names.sort(key=lambda x: x[0])
# short-circuit if single atom
# if len(morgan_atoms) == 1:
# return m.GetAtomWithIdx(bitInfo[key][0][0]).GetSymbol()
if len(names) == 0:
if len(morgan_atoms) == 1:
# only 1 atom, just return element
return m.GetAtomWithIdx(list(morgan_atoms)[0]).GetSymbol()
return None
return names[-1][1].replace("_", " ")
return names[0][1].replace("_", " ")


def clear_descriptors(
examples: List[Example],
) -> List[Example]:
"""Clears all descriptors from examples
:param examples: list of examples
:param descriptor_type: type of descriptor to clear, if None, all descriptors are cleared
"""
for e in examples:
e.descriptors = None # type: ignore
return examples


def add_descriptors(
Expand Down Expand Up @@ -281,9 +304,18 @@ def add_descriptors(
descriptor_names = _get_joint_ecfp_descriptors(examples)
for e, m in zip(examples, mols):
# Now compare to reference and get other fp vectors
b = {} # type: Dict[Any, Any]
temp_fp = AllChem.GetMorganFingerprint(m, 3, bitInfo=b)
descriptors = tuple([1 if x in b.keys() else 0 for x in descriptor_names])
bitInfo = {} # type: Dict[Any, Any]
temp_fp = AllChem.GetMorganFingerprint(m, 3, bitInfo=bitInfo)
# remove single atoms from ecfp descriptors
to_del = []
for b in bitInfo:
if bitInfo[b][0][1] == 0:
to_del.append(b)
for b in to_del:
del bitInfo[b]
descriptors = tuple(
[1 if x in bitInfo.keys() else 0 for x in descriptor_names]
)
e.descriptors = Descriptors(
descriptor_type=descriptor_type,
descriptors=descriptors,
Expand Down Expand Up @@ -1068,6 +1100,9 @@ def plot_descriptors(
m = smi2mol(examples[0].smiles)
fp = AllChem.GetMorganFingerprint(m, 3, bitInfo=bi)
for rect, ti, k, ki, n in zip(bar1, t, keys, key_ids, names):
# account for Nones
if n is None:
n = ""
# annotate patches with text desciption
y = rect.get_y() + rect.get_height() / 2.0
n = textwrap.fill(str(n), 20)
Expand Down Expand Up @@ -1229,35 +1264,28 @@ def text_explain(
multiple_bases = _check_multiple_bases(examples)

# Take t-statistics, rank them
tstats = list(examples[0].descriptors.tstats)
d_importance = {
n: t # name: [t-stat, index]
d_importance = [
(n, t) # name: t-stat
for i, (n, t) in enumerate(
zip(
examples[0].descriptors.plotting_names,
tstats,
examples[0].descriptors.tstats,
)
)
# don't want NANs and want match (if not multiple bases)
if not np.isnan(t)
and multiple_bases
or examples[0].descriptors.descriptors[i] != 0
}
if not np.isnan(t) and True or examples[0].descriptors.descriptors[i] != 0
]

d_importance = dict(
sorted(d_importance.items(), key=lambda item: abs(item[1]), reverse=True)
)
d_importance = sorted(d_importance, key=lambda x: abs(x[1]), reverse=True)
# get significance value - if >significance, then important else weakly important?
w = np.array([1 / (1 + (1 / (e.similarity + 0.000001) - 1) ** 5) for e in examples])
effective_n = np.sum(w) ** 2 / np.sum(w**2)
T = ss.t.ppf(0.975, df=effective_n)

# text explanation
success = 0
result = []
existing_names = set()
for k, v in d_importance.items():

for k, v in d_importance:
if success == count:
break
name = k
Expand Down
Loading

0 comments on commit d9558a5

Please sign in to comment.