Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Models #3

Merged
merged 4 commits into from
Mar 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 127 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,131 @@
# Graph neural network for predicting NMR chemical shifts


## TODO
## Model Performance

* get initial layer
| | N | baseline |
|:------------|-----:|:--------------------|
| Mol-H-r | 307 | 0.9591749434360993 |
| Mol-H-rmsd | 307 | 0.39710393617916234 |
| P-C-r | 6701 | 0.864163 |
| P-H-r | 7747 | 0.72265 |
| P-N-r | 7640 | 0.890842 |
| P-CA-r | 8305 | 0.97374 |
| P-CB-r | 6827 | 0.990706 |
| P-CD-r | 739 | 0.996123 |
| P-CD1-r | 961 | 0.999515 |
| P-CD2-r | 609 | 0.999223 |
| P-CE-r | 340 | 0.991736 |
| P-CE1-r | 261 | 0.958121 |
| P-CE2-r | 173 | 0.943739 |
| P-CE3-r | 37 | -0.215088 |
| P-CG-r | 1674 | 0.998763 |
| P-CG1-r | 589 | 0.93124 |
| P-CG2-r | 839 | 0.829016 |
| P-CH2-r | 43 | 0.158363 |
| P-CZ-r | 125 | 0.984575 |
| P-CZ2-r | 45 | 0.311805 |
| P-CZ3-r | 37 | 0.164961 |
| P-HA-r | 5565 | 0.839377 |
| P-HA2-r | 462 | 0.495514 |
| P-HA3-r | 449 | 0.262298 |
| P-HB-r | 960 | 0.958713 |
| P-HB2-r | 3427 | 0.901358 |
| P-HB3-r | 3255 | 0.901234 |
| P-HD1-r | 383 | 0.44733 |
| P-HD11-r | 753 | 0.615756 |
| P-HD12-r | 753 | 0.585852 |
| P-HD13-r | 753 | 0.609181 |
| P-HD2-r | 1043 | 0.988991 |
| P-HD21-r | 428 | 0.617599 |
| P-HD22-r | 428 | 0.651927 |
| P-HD23-r | 428 | 0.605888 |
| P-HD3-r | 637 | 0.95089 |
| P-HE-r | 93 | 0.396258 |
| P-HE1-r | 413 | 0.879142 |
| P-HE2-r | 561 | 0.98963 |
| P-HE3-r | 293 | 0.985685 |
| P-HG-r | 389 | 0.810401 |
| P-HG1-r | 11 | 0.0653286 |
| P-HG11-r | 350 | 0.572609 |
| P-HG12-r | 350 | 0.498696 |
| P-HG13-r | 350 | 0.558426 |
| P-HG2-r | 1317 | 0.867619 |
| P-HG21-r | 936 | 0.689592 |
| P-HG22-r | 936 | 0.674086 |
| P-HG23-r | 936 | 0.662057 |
| P-HG3-r | 1200 | 0.856177 |
| P-HH-r | 1 | nan |
| P-HH2-r | 51 | 0.217372 |
| P-HZ-r | 134 | 0.407285 |
| P-HZ2-r | 54 | 0.419415 |
| P-HZ3-r | 45 | 0.318577 |
| P-ND1-r | 9 | 0.184443 |
| P-ND2-r | 173 | 0.320299 |
| P-NE-r | 88 | 0.0135033 |
| P-NE1-r | 64 | 0.0998792 |
| P-NE2-r | 149 | 0.972614 |
| P-NH1-r | 3 | -0.914066 |
| P-NH2-r | 3 | -0.276087 |
| P-NZ-r | 1 | nan |
| P-C-rmsd | 6701 | 1.22819 |
| P-H-rmsd | 7747 | 0.279766 |
| P-N-rmsd | 7640 | 6.65505 |
| P-CA-rmsd | 8305 | 1.3298 |
| P-CB-rmsd | 6827 | 3.10571 |
| P-CD-rmsd | 739 | 10.3192 |
| P-CD1-rmsd | 961 | 2.74597 |
| P-CD2-rmsd | 609 | 4.35399 |
| P-CE-rmsd | 340 | 1.14623 |
| P-CE1-rmsd | 261 | 4.69154 |
| P-CE2-rmsd | 173 | 4.82229 |
| P-CE3-rmsd | 37 | 3.0327 |
| P-CG-rmsd | 1674 | 1.63828 |
| P-CG1-rmsd | 589 | 1.558 |
| P-CG2-rmsd | 839 | 1.87753 |
| P-CH2-rmsd | 43 | 1.95861 |
| P-CZ-rmsd | 125 | 4.32496 |
| P-CZ2-rmsd | 45 | 1.22984 |
| P-CZ3-rmsd | 37 | 1.99567 |
| P-HA-rmsd | 5565 | 0.0903255 |
| P-HA2-rmsd | 462 | 0.119584 |
| P-HA3-rmsd | 449 | 0.234069 |
| P-HB-rmsd | 960 | 0.103812 |
| P-HB2-rmsd | 3427 | 0.10552 |
| P-HB3-rmsd | 3255 | 0.117287 |
| P-HD1-rmsd | 383 | 0.114696 |
| P-HD11-rmsd | 753 | 0.0699893 |
| P-HD12-rmsd | 753 | 0.0744762 |
| P-HD13-rmsd | 753 | 0.0711484 |
| P-HD2-rmsd | 1043 | 0.105893 |
| P-HD21-rmsd | 428 | 0.0737762 |
| P-HD22-rmsd | 428 | 0.0689306 |
| P-HD23-rmsd | 428 | 0.0764191 |
| P-HD3-rmsd | 637 | 0.0869007 |
| P-HE-rmsd | 93 | 0.422132 |
| P-HE1-rmsd | 413 | 0.376196 |
| P-HE2-rmsd | 561 | 0.0861489 |
| P-HE3-rmsd | 293 | 0.0855213 |
| P-HG-rmsd | 389 | 0.118694 |
| P-HG1-rmsd | 11 | 10.3704 |
| P-HG11-rmsd | 350 | 0.0504736 |
| P-HG12-rmsd | 350 | 0.0552385 |
| P-HG13-rmsd | 350 | 0.0516929 |
| P-HG2-rmsd | 1317 | 0.0654069 |
| P-HG21-rmsd | 936 | 0.0634577 |
| P-HG22-rmsd | 936 | 0.0650697 |
| P-HG23-rmsd | 936 | 0.0679991 |
| P-HG3-rmsd | 1200 | 0.0775636 |
| P-HH-rmsd | 1 | 4.07231 |
| P-HH2-rmsd | 51 | 0.0862706 |
| P-HZ-rmsd | 134 | 0.147387 |
| P-HZ2-rmsd | 54 | 0.13507 |
| P-HZ3-rmsd | 45 | 0.083249 |
| P-ND1-rmsd | 9 | 1576.13 |
| P-ND2-rmsd | 173 | 6.56618 |
| P-NE-rmsd | 88 | 231.589 |
| P-NE1-rmsd | 64 | 4.51713 |
| P-NE2-rmsd | 149 | 13.9975 |
| P-NH1-rmsd | 3 | 5.76985 |
| P-NH2-rmsd | 3 | 0.91028 |
| P-NZ-rmsd | 1 | 165.069 |
Binary file added models/baseline/saved_model.pb
Binary file not shown.
Binary file not shown.
Binary file added models/baseline/variables/variables.index
Binary file not shown.
4 changes: 2 additions & 2 deletions nmrgnn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from .metrics import *
from .losses import *

