Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model specific Crystal MLP proxy and wrapper #81

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions config/proxy/crystals.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
_target_: gflownet.proxy.crystals.SendekMLPWrapper
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Proxy and user config file for crystal runs


feature_set: "comp"
path_to_proxy: ${user.data.crystals}
scale: True
4 changes: 4 additions & 0 deletions config/user/divya.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
logdir:
root: /network/scratch/d/divya.sharma/logs/gflownet
data:
crystals: /network/scratch/d/divya.sharma/gflownet/data/crystals
Binary file added data/crystals/all.ckpt
Binary file not shown.
Binary file added data/crystals/all_mean.pt
Binary file not shown.
Binary file added data/crystals/all_std.pt
Binary file not shown.
Binary file added data/crystals/comp.ckpt
Binary file not shown.
Binary file added data/crystals/comp_mean.pt
Binary file not shown.
Binary file added data/crystals/comp_std.pt
Binary file not shown.
10 changes: 7 additions & 3 deletions gflownet/envs/crystals.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ def __init__(
self.idx2elem = {i: e for i, e in enumerate(self.elements)}
self.eos = -1
self.action_space = self.get_actions_space()
self.fixed_policy_output = self.get_fixed_policy_output()
self.random_policy_output = self.get_fixed_policy_output()
self.policy_output_dim = len(self.fixed_policy_output)
self.policy_input_dim = len(self.state2policy())

def get_actions_space(self):
"""
Expand Down Expand Up @@ -229,7 +233,7 @@ def reset(self, env_id=None):
self.id = env_id
return self

def get_parents(self, state=None, done=None, actions=None):
def get_parents(self, state=None, done=None, action=None):
"""
Determines all parents and actions that lead to a state.

Expand Down Expand Up @@ -259,7 +263,7 @@ def get_parents(self, state=None, done=None, actions=None):
if done is None:
done = self.done
if done:
return [state], [self.eos]
return [state], [(self.eos, 0)]
else:
parents = []
actions = []
Expand All @@ -269,7 +273,7 @@ def get_parents(self, state=None, done=None, actions=None):
parent = state.copy()
parent[self.elem2idx[element]] -= n
parents.append(parent)
actions.append(idx)
actions.append(action)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return actions as a list of action tuples, instead of indices

return parents, actions

def step(self, action: Tuple[int, int]) -> Tuple[List[int], Tuple[int, int], bool]:
Expand Down
102 changes: 102 additions & 0 deletions gflownet/proxy/crystals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from gflownet.proxy.base import Proxy
import torch
import torch.nn as nn

import os.path as osp


class SendekMLPWrapper(Proxy):
"""
Wrapper for MLP proxy trained on Li-ion SSB ionic conductivities calculated
from Sendek et. al's logistic regression model

Attributes
----------

feature_set: str
supports either "comp" or "all", with the former including only
features that denote the composition of a crystal in the form of
[Li_content, Natoms, e1, e2, e3...]. The latter contains one-hot
encoded space-group values as well as the crystal structure
paramaters [a, b, c, alpha, beta, gamma]

path_to_proxy: str
path to saved torch checkpoint for MLP state dictionary

scale: bool
whether or not the input nereds to be standardised based on
mean and standard deviation of the dataset
"""

def __init__(self, feature_set, path_to_proxy, scale=False, **kwargs):
super().__init__(**kwargs)
# TODO: assert oracle_split in ["D2_target", "D2_target_fid1", "D2_target_fid2"]
# TODO: assert oracle_type in ["MLP"]
if feature_set == "comp":
self.oracle = CrystalMLP(85, [256, 256])
elif feature_set == "all":
self.oracle = CrystalMLP(322, [512, 1024, 1024, 512])
self.oracle.load_state_dict(
torch.load(osp.join(path_to_proxy, feature_set + ".ckpt"))
)
self.oracle.to(self.device)
if scale:
self.scale = {
"mean": torch.load(osp.join(path_to_proxy, feature_set + "_mean.pt")),
"std": torch.load(osp.join(path_to_proxy, feature_set + "_std.pt")),
}
else:
self.scale = None

def __call__(self, crystals):
"""
Returns a vector of size [batch_size] that are calculated
ionic conductivity values between 0 and 1
"""

if self.scale is not None:
crystals = (crystals - self.scale["mean"]) / self.scale["std"]
crystals = torch.nan_to_num(crystals, nan=0.0)
sh-divya marked this conversation as resolved.
Show resolved Hide resolved
with torch.no_grad():
scaled_ionic_conductivity = self.oracle(crystals.to(device=self.device, dtype=self.float))

return scaled_ionic_conductivity


class CrystalMLP(nn.Module):

"""
Skeleton code for an MLP that can be used to train
MLPs on ionic conductivity values of
"""

def __init__(self, in_feat, hidden_layers):
super(CrystalMLP, self).__init__()
self.nn_layers = []
self.modules = []

for i in range(len(hidden_layers)):
if i == 0:
self.nn_layers.append(nn.Linear(in_feat, hidden_layers[i]))
else:
self.nn_layers.append(nn.Linear(hidden_layers[i - 1], hidden_layers[i]))
self.modules.append(self.nn_layers[-1])
self.nn_layers.append(nn.BatchNorm1d(hidden_layers[i]))
self.nn_layers.append(nn.Linear(hidden_layers[-1], 1))
self.modules.append(self.nn_layers[-1])
self.nn_layers = nn.ModuleList(self.nn_layers)
self.hidden_act = nn.LeakyReLU(0.2)
self.drop = nn.Dropout(p=0.5)
self.final_act = nn.Sigmoid()

def forward(self, x):
for l, layer in enumerate(self.nn_layers):
if isinstance(layer, nn.BatchNorm1d):
continue
x = layer(x)
if l == len(self.nn_layers) - 1:
x = self.final_act(x)
if l % 2 == 1:
x = self.hidden_act(x)

return x
37 changes: 37 additions & 0 deletions scripts/get_periodic_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from mendeleev.fetch import fetch_table
import json


def element_maps(table_name, table_key, elems, write=False):
"""
Returns or writes to json file; map created from
"table_name" dataframe in mendeleev, using
column "atomic number" to column "table_key"

Elems are either first n number entries to
"table_name" or list of strings with element
symbols for the "elements table"
"""
base = fetch_table(table_name)
if isinstance(elems, int):
df = base.iloc[:elems]
elif isinstance(elems, tuple):
df = base.iloc[elems[0] - 1 : elems[1] - 1]
elif isinstance(elems, list) and table_name == "elements":
df = base[base[table_key].isin(elems)]

if table_name == "oxidationstates":
df = df.groupby("atomic_number")[table_key].apply(list).to_dict()
elif table_name == "elements":
df = df.set_index("atomic_number")[table_key].to_dict()

if write:
fptr = table_key + ".json"
with open(fptr, "w", encoding="utf-8") as fobj:
json.dump(df, fobj, ensure_ascii=False, indent=4)
else:
print(df)


if __name__ == "__main__":
element_maps("oxidationstates", "oxidation_state", 10)