From c6935c2e8ebed4d4e7e917e15ef81e5a21f3407d Mon Sep 17 00:00:00 2001 From: Pengzhi Gao Date: Tue, 1 Oct 2019 17:03:53 -0400 Subject: [PATCH] Resolve #218: Support soft ids as inputs for BERT/GPT2/RoBERTa/XLNet (#220) * Support soft_ids as inputs in BERT/GPT2/RoBERTa * Support soft_ids as inputs in XLNet * Improve type hints * Add an example to binary_adversarial_losses * Resolve comments --- examples/xlnet/xlnet_classification_main.py | 4 ++-- stubs/torch/__init__.pyi | 7 +++++- texar/torch/losses/adv_losses.py | 18 ++++++++++----- .../modules/classifiers/bert_classifier.py | 11 ++++++---- .../classifiers/bert_classifier_test.py | 18 +++++++++++++++ .../modules/classifiers/gpt2_classifier.py | 11 ++++++---- .../classifiers/gpt2_classifier_test.py | 17 ++++++++++++++ .../modules/classifiers/roberta_classifier.py | 11 ++++++---- .../classifiers/roberta_classifier_test.py | 18 +++++++++++++++ .../modules/classifiers/xlnet_classifier.py | 14 +++++++----- .../classifiers/xlnet_classifier_test.py | 19 ++++++++++++++++ texar/torch/modules/encoders/bert_encoder.py | 22 ++++++++++++++----- .../modules/encoders/bert_encoder_test.py | 19 ++++++++++++++++ texar/torch/modules/encoders/gpt2_encoder.py | 19 +++++++++++----- .../modules/encoders/gpt2_encoder_test.py | 15 +++++++++++++ .../torch/modules/encoders/roberta_encoder.py | 11 ++++++---- .../modules/encoders/roberta_encoder_test.py | 19 ++++++++++++++++ texar/torch/modules/encoders/xlnet_encoder.py | 20 +++++++++++++---- .../modules/encoders/xlnet_encoder_test.py | 16 ++++++++++++++ .../modules/regressors/xlnet_regressor.py | 14 +++++++----- .../regressors/xlnet_regressor_test.py | 13 +++++++++++ 21 files changed, 267 insertions(+), 49 deletions(-) diff --git a/examples/xlnet/xlnet_classification_main.py b/examples/xlnet/xlnet_classification_main.py index c9ff45f0c..516f7c16c 100644 --- a/examples/xlnet/xlnet_classification_main.py +++ b/examples/xlnet/xlnet_classification_main.py @@ -100,7 +100,7 @@ def construct_datasets(args) -> Dict[str, tx.data.RecordData]: class RegressorWrapper(tx.modules.XLNetRegressor): def forward(self, # type: ignore batch: tx.data.Batch) -> Dict[str, torch.Tensor]: - preds = super().forward(token_ids=batch.input_ids, + preds = super().forward(inputs=batch.input_ids, segment_ids=batch.segment_ids, input_mask=batch.input_mask) loss = (preds - batch.label_ids) ** 2 @@ -111,7 +111,7 @@ def forward(self, # type: ignore class ClassifierWrapper(tx.modules.XLNetClassifier): def forward(self, # type: ignore batch: tx.data.Batch) -> Dict[str, torch.Tensor]: - logits, preds = super().forward(token_ids=batch.input_ids, + logits, preds = super().forward(inputs=batch.input_ids, segment_ids=batch.segment_ids, input_mask=batch.input_mask) loss = F.cross_entropy(logits, batch.label_ids, reduction='none') diff --git a/stubs/torch/__init__.pyi b/stubs/torch/__init__.pyi index f449cdb42..8e47f7d19 100644 --- a/stubs/torch/__init__.pyi +++ b/stubs/torch/__init__.pyi @@ -2763,10 +2763,15 @@ def tanh(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: ... def tanh_(input: Tensor) -> Tensor: ... - +@overload def tensordot(input: Tensor, other: Tensor, dims_self: MaybeTuple[builtins.int], dims_other: MaybeTuple[builtins.int]) -> Tensor: ... +@overload +def tensordot(input: Tensor, other: Tensor, + dims: Union[builtins.int, Tuple[ + List[builtins.int], List[builtins.int]]]) -> Tensor: ... + def th_addmm(input: Tensor, mat1: Tensor, mat2: Tensor, *, beta: builtins.float = 1, alpha: builtins.float = 1, out: Optional[Tensor] = None) -> Tensor: ... diff --git a/texar/torch/losses/adv_losses.py b/texar/torch/losses/adv_losses.py index 8e5eab569..7e966ae74 100644 --- a/texar/torch/losses/adv_losses.py +++ b/texar/torch/losses/adv_losses.py @@ -27,11 +27,6 @@ ] -__all__ = [ - 'binary_adversarial_losses', -] - - def binary_adversarial_losses( real_data: torch.Tensor, fake_data: torch.Tensor, @@ -39,6 +34,19 @@ def binary_adversarial_losses( mode: str = "max_real") -> Tuple[torch.Tensor, torch.Tensor]: r"""Computes adversarial losses of real/fake binary discrimination game. + Example: + + .. code-block:: python + + # Using BERTClassifier as the discriminator, which can accept + # "soft" token ids for gradient backpropagation + discriminator = tx.modules.BERTClassifier('bert-base-uncased') + + G_loss, D_loss = tx.losses.binary_adversarial_losses( + real_data=real_token_ids, # [batch_size, max_time] + fake_data=fake_soft_token_ids, # [batch_size, max_time, vocab_size] + discriminator_fn=discriminator) + Args: real_data (Tensor or array): Real data of shape `[num_real_examples, ...]`. diff --git a/texar/torch/modules/classifiers/bert_classifier.py b/texar/torch/modules/classifiers/bert_classifier.py index 16ad69d59..5303e7259 100644 --- a/texar/torch/modules/classifiers/bert_classifier.py +++ b/texar/torch/modules/classifiers/bert_classifier.py @@ -14,7 +14,7 @@ """ BERT classifier. """ -from typing import Optional, Tuple +from typing import Optional, Tuple, Union import torch from torch import nn @@ -196,7 +196,7 @@ def default_hparams(): return hparams def forward(self, # type: ignore - inputs: torch.Tensor, + inputs: Union[torch.Tensor, torch.LongTensor], sequence_length: Optional[torch.LongTensor] = None, segment_ids: Optional[torch.LongTensor] = None) \ -> Tuple[torch.Tensor, torch.LongTensor]: @@ -206,8 +206,11 @@ def forward(self, # type: ignore :class:`~texar.torch.modules.BERTEncoder`. Args: - inputs: A 2D Tensor of shape `[batch_size, max_time]`, - containing the token ids of tokens in input sequences. + inputs: Either a **2D Tensor** of shape `[batch_size, max_time]`, + containing the ids of tokens in input sequences, or + a **3D Tensor** of shape `[batch_size, max_time, vocab_size]`, + containing soft token ids (i.e., weights or probabilities) + used to mix the embedding vectors. sequence_length (optional): A 1D Tensor of shape `[batch_size]`. Input tokens beyond respective sequence lengths are masked out automatically. diff --git a/texar/torch/modules/classifiers/bert_classifier_test.py b/texar/torch/modules/classifiers/bert_classifier_test.py index 0837926a0..2084a3791 100644 --- a/texar/torch/modules/classifiers/bert_classifier_test.py +++ b/texar/torch/modules/classifiers/bert_classifier_test.py @@ -162,6 +162,24 @@ def test_binary(self): self.assertEqual(logits.shape, torch.Size([self.batch_size])) self.assertEqual(preds.shape, torch.Size([self.batch_size])) + def test_soft_ids(self): + r"""Tests soft ids. + """ + inputs = torch.rand(self.batch_size, self.max_length, 30522) + + hparams = { + "pretrained_model_name": None, + "num_classes": 1, + "clas_strategy": "time_wise", + } + classifier = BERTClassifier(hparams=hparams) + logits, preds = classifier(inputs) + + self.assertEqual(logits.shape, torch.Size( + [self.batch_size, self.max_length])) + self.assertEqual(preds.shape, torch.Size( + [self.batch_size, self.max_length])) + if __name__ == "__main__": unittest.main() diff --git a/texar/torch/modules/classifiers/gpt2_classifier.py b/texar/torch/modules/classifiers/gpt2_classifier.py index 18c714f05..b6c022102 100644 --- a/texar/torch/modules/classifiers/gpt2_classifier.py +++ b/texar/torch/modules/classifiers/gpt2_classifier.py @@ -14,7 +14,7 @@ """ GPT2 classifiers. """ -from typing import Optional, Tuple +from typing import Optional, Tuple, Union import torch from torch import nn @@ -192,7 +192,7 @@ def default_hparams(): return hparams def forward(self, # type: ignore - inputs: torch.Tensor, + inputs: Union[torch.Tensor, torch.LongTensor], sequence_length: Optional[torch.LongTensor] = None) \ -> Tuple[torch.Tensor, torch.LongTensor]: r"""Feeds the inputs through the network and makes classification. @@ -201,8 +201,11 @@ def forward(self, # type: ignore :class:`~texar.torch.modules.GPT2Encoder`. Args: - inputs: A 2D Tensor of shape `[batch_size, max_time]`, - containing the token ids of tokens in input sequences. + inputs: Either a **2D Tensor** of shape `[batch_size, max_time]`, + containing the ids of tokens in input sequences, or + a **3D Tensor** of shape `[batch_size, max_time, vocab_size]`, + containing soft token ids (i.e., weights or probabilities) + used to mix the embedding vectors. sequence_length (optional): A 1D Tensor of shape `[batch_size]`. Input tokens beyond respective sequence lengths are masked out automatically. diff --git a/texar/torch/modules/classifiers/gpt2_classifier_test.py b/texar/torch/modules/classifiers/gpt2_classifier_test.py index 7071df0fc..b7e465e0e 100644 --- a/texar/torch/modules/classifiers/gpt2_classifier_test.py +++ b/texar/torch/modules/classifiers/gpt2_classifier_test.py @@ -162,6 +162,23 @@ def test_binary(self): self.assertEqual(logits.shape, torch.Size([self.batch_size])) self.assertEqual(preds.shape, torch.Size([self.batch_size])) + def test_soft_ids(self): + r"""Tests soft ids. + """ + inputs = torch.rand(self.batch_size, self.max_length, 50257) + hparams = { + "pretrained_model_name": None, + "num_classes": 1, + "clas_strategy": "time_wise", + } + classifier = GPT2Classifier(hparams=hparams) + logits, preds = classifier(inputs) + + self.assertEqual(logits.shape, torch.Size( + [self.batch_size, self.max_length])) + self.assertEqual(preds.shape, torch.Size( + [self.batch_size, self.max_length])) + if __name__ == "__main__": unittest.main() diff --git a/texar/torch/modules/classifiers/roberta_classifier.py b/texar/torch/modules/classifiers/roberta_classifier.py index d32a1cc25..2375e9b51 100644 --- a/texar/torch/modules/classifiers/roberta_classifier.py +++ b/texar/torch/modules/classifiers/roberta_classifier.py @@ -14,7 +14,7 @@ """ RoBERTa classifier. """ -from typing import Optional, Tuple +from typing import Optional, Tuple, Union import torch @@ -136,7 +136,7 @@ def default_hparams(): return hparams def forward(self, # type: ignore - inputs: torch.Tensor, + inputs: Union[torch.Tensor, torch.LongTensor], sequence_length: Optional[torch.LongTensor] = None) \ -> Tuple[torch.Tensor, torch.LongTensor]: r"""Feeds the inputs through the network and makes classification. @@ -145,8 +145,11 @@ def forward(self, # type: ignore :class:`~texar.torch.modules.RoBERTaEncoder`. Args: - inputs: A 2D Tensor of shape `[batch_size, max_time]`, - containing the token ids of tokens in input sequences. + inputs: Either a **2D Tensor** of shape `[batch_size, max_time]`, + containing the ids of tokens in input sequences, or + a **3D Tensor** of shape `[batch_size, max_time, vocab_size]`, + containing soft token ids (i.e., weights or probabilities) + used to mix the embedding vectors. sequence_length (optional): A 1D Tensor of shape `[batch_size]`. Input tokens beyond respective sequence lengths are masked out automatically. diff --git a/texar/torch/modules/classifiers/roberta_classifier_test.py b/texar/torch/modules/classifiers/roberta_classifier_test.py index bb897a2ea..3fe0a1a47 100644 --- a/texar/torch/modules/classifiers/roberta_classifier_test.py +++ b/texar/torch/modules/classifiers/roberta_classifier_test.py @@ -163,6 +163,24 @@ def test_binary(self): self.assertEqual(logits.shape, torch.Size([self.batch_size])) self.assertEqual(preds.shape, torch.Size([self.batch_size])) + def test_soft_ids(self): + r"""Tests soft ids. + """ + inputs = torch.rand(self.batch_size, self.max_length, 50265) + + hparams = { + "pretrained_model_name": None, + "num_classes": 1, + "clas_strategy": "time_wise", + } + classifier = RoBERTaClassifier(hparams=hparams) + logits, preds = classifier(inputs) + + self.assertEqual(logits.shape, torch.Size( + [self.batch_size, self.max_length])) + self.assertEqual(preds.shape, torch.Size( + [self.batch_size, self.max_length])) + if __name__ == "__main__": unittest.main() diff --git a/texar/torch/modules/classifiers/xlnet_classifier.py b/texar/torch/modules/classifiers/xlnet_classifier.py index 4ab74dccc..e17e7d465 100644 --- a/texar/torch/modules/classifiers/xlnet_classifier.py +++ b/texar/torch/modules/classifiers/xlnet_classifier.py @@ -15,7 +15,7 @@ XLNet Classifier. """ -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Optional, Tuple, Union import torch from torch import nn @@ -243,14 +243,18 @@ def param_groups(self, return self.parameters() def forward(self, # type: ignore - token_ids: torch.LongTensor, + inputs: Union[torch.Tensor, torch.LongTensor], segment_ids: Optional[torch.LongTensor] = None, input_mask: Optional[torch.Tensor] = None) \ -> Tuple[torch.Tensor, torch.LongTensor]: r"""Feeds the inputs through the network and makes classification. Args: - token_ids: Shape `[batch_size, max_time]`. + inputs: Either a **2D Tensor** of shape `[batch_size, max_time]`, + containing the ids of tokens in input sequences, or + a **3D Tensor** of shape `[batch_size, max_time, vocab_size]`, + containing soft token ids (i.e., weights or probabilities) + used to mix the embedding vectors. segment_ids: Shape `[batch_size, max_time]`. input_mask: Float tensor of shape `[batch_size, max_time]`. Note that positions with value 1 are masked out. @@ -276,7 +280,7 @@ def forward(self, # type: ignore shape ``[batch_size, max_time]``. """ # output: [batch_size, seq_len, hidden_dim] - output, _ = self._encoder(token_ids=token_ids, + output, _ = self._encoder(inputs=inputs, segment_ids=segment_ids, input_mask=input_mask) @@ -286,7 +290,7 @@ def forward(self, # type: ignore elif strategy == 'cls_time': summary = output[:, -1] elif strategy == 'all_time': - length_diff = self._hparams.max_seq_length - token_ids.shape[1] + length_diff = self._hparams.max_seq_length - inputs.shape[1] summary_input = F.pad(output, [0, 0, 0, length_diff, 0, 0]) summary_input_dim = (self._encoder.output_size * self._hparams.max_seq_length) diff --git a/texar/torch/modules/classifiers/xlnet_classifier_test.py b/texar/torch/modules/classifiers/xlnet_classifier_test.py index 9995c117c..8a9aacc26 100644 --- a/texar/torch/modules/classifiers/xlnet_classifier_test.py +++ b/texar/torch/modules/classifiers/xlnet_classifier_test.py @@ -171,6 +171,25 @@ def test_binary(self): self.assertEqual(logits.shape, torch.Size([self.batch_size])) self.assertEqual(preds.shape, torch.Size([self.batch_size])) + def test_soft_ids(self): + r"""Tests soft ids. + """ + inputs = torch.rand(self.batch_size, self.max_length, 32000) + + # case 1 + hparams = { + "pretrained_model_name": None, + "num_classes": 1, + "clas_strategy": "time_wise", + } + classifier = XLNetClassifier(hparams=hparams) + logits, preds = classifier(inputs) + + self.assertEqual(logits.shape, torch.Size( + [self.batch_size, self.max_length])) + self.assertEqual(preds.shape, torch.Size( + [self.batch_size, self.max_length])) + if __name__ == "__main__": unittest.main() diff --git a/texar/torch/modules/encoders/bert_encoder.py b/texar/torch/modules/encoders/bert_encoder.py index fd4ce890d..8a830767e 100644 --- a/texar/torch/modules/encoders/bert_encoder.py +++ b/texar/torch/modules/encoders/bert_encoder.py @@ -15,7 +15,7 @@ BERT encoder. """ -from typing import Optional +from typing import Optional, Union import torch from torch import nn @@ -281,14 +281,17 @@ def default_hparams(): } def forward(self, # type: ignore - inputs: torch.Tensor, + inputs: Union[torch.Tensor, torch.LongTensor], sequence_length: Optional[torch.LongTensor] = None, segment_ids: Optional[torch.LongTensor] = None): r"""Encodes the inputs. Args: - inputs: A 2D Tensor of shape `[batch_size, max_time]`, - containing the token ids of tokens in the input sequences. + inputs: Either a **2D Tensor** of shape `[batch_size, max_time]`, + containing the ids of tokens in input sequences, or + a **3D Tensor** of shape `[batch_size, max_time, vocab_size]`, + containing soft token ids (i.e., weights or probabilities) + used to mix the embedding vectors. segment_ids (optional): A 2D Tensor of shape `[batch_size, max_time]`, containing the segment ids of tokens in input sequences. If `None` (default), a @@ -308,8 +311,13 @@ def forward(self, # type: ignore pre-trained on top of the hidden state associated to the first character of the input (`CLS`), see BERT's paper. """ + if inputs.dim() == 2: + word_embeds = self.word_embedder(ids=inputs) + elif inputs.dim() == 3: + word_embeds = self.word_embedder(soft_ids=inputs) + else: + raise ValueError("'inputs' should be a 2D or 3D tensor.") - word_embeds = self.word_embedder(inputs) batch_size = inputs.size(0) pos_length = inputs.new_full((batch_size,), inputs.size(1), dtype=torch.int64) @@ -317,7 +325,9 @@ def forward(self, # type: ignore if self.segment_embedder is not None: if segment_ids is None: - segment_ids = torch.zeros_like(inputs) + segment_ids = torch.zeros((inputs.size(0), inputs.size(1)), + dtype=torch.long, + device=inputs.device) segment_embeds = self.segment_embedder(segment_ids) inputs_embeds = word_embeds + segment_embeds + pos_embeds else: diff --git a/texar/torch/modules/encoders/bert_encoder_test.py b/texar/torch/modules/encoders/bert_encoder_test.py index 4fd39c41e..fa9479dfc 100644 --- a/texar/torch/modules/encoders/bert_encoder_test.py +++ b/texar/torch/modules/encoders/bert_encoder_test.py @@ -171,6 +171,25 @@ def test_encode(self): pooled_output.shape, torch.Size([self.batch_size, encoder.output_size])) + def test_soft_ids(self): + r"""Tests soft ids. + """ + hparams = { + "pretrained_model_name": None, + } + encoder = BERTEncoder(hparams=hparams) + + inputs = torch.rand(self.batch_size, self.max_length, 30522) + outputs, pooled_output = encoder(inputs) + + outputs_dim = encoder.hparams.encoder.dim + self.assertEqual( + outputs.shape, + torch.Size([self.batch_size, self.max_length, outputs_dim])) + self.assertEqual( + pooled_output.shape, + torch.Size([self.batch_size, encoder.output_size])) + if __name__ == "__main__": unittest.main() diff --git a/texar/torch/modules/encoders/gpt2_encoder.py b/texar/torch/modules/encoders/gpt2_encoder.py index d3045ddf0..003c5540f 100644 --- a/texar/torch/modules/encoders/gpt2_encoder.py +++ b/texar/torch/modules/encoders/gpt2_encoder.py @@ -15,7 +15,7 @@ GPT2 encoders. """ -from typing import Optional +from typing import Optional, Union import torch @@ -260,13 +260,16 @@ def default_hparams(): } def forward(self, # type: ignore - inputs: torch.Tensor, + inputs: Union[torch.Tensor, torch.LongTensor], sequence_length: Optional[torch.LongTensor] = None): r"""Encodes the inputs. Args: - inputs: A 2D Tensor of shape `[batch_size, max_time]`, - containing the token ids of tokens in the input sequences. + inputs: Either a **2D Tensor** of shape `[batch_size, max_time]`, + containing the ids of tokens in input sequences, or + a **3D Tensor** of shape `[batch_size, max_time, vocab_size]`, + containing soft token ids (i.e., weights or probabilities) + used to mix the embedding vectors. sequence_length (optional): A 1D Tensor of shape `[batch_size]`. Input tokens beyond respective sequence lengths are masked out automatically. @@ -275,7 +278,13 @@ def forward(self, # type: ignore outputs: A Tensor of shape `[batch_size, max_time, dim]` containing the encoded vectors. """ - word_embeds = self.word_embedder(inputs) + if inputs.dim() == 2: + word_embeds = self.word_embedder(ids=inputs) + elif inputs.dim() == 3: + word_embeds = self.word_embedder(soft_ids=inputs) + else: + raise ValueError("'inputs' should be a 2D or 3D tensor.") + batch_size = inputs.size(0) pos_length = inputs.new_full( (batch_size,), inputs.size(1), dtype=torch.long) diff --git a/texar/torch/modules/encoders/gpt2_encoder_test.py b/texar/torch/modules/encoders/gpt2_encoder_test.py index 0e8687f51..61b1550fb 100644 --- a/texar/torch/modules/encoders/gpt2_encoder_test.py +++ b/texar/torch/modules/encoders/gpt2_encoder_test.py @@ -160,6 +160,21 @@ def test_encode(self): outputs.shape, torch.Size([self.batch_size, self.max_length, encoder.output_size])) + def test_soft_ids(self): + r"""Tests soft ids. + """ + hparams = { + "pretrained_model_name": None, + } + encoder = GPT2Encoder(hparams=hparams) + + inputs = torch.rand(self.batch_size, self.max_length, 50257) + outputs = encoder(inputs) + + self.assertEqual( + outputs.shape, + torch.Size([self.batch_size, self.max_length, encoder.output_size])) + if __name__ == "__main__": unittest.main() diff --git a/texar/torch/modules/encoders/roberta_encoder.py b/texar/torch/modules/encoders/roberta_encoder.py index 6f637b233..f98394963 100644 --- a/texar/torch/modules/encoders/roberta_encoder.py +++ b/texar/torch/modules/encoders/roberta_encoder.py @@ -15,7 +15,7 @@ RoBERTa encoder. """ -from typing import Optional +from typing import Optional, Union import torch @@ -219,7 +219,7 @@ def default_hparams(): } def forward(self, # type: ignore - inputs: torch.Tensor, + inputs: Union[torch.Tensor, torch.LongTensor], sequence_length: Optional[torch.LongTensor] = None, segment_ids: Optional[torch.LongTensor] = None): r"""Encodes the inputs. Differing from the standard BERT, the RoBERTa @@ -227,8 +227,11 @@ def forward(self, # type: ignore require `segment_ids` as an input. Args: - inputs: A 2D Tensor of shape `[batch_size, max_time]`, - containing the token ids of tokens in the input sequences. + inputs: Either a **2D Tensor** of shape `[batch_size, max_time]`, + containing the ids of tokens in input sequences, or + a **3D Tensor** of shape `[batch_size, max_time, vocab_size]`, + containing soft token ids (i.e., weights or probabilities) + used to mix the embedding vectors. sequence_length (optional): A 1D Tensor of shape `[batch_size]`. Input tokens beyond respective sequence lengths are masked out automatically. diff --git a/texar/torch/modules/encoders/roberta_encoder_test.py b/texar/torch/modules/encoders/roberta_encoder_test.py index 8acf99196..843c333a7 100644 --- a/texar/torch/modules/encoders/roberta_encoder_test.py +++ b/texar/torch/modules/encoders/roberta_encoder_test.py @@ -169,6 +169,25 @@ def test_encode(self): pooled_output.shape, torch.Size([self.batch_size, encoder.output_size])) + def test_soft_ids(self): + r"""Tests soft ids. + """ + hparams = { + "pretrained_model_name": None, + } + encoder = RoBERTaEncoder(hparams=hparams) + + inputs = torch.rand(self.batch_size, self.max_length, 50265) + outputs, pooled_output = encoder(inputs) + + outputs_dim = encoder.hparams.encoder.dim + self.assertEqual( + outputs.shape, + torch.Size([self.batch_size, self.max_length, outputs_dim])) + self.assertEqual( + pooled_output.shape, + torch.Size([self.batch_size, encoder.output_size])) + if __name__ == "__main__": unittest.main() diff --git a/texar/torch/modules/encoders/xlnet_encoder.py b/texar/torch/modules/encoders/xlnet_encoder.py index 7fc6c4d8a..44a4fc1b3 100644 --- a/texar/torch/modules/encoders/xlnet_encoder.py +++ b/texar/torch/modules/encoders/xlnet_encoder.py @@ -15,7 +15,7 @@ XLNet encoder. """ -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union import torch from torch import nn @@ -319,7 +319,7 @@ def _create_causal_attn_mask(self, return ret def forward(self, # type: ignore - token_ids: torch.LongTensor, + inputs: Union[torch.Tensor, torch.LongTensor], segment_ids: Optional[torch.LongTensor] = None, input_mask: Optional[torch.Tensor] = None, memory: Optional[List[torch.Tensor]] = None, @@ -335,7 +335,11 @@ def forward(self, # type: ignore r"""Compute XLNet representations for the input. Args: - token_ids: Shape `[batch_size, max_time]`. + inputs: Either a **2D Tensor** of shape `[batch_size, max_time]`, + containing the ids of tokens in input sequences, or + a **3D Tensor** of shape `[batch_size, max_time, vocab_size]`, + containing soft token ids (i.e., weights or probabilities) + used to mix the embedding vectors. segment_ids: Shape `[batch_size, max_time]`. input_mask: Float tensor of shape `[batch_size, max_time]`. Note that positions with value 1 are masked out. @@ -374,7 +378,15 @@ def forward(self, # type: ignore `[batch_size, cache_len, hidden_dim]`. This can be used as the :attr:`memory` argument in the next batch. """ - return self._forward(self.word_embed(token_ids), + if inputs.dim() == 2: + word_embeds = self.word_embed(inputs) + elif inputs.dim() == 3: + word_embeds = torch.tensordot(inputs, self.word_embed.weight, + dims=([-1], [0])) + else: + raise ValueError("'inputs' should be a 2D or 3D tensor.") + + return self._forward(word_embed=word_embeds, segment_ids=segment_ids, input_mask=input_mask, memory=memory, diff --git a/texar/torch/modules/encoders/xlnet_encoder_test.py b/texar/torch/modules/encoders/xlnet_encoder_test.py index 418a7d72d..a12a9b7d8 100644 --- a/texar/torch/modules/encoders/xlnet_encoder_test.py +++ b/texar/torch/modules/encoders/xlnet_encoder_test.py @@ -134,6 +134,22 @@ def test_encode(self): torch.Size([self.batch_size, self.max_length, encoder.output_size])) self.assertEqual(new_memory, None) + def test_soft_ids(self): + r"""Tests soft ids. + """ + hparams = { + "pretrained_model_name": None, + } + encoder = XLNetEncoder(hparams=hparams) + + inputs = torch.rand(self.batch_size, self.max_length, 32000) + outputs, new_memory = encoder(inputs) + + self.assertEqual( + outputs.shape, + torch.Size([self.batch_size, self.max_length, encoder.output_size])) + self.assertEqual(new_memory, None) + if __name__ == "__main__": unittest.main() diff --git a/texar/torch/modules/regressors/xlnet_regressor.py b/texar/torch/modules/regressors/xlnet_regressor.py index 110faf403..376714351 100644 --- a/texar/torch/modules/regressors/xlnet_regressor.py +++ b/texar/torch/modules/regressors/xlnet_regressor.py @@ -15,7 +15,7 @@ XLNet Regressors. """ -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Union import torch from torch import nn @@ -224,13 +224,17 @@ def param_groups(self, return self.parameters() def forward(self, # type: ignore - token_ids: torch.LongTensor, + inputs: Union[torch.Tensor, torch.LongTensor], segment_ids: Optional[torch.LongTensor] = None, input_mask: Optional[torch.Tensor] = None) -> torch.Tensor: r"""Feeds the inputs through the network and makes regression. Args: - token_ids: Shape `[batch_size, max_time]`. + inputs: Either a **2D Tensor** of shape `[batch_size, max_time]`, + containing the ids of tokens in input sequences, or + a **3D Tensor** of shape `[batch_size, max_time, vocab_size]`, + containing soft token ids (i.e., weights or probabilities) + used to mix the embedding vectors. segment_ids: Shape `[batch_size, max_time]`. input_mask: Float tensor of shape `[batch_size, max_time]`. Note that positions with value 1 are masked out. @@ -245,7 +249,7 @@ def forward(self, # type: ignore `[batch_size, max_time]`. """ # output: [batch_size, seq_len, hidden_dim] - output, _ = self._encoder(token_ids=token_ids, + output, _ = self._encoder(inputs=inputs, segment_ids=segment_ids, input_mask=input_mask) @@ -255,7 +259,7 @@ def forward(self, # type: ignore elif strategy == 'cls_time': summary = output[:, -1] elif strategy == 'all_time': - length_diff = self._hparams.max_seq_length - token_ids.shape[1] + length_diff = self._hparams.max_seq_length - inputs.shape[1] summary_input = F.pad(output, [0, 0, 0, length_diff, 0, 0]) summary_input_dim = (self._encoder.output_size * self._hparams.max_seq_length) diff --git a/texar/torch/modules/regressors/xlnet_regressor_test.py b/texar/torch/modules/regressors/xlnet_regressor_test.py index 77fe36af7..3125d5692 100644 --- a/texar/torch/modules/regressors/xlnet_regressor_test.py +++ b/texar/torch/modules/regressors/xlnet_regressor_test.py @@ -104,6 +104,19 @@ def test_regression(self): self.assertEqual(preds.shape, torch.Size( [self.batch_size, self.max_length])) + def test_soft_ids(self): + r"""Tests soft ids. + """ + inputs = torch.rand(self.batch_size, self.max_length, 32000) + + # case 1 + hparams = { + "pretrained_model_name": None, + } + regressor = XLNetRegressor(hparams=hparams) + preds = regressor(inputs) + self.assertEqual(preds.shape, torch.Size([self.batch_size])) + if __name__ == "__main__": unittest.main()