custom_things = [MeanSquaredLogartihmicErrorNames, NameRMSD, NameCorr, MPLayer,
RBFExpansion, EdgeFCBlock, MPBlock, FCBlock, corr_loss]
custom_things = [NameRMSD, NameCorr, MPLayer, NameLoss, NameCount,
RBFExpansion, EdgeFCBlock, MPBlock, FCBlock]
custom_objects = {o.__name__: o for o in custom_things}
del custom_things
1 change: 1 addition & 0 deletions nmrgnn/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def call(self, inputs):
# m -> atom atom feature output
reduced = tf.einsum('ijn,ijl,lmn,i->im', edges,
sliced_features, self.w, inv_degree)
# TODO break it up to reduce memory
out = self.activation(reduced)
# output -> N x D number of atoms x node feature dimension
if self.mpl_regularizer is not None:
Expand Down
29 changes: 5 additions & 24 deletions nmrgnn/losses.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,6 @@
import tensorflow as tf
import numpy as np


class MeanSquaredLogartihmicErrorNames(tf.keras.losses.MeanSquaredLogarithmicError):
def call(self, y_true, y_pred):
return super().call(y_true[:, 0], y_pred)


def corr_coeff(x, y, w = None):
if w is None:
w = tf.ones_like(x)
Expand All @@ -20,31 +14,18 @@ def corr_coeff(x, y, w = None):
cor = tf.math.divide_no_nan(cov, m * tf.math.sqrt(tf.clip_by_value((xm2 - xm**2) * (ym2 - ym**2), 0, 1e32)))
return cor

def corr_loss(labels, predictions, sample_weight = None, s=1e-3):
'''
Mostly correlation, with small squared diff
'''
x = predictions
y = labels[:,0]
w = labels[:,-1]
l2 = tf.math.divide_no_nan(tf.reduce_sum( w * tf.math.abs( y - x) ), tf.reduce_sum(w))
#loss = tf.keras.losses.mean_squared_logarithmic_error(y, x) * w
#l2 = tf.reduce_mean(loss)
return s * l2 + (1 - corr_coeff(x, y, w))



class NameLoss:
class NameLoss(tf.keras.losses.Loss):
'''Compute L2 loss * s + corr_loss * (1 - s) for specific atom name'''

def __init__(self, label_idx, s=1.):
def __init__(self, label_idx, s=1., name='name-loss', reduction='none'):
super(NameLoss, self).__init__(name=name, reduction=reduction)
self.label_idx = label_idx
self.ln = np.array(label_idx, dtype=np.int32)
self.s = s

def get_config(self):
config = {'label_idx': self.label_idx, 's': self.s}
return config
base_config = super(NameLoss, self).get_config()
return dict(list(base_config.items()) + list(config.items()))

def call(self, y_true, y_pred, sample_weight=None):
# mask diff by which predictions match the label
Expand Down
Loading