Skip to content

Commit

Permalink
Resolve #218: Support soft ids as inputs for BERT/GPT2/RoBERTa/XLNet (#…
Browse files Browse the repository at this point in the history
…220)

* Support soft_ids as inputs in BERT/GPT2/RoBERTa

* Support soft_ids as inputs in XLNet

* Improve type hints

* Add an example to binary_adversarial_losses

* Resolve comments
  • Loading branch information
gpengzhi authored Oct 1, 2019
1 parent f14fa55 commit c6935c2
Show file tree
Hide file tree
Showing 21 changed files with 267 additions and 49 deletions.
4 changes: 2 additions & 2 deletions examples/xlnet/xlnet_classification_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def construct_datasets(args) -> Dict[str, tx.data.RecordData]:
class RegressorWrapper(tx.modules.XLNetRegressor):
def forward(self, # type: ignore
batch: tx.data.Batch) -> Dict[str, torch.Tensor]:
preds = super().forward(token_ids=batch.input_ids,
preds = super().forward(inputs=batch.input_ids,
segment_ids=batch.segment_ids,
input_mask=batch.input_mask)
loss = (preds - batch.label_ids) ** 2
Expand All @@ -111,7 +111,7 @@ def forward(self, # type: ignore
class ClassifierWrapper(tx.modules.XLNetClassifier):
def forward(self, # type: ignore
batch: tx.data.Batch) -> Dict[str, torch.Tensor]:
logits, preds = super().forward(token_ids=batch.input_ids,
logits, preds = super().forward(inputs=batch.input_ids,
segment_ids=batch.segment_ids,
input_mask=batch.input_mask)
loss = F.cross_entropy(logits, batch.label_ids, reduction='none')
Expand Down
7 changes: 6 additions & 1 deletion stubs/torch/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2763,10 +2763,15 @@ def tanh(input: Tensor, *, out: Optional[Tensor] = None) -> Tensor: ...

def tanh_(input: Tensor) -> Tensor: ...


@overload
def tensordot(input: Tensor, other: Tensor, dims_self: MaybeTuple[builtins.int],
dims_other: MaybeTuple[builtins.int]) -> Tensor: ...

@overload
def tensordot(input: Tensor, other: Tensor,
dims: Union[builtins.int, Tuple[
List[builtins.int], List[builtins.int]]]) -> Tensor: ...


def th_addmm(input: Tensor, mat1: Tensor, mat2: Tensor, *, beta: builtins.float = 1, alpha: builtins.float = 1,
out: Optional[Tensor] = None) -> Tensor: ...
Expand Down
18 changes: 13 additions & 5 deletions texar/torch/losses/adv_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,26 @@
]


__all__ = [
'binary_adversarial_losses',
]


def binary_adversarial_losses(
real_data: torch.Tensor,
fake_data: torch.Tensor,
discriminator_fn: Callable[[torch.Tensor], MaybeTuple[torch.Tensor]],
mode: str = "max_real") -> Tuple[torch.Tensor, torch.Tensor]:
r"""Computes adversarial losses of real/fake binary discrimination game.
Example:
.. code-block:: python
# Using BERTClassifier as the discriminator, which can accept
# "soft" token ids for gradient backpropagation
discriminator = tx.modules.BERTClassifier('bert-base-uncased')
G_loss, D_loss = tx.losses.binary_adversarial_losses(
real_data=real_token_ids, # [batch_size, max_time]
fake_data=fake_soft_token_ids, # [batch_size, max_time, vocab_size]
discriminator_fn=discriminator)
Args:
real_data (Tensor or array): Real data of shape
`[num_real_examples, ...]`.
Expand Down
11 changes: 7 additions & 4 deletions texar/torch/modules/classifiers/bert_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""
BERT classifier.
"""
from typing import Optional, Tuple
from typing import Optional, Tuple, Union

import torch
from torch import nn
Expand Down Expand Up @@ -196,7 +196,7 @@ def default_hparams():
return hparams

def forward(self, # type: ignore
inputs: torch.Tensor,
inputs: Union[torch.Tensor, torch.LongTensor],
sequence_length: Optional[torch.LongTensor] = None,
segment_ids: Optional[torch.LongTensor] = None) \
-> Tuple[torch.Tensor, torch.LongTensor]:
Expand All @@ -206,8 +206,11 @@ def forward(self, # type: ignore
:class:`~texar.torch.modules.BERTEncoder`.
Args:
inputs: A 2D Tensor of shape `[batch_size, max_time]`,
containing the token ids of tokens in input sequences.
inputs: Either a **2D Tensor** of shape `[batch_size, max_time]`,
containing the ids of tokens in input sequences, or
a **3D Tensor** of shape `[batch_size, max_time, vocab_size]`,
containing soft token ids (i.e., weights or probabilities)
used to mix the embedding vectors.
sequence_length (optional): A 1D Tensor of shape `[batch_size]`.
Input tokens beyond respective sequence lengths are masked
out automatically.
Expand Down
18 changes: 18 additions & 0 deletions texar/torch/modules/classifiers/bert_classifier_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,24 @@ def test_binary(self):
self.assertEqual(logits.shape, torch.Size([self.batch_size]))
self.assertEqual(preds.shape, torch.Size([self.batch_size]))

def test_soft_ids(self):
r"""Tests soft ids.
"""
inputs = torch.rand(self.batch_size, self.max_length, 30522)

hparams = {
"pretrained_model_name": None,
"num_classes": 1,
"clas_strategy": "time_wise",
}
classifier = BERTClassifier(hparams=hparams)
logits, preds = classifier(inputs)

self.assertEqual(logits.shape, torch.Size(
[self.batch_size, self.max_length]))
self.assertEqual(preds.shape, torch.Size(
[self.batch_size, self.max_length]))


if __name__ == "__main__":
unittest.main()
11 changes: 7 additions & 4 deletions texar/torch/modules/classifiers/gpt2_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""
GPT2 classifiers.
"""
from typing import Optional, Tuple
from typing import Optional, Tuple, Union

import torch
from torch import nn
Expand Down Expand Up @@ -192,7 +192,7 @@ def default_hparams():
return hparams

def forward(self, # type: ignore
inputs: torch.Tensor,
inputs: Union[torch.Tensor, torch.LongTensor],
sequence_length: Optional[torch.LongTensor] = None) \
-> Tuple[torch.Tensor, torch.LongTensor]:
r"""Feeds the inputs through the network and makes classification.
Expand All @@ -201,8 +201,11 @@ def forward(self, # type: ignore
:class:`~texar.torch.modules.GPT2Encoder`.
Args:
inputs: A 2D Tensor of shape `[batch_size, max_time]`,
containing the token ids of tokens in input sequences.
inputs: Either a **2D Tensor** of shape `[batch_size, max_time]`,
containing the ids of tokens in input sequences, or
a **3D Tensor** of shape `[batch_size, max_time, vocab_size]`,
containing soft token ids (i.e., weights or probabilities)
used to mix the embedding vectors.
sequence_length (optional): A 1D Tensor of shape `[batch_size]`.
Input tokens beyond respective sequence lengths are masked
out automatically.
Expand Down
17 changes: 17 additions & 0 deletions texar/torch/modules/classifiers/gpt2_classifier_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,23 @@ def test_binary(self):
self.assertEqual(logits.shape, torch.Size([self.batch_size]))
self.assertEqual(preds.shape, torch.Size([self.batch_size]))

def test_soft_ids(self):
r"""Tests soft ids.
"""
inputs = torch.rand(self.batch_size, self.max_length, 50257)
hparams = {
"pretrained_model_name": None,
"num_classes": 1,
"clas_strategy": "time_wise",
}
classifier = GPT2Classifier(hparams=hparams)
logits, preds = classifier(inputs)

self.assertEqual(logits.shape, torch.Size(
[self.batch_size, self.max_length]))
self.assertEqual(preds.shape, torch.Size(
[self.batch_size, self.max_length]))


if __name__ == "__main__":
unittest.main()
11 changes: 7 additions & 4 deletions texar/torch/modules/classifiers/roberta_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""
RoBERTa classifier.
"""
from typing import Optional, Tuple
from typing import Optional, Tuple, Union

import torch

Expand Down Expand Up @@ -136,7 +136,7 @@ def default_hparams():
return hparams

def forward(self, # type: ignore
inputs: torch.Tensor,
inputs: Union[torch.Tensor, torch.LongTensor],
sequence_length: Optional[torch.LongTensor] = None) \
-> Tuple[torch.Tensor, torch.LongTensor]:
r"""Feeds the inputs through the network and makes classification.
Expand All @@ -145,8 +145,11 @@ def forward(self, # type: ignore
:class:`~texar.torch.modules.RoBERTaEncoder`.
Args:
inputs: A 2D Tensor of shape `[batch_size, max_time]`,
containing the token ids of tokens in input sequences.
inputs: Either a **2D Tensor** of shape `[batch_size, max_time]`,
containing the ids of tokens in input sequences, or
a **3D Tensor** of shape `[batch_size, max_time, vocab_size]`,
containing soft token ids (i.e., weights or probabilities)
used to mix the embedding vectors.
sequence_length (optional): A 1D Tensor of shape `[batch_size]`.
Input tokens beyond respective sequence lengths are masked
out automatically.
Expand Down
18 changes: 18 additions & 0 deletions texar/torch/modules/classifiers/roberta_classifier_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,24 @@ def test_binary(self):
self.assertEqual(logits.shape, torch.Size([self.batch_size]))
self.assertEqual(preds.shape, torch.Size([self.batch_size]))

def test_soft_ids(self):
r"""Tests soft ids.
"""
inputs = torch.rand(self.batch_size, self.max_length, 50265)

hparams = {
"pretrained_model_name": None,
"num_classes": 1,
"clas_strategy": "time_wise",
}
classifier = RoBERTaClassifier(hparams=hparams)
logits, preds = classifier(inputs)

self.assertEqual(logits.shape, torch.Size(
[self.batch_size, self.max_length]))
self.assertEqual(preds.shape, torch.Size(
[self.batch_size, self.max_length]))


if __name__ == "__main__":
unittest.main()
14 changes: 9 additions & 5 deletions texar/torch/modules/classifiers/xlnet_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
XLNet Classifier.
"""

from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, Optional, Tuple, Union

import torch
from torch import nn
Expand Down Expand Up @@ -243,14 +243,18 @@ def param_groups(self,
return self.parameters()

def forward(self, # type: ignore
token_ids: torch.LongTensor,
inputs: Union[torch.Tensor, torch.LongTensor],
segment_ids: Optional[torch.LongTensor] = None,
input_mask: Optional[torch.Tensor] = None) \
-> Tuple[torch.Tensor, torch.LongTensor]:
r"""Feeds the inputs through the network and makes classification.
Args:
token_ids: Shape `[batch_size, max_time]`.
inputs: Either a **2D Tensor** of shape `[batch_size, max_time]`,
containing the ids of tokens in input sequences, or
a **3D Tensor** of shape `[batch_size, max_time, vocab_size]`,
containing soft token ids (i.e., weights or probabilities)
used to mix the embedding vectors.
segment_ids: Shape `[batch_size, max_time]`.
input_mask: Float tensor of shape `[batch_size, max_time]`. Note
that positions with value 1 are masked out.
Expand All @@ -276,7 +280,7 @@ def forward(self, # type: ignore
shape ``[batch_size, max_time]``.
"""
# output: [batch_size, seq_len, hidden_dim]
output, _ = self._encoder(token_ids=token_ids,
output, _ = self._encoder(inputs=inputs,
segment_ids=segment_ids,
input_mask=input_mask)

Expand All @@ -286,7 +290,7 @@ def forward(self, # type: ignore
elif strategy == 'cls_time':
summary = output[:, -1]
elif strategy == 'all_time':
length_diff = self._hparams.max_seq_length - token_ids.shape[1]
length_diff = self._hparams.max_seq_length - inputs.shape[1]
summary_input = F.pad(output, [0, 0, 0, length_diff, 0, 0])
summary_input_dim = (self._encoder.output_size *
self._hparams.max_seq_length)
Expand Down
19 changes: 19 additions & 0 deletions texar/torch/modules/classifiers/xlnet_classifier_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,25 @@ def test_binary(self):
self.assertEqual(logits.shape, torch.Size([self.batch_size]))
self.assertEqual(preds.shape, torch.Size([self.batch_size]))

def test_soft_ids(self):
r"""Tests soft ids.
"""
inputs = torch.rand(self.batch_size, self.max_length, 32000)

# case 1
hparams = {
"pretrained_model_name": None,
"num_classes": 1,
"clas_strategy": "time_wise",
}
classifier = XLNetClassifier(hparams=hparams)
logits, preds = classifier(inputs)

self.assertEqual(logits.shape, torch.Size(
[self.batch_size, self.max_length]))
self.assertEqual(preds.shape, torch.Size(
[self.batch_size, self.max_length]))


if __name__ == "__main__":
unittest.main()
22 changes: 16 additions & 6 deletions texar/torch/modules/encoders/bert_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
BERT encoder.
"""

from typing import Optional
from typing import Optional, Union

import torch
from torch import nn
Expand Down Expand Up @@ -281,14 +281,17 @@ def default_hparams():
}

def forward(self, # type: ignore
inputs: torch.Tensor,
inputs: Union[torch.Tensor, torch.LongTensor],
sequence_length: Optional[torch.LongTensor] = None,
segment_ids: Optional[torch.LongTensor] = None):
r"""Encodes the inputs.
Args:
inputs: A 2D Tensor of shape `[batch_size, max_time]`,
containing the token ids of tokens in the input sequences.
inputs: Either a **2D Tensor** of shape `[batch_size, max_time]`,
containing the ids of tokens in input sequences, or
a **3D Tensor** of shape `[batch_size, max_time, vocab_size]`,
containing soft token ids (i.e., weights or probabilities)
used to mix the embedding vectors.
segment_ids (optional): A 2D Tensor of shape
`[batch_size, max_time]`, containing the segment ids
of tokens in input sequences. If `None` (default), a
Expand All @@ -308,16 +311,23 @@ def forward(self, # type: ignore
pre-trained on top of the hidden state associated to the first
character of the input (`CLS`), see BERT's paper.
"""
if inputs.dim() == 2:
word_embeds = self.word_embedder(ids=inputs)
elif inputs.dim() == 3:
word_embeds = self.word_embedder(soft_ids=inputs)
else:
raise ValueError("'inputs' should be a 2D or 3D tensor.")

word_embeds = self.word_embedder(inputs)
batch_size = inputs.size(0)
pos_length = inputs.new_full((batch_size,), inputs.size(1),
dtype=torch.int64)
pos_embeds = self.position_embedder(sequence_length=pos_length)

if self.segment_embedder is not None:
if segment_ids is None:
segment_ids = torch.zeros_like(inputs)
segment_ids = torch.zeros((inputs.size(0), inputs.size(1)),
dtype=torch.long,
device=inputs.device)
segment_embeds = self.segment_embedder(segment_ids)
inputs_embeds = word_embeds + segment_embeds + pos_embeds
else:
Expand Down
19 changes: 19 additions & 0 deletions texar/torch/modules/encoders/bert_encoder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,25 @@ def test_encode(self):
pooled_output.shape,
torch.Size([self.batch_size, encoder.output_size]))

def test_soft_ids(self):
r"""Tests soft ids.
"""
hparams = {
"pretrained_model_name": None,
}
encoder = BERTEncoder(hparams=hparams)

inputs = torch.rand(self.batch_size, self.max_length, 30522)
outputs, pooled_output = encoder(inputs)

outputs_dim = encoder.hparams.encoder.dim
self.assertEqual(
outputs.shape,
torch.Size([self.batch_size, self.max_length, outputs_dim]))
self.assertEqual(
pooled_output.shape,
torch.Size([self.batch_size, encoder.output_size]))


if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit c6935c2

Please sign in to comment.