Skip to content

Commit

Permalink
Add embedding finetune demo (PaddlePaddle#1204)
Browse files Browse the repository at this point in the history
* Add embedding seq-cls finetune demo and update api

* Update docs of pad_sequence and trunc_sequence
  • Loading branch information
KPatr1ck authored Jan 27, 2021
1 parent 5832b1a commit 045e4e2
Show file tree
Hide file tree
Showing 7 changed files with 484 additions and 52 deletions.
175 changes: 175 additions & 0 deletions demo/text_classification/embedding/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from typing import List

import paddle
import paddle.nn as nn
import paddle.nn.functional as F

import paddlenlp as nlp
from paddlenlp.embeddings import TokenEmbedding
from paddlenlp.data import JiebaTokenizer

from paddlehub.utils.log import logger
from paddlehub.utils.utils import pad_sequence, trunc_sequence


class BoWModel(nn.Layer):
"""
This class implements the Bag of Words Classification Network model to classify texts.
At a high level, the model starts by embedding the tokens and running them through
a word embedding. Then, we encode these epresentations with a `BoWEncoder`.
Lastly, we take the output of the encoder to create a final representation,
which is passed through some feed-forward layers to output a logits (`output_layer`).
Args:
vocab_size (obj:`int`): The vocabulary size.
emb_dim (obj:`int`, optional, defaults to 300): The embedding dimension.
hidden_size (obj:`int`, optional, defaults to 128): The first full-connected layer hidden size.
fc_hidden_size (obj:`int`, optional, defaults to 96): The second full-connected layer hidden size.
num_classes (obj:`int`): All the labels that the data has.
"""

def __init__(self,
num_classes: int = 2,
embedder: TokenEmbedding = None,
tokenizer: JiebaTokenizer = None,
hidden_size: int = 128,
fc_hidden_size: int = 96,
load_checkpoint: str = None,
label_map: dict = None):
super().__init__()
self.embedder = embedder
self.tokenizer = tokenizer
self.label_map = label_map

emb_dim = self.embedder.embedding_dim
self.bow_encoder = nlp.seq2vec.BoWEncoder(emb_dim)
self.fc1 = nn.Linear(self.bow_encoder.get_output_dim(), hidden_size)
self.fc2 = nn.Linear(hidden_size, fc_hidden_size)
self.dropout = nn.Dropout(p=0.3, axis=1)
self.output_layer = nn.Linear(fc_hidden_size, num_classes)
self.criterion = nn.loss.CrossEntropyLoss()
self.metric = paddle.metric.Accuracy()

if load_checkpoint is not None and os.path.isfile(load_checkpoint):
state_dict = paddle.load(load_checkpoint)
self.set_state_dict(state_dict)
logger.info('Loaded parameters from %s' % os.path.abspath(load_checkpoint))

def training_step(self, batch: List[paddle.Tensor], batch_idx: int):
"""
One step for training, which should be called as forward computation.
Args:
batch(:obj:List[paddle.Tensor]): The one batch data, which contains the model needed,
such as input_ids, sent_ids, pos_ids, input_mask and labels.
batch_idx(int): The index of batch.
Returns:
results(:obj: Dict) : The model outputs, such as loss and metrics.
"""
_, avg_loss, metric = self(ids=batch[0], labels=batch[1])
self.metric.reset()
return {'loss': avg_loss, 'metrics': metric}

def validation_step(self, batch: List[paddle.Tensor], batch_idx: int):
"""
One step for validation, which should be called as forward computation.
Args:
batch(:obj:List[paddle.Tensor]): The one batch data, which contains the model needed,
such as input_ids, sent_ids, pos_ids, input_mask and labels.
batch_idx(int): The index of batch.
Returns:
results(:obj: Dict) : The model outputs, such as metrics.
"""
_, _, metric = self(ids=batch[0], labels=batch[1])
self.metric.reset()
return {'metrics': metric}

def forward(self, ids: paddle.Tensor, labels: paddle.Tensor = None):

# Shape: (batch_size, num_tokens, embedding_dim)
embedded_text = self.embedder(ids)

# Shape: (batch_size, embedding_dim)
summed = self.bow_encoder(embedded_text)
summed = self.dropout(summed)
encoded_text = paddle.tanh(summed)

# Shape: (batch_size, hidden_size)
fc1_out = paddle.tanh(self.fc1(encoded_text))
# Shape: (batch_size, fc_hidden_size)
fc2_out = paddle.tanh(self.fc2(fc1_out))
# Shape: (batch_size, num_classes)
logits = self.output_layer(fc2_out)

probs = F.softmax(logits, axis=1)
if labels is not None:
loss = self.criterion(logits, labels)
correct = self.metric.compute(probs, labels)
acc = self.metric.update(correct)
return probs, loss, {'acc': acc}
else:
return probs

def _batchify(self, data: List[List[str]], max_seq_len: int, batch_size: int):
examples = []
for item in data:
ids = self.tokenizer.encode(sentence=item[0])

if len(ids) > max_seq_len:
ids = trunc_sequence(ids, max_seq_len)
else:
pad_token = self.tokenizer.vocab.pad_token
pad_token_id = self.tokenizer.vocab.to_indices(pad_token)
ids = pad_sequence(ids, max_seq_len, pad_token_id)
examples.append(ids)

# Seperates data into some batches.
one_batch = []
for example in examples:
one_batch.append(example)
if len(one_batch) == batch_size:
yield one_batch
one_batch = []
if one_batch:
# The last batch whose size is less than the config batch_size setting.
yield one_batch

def predict(
self,
data: List[List[str]],
max_seq_len: int = 128,
batch_size: int = 1,
use_gpu: bool = False,
return_result: bool = True,
):
paddle.set_device('gpu') if use_gpu else paddle.set_device('cpu')

batches = self._batchify(data, max_seq_len, batch_size)
results = []
self.eval()
for batch in batches:
ids = paddle.to_tensor(batch)
probs = self(ids)
idx = paddle.argmax(probs, axis=1).numpy()

if return_result:
idx = idx.tolist()
labels = [self.label_map[i] for i in idx]
results.extend(labels)
else:
results.extend(probs.numpy())

return results
55 changes: 55 additions & 0 deletions demo/text_classification/embedding/predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddlehub as hub
from paddlenlp.data import JiebaTokenizer
from model import BoWModel

import ast
import argparse


parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--hub_embedding_name", type=str, default='w2v_baidu_encyclopedia_target_word-word_dim300', help="")
parser.add_argument("--max_seq_len", type=int, default=128, help="Number of words of the longest seqence.")
parser.add_argument("--batch_size", type=int, default=64, help="Total examples' number in batch for training.")
parser.add_argument("--checkpoint", type=str, default='./checkpoint/best_model/model.pdparams', help="Model checkpoint")
parser.add_argument("--use_gpu", type=ast.literal_eval, default=True, help="Whether use GPU for fine-tuning, input should be True or False")

args = parser.parse_args()


if __name__ == '__main__':
# Data to be prdicted
data = [
["这个宾馆比较陈旧了,特价的房间也很一般。总体来说一般"],
["交通方便;环境很好;服务态度很好 房间较小"],
["还稍微重了点,可能是硬盘大的原故,还要再轻半斤就好了。其他要进一步验证。贴的几种膜气泡较多,用不了多久就要更换了,屏幕膜稍好点,但比没有要强多了。建议配赠几张膜让用用户自己贴。"],
["前台接待太差,酒店有A B楼之分,本人check-in后,前台未告诉B楼在何处,并且B楼无明显指示;房间太小,根本不像4星级设施,下次不会再选择入住此店啦"],
["19天硬盘就罢工了~~~算上运来的一周都没用上15天~~~可就是不能换了~~~唉~~~~你说这算什么事呀~~~"],
]

label_map = {0: 'negative', 1: 'positive'}

embedder = hub.Module(name=args.hub_embedding_name)
tokenizer = embedder.get_tokenizer()
model = BoWModel(
embedder=embedder,
tokenizer=tokenizer,
load_checkpoint=args.checkpoint,
label_map=label_map)

results = model.predict(data, max_seq_len=args.max_seq_len, batch_size=args.batch_size, use_gpu=args.use_gpu, return_result=False)
for idx, text in enumerate(data):
print('Data: {} \t Lable: {}'.format(text[0], results[idx]))
57 changes: 57 additions & 0 deletions demo/text_classification/embedding/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
import paddlehub as hub
from paddlehub.datasets import ChnSentiCorp
from paddlenlp.data import JiebaTokenizer
from model import BoWModel

import ast
import argparse


parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--hub_embedding_name", type=str, default='w2v_baidu_encyclopedia_target_word-word_dim300', help="")
parser.add_argument("--num_epoch", type=int, default=10, help="Number of epoches for fine-tuning.")
parser.add_argument("--learning_rate", type=float, default=5e-4, help="Learning rate used to train with warmup.")
parser.add_argument("--max_seq_len", type=int, default=128, help="Number of words of the longest seqence.")
parser.add_argument("--batch_size", type=int, default=64, help="Total examples' number in batch for training.")
parser.add_argument("--checkpoint_dir", type=str, default='./checkpoint', help="Directory to model checkpoint")
parser.add_argument("--save_interval", type=int, default=5, help="Save checkpoint every n epoch.")
parser.add_argument("--use_gpu", type=ast.literal_eval, default=True, help="Whether use GPU for fine-tuning, input should be True or False")

args = parser.parse_args()


if __name__ == '__main__':
embedder = hub.Module(name=args.hub_embedding_name)
tokenizer = embedder.get_tokenizer()

train_dataset = ChnSentiCorp(tokenizer=tokenizer, max_seq_len=args.max_seq_len, mode='train')
dev_dataset = ChnSentiCorp(tokenizer=tokenizer, max_seq_len=args.max_seq_len, mode='dev')
test_dataset = ChnSentiCorp(tokenizer=tokenizer, max_seq_len=args.max_seq_len, mode='test')

model = BoWModel(embedder=embedder)
optimizer = paddle.optimizer.AdamW(
learning_rate=args.learning_rate, parameters=model.parameters())
trainer = hub.Trainer(model, optimizer, checkpoint_dir=args.checkpoint_dir, use_gpu=args.use_gpu)
trainer.train(
train_dataset,
epochs=args.num_epoch,
batch_size=args.batch_size,
eval_dataset=dev_dataset,
save_interval=args.save_interval,
)
trainer.evaluate(test_dataset, batch_size=args.batch_size)
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing import List
from paddlenlp.embeddings import TokenEmbedding
from paddlehub.module.module import moduleinfo, serving
from paddlehub.module.nlp_module import EmbeddingModule


@moduleinfo(
Expand All @@ -23,33 +24,13 @@
summary="",
author="paddlepaddle",
author_email="",
type="nlp/semantic_model")
type="nlp/semantic_model",
meta=EmbeddingModule)
class Embedding(TokenEmbedding):
"""
Embedding model
"""
def __init__(self, *args, **kwargs):
super(Embedding, self).__init__(embedding_name="w2v.baidu_encyclopedia.target.word-word.dim300", *args, **kwargs)

@serving
def calc_similarity(self, data: List[List[str]]):
"""
Calculate similarities of giving word pairs.
"""
results = []
for word_pair in data:
if len(word_pair) != 2:
raise RuntimeError(
f'The input must have two words, but got {len(word_pair)}. Please check your inputs.')
if not isinstance(word_pair[0], str) or not isinstance(word_pair[1], str):
raise RuntimeError(
f'The types of text pair must be (str, str), but got'
f' ({type(word_pair[0]).__name__}, {type(word_pair[1]).__name__}). Please check your inputs.')
embedding_name = 'w2v.baidu_encyclopedia.target.word-word.dim300'

for word in word_pair:
if self.get_idx_from_word(word) == \
self.get_idx_from_word(self.vocab.unk_token):
raise RuntimeError(
f'Word "{word}" is not in vocab. Please check your inputs.')
results.append(str(self.cosine_sim(*word_pair)))
return results
def __init__(self, *args, **kwargs):
super(Embedding, self).__init__(embedding_name=self.embedding_name, *args, **kwargs)
Loading

0 comments on commit 045e4e2

Please sign in to comment.