Skip to content

Commit

Permalink
Merge pull request #3 from tkipf/clean
Browse files Browse the repository at this point in the history
Fix missing argument error
  • Loading branch information
ethanfetaya authored Mar 12, 2018
2 parents 94b0d47 + e1089bb commit 9a8f3ff
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 10 deletions.
7 changes: 4 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,8 @@ def train(epoch, best_val_loss):
if args.prior:
loss_kl = kl_categorical(prob, log_prior, args.num_atoms)
else:
loss_kl = kl_categorical_uniform(prob, args.num_atoms)
loss_kl = kl_categorical_uniform(prob, args.num_atoms,
args.edge_types)

loss = loss_nll + loss_kl

Expand Down Expand Up @@ -254,7 +255,7 @@ def train(epoch, best_val_loss):

target = data[:, :, 1:, :]
loss_nll = nll_gaussian(output, target, args.var)
loss_kl = kl_categorical_uniform(prob, args.num_atoms)
loss_kl = kl_categorical_uniform(prob, args.num_atoms, args.edge_types)

acc = edge_accuracy(logits, relations)
acc_val.append(acc)
Expand Down Expand Up @@ -323,7 +324,7 @@ def test():

target = data_decoder[:, :, 1:, :]
loss_nll = nll_gaussian(output, target, args.var)
loss_kl = kl_categorical_uniform(prob, args.num_atoms)
loss_kl = kl_categorical_uniform(prob, args.num_atoms, args.edge_types)

acc = edge_accuracy(logits, relations)
acc_test.append(acc)
Expand Down
14 changes: 7 additions & 7 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,11 @@ def load_data(batch_size=1, suffix=''):
loc_train = np.load('data/loc_train' + suffix + '.npy')
vel_train = np.load('data/vel_train' + suffix + '.npy')
edges_train = np.load('data/edges_train' + suffix + '.npy')

loc_valid = np.load('data/loc_valid' + suffix + '.npy')
vel_valid = np.load('data/vel_valid' + suffix + '.npy')
edges_valid = np.load('data/edges_valid' + suffix + '.npy')

loc_test = np.load('data/loc_test' + suffix + '.npy')
vel_test = np.load('data/vel_test' + suffix + '.npy')
edges_test = np.load('data/edges_test' + suffix + '.npy')
Expand Down Expand Up @@ -152,7 +152,7 @@ def load_data(batch_size=1, suffix=''):
feat_train = np.concatenate([loc_train, vel_train], axis=3)
edges_train = np.reshape(edges_train, [-1, num_atoms ** 2])
edges_train = np.array((edges_train + 1) / 2, dtype=np.int64)

loc_valid = np.transpose(loc_valid, [0, 3, 1, 2])
vel_valid = np.transpose(vel_valid, [0, 3, 1, 2])
feat_valid = np.concatenate([loc_valid, vel_valid], axis=3)
Expand Down Expand Up @@ -205,7 +205,7 @@ def load_kuramoto_data(batch_size=1, suffix=''):
# Normalize each feature dim. individually
feat_max = feat_train.max(0).max(0).max(0)
feat_min = feat_train.min(0).min(0).min(0)

feat_max = np.expand_dims(np.expand_dims(np.expand_dims(feat_max, 0), 0), 0)
feat_min = np.expand_dims(np.expand_dims(np.expand_dims(feat_min, 0), 0), 0)

Expand Down Expand Up @@ -453,10 +453,11 @@ def kl_categorical(preds, log_prior, num_atoms, eps=1e-16):
return kl_div.sum() / (num_atoms * preds.size(0))


def kl_categorical_uniform(preds, num_atoms, add_const=False, eps=1e-16):
def kl_categorical_uniform(preds, num_atoms, num_edge_types, add_const=False,
eps=1e-16):
kl_div = preds * torch.log(preds + eps)
if add_const:
const = np.log(args.edge_types)
const = np.log(num_edge_types)
kl_div += const
return kl_div.sum() / (num_atoms * preds.size(0))

Expand All @@ -474,4 +475,3 @@ def edge_accuracy(preds, target):
correct = preds.float().data.eq(
target.float().data.view_as(preds)).cpu().sum()
return np.float(correct) / (target.size(0) * target.size(1))

0 comments on commit 9a8f3ff

Please sign in to comment.