From 565423dff083e76ae97ff78a8e2738f0a1303784 Mon Sep 17 00:00:00 2001 From: Brian DuSell Date: Thu, 9 Jun 2022 22:04:04 -0400 Subject: [PATCH] Ignore when computing cross-entropy loss and sequence accuracy. Fixes clay-lab/transductions #64 --- core/metrics/base_metric.py | 16 ++++++++++------ core/trainer.py | 35 +++++++++++++++++++++-------------- 2 files changed, 31 insertions(+), 20 deletions(-) diff --git a/core/metrics/base_metric.py b/core/metrics/base_metric.py index 33502ff..6dcbf9d 100644 --- a/core/metrics/base_metric.py +++ b/core/metrics/base_metric.py @@ -56,13 +56,18 @@ class SequenceAccuracy(BaseMetric): correct, the sequence scores 1.0; otherwise, it scores 0.0. """ + def __init__(self, pad_token_id=None): + super().__init__() + self._pad = pad_token_id + def compute(self, prediction: Tensor, target: Tensor): prediction = prediction.argmax(1) - - correct = (prediction == target).prod(axis=1) - total = correct.shape[0] - correct = correct.sum() - + correct_tokens = prediction == target + if self._pad is not None: + correct_tokens.logical_or_(target == self._pad) + correct_sequences = torch.all(correct_tokens, dim=1) + correct = correct_sequences.sum() + total = target.size(0) return correct, total @@ -91,7 +96,6 @@ def compute(self, prediction: Tensor, target: Tensor): correct = prediction[:, self.n] == target[:, self.n] total = correct.shape[0] correct = correct.sum() - return correct, total diff --git a/core/trainer.py b/core/trainer.py index a42cbd0..dfb3e2b 100644 --- a/core/trainer.py +++ b/core/trainer.py @@ -169,12 +169,14 @@ def train(self): early_stoping = EarlyStopping(self._cfg.experiment.hyperparameters) + pad_idx = self._model._decoder.vocab.stoi[""] + # Metrics - seq_acc = SequenceAccuracy() - tok_acc = TokenAccuracy(self._dataset.target_field.vocab.stoi[""]) - len_acc = LengthAccuracy(self._dataset.target_field.vocab.stoi[""]) + seq_acc = SequenceAccuracy(pad_idx) + tok_acc = TokenAccuracy(pad_idx) + len_acc = LengthAccuracy(pad_idx) first_acc = NthTokenAccuracy(n=1) - avg_loss = LossMetric(F.cross_entropy) + avg_loss = LossMetric(lambda p, t: F.cross_entropy(p, t, ignore_index=pad_idx)) meter = Meter([seq_acc, tok_acc, len_acc, first_acc, avg_loss]) @@ -226,7 +228,7 @@ def train(self): meter(output, target) # Compute average validation loss - val_loss = F.cross_entropy(output, target) + val_loss = F.cross_entropy(output, target, ignore_index=pad_idx) V.set_postfix(val_loss="{:4.3f}".format(val_loss.item())) meter.log(stage="val", step=epoch) @@ -288,12 +290,14 @@ def eval(self, eval_cfg: DictConfig): # Load checkpoint data self._load_checkpoint(eval_cfg.checkpoint_dir) + pad_idx = self._dataset.target_field.vocab.stoi[""] + # Create meter - seq_acc = SequenceAccuracy() - tok_acc = TokenAccuracy(self._dataset.target_field.vocab.stoi[""]) - len_acc = LengthAccuracy(self._dataset.target_field.vocab.stoi[""]) + seq_acc = SequenceAccuracy(pad_idx) + tok_acc = TokenAccuracy(pad_idx) + len_acc = LengthAccuracy(pad_idx) first_acc = NthTokenAccuracy(n=1) - avg_loss = LossMetric(F.cross_entropy) + avg_loss = LossMetric(lambda p, t: F.cross_entropy(p, t, ignore_index=pad_idx)) meter = Meter([seq_acc, tok_acc, len_acc, first_acc, avg_loss]) @@ -339,9 +343,11 @@ def arith_eval(self, eval_cfg: DictConfig): # Load checkpoint data self._load_checkpoint(eval_cfg.checkpoint_dir) + pad_idx = self._dataset.target_field.vocab.stoi[""] + # Create meter - seq_acc = SequenceAccuracy() - len_acc = LengthAccuracy(self._dataset.target_field.vocab.stoi[""]) + seq_acc = SequenceAccuracy(pad_idx) + len_acc = LengthAccuracy(pad_idx) object_acc = NthTokenAccuracy(n=5) meter = Meter([seq_acc, len_acc, object_acc]) @@ -609,11 +615,12 @@ def fit_tpdn(self, tpdn_cfg: DictConfig): tpdn.eval() disp_loss = nn.CrossEntropyLoss() + pad_idx = self._dataset.target_field.vocab.stoi[""] meter = Meter( [ - SequenceAccuracy(), - TokenAccuracy(self._dataset.target_field.vocab.stoi[""]), - LengthAccuracy(self._dataset.target_field.vocab.stoi[""]), + SequenceAccuracy(pad_idx), + TokenAccuracy(pad_idx), + LengthAccuracy(pad_idx), NthTokenAccuracy(n=1), NthTokenAccuracy(n=3), NthTokenAccuracy(n=5),