Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flexible property prediction heads #362

Draft
wants to merge 47 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
e8e9ef8
first sketches
M-R-Schaefer Oct 30, 2024
79872b6
removed atomistic model, added properties dict
M-R-Schaefer Nov 1, 2024
c3c950b
removed atomisticmodel import
M-R-Schaefer Nov 1, 2024
9569461
first draft of latent ewald sum
M-R-Schaefer Nov 1, 2024
2509291
working implementation of property heads
M-R-Schaefer Nov 1, 2024
70e0fc2
added property head to builder
M-R-Schaefer Nov 1, 2024
14e7f52
fixed stress and added property shift
M-R-Schaefer Nov 1, 2024
6fc2a98
moved divisor out of inner loss function
M-R-Schaefer Nov 1, 2024
a1b2957
fixed properties head
M-R-Schaefer Nov 2, 2024
ce4da9c
removed outdated comment
M-R-Schaefer Nov 2, 2024
b7d35cd
removed old additional properties flag
M-R-Schaefer Nov 2, 2024
596cd49
added additional properties to dataset
M-R-Schaefer Nov 2, 2024
232742d
fixed loss fns, added additinoal prop to convert
M-R-Schaefer Nov 2, 2024
897c3af
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 4, 2024
501a4bb
fixed bal and shallow ens
M-R-Schaefer Nov 5, 2024
36ca219
remove atomistic model
M-R-Schaefer Nov 5, 2024
af2aaf2
fixed template
M-R-Schaefer Nov 5, 2024
f339cdf
fixed jaxmd and shallow ens compat
M-R-Schaefer Nov 5, 2024
f623669
apax nodes eager mode
M-R-Schaefer Nov 5, 2024
bf8808a
no longer error when restarting finished training
M-R-Schaefer Nov 5, 2024
0a73502
fixed remaining tests
M-R-Schaefer Nov 5, 2024
14f1e89
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 5, 2024
64f77a3
Merge branch 'main' into dtensor
M-R-Schaefer Nov 6, 2024
0a568e5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 6, 2024
40e283c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 6, 2024
a0d4589
removed barrier wait
M-R-Schaefer Nov 17, 2024
cb78200
shallow ensemble compatibility for property head
M-R-Schaefer Nov 17, 2024
945c745
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 17, 2024
2bc9f57
moved dtype parsing into models
M-R-Schaefer Nov 17, 2024
29a4554
more dtype refactoring
M-R-Schaefer Nov 17, 2024
6a7b28a
atom mask compatibility with arbitrary sized arrays
M-R-Schaefer Nov 17, 2024
b6681fd
removed old additional proeprteis docstring
M-R-Schaefer Nov 17, 2024
bb7359f
remove debug comments
M-R-Schaefer Nov 17, 2024
073aba8
detect which labels to get via loss config
M-R-Schaefer Nov 17, 2024
9fdd618
Merge branch 'dtensor' of https://github.com/apax-hub/apax into dtensor
M-R-Schaefer Nov 17, 2024
bda5a04
Merge branch 'main' into dtensor
M-R-Schaefer Nov 17, 2024
9bd3075
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 17, 2024
f6705f2
fixed tests
M-R-Schaefer Nov 17, 2024
4f9f58f
Merge branch 'dtensor' of https://github.com/apax-hub/apax into dtensor
M-R-Schaefer Nov 17, 2024
25c6e31
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 17, 2024
522494d
Merge branch 'main' into dtensor
M-R-Schaefer Nov 19, 2024
748a21f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 19, 2024
6c51f15
Merge branch 'main' into dtensor
M-R-Schaefer Nov 26, 2024
13cdaef
remove direct coulomb
M-R-Schaefer Nov 26, 2024
7f4093b
fix additional properties in batchprocessor
M-R-Schaefer Nov 26, 2024
b9fd69c
full compatibility with pbp dataset
M-R-Schaefer Nov 26, 2024
fd23aba
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions apax/bal/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
check_for_ensemble,
restore_parameters,
)
from apax.utils.transform import make_energy_only_model


