Skip to content

Commit

Permalink
Merge pull request #1 from tkipf/clean
Browse files Browse the repository at this point in the history
Cleanup
  • Loading branch information
ethanfetaya authored Mar 9, 2018
2 parents f31ad4b + 04d2feb commit 94b0d47
Show file tree
Hide file tree
Showing 10 changed files with 125 additions and 269 deletions.
2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
MIT License

Copyright (c) 2018 ethanfetaya
Copyright (c) 2018 Ethan Fetaya, Thomas Kipf

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
Expand Down
55 changes: 50 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,12 +1,57 @@
# NRI
Neural relational inference for interacting systems - pytorch
# Neural relational inference for interacting systems

This repository contains the official PyTorch implementation of:

**Neural relational inference for interacting systems.**
Thomas Kipf*, Ethan Fetaya*, Kuan-Chieh Wang, Max Welling, Richard Zemel.
https://arxiv.org/abs/1802.04687 (*: equal contribution)˚

![Neural Relational Inference (NRI)](nri.png)

**Abstract:** Interacting systems are prevalent in nature, from dynamical systems in physics to complex societal dynamics. The interplay of components can give rise to complex behavior, which can often be explained using a simple model of the system's constituent parts. In this work, we introduce the neural relational inference (NRI) model: an unsupervised model that learns to infer interactions while simultaneously learning the dynamics purely from observational data. Our model takes the form of a variational auto-encoder, in which the latent code represents the underlying interaction graph and the reconstruction is based on graph neural networks. In experiments on simulated physical systems, we show that our NRI model can accurately recover ground-truth interactions in an unsupervised manner. We further demonstrate that we can find an interpretable structure and predict complex dynamics in real motion capture and sports tracking data.

### Data generation

cd data
To replicate the experiments on simulated physical data, first generate training, validation and test data by running:

```
cd data
python generate_dataset.py
```
This generates the springs dataset, use `--simulation charged` for charged particles.

### Run experiments

From the project's root folder, simply run
```
python train.py
```
to train a Neural Relational Inference (NRI) model on the springs dataset. You can specify a different dataset by modifying the `suffix` argument: `--suffix charged5` will run the model on the charged particle simulation with 5 particles (if it has been generated).

To train the encoder or decoder separately, run

```
python train_enc.py
```
or

```
python train_dec.py
```
respectively. We provide a number of training options which are documented in the respective training files.

cd ..
Additionally, we provide code for an LSTM baseline (denoted *LSTM (joint)* in the paper), which you can run as follows:
```
python lstm_baseline.py
```

That generates the springs dataset, use '--simulation charged' for charged particles.
### Cite
If you make use of this code in your own work, please cite our paper:
```
@article{kipf2018neural,
title={Neural Relational Inference for Interacting Systems},
author={Kipf, Thomas and Fetaya, Ethan and Wang, Kuan-Chieh and Welling, Max and Zemel, Richard},
journal={arXiv preprint arXiv:1802.04687},
year={2018}
}
```
Empty file added data/__init__.py
Empty file.
12 changes: 0 additions & 12 deletions data/synthetic_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@
import time


# np.random.seed(0)


