diff --git a/train.py b/train.py index 1263fed..ba44aab 100644 --- a/train.py +++ b/train.py @@ -218,7 +218,8 @@ def train(epoch, best_val_loss): if args.prior: loss_kl = kl_categorical(prob, log_prior, args.num_atoms) else: - loss_kl = kl_categorical_uniform(prob, args.num_atoms) + loss_kl = kl_categorical_uniform(prob, args.num_atoms, + args.edge_types) loss = loss_nll + loss_kl @@ -254,7 +255,7 @@ def train(epoch, best_val_loss): target = data[:, :, 1:, :] loss_nll = nll_gaussian(output, target, args.var) - loss_kl = kl_categorical_uniform(prob, args.num_atoms) + loss_kl = kl_categorical_uniform(prob, args.num_atoms, args.edge_types) acc = edge_accuracy(logits, relations) acc_val.append(acc) @@ -323,7 +324,7 @@ def test(): target = data_decoder[:, :, 1:, :] loss_nll = nll_gaussian(output, target, args.var) - loss_kl = kl_categorical_uniform(prob, args.num_atoms) + loss_kl = kl_categorical_uniform(prob, args.num_atoms, args.edge_types) acc = edge_accuracy(logits, relations) acc_test.append(acc) diff --git a/utils.py b/utils.py index 5cb4a01..0d9fa5b 100644 --- a/utils.py +++ b/utils.py @@ -119,11 +119,11 @@ def load_data(batch_size=1, suffix=''): loc_train = np.load('data/loc_train' + suffix + '.npy') vel_train = np.load('data/vel_train' + suffix + '.npy') edges_train = np.load('data/edges_train' + suffix + '.npy') - + loc_valid = np.load('data/loc_valid' + suffix + '.npy') vel_valid = np.load('data/vel_valid' + suffix + '.npy') edges_valid = np.load('data/edges_valid' + suffix + '.npy') - + loc_test = np.load('data/loc_test' + suffix + '.npy') vel_test = np.load('data/vel_test' + suffix + '.npy') edges_test = np.load('data/edges_test' + suffix + '.npy') @@ -152,7 +152,7 @@ def load_data(batch_size=1, suffix=''): feat_train = np.concatenate([loc_train, vel_train], axis=3) edges_train = np.reshape(edges_train, [-1, num_atoms ** 2]) edges_train = np.array((edges_train + 1) / 2, dtype=np.int64) - + loc_valid = np.transpose(loc_valid, [0, 3, 1, 2]) vel_valid = np.transpose(vel_valid, [0, 3, 1, 2]) feat_valid = np.concatenate([loc_valid, vel_valid], axis=3) @@ -205,7 +205,7 @@ def load_kuramoto_data(batch_size=1, suffix=''): # Normalize each feature dim. individually feat_max = feat_train.max(0).max(0).max(0) feat_min = feat_train.min(0).min(0).min(0) - + feat_max = np.expand_dims(np.expand_dims(np.expand_dims(feat_max, 0), 0), 0) feat_min = np.expand_dims(np.expand_dims(np.expand_dims(feat_min, 0), 0), 0) @@ -453,10 +453,11 @@ def kl_categorical(preds, log_prior, num_atoms, eps=1e-16): return kl_div.sum() / (num_atoms * preds.size(0)) -def kl_categorical_uniform(preds, num_atoms, add_const=False, eps=1e-16): +def kl_categorical_uniform(preds, num_atoms, num_edge_types, add_const=False, + eps=1e-16): kl_div = preds * torch.log(preds + eps) if add_const: - const = np.log(args.edge_types) + const = np.log(num_edge_types) kl_div += const return kl_div.sum() / (num_atoms * preds.size(0)) @@ -474,4 +475,3 @@ def edge_accuracy(preds, target): correct = preds.float().data.eq( target.float().data.view_as(preds)).cpu().sum() return np.float(correct) / (target.size(0) * target.size(1)) -