def create_feature_fn(
Expand Down Expand Up @@ -150,12 +151,13 @@ def kernel_selection(
_, init_box = dataset.init_input()

Builder = config.model.get_builder()
builder = Builder(config.model.get_dict(), n_species=119)
builder = Builder(config.model.model_dump(), n_species=119)

model = builder.build_energy_model(apply_mask=True, init_box=init_box)
energy_model = builder.build_energy_model(apply_mask=True, init_box=init_box)
energy_model = make_energy_only_model(energy_model.apply)

feature_fn = create_feature_fn(
model, params, base_feature_map, feature_transforms, is_ensemble
energy_model, params, base_feature_map, feature_transforms, is_ensemble
)
g = compute_features(feature_fn, dataset)
km = kernel.KernelMatrix(g, n_train)
Expand Down
8 changes: 4 additions & 4 deletions apax/bal/feature_maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def inner(ll_params):
inputs["box"],
inputs["offsets"],
)
out = model.apply(full_params, R, Z, idx, box, offsets)
out = model(full_params, R, Z, idx, box, offsets)
# take mean in case of shallow ensemble
# no effect for single model
out = jnp.mean(out)
Expand Down Expand Up @@ -108,7 +108,7 @@ def apply(self, model: EnergyModel) -> FeatureMap:
def ll_grad(params, inputs):
ll_params, remaining_params = extract_feature_params(params, self.layer_name)

energy_fn = lambda *inputs: jnp.mean(model.apply(*inputs))
energy_fn = lambda *inputs: jnp.mean(model(*inputs))
force_fn = jax.grad(energy_fn, 1)

def inner(ll_params):
Expand Down Expand Up @@ -184,7 +184,7 @@ def inner(params):
inputs["box"],
inputs["offsets"],
)
out = model.apply(params, R, Z, idx, box, offsets)
out = model(params, R, Z, idx, box, offsets)
# take mean in case of shallow ensemble
# no effect for single model
out = jnp.mean(out)
Expand Down Expand Up @@ -214,7 +214,7 @@ class IdentityFeatures(FeatureTransformation, extra="forbid"):
name: Literal["identity"]

def apply(self, model: EnergyModel) -> FeatureMap:
return model.apply
return model


