Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modularizing encoder/decoder in Autoencoder class #10

Merged
merged 2 commits into from
Oct 26, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
245 changes: 118 additions & 127 deletions src/lasdi/latent_space.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,32 @@
import torch
import numpy as np

# activation dict
act_dict = {'ELU': torch.nn.ELU,
'hardshrink': torch.nn.Hardshrink,
'hardsigmoid': torch.nn.Hardsigmoid,
'hardtanh': torch.nn.Hardtanh,
'hardswish': torch.nn.Hardswish,
'leakyReLU': torch.nn.LeakyReLU,
'logsigmoid': torch.nn.LogSigmoid,
'multihead': torch.nn.MultiheadAttention,
'PReLU': torch.nn.PReLU,
'ReLU': torch.nn.ReLU,
'ReLU6': torch.nn.ReLU6,
'RReLU': torch.nn.RReLU,
'SELU': torch.nn.SELU,
'CELU': torch.nn.CELU,
'GELU': torch.nn.GELU,
'sigmoid': torch.nn.Sigmoid,
'SiLU': torch.nn.SiLU,
'mish': torch.nn.Mish,
'softplus': torch.nn.Softplus,
'softshrink': torch.nn.Softshrink,
'tanh': torch.nn.Tanh,
'tanhshrink': torch.nn.Tanhshrink,
'threshold': torch.nn.Threshold,
}

def initial_condition_latent(param_grid, physics, autoencoder):