class SpringSim(object):
def __init__(self, n_balls=5, box_size=5., loc_std=.5, vel_norm=.5,
interaction_strength=.1, noise_var=0.):
Expand Down Expand Up @@ -305,17 +302,8 @@ def sample_trajectory(self, T=10000, sample_freq=10,
for i in range(loc.shape[-1]):
plt.plot(loc[:, 0, i], loc[:, 1, i])
plt.plot(loc[0, 0, i], loc[0, 1, i], 'd')
# #plt.plot(vel_norm[:,i])
plt.figure()
energies = [sim._energy(loc[i, :, :], vel[i, :, :], edges) for i in
range(loc.shape[0])]
plt.plot(energies)
# mom = vel.sum(axis=2)
# mom_diff = (mom[1:,:]-mom[:-1,:]).sum(axis=1)
# plt.figure()
# plt.plot(mom_diff)
plt.show()

# np.save("loc.npy", loc)
# np.save("vel.npy", vel)
# np.save("edges.npy", edges)
51 changes: 7 additions & 44 deletions lstm_baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import torch.optim as optim
from torch.optim import lr_scheduler
from torch import autograd

from utils import *
from modules import *
Expand All @@ -20,15 +19,15 @@
parser.add_argument('--seed', type=int, default=42, help='Random seed.')
parser.add_argument('--epochs', type=int, default=500,
help='Number of epochs to train.')
parser.add_argument('--batch_size', type=int, default=128,
parser.add_argument('--batch-size', type=int, default=128,
help='Number of samples per batch.')
parser.add_argument('--lr', type=float, default=0.0005,
help='Initial learning rate.')
parser.add_argument('--hidden', type=int, default=256,
help='Number of hidden units.')
parser.add_argument('--num_atoms', type=int, default=5,
help='Number of atoms in simulation.')
parser.add_argument('--num_layers', type=int, default=2,
parser.add_argument('--num-layers', type=int, default=2,
help='Number of LSTM layers.')
parser.add_argument('--suffix', type=str, default='_springs',
help='Suffix for training data (e.g. "_charged".')
Expand Down Expand Up @@ -58,8 +57,6 @@
parser.add_argument('--var', type=float, default=5e-5,
help='Output variance.')

print("NOTE: For Kuramoto model, set variance to 0.01.")

args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()
print(args)
Expand Down Expand Up @@ -93,24 +90,8 @@
print("WARNING: No save_folder provided!" +
"Testing (within this script) will throw an error.")

if args.motion:
train_loader, valid_loader, test_loader = load_motion_data(args.batch_size,
args.suffix)
elif args.suffix == "_kuramoto5" or args.suffix == "_kuramoto10":
train_loader, valid_loader, test_loader = load_kuramoto_data(
args.batch_size,
args.suffix)
else:
train_loader, valid_loader, test_loader, loc_max, loc_min, vel_max, vel_min = load_data(
args.batch_size, args.suffix)


# data, relations = train_loader.__iter__().next()
# data, relations = data.cuda(), relations.cuda()
# data, relations = Variable(data), Variable(relations, requires_grad=True)
# logits = encoder(data, rel_rec, rel_send)
# edges = gumbel_softmax(logits, tau=args.temp, hard=False)
# np.save("data/motion_edges.npy", edges.data.cpu().numpy())
train_loader, valid_loader, test_loader, loc_max, loc_min, vel_max, vel_min = load_data(
args.batch_size, args.suffix)


class RecurrentBaseline(nn.Module):
Expand Down Expand Up @@ -146,7 +127,7 @@ def batch_norm(self, inputs):
def step(self, ins, hidden=None):
# Input shape: [num_sims, n_atoms, n_in]
x = F.relu(self.fc1_1(ins))
# x = F.dropout(x, self.dropout_prob, training=self.training)
x = F.dropout(x, self.dropout_prob, training=self.training)
x = F.relu(self.fc1_2(x))
x = x.view(ins.size(0), -1)
# [num_sims, n_atoms*n_hid]
Expand Down Expand Up @@ -205,7 +186,6 @@ def forward(self, inputs, prediction_steps, burn_in=False, burn_in_steps=1):
model.load_state_dict(torch.load(model_file))
args.save_folder = False


optimizer = optim.Adam(list(model.parameters()), lr=args.lr)
scheduler = lr_scheduler.StepLR(optimizer, step_size=args.lr_decay,
gamma=args.gamma)
Expand Down Expand Up @@ -233,8 +213,6 @@ def train(epoch, best_val_loss):
mse_baseline_val = []
mse_train = []
mse_val = []
mse_last_train = []
mse_last_val = []

model.train()
scheduler.step()
Expand All @@ -246,8 +224,6 @@ def train(epoch, best_val_loss):

optimizer.zero_grad()

# output = model(data, args.prediction_steps)

output = model(data, 100,
burn_in=True,
burn_in_steps=args.timesteps - args.prediction_steps)
Expand All @@ -256,15 +232,13 @@ def train(epoch, best_val_loss):
loss = nll_gaussian(output, target, args.var)

mse = F.mse_loss(output, target)
mse_last = F.mse_loss(output[:, :, -1, :], target[:, :, -1, :])
mse_baseline = F.mse_loss(data[:, :, :-1, :], data[:, :, 1:, :])

loss.backward()
optimizer.step()

loss_train.append(loss.data[0])
mse_train.append(mse.data[0])
mse_last_train.append(mse_last.data[0])
mse_baseline_train.append(mse_baseline.data[0])

model.eval()
Expand All @@ -281,22 +255,18 @@ def train(epoch, best_val_loss):
loss = nll_gaussian(output, target, args.var)

mse = F.mse_loss(output, target)
mse_last = F.mse_loss(output[:, :, -1, :], target[:, :, -1, :])
mse_baseline = F.mse_loss(data[:, :, :-1, :], data[:, :, 1:, :])

loss_val.append(loss.data[0])
mse_val.append(mse.data[0])
mse_last_val.append(mse_last.data[0])
mse_baseline_val.append(mse_baseline.data[0])

print('Epoch: {:04d}'.format(epoch),
'nll_train: {:.10f}'.format(np.mean(loss_train)),
'mse_train: {:.12f}'.format(np.mean(mse_train)),
# 'mse_last_train: {:.12f}'.format(np.mean(mse_last_train)),
'mse_baseline_train: {:.10f}'.format(np.mean(mse_baseline_train)),
'nll_val: {:.10f}'.format(np.mean(loss_val)),
'mse_val: {:.12f}'.format(np.mean(mse_val)),
# 'mse_last_val: {:.12f}'.format(np.mean(mse_last_val)),
'mse_baseline_val: {:.10f}'.format(np.mean(mse_baseline_val)),
'time: {:.4f}s'.format(time.time() - t))
if args.save_folder and np.mean(loss_val) < best_val_loss:
Expand All @@ -305,11 +275,9 @@ def train(epoch, best_val_loss):
print('Epoch: {:04d}'.format(epoch),
'nll_train: {:.10f}'.format(np.mean(loss_train)),
'mse_train: {:.12f}'.format(np.mean(mse_train)),
# 'mse_last_train: {:.12f}'.format(np.mean(mse_last_train)),
'mse_baseline_train: {:.10f}'.format(np.mean(mse_baseline_train)),
'nll_val: {:.10f}'.format(np.mean(loss_val)),
'mse_val: {:.12f}'.format(np.mean(mse_val)),
# 'mse_last_val: {:.12f}'.format(np.mean(mse_last_val)),
'mse_baseline_val: {:.10f}'.format(np.mean(mse_baseline_val)),
'time: {:.4f}s'.format(time.time() - t), file=log)
log.flush()
Expand All @@ -320,7 +288,6 @@ def test():
loss_test = []
mse_baseline_test = []
mse_test = []
mse_last_test = []
tot_mse = 0
tot_mse_baseline = 0
counter = 0
Expand All @@ -346,20 +313,18 @@ def test():
loss = nll_gaussian(output, target, args.var)

mse = F.mse_loss(output, target)
mse_last = F.mse_loss(output[:, :, -1, :], target[:, :, -1, :])
mse_baseline = F.mse_loss(ins_cut[:, :, :-1, :], ins_cut[:, :, 1:, :])

loss_test.append(loss.data[0])
mse_test.append(mse.data[0])
mse_last_test.append(mse_last.data[0])
mse_baseline_test.append(mse_baseline.data[0])

if args.motion or args.non_markov:
# RNN decoder evaluation setting

# For plotting purposes
output = model(inputs, 100, burn_in=True,
burn_in_steps=args.timesteps)
burn_in_steps=args.timesteps)

output = output[:, :, args.timesteps:, :]
target = inputs[:, :, -args.timesteps:, :]
Expand All @@ -380,7 +345,7 @@ def test():

# For plotting purposes
output = model(inputs, 100, burn_in=True,
burn_in_steps=args.timesteps)
burn_in_steps=args.timesteps)

output = output[:, :, args.timesteps:args.timesteps + 20, :]
target = inputs[:, :, args.timesteps + 1:args.timesteps + 21, :]
Expand Down Expand Up @@ -417,7 +382,6 @@ def test():
print('--------------------------------')
print('nll_test: {:.10f}'.format(np.mean(loss_test)),
'mse_test: {:.12f}'.format(np.mean(mse_test)),
# 'mse_last_test: {:.12f}'.format(np.mean(mse_last_test)),
'mse_baseline_test: {:.10f}'.format(np.mean(mse_baseline_test)))
print('MSE: {}'.format(mse_str))
print('MSE Baseline: {}'.format(mse_baseline_str))
Expand All @@ -427,7 +391,6 @@ def test():
print('--------------------------------', file=log)
print('nll_test: {:.10f}'.format(np.mean(loss_test)),
'mse_test: {:.12f}'.format(np.mean(mse_test)),
# 'mse_last_test: {:.12f}'.format(np.mean(mse_last_test)),
'mse_baseline_test: {:.10f}'.format(np.mean(mse_baseline_test)),
file=log)
print('MSE: {}'.format(mse_str), file=log)
Expand Down
Loading

0 comments on commit 94b0d47

Please sign in to comment.