Skip to content

Commit

Permalink
Updates
Browse files Browse the repository at this point in the history
  • Loading branch information
magicat128 committed Sep 15, 2020
1 parent d78ce5f commit c2492c2
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,9 @@ def evaluate(features, support, mask, labels, placeholders):
sess.run(tf.global_variables_initializer())

cost_val = []
best_acc = 0
best_val = 0
best_epoch = 0
best_acc = 0
best_cost = 0
test_doc_embeddings = None
preds = None
Expand Down Expand Up @@ -136,9 +137,11 @@ def evaluate(features, support, mask, labels, placeholders):

# Test
test_cost, test_acc, test_duration, embeddings, pred, labels = evaluate(test_feature, test_adj, test_mask, test_y, placeholders)
if test_acc > best_acc:
best_acc = test_acc

if val_acc >= best_val:
best_val = val_acc
best_epoch = epoch
best_acc = test_acc
best_cost = test_cost
test_doc_embeddings = embeddings
preds = pred
Expand All @@ -147,7 +150,7 @@ def evaluate(features, support, mask, labels, placeholders):
print("Epoch:", '%04d' % (epoch + 1), "train_loss=", "{:.5f}".format(train_loss),
"train_acc=", "{:.5f}".format(train_acc), "val_loss=", "{:.5f}".format(val_cost),
"val_acc=", "{:.5f}".format(val_acc), "test_acc=", "{:.5f}".format(test_acc),
"time=", "{:.5f}".format(time.time() - t),"best_acc=", "{:.5f}".format(best_acc))
"time=", "{:.5f}".format(time.time() - t))

if FLAGS.early_stopping > 0 and epoch > FLAGS.early_stopping and cost_val[-1] > np.mean(cost_val[-(FLAGS.early_stopping+1):-1]):
print("Early stopping...")
Expand Down

0 comments on commit c2492c2

Please sign in to comment.