Skip to content

Commit

Permalink
Resolve #126
Browse files Browse the repository at this point in the history
Co-authored-by: Félix <[email protected]>
  • Loading branch information
Linux-cpp-lisp and felixmusil committed Dec 20, 2022
1 parent 1c63525 commit c92705e
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 21 deletions.
21 changes: 7 additions & 14 deletions nequip/nn/_convnetlayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,10 @@
GraphModuleMixin,
InteractionBlock,
)
from nequip.nn.nonlinearities import ShiftedSoftPlus
from nequip.nn.nonlinearities import get_nonlinearity
from nequip.utils.tp_utils import tp_path_exists


acts = {
"abs": torch.abs,
"tanh": torch.tanh,
"ssp": ShiftedSoftPlus,
"silu": torch.nn.functional.silu,
}


class ConvNetLayer(GraphModuleMixin, torch.nn.Module):
"""
Args:
Expand Down Expand Up @@ -96,15 +88,16 @@ def __init__(
)
irreps_gates = o3.Irreps([(mul, ir) for mul, _ in irreps_gated])

# TO DO, it's not that safe to directly use the
# dictionary
equivariant_nonlin = Gate(
irreps_scalars=irreps_scalars,
act_scalars=[
acts[nonlinearity_scalars[ir.p]] for _, ir in irreps_scalars
get_nonlinearity(nonlinearity_scalars[ir.p])
for _, ir in irreps_scalars
],
irreps_gates=irreps_gates,
act_gates=[acts[nonlinearity_gates[ir.p]] for _, ir in irreps_gates],
act_gates=[
get_nonlinearity(nonlinearity_gates[ir.p]) for _, ir in irreps_gates
],
irreps_gated=irreps_gated,
)

Expand All @@ -116,7 +109,7 @@ def __init__(
equivariant_nonlin = NormActivation(
irreps_in=conv_irreps_out,
# norm is an even scalar, so use nonlinearity_scalars[1]
scalar_nonlinearity=acts[nonlinearity_scalars[1]],
scalar_nonlinearity=get_nonlinearity(nonlinearity_scalars[1]),
normalize=True,
epsilon=1e-8,
bias=False,
Expand Down
7 changes: 2 additions & 5 deletions nequip/nn/_interaction_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from e3nn.o3 import TensorProduct, Linear, FullyConnectedTensorProduct

from nequip.data import AtomicDataDict
from nequip.nn.nonlinearities import ShiftedSoftPlus
from nequip.nn.nonlinearities import get_nonlinearity
from ._graph_mixin import GraphModuleMixin


Expand Down Expand Up @@ -115,10 +115,7 @@ def __init__(
[self.irreps_in[AtomicDataDict.EDGE_EMBEDDING_KEY].num_irreps]
+ invariant_layers * [invariant_neurons]
+ [tp.weight_numel],
{
"ssp": ShiftedSoftPlus,
"silu": torch.nn.functional.silu,
}[nonlinearity_scalars["e"]],
get_nonlinearity(nonlinearity_scalars["e"]),
)

self.tp = tp
Expand Down
27 changes: 25 additions & 2 deletions nequip/nn/nonlinearities.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,31 @@
import torch
from typing import Callable

import math

import torch

from e3nn.util.jit import compile_mode


@torch.jit.script
def ShiftedSoftPlus(x):
def shifted_soft_plus(x: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.softplus(x) - math.log(2.0)


@compile_mode("script")
class ShiftedSoftPlus(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return shifted_soft_plus(x)


def get_nonlinearity(name: str) -> Callable:
if name == "abs":
return torch.abs
elif name == "tanh":
return torch.nn.Tanh()
elif name == "ssp":
return ShiftedSoftPlus()
elif name == "silu":
return torch.nn.SiLU()
else:
raise KeyError(f"No such nonlinearity `{name}`")

0 comments on commit c92705e

Please sign in to comment.