-
Notifications
You must be signed in to change notification settings - Fork 7
/
specification_oscillator2_torch.txt
65 lines (49 loc) · 2.03 KB
/
specification_oscillator2_torch.txt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
"""
Find the mathematical function skeleton that represents acceleration in a damped nonlinear oscillator system with driving force, given data on time, position, and velocity.
This function form can only contain differentiable mathematial terms.
"""
import torch
#Initialize parameters
MAX_NPARAMS = 10
PRAMS_INIT = [torch.nn.Parameter(torch.tensor(1.0)) for _ in range(MAX_NPARAMS)]
@evaluate.run
def evaluate(data: dict) -> float:
""" Evaluate the equation on data observations."""
# Load data observations
inputs, outputs = data['inputs'], data['outputs']
t, x, v = inputs[:,0], inputs[:,1], inputs[:,2]
# Optimize parameters based on data
LR = 1e-4
N_ITERATIONS = 10000
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.params = torch.nn.ParameterList(PRAMS_INIT)
def forward(self, t, x, v):
return equation(t, x, v, self.params)
model = Model()
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
for i in range(N_ITERATIONS):
optimizer.zero_grad()
y_pred = model(t, x, v)
loss = torch.mean((y_pred - outputs) ** 2)
loss.backward()
optimizer.step()
# Return evaluation score
if torch.isnan(loss).any() or torch.isinf(loss).any():
return None
else:
return -loss.item()
@equation.evolve
def equation(t: torch.Tensor, x: torch.Tensor, v: torch.Tensor, params: torch.nn.ParameterList) -> torch.Tensor:
""" Mathematical function for acceleration in a damped nonlinear oscillator
Args:
t (torch.Tensor): time.
x (torch.Tensor): observations of current position.
v (torch.Tensor): observations of velocity.
params (torch.nn.ParameterList): List of numeric constants or parameters to be optimized
Return:
torch.Tensor: acceleration as the result of applying the mathematical function to the inputs.
"""
dv = params[0] * t + params[1] * x + params[2] * v + params[3]
return dv