'''
Expand All @@ -23,157 +49,122 @@ def initial_condition_latent(param_grid, physics, autoencoder):
Z0.append(z0)

return Z0

class Autoencoder(torch.nn.Module):
# set by physics.qgrid_size
qgrid_size = []
# prod(qgrid_size)
space_dim = -1
n_z = -1

# activation dict
act_dict = {'ELU': torch.nn.ELU,
'hardshrink': torch.nn.Hardshrink,
'hardsigmoid': torch.nn.Hardsigmoid,
'hardtanh': torch.nn.Hardtanh,
'hardswish': torch.nn.Hardswish,
'leakyReLU': torch.nn.LeakyReLU,
'logsigmoid': torch.nn.LogSigmoid,
'multihead': torch.nn.MultiheadAttention,
'PReLU': torch.nn.PReLU,
'ReLU': torch.nn.ReLU,
'ReLU6': torch.nn.ReLU6,
'RReLU': torch.nn.RReLU,
'SELU': torch.nn.SELU,
'CELU': torch.nn.CELU,
'GELU': torch.nn.GELU,
'sigmoid': torch.nn.Sigmoid,
'SiLU': torch.nn.SiLU,
'mish': torch.nn.Mish,
'softplus': torch.nn.Softplus,
'softshrink': torch.nn.Softshrink,
'tanh': torch.nn.Tanh,
'tanhshrink': torch.nn.Tanhshrink,
'threshold': torch.nn.Threshold,
}

def __init__(self, physics, config):
super(Autoencoder, self).__init__()

self.qgrid_size = physics.qgrid_size
self.space_dim = np.prod(self.qgrid_size)
hidden_units = config['hidden_units']
n_z = config['latent_dimension']
self.n_z = n_z

n_layers = len(hidden_units)
self.n_layers = n_layers

fc1_e = torch.nn.Linear(self.space_dim, hidden_units[0])
torch.nn.init.xavier_uniform_(fc1_e.weight)
self.fc1_e = fc1_e

if n_layers > 1:
for i in range(n_layers - 1):
fc_e = torch.nn.Linear(hidden_units[i], hidden_units[i + 1])
torch.nn.init.xavier_uniform_(fc_e.weight)
setattr(self, 'fc' + str(i + 2) + '_e', fc_e)

fc_e = torch.nn.Linear(hidden_units[-1], n_z)
torch.nn.init.xavier_uniform_(fc_e.weight)
setattr(self, 'fc' + str(n_layers + 1) + '_e', fc_e)

act_type = config['activation'] if 'activation' in config else 'sigmoid'
class MultiLayerPerceptron(torch.nn.Module):

def __init__(self, layer_sizes,
act_type='sigmoid', reshape_index=None, reshape_shape=None,
threshold=0.1, value=0.0, num_heads=1):
super(MultiLayerPerceptron, self).__init__()

# including input, hidden, output layers
self.n_layers = len(layer_sizes)
self.layer_sizes = layer_sizes

# Linear features between layers
self.fcs = []
for k in range(self.n_layers-1):
self.fcs += [torch.nn.Linear(layer_sizes[k], layer_sizes[k + 1])]
self.fcs = torch.nn.ModuleList(self.fcs)
self.init_weight()

# Reshape input or output layer
assert((reshape_index is None) or (reshape_index in [0, -1]))
assert((reshape_shape is None) or (np.prod(reshape_shape) == layer_sizes[reshape_index]))
self.reshape_index = reshape_index
self.reshape_shape = reshape_shape

# Initalize activation function
self.act_type = act_type
self.use_multihead = False
if act_type == "threshold":
#grab relevant initialization values from config
threshold = config["threshold"] if "threshold" in config else 0.1
value = config["value"] if "value" in config else 0.0
self.g_e = self.act_dict[act_type](threshold, value)
self.act = act_dict[act_type](threshold, value)

elif act_type == "multihead":
#grab relevant initialization values from config
num_heads = config['num_heads'] if 'num_heads' in config else 1
if n_layers > 1:
for i in range(n_layers):
setattr(self, 'a' + str(i + 1), self.act_dict[act_type](hidden_units[i], num_heads))
self.g_e = torch.nn.Identity() # No additional activation
self.use_multihead = True
if (self.n_layers > 3): # if you have more than one hidden layer
self.act = []
for i in range(self.n_layers-2):
self.act += [act_dict[act_type](layer_sizes[i+1], num_heads)]
else:
self.act = [torch.nn.Identity()] # No additional activation
self.act = torch.nn.ModuleList(self.fcs)

#all other activation functions initialized here
else:
self.g_e = self.act_dict[act_type]()

fc1_d = torch.nn.Linear(n_z, hidden_units[-1])
torch.nn.init.xavier_uniform_(fc1_d.weight)
self.fc1_d = fc1_d

if n_layers > 1:
for i in range(n_layers - 1, 0, -1):
fc_d = torch.nn.Linear(hidden_units[i], hidden_units[i - 1])
torch.nn.init.xavier_uniform_(fc_d.weight)
setattr(self, 'fc' + str(n_layers - i + 1) + '_d', fc_d)

fc_d = torch.nn.Linear(hidden_units[0], self.space_dim)
torch.nn.init.xavier_uniform_(fc_d.weight)
setattr(self, 'fc' + str(n_layers + 1) + '_d', fc_d)



def encoder(self, x):
# make sure the input has a proper shape
assert(list(x.shape[-len(self.qgrid_size):]) == self.qgrid_size)
# we use torch.Tensor.view instead of torch.Tensor.reshape,
# in order to avoid data copying.
x = x.view(list(x.shape[:-len(self.qgrid_size)]) + [self.space_dim])

for i in range(1, self.n_layers + 1):
fc = getattr(self, 'fc' + str(i) + '_e')
x = fc(x) # apply linear layer
if hasattr(self, 'a1'): # test if there is at least one attention layer
self.act = act_dict[act_type]()
return

def forward(self, x):
if (self.reshape_index == 0):
# make sure the input has a proper shape
assert(list(x.shape[-len(self.reshape_shape):]) == self.reshape_shape)
# we use torch.Tensor.view instead of torch.Tensor.reshape,
# in order to avoid data copying.
x = x.view(list(x.shape[:-len(self.reshape_shape)]) + [self.layer_sizes[self.reshape_index]])

for i in range(self.n_layers-2):
x = self.fcs[i](x) # apply linear layer
if (self.use_multihead):
x = self.apply_attention(self, x, i)
x = self.g_e(x) # apply activation function
else:
x = self.act(x)

fc = getattr(self, 'fc' + str(self.n_layers + 1) + '_e')
x = fc(x)

return x
x = self.fcs[-1](x)

if (self.reshape_index == -1):
# we use torch.Tensor.view instead of torch.Tensor.reshape,
# in order to avoid data copying.
x = x.view(list(x.shape[:-1]) + self.reshape_shape)

def decoder(self, x):
return x

def apply_attention(self, x, act_idx):
x = x.unsqueeze(1) # Add sequence dimension for attention
x, _ = self.act[act_idx](x, x, x) # apply attention
x = x.squeeze(1) # Remove sequence dimension
return x

def init_weight(self):
# TODO(kevin): support other initializations?
for fc in self.fcs:
torch.nn.init.xavier_uniform_(fc.weight)
return

for i in range(1, self.n_layers + 1):
fc = getattr(self, 'fc' + str(i) + '_d')
x = fc(x) # apply linear layer
if hasattr(self, 'a1'): # test if there is at least one attention layer
x = self.apply_attention(self, x, self.n_layers - i)
class Autoencoder(torch.nn.Module):

x = self.g_e(x) # apply activation function
def __init__(self, physics, config):
super(Autoencoder, self).__init__()

fc = getattr(self, 'fc' + str(self.n_layers + 1) + '_d')
x = fc(x)
self.qgrid_size = physics.qgrid_size
self.space_dim = np.prod(self.qgrid_size)
hidden_units = config['hidden_units']
n_z = config['latent_dimension']
self.n_z = n_z

# we use torch.Tensor.view instead of torch.Tensor.reshape,
# in order to avoid data copying.
x = x.view(list(x.shape[:-1]) + self.qgrid_size)
layer_sizes = [self.space_dim] + hidden_units + [n_z]
#grab relevant initialization values from config
act_type = config['activation'] if 'activation' in config else 'sigmoid'
threshold = config["threshold"] if "threshold" in config else 0.1
value = config["value"] if "value" in config else 0.0
num_heads = config['num_heads'] if 'num_heads' in config else 1

return x
self.encoder = MultiLayerPerceptron(layer_sizes, act_type,
reshape_index=0, reshape_shape=self.qgrid_size,
threshold=threshold, value=value, num_heads=num_heads)

self.decoder = MultiLayerPerceptron(layer_sizes[::-1], act_type,
reshape_index=-1, reshape_shape=self.qgrid_size,
threshold=threshold, value=value, num_heads=num_heads)

return

def forward(self, x):

x = self.encoder(x)
x = self.decoder(x)

return x


def apply_attention(self, x, layer):
x = x.unsqueeze(1) # Add sequence dimension for attention
a = getattr(self, 'a' + str(layer))
x, _ = a(x, x, x) # apply attention
x = x.squeeze(1) # Remove sequence dimension

return x

def export(self):
dict_ = {'autoencoder_param': self.cpu().state_dict()}
Expand Down