Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix pbc conv args Final Draft #306

Merged
merged 15 commits into from
Nov 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions examples/LennardJones/LJ_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@ def __init__(self, dirpath, config, dist=False, sampling=None):
self.dataset.append(self.transform_input_to_data_object_base(filepath))

def transform_input_to_data_object_base(self, filepath):

# Using readline()
file = open(filepath, "r")

Expand Down Expand Up @@ -174,6 +173,11 @@ def transform_input_to_data_object_base(self, filepath):
.unsqueeze(0)
.to(torch.float32),
energy=torch.tensor(total_energy).unsqueeze(0).to(torch.float32),
pbc=[
True,
True,
True,
], # LJ example always has periodic boundary conditions
)

# Create pbc edges and lengths
Expand Down Expand Up @@ -205,7 +209,6 @@ def deterministic_graph_data(
unit_cell_z_range: list = [3, 4],
relative_maximum_atomic_displacement: float = 1e-1,
):

comm = MPI.COMM_WORLD
comm_size = comm.Get_size()
comm_rank = comm.Get_rank()
Expand Down Expand Up @@ -330,6 +333,7 @@ def create_configuration(
data.supercell_size = torch.diag(
torch.tensor([supercell_size_x, supercell_size_y, supercell_size_z])
)
data.pbc = [True, True, True]

create_graph_connectivity_pbc = get_radius_graph_pbc(
radius_cutoff, max_num_neighbors
Expand Down Expand Up @@ -379,27 +383,23 @@ class AtomicStructureHandler:
def __init__(
self, list_atom_types, bravais_lattice_constants, radius_cutoff, formula
):

self.bravais_lattice_constants = bravais_lattice_constants
self.radius_cutoff = radius_cutoff
self.formula = formula

def compute(self, data):

assert data.pos.shape[0] == data.x.shape[0]

interatomic_potential = torch.zeros([data.pos.shape[0], 1])
interatomic_forces = torch.zeros([data.pos.shape[0], 3])

for node_id in range(data.pos.shape[0]):

neighbor_list_indices = torch.where(data.edge_index[0, :] == node_id)[
0
].tolist()
neighbor_list = data.edge_index[1, neighbor_list_indices]

for neighbor_id, edge_id in zip(neighbor_list, neighbor_list_indices):

neighbor_pos = data.pos[neighbor_id, :]
distance_vector = data.pos[neighbor_id, :] - data.pos[node_id, :]

Expand Down
24 changes: 18 additions & 6 deletions hydragnn/models/DIMEStack.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from torch_geometric.utils import scatter

from .Base import Base
from hydragnn.utils.model.operations import get_edge_vectors_and_lengths


class DIMEStack(Base):
Expand Down Expand Up @@ -144,23 +145,34 @@ def get_conv(self, input_dim, output_dim):
)

def _embedding(self, data):
super()._embedding(data)

assert (
data.pos is not None
), "DimeNet requires node positions (data.pos) to be set."

# Calculate triplet indices
i, j, idx_i, idx_j, idx_k, idx_kj, idx_ji = triplets(
data.edge_index, num_nodes=data.x.size(0)
)
dist = (data.pos[i] - data.pos[j]).pow(2).sum(dim=-1).sqrt()

# Calculate angles.
pos_i = data.pos[idx_i]
pos_ji, pos_ki = data.pos[idx_j] - pos_i, data.pos[idx_k] - pos_i
# Calculate edge_vec and edge_dist
edge_vec, edge_dist = get_edge_vectors_and_lengths(
data.pos, data.edge_index, data.edge_shifts
)

# Calculate angles
pos_ji = edge_vec[idx_ji]
pos_kj = edge_vec[idx_kj]
pos_ki = (
pos_kj + pos_ji
) # It's important to calculate the vectors separately and then add in case of periodic boundary conditions
a = (pos_ji * pos_ki).sum(dim=-1)
b = torch.cross(pos_ji, pos_ki).norm(dim=-1)
angle = torch.atan2(b, a)

rbf = self.rbf(dist)
sbf = self.sbf(dist, angle, idx_kj)
rbf = self.rbf(edge_dist.squeeze())
sbf = self.sbf(edge_dist.squeeze(), angle, idx_kj)

conv_args = {
"rbf": rbf,
Expand Down
25 changes: 13 additions & 12 deletions hydragnn/models/EGCLStack.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .Base import Base

from hydragnn.utils.model import unsorted_segment_mean
from hydragnn.utils.model.operations import get_edge_vectors_and_lengths


class EGCLStack(Base):
Expand Down Expand Up @@ -89,6 +90,12 @@ def get_conv(self, input_dim, output_dim, last_layer=False):
)

def _embedding(self, data):
super()._embedding(data)

data.edge_shifts = torch.zeros(
(data.edge_index.size(1), 3), device=data.edge_index.device
) # Override. pbc edge shifts are currently not supported in positional update models

if self.edge_dim > 0:
conv_args = {
"edge_index": data.edge_index,
Expand Down Expand Up @@ -229,20 +236,14 @@ def coord_model(self, coord, edge_index, coord_diff, edge_feat):
coord = coord + agg * self.coords_weight
return coord

def coord2radial(self, edge_index, coord):
row, col = edge_index
coord_diff = coord[row] - coord[col]
radial = torch.sum((coord_diff) ** 2, 1).unsqueeze(1)

if self.norm_diff:
norm = torch.sqrt(radial) + 1
coord_diff = coord_diff / (norm)

return radial, coord_diff

def forward(self, x, coord, edge_index, edge_attr, node_attr=None):
row, col = edge_index
radial, coord_diff = self.coord2radial(edge_index, coord)
edge_shifts = torch.zeros(
(edge_index.size(1), 3), device=edge_index.device
) # pbc edge shifts are currently not supported in positional update models
coord_diff, radial = get_edge_vectors_and_lengths(
coord, edge_index, edge_shifts, normalize=self.norm_diff, eps=1.0
)
# Message Passing
edge_feat = self.edge_model(x[row], x[col], radial, edge_attr)
if self.equivariant:
Expand Down
1 change: 1 addition & 0 deletions hydragnn/models/MACEStack.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
NonLinearMultiheadDecoderBlock,
LinearMultiheadDecoderBlock,
)
from hydragnn.utils.model.operations import get_edge_vectors_and_lengths

# Etc
import numpy as np
Expand Down
33 changes: 16 additions & 17 deletions hydragnn/models/PAINNStack.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from torch.utils.checkpoint import checkpoint

from .Base import Base
from hydragnn.utils.model.operations import get_edge_vectors_and_lengths


class PAINNStack(Base):
Expand Down Expand Up @@ -125,24 +126,25 @@ def get_conv(self, input_dim, output_dim, last_layer=False):
)