FeatureMapOptions = TypeAdapter(
Expand Down
2 changes: 1 addition & 1 deletion apax/cli/apax_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def visualize_model(

R, Z, idx, box, offsets = make_minimal_input()
Builder = config.model.get_builder()
builder = Builder(config.model.get_dict(), n_species=10)
builder = Builder(config.model.model_dump(), n_species=10)
model = builder.build_energy_model()
print(model.tabulate(jax.random.PRNGKey(0), R, Z, idx, box, offsets))

Expand Down
2 changes: 0 additions & 2 deletions apax/cli/templates/train_config_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ data:
processing: cached
shuffle_buffer_size: 1000

additional_properties_info: {}

n_train: 1000
n_valid: 100

Expand Down
36 changes: 24 additions & 12 deletions apax/config/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,28 @@ class ExponentialRepulsion(Correction, extra="forbid"):
r_max: NonNegativeFloat = 1.5


EmpiricalCorrection = Union[ZBLRepulsion, ExponentialRepulsion]
class LatentEwald(Correction, extra="forbid"):
name: Literal["latent_ewald"]
kgrid: list
sigma: float = 1.0


EmpiricalCorrection = Union[ZBLRepulsion, ExponentialRepulsion, LatentEwald]


class PropertyHead(BaseModel, extra="forbid"):
""" """

name: str
aggregation: str = "none"
mode: str = "l0"

nn: List[PositiveInt] = [128, 128]
n_shallow_members: int = 0
w_init: Literal["normal", "lecun"] = "lecun"
b_init: Literal["normal", "zeros"] = "zeros"
use_ntk: bool = False
dtype: Literal["fp32", "fp64"] = "fp32"


class BaseModelConfig(BaseModel, extra="forbid"):
Expand Down Expand Up @@ -156,6 +177,8 @@ class BaseModelConfig(BaseModel, extra="forbid"):

ensemble: Optional[EnsembleConfig] = None

property_heads: list[PropertyHead] = []

# corrections
empirical_corrections: list[EmpiricalCorrection] = []

Expand All @@ -165,17 +188,6 @@ class BaseModelConfig(BaseModel, extra="forbid"):
readout_dtype: Literal["fp32", "fp64"] = "fp32"
scale_shift_dtype: Literal["fp32", "fp64"] = "fp64"

def get_dict(self):
import jax.numpy as jnp

model_dict = self.model_dump()
prec_dict = {"fp32": jnp.float32, "fp64": jnp.float64}
model_dict["descriptor_dtype"] = prec_dict[model_dict["descriptor_dtype"]]
model_dict["readout_dtype"] = prec_dict[model_dict["readout_dtype"]]
model_dict["scale_shift_dtype"] = prec_dict[model_dict["scale_shift_dtype"]]

return model_dict


class GMNNConfig(BaseModelConfig, extra="forbid"):
"""
Expand Down
3 changes: 0 additions & 3 deletions apax/config/train_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,6 @@ class DataConfig(BaseModel, extra="forbid"):
| Number of validation examples to be evaluated at once.
shuffle_buffer_size : int, default = 1000
| Size of the `tf.data` shuffle buffer.
additional_properties_info : dict, optional
| dict of property name, shape (ragged or fixed) pairs. Currently unused.
energy_regularisation :
| Magnitude of the regularization in the per-element energy regression.
pos_unit : str, default = "Ang"
Expand All @@ -141,7 +139,6 @@ class DataConfig(BaseModel, extra="forbid"):
n_valid: PositiveInt = 100
batch_size: PositiveInt = 32
valid_batch_size: PositiveInt = 100
additional_properties_info: dict[str, str] = {}

shift_method: str = "per_element_regression_shift"
shift_options: dict = {"energy_regularisation": 1.0}
Expand Down
47 changes: 43 additions & 4 deletions apax/data/input_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def __init__(
n_jit_steps=1,
pos_unit: str = "Ang",
energy_unit: str = "eV",
additional_properties: list[tuple] = [],
pre_shuffle=False,
shuffle_buffer_size=1000,
ignore_labels=False,
Expand All @@ -102,6 +103,7 @@ def __init__(
self.n_data = len(atoms_list)
self.batch_size = self.validate_batch_size(bs)
self.pos_unit = pos_unit
self.additional_properties = additional_properties

if pre_shuffle:
shuffle(atoms_list)
Expand All @@ -112,7 +114,9 @@ def __init__(
self.max_atoms = max_atoms
self.max_nbrs = max_nbrs
if atoms_list[0].calc and not ignore_labels:
self.labels = atoms_to_labels(atoms_list, pos_unit, energy_unit)
self.labels = atoms_to_labels(
atoms_list, pos_unit, energy_unit, additional_properties
)
else:
self.labels = None

Expand Down Expand Up @@ -161,8 +165,16 @@ def prepare_data(self, i):
labels["forces"] = np.pad(
labels["forces"], ((0, zeros_to_add), (0, 0)), "constant"
)

for prop in self.additional_properties:
name, shape = prop
if shape[0] == "natoms":
pad_shape = [(0, zeros_to_add)] + [(0, 0)] * (len(shape) - 1)
labels[name] = np.pad(labels[name], pad_shape, "constant")

inputs = {k: tf.constant(v) for k, v in inputs.items()}
labels = {k: tf.constant(v) for k, v in labels.items()}

return (inputs, labels)

def enqueue(self, num_elements):
Expand Down Expand Up @@ -202,6 +214,14 @@ def make_signature(self) -> tf.TensorSpec:
label_signature["stress"] = tf.TensorSpec(
(3, 3), dtype=tf.float64, name="stress"
)

for prop in self.additional_properties:
name, shape = prop
if shape[0] == "natoms":
shape[0] = self.max_atoms

sig = tf.TensorSpec(tuple(shape), dtype=tf.float64, name=name)
label_signature[name] = sig
signature = (input_signature, label_signature)
return signature

Expand Down Expand Up @@ -377,14 +397,17 @@ def round_up_to_multiple(value, multiple):

class BatchProcessor:
def __init__(

self, cutoff, atom_padding: int, nl_padding: int, forces=True, stress=False
, additional_properties=[]
) -> None:
self.cutoff = cutoff
self.atom_padding = atom_padding
self.nl_padding = nl_padding

self.forces = forces
self.stress = stress
self.additional_properties = additional_properties

def __call__(self, samples: list[dict]):
n_samples = len(samples)
Expand All @@ -401,7 +424,12 @@ def __call__(self, samples: list[dict]):
labels = {
"energy": np.zeros(n_samples, dtype=np.float64),
}

for prop in self.additional_properties:
name, shape = prop
if shape[0] == "natoms":
shape = [max_atoms] + shape[1:]
shape = [n_samples] + shape
labels[name] = np.zeros(shape, dtype=np.float64)
if self.forces:
labels["forces"] = np.zeros((n_samples, max_atoms, 3), dtype=np.float64)
if self.stress:
Expand All @@ -425,6 +453,13 @@ def __call__(self, samples: list[dict]):
if self.stress:
labels["stress"][i] = lab["stress"]

for prop in self.additional_properties:
name, shape = prop
if shape[0] == "natoms":
labels[name][i, : inp["n_atoms"]] = lab[name]
else:
labels[name][i] = lab[name]

max_nbrs = np.max([idx.shape[1] for idx in idxs])
max_nbrs = round_up_to_multiple(max_nbrs, self.nl_padding)

Expand Down Expand Up @@ -470,6 +505,7 @@ def __init__(
nl_padding: int = 2000,
pos_unit: str = "Ang",
energy_unit: str = "eV",
additional_properties=[],
pre_shuffle=False,
) -> None:
self.cutoff = cutoff
Expand All @@ -478,6 +514,7 @@ def __init__(
self.n_data = len(atoms_list)
self.batch_size = self.validate_batch_size(bs)
self.pos_unit = pos_unit
self.additional_properties = additional_properties

if num_workers:
self.num_workers = num_workers
Expand All @@ -488,7 +525,9 @@ def __init__(

# Transform atoms into inputs and labels
self.inputs = atoms_to_inputs(atoms_list, pos_unit)
self.labels = atoms_to_labels(atoms_list, pos_unit, energy_unit)
self.labels = atoms_to_labels(
atoms_list, pos_unit, energy_unit, additional_properties
)
label_keys = self.labels.keys()

self.data = list(
Expand All @@ -500,7 +539,7 @@ def __init__(
forces = "forces" in label_keys
stress = "stress" in label_keys
self.prepare_batch = BatchProcessor(
cutoff, atom_padding, nl_padding, forces, stress
cutoff, atom_padding, nl_padding, forces, stress, additional_properties
)

self.count = 0
Expand Down
22 changes: 13 additions & 9 deletions apax/layers/descriptor/basis_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np

from apax.layers.initializers import uniform_range
from apax.utils.convert import str_to_dtype


class GaussianBasis(nn.Module):
Expand All @@ -15,6 +16,8 @@ class GaussianBasis(nn.Module):
dtype: Any = jnp.float32

def setup(self):
dtype = str_to_dtype(self.dtype)

self.betta = self.n_basis**2 / self.r_max**2
self.rad_norm = (2.0 * self.betta / np.pi) ** 0.25
shifts = self.r_min + (self.r_max - self.r_min) / self.n_basis * np.arange(
Expand All @@ -23,7 +26,7 @@ def setup(self):

# shape: 1 x n_basis
shifts = einops.repeat(shifts, "n_basis -> 1 n_basis")
self.shifts = jnp.asarray(shifts, dtype=self.dtype)
self.shifts = jnp.asarray(shifts, dtype=dtype)

def __call__(self, dr):
dr = einops.repeat(dr, "neighbors -> neighbors 1")
Expand All @@ -47,7 +50,8 @@ class BesselBasis(nn.Module):
dtype: Any = jnp.float32

def setup(self):
self.n = jnp.arange(self.n_basis, dtype=self.dtype)
dtype = str_to_dtype(self.dtype)
self.n = jnp.arange(self.n_basis, dtype=dtype)

def __call__(self, dr):
dr = einops.repeat(dr, "neighbors -> neighbors 1")
Expand All @@ -69,10 +73,9 @@ class RadialFunction(nn.Module):
dtype: Any = jnp.float32

def setup(self):
dtype = str_to_dtype(self.dtype)
self.r_max = self.basis_fn.r_max
self.embed_norm = jnp.array(
1.0 / np.sqrt(self.basis_fn.n_basis), dtype=self.dtype
)
self.embed_norm = jnp.array(1.0 / np.sqrt(self.basis_fn.n_basis), dtype=dtype)
if self.one_sided_dist:
lower_bound = 0.0
else:
Expand All @@ -81,7 +84,7 @@ def setup(self):
if self.emb_init is not None:
self._n_radial = self.n_radial
if self.emb_init == "uniform":
emb_initializer = uniform_range(lower_bound, 1.0, dtype=self.dtype)
emb_initializer = uniform_range(lower_bound, 1.0, dtype=dtype)
self.embeddings = self.param(
"atomic_type_embedding",
emb_initializer,
Expand All @@ -91,7 +94,7 @@ def setup(self):
self.n_radial,
self.basis_fn.n_basis,
),
self.dtype,
dtype,
)
else:
raise ValueError(
Expand All @@ -102,7 +105,8 @@ def setup(self):
self._n_radial = self.basis_fn.n_basis

def __call__(self, dr, Z_i, Z_j):
dr = dr.astype(self.dtype)
dtype = str_to_dtype(self.dtype)
dr = dr.astype(dtype)
# basis shape: neighbors x n_basis
basis = self.basis_fn(dr)

Expand All @@ -128,6 +132,6 @@ def __call__(self, dr, Z_i, Z_j):

radial_function = radial_function * cutoff

assert radial_function.dtype == self.dtype
assert radial_function.dtype == dtype

return radial_function
Loading