def _embedding(self, data):
super()._embedding(data)

assert (
data.pos is not None
), "PAINNNet requires node positions (data.pos) to be set."
), "PAINN requires node positions (data.pos) to be set."

# Calculate relative vectors and distances
i, j = data.edge_index[0], data.edge_index[1]
diff = data.pos[i] - data.pos[j]
dist = diff.pow(2).sum(dim=-1).sqrt()
norm_diff = diff / dist.unsqueeze(-1)
# Get normalized edge vectors and lengths
norm_edge_vec, edge_dist = get_edge_vectors_and_lengths(
data.pos, data.edge_index, data.edge_shifts, normalize=True
)

# Instantiate tensor to hold equivariant traits
v = torch.zeros(data.x.size(0), 3, data.x.size(1), device=data.x.device)
data.v = v

conv_args = {
"edge_index": data.edge_index.t().to(torch.long),
"diff": norm_diff,
"dist": dist,
"diff": norm_edge_vec,
"dist": edge_dist,
}

return data.x, data.v, conv_args
Expand Down Expand Up @@ -171,9 +173,8 @@ def forward(self, node_scalar, node_vector, edge, edge_diff, edge_dist):
filter_weight = self.filter_layer(
sinc_expansion(edge_dist, self.edge_size, self.cutoff)
)
filter_weight = filter_weight * cosine_cutoff(edge_dist, self.cutoff).unsqueeze(
-1
)
filter_weight = filter_weight * cosine_cutoff(edge_dist, self.cutoff)

scalar_out = self.scalar_message_mlp(node_scalar)
filter_out = filter_weight * scalar_out[edge[:, 1]]

Expand All @@ -185,9 +186,9 @@ def forward(self, node_scalar, node_vector, edge, edge_diff, edge_dist):

# num_pairs * 3 * node_size, num_pairs * node_size
message_vector = node_vector[edge[:, 1]] * gate_state_vector.unsqueeze(1)
edge_vector = gate_edge_vector.unsqueeze(1) * (
edge_diff / edge_dist.unsqueeze(-1)
).unsqueeze(-1)
edge_vector = gate_edge_vector.unsqueeze(1) * (edge_diff / edge_dist).unsqueeze(
-1
)
message_vector = message_vector + edge_vector

# sum message
Expand Down Expand Up @@ -266,9 +267,7 @@ def sinc_expansion(edge_dist: torch.Tensor, edge_size: int, cutoff: float):
sin(n * pi * d / d_cut) / d
"""
n = torch.arange(edge_size, device=edge_dist.device) + 1
return torch.sin(
edge_dist.unsqueeze(-1) * n * torch.pi / cutoff
) / edge_dist.unsqueeze(-1)
return torch.sin(edge_dist * n * torch.pi / cutoff) / edge_dist


def cosine_cutoff(edge_dist: torch.Tensor, cutoff: float):
Expand Down
23 changes: 16 additions & 7 deletions hydragnn/models/PNAEqStack.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@
from torch_geometric.nn.aggr.scaler import DegreeScalerAggregation
from torch_geometric.typing import Adj, OptTensor

# HydraGNN
from .Base import Base
from hydragnn.utils.model.operations import get_edge_vectors_and_lengths


class PNAEqStack(Base):
Expand Down Expand Up @@ -156,16 +158,17 @@ def get_conv(self, input_dim, output_dim, last_layer=False):
)

def _embedding(self, data):
super()._embedding(data)

assert (
data.pos is not None
), "PNAEq requires node positions (data.pos) to be set."

# Calculate relative vectors and distances
i, j = data.edge_index[0], data.edge_index[1]
diff = data.pos[i] - data.pos[j]
dist = diff.pow(2).sum(dim=-1).sqrt()
rbf = self.rbf(dist)
norm_diff = diff / dist.unsqueeze(-1)
# Edge vector and distance features
norm_edge_vec, edge_dist = get_edge_vectors_and_lengths(
data.pos, data.edge_index, data.edge_shifts, normalize=True
)
rbf = self.rbf(edge_dist.squeeze())

# Instantiate tensor to hold equivariant traits
v = torch.zeros(data.x.size(0), 3, data.x.size(1), device=data.x.device)
Expand All @@ -174,9 +177,15 @@ def _embedding(self, data):
conv_args = {
"edge_index": data.edge_index.t().to(torch.long),
"edge_rbf": rbf,
"edge_vec": norm_diff,
"edge_vec": norm_edge_vec,
}

if self.use_edge_attr:
assert (
data.edge_attr is not None
), "Data must have edge attributes if use_edge_attributes is set."
conv_args.update({"edge_attr": data.edge_attr})

return data.x, data.v, conv_args


Expand Down
13 changes: 9 additions & 4 deletions hydragnn/models/PNAPlusStack.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

# HydraGNN
from .Base import Base
from hydragnn.utils.model.operations import get_edge_vectors_and_lengths


class PNAPlusStack(Base):
Expand Down Expand Up @@ -98,14 +99,18 @@ def get_conv(self, input_dim, output_dim):
)

def _embedding(self, data):
super()._embedding(data)

assert (
data.pos is not None
), "PNA+ requires node positions (data.pos) to be set."

j, i = data.edge_index # j->i
dist = (data.pos[i] - data.pos[j]).pow(2).sum(dim=-1).sqrt()
rbf = self.rbf(dist)
# rbf = dist.unsqueeze(-1)
# Radial embedding
_, edge_dist = get_edge_vectors_and_lengths(
data.pos, data.edge_index, data.edge_shifts
)
rbf = self.rbf(edge_dist.squeeze())

conv_args = {"edge_index": data.edge_index.to(torch.long), "rbf": rbf}

if self.use_edge_attr:
Expand Down
24 changes: 12 additions & 12 deletions hydragnn/models/SCFStack.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from torch_geometric.nn import Sequential as PyGSeq
from torch_geometric.nn import MessagePassing
from torch_geometric.nn.models.schnet import (
CFConv,
GaussianSmearing,
RadiusInteractionGraph,
ShiftedSoftplus,
Expand All @@ -27,6 +26,7 @@
from .Base import Base

from hydragnn.utils.model import unsorted_segment_mean
from hydragnn.utils.model.operations import get_edge_vectors_and_lengths


class SCFStack(Base):
Expand Down Expand Up @@ -135,12 +135,17 @@ def get_conv(self, input_dim, output_dim, last_layer):
)

def _embedding(self, data):
super()._embedding(data)

if (self.use_edge_attr) and (self.equivariance):
raise Exception(
"For SchNet if using edge attributes, then E(3)-equivariance cannot be ensured. Please disable equivariance or edge attributes."
)
elif self.use_edge_attr:
edge_index = data.edge_index
data.edge_shifts = torch.zeros(
(data.edge_index.size(1), 3), device=data.edge_index.device
) # Override. pbc edge shifts are currently not supported in positional update models
edge_weight = data.edge_attr.norm(dim=-1)

conv_args = {
Expand Down Expand Up @@ -218,7 +223,12 @@ def forward(
x = self.lin1(x)

if self.equivariant:
radial, coord_diff = self.coord2radial(edge_index, pos)
edge_shifts = torch.zeros(
(edge_index.size(1), 3), device=edge_index.device
) # pbc edge shifts are currently not supported in positional update models
coord_diff, radial = get_edge_vectors_and_lengths(
pos, edge_index, edge_shifts, normalize=True, eps=1.0
)
pos = self.coord_model(pos, edge_index, coord_diff, W)

x = self.propagate(edge_index, x=x, W=W)
Expand All @@ -230,13 +240,3 @@ def forward(

def message(self, x_j: Tensor, W: Tensor) -> Tensor:
return x_j * W

def coord2radial(self, edge_index, coord):
row, col = edge_index
coord_diff = coord[row] - coord[col]
radial = torch.sum((coord_diff) ** 2, 1).unsqueeze(1)

norm = torch.sqrt(radial) + 1
coord_diff = coord_diff / (norm)

return radial, coord_diff
Loading
Loading