forked from PaddlePaddle/PaddleHub
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add embedding finetune demo (PaddlePaddle#1204)
* Add embedding seq-cls finetune demo and update api * Update docs of pad_sequence and trunc_sequence
- Loading branch information
Showing
7 changed files
with
484 additions
and
52 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,175 @@ | ||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License" | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import os | ||
from typing import List | ||
|
||
import paddle | ||
import paddle.nn as nn | ||
import paddle.nn.functional as F | ||
|
||
import paddlenlp as nlp | ||
from paddlenlp.embeddings import TokenEmbedding | ||
from paddlenlp.data import JiebaTokenizer | ||
|
||
from paddlehub.utils.log import logger | ||
from paddlehub.utils.utils import pad_sequence, trunc_sequence | ||
|
||
|
||
class BoWModel(nn.Layer): | ||
""" | ||
This class implements the Bag of Words Classification Network model to classify texts. | ||
At a high level, the model starts by embedding the tokens and running them through | ||
a word embedding. Then, we encode these epresentations with a `BoWEncoder`. | ||
Lastly, we take the output of the encoder to create a final representation, | ||
which is passed through some feed-forward layers to output a logits (`output_layer`). | ||
Args: | ||
vocab_size (obj:`int`): The vocabulary size. | ||
emb_dim (obj:`int`, optional, defaults to 300): The embedding dimension. | ||
hidden_size (obj:`int`, optional, defaults to 128): The first full-connected layer hidden size. | ||
fc_hidden_size (obj:`int`, optional, defaults to 96): The second full-connected layer hidden size. | ||
num_classes (obj:`int`): All the labels that the data has. | ||
""" | ||
|
||
def __init__(self, | ||
num_classes: int = 2, | ||
embedder: TokenEmbedding = None, | ||
tokenizer: JiebaTokenizer = None, | ||
hidden_size: int = 128, | ||
fc_hidden_size: int = 96, | ||
load_checkpoint: str = None, | ||
label_map: dict = None): | ||
super().__init__() | ||
self.embedder = embedder | ||
self.tokenizer = tokenizer | ||
self.label_map = label_map | ||
|
||
emb_dim = self.embedder.embedding_dim | ||
self.bow_encoder = nlp.seq2vec.BoWEncoder(emb_dim) | ||
self.fc1 = nn.Linear(self.bow_encoder.get_output_dim(), hidden_size) | ||
self.fc2 = nn.Linear(hidden_size, fc_hidden_size) | ||
self.dropout = nn.Dropout(p=0.3, axis=1) | ||
self.output_layer = nn.Linear(fc_hidden_size, num_classes) | ||
self.criterion = nn.loss.CrossEntropyLoss() | ||
self.metric = paddle.metric.Accuracy() | ||
|
||
if load_checkpoint is not None and os.path.isfile(load_checkpoint): | ||
state_dict = paddle.load(load_checkpoint) | ||
self.set_state_dict(state_dict) | ||
logger.info('Loaded parameters from %s' % os.path.abspath(load_checkpoint)) | ||
|
||
def training_step(self, batch: List[paddle.Tensor], batch_idx: int): | ||
""" | ||
One step for training, which should be called as forward computation. | ||
Args: | ||
batch(:obj:List[paddle.Tensor]): The one batch data, which contains the model needed, | ||
such as input_ids, sent_ids, pos_ids, input_mask and labels. | ||
batch_idx(int): The index of batch. | ||
Returns: | ||
results(:obj: Dict) : The model outputs, such as loss and metrics. | ||
""" | ||
_, avg_loss, metric = self(ids=batch[0], labels=batch[1]) | ||
self.metric.reset() | ||
return {'loss': avg_loss, 'metrics': metric} | ||
|
||
def validation_step(self, batch: List[paddle.Tensor], batch_idx: int): | ||
""" | ||
One step for validation, which should be called as forward computation. | ||
Args: | ||
batch(:obj:List[paddle.Tensor]): The one batch data, which contains the model needed, | ||
such as input_ids, sent_ids, pos_ids, input_mask and labels. | ||
batch_idx(int): The index of batch. | ||
Returns: | ||
results(:obj: Dict) : The model outputs, such as metrics. | ||
""" | ||
_, _, metric = self(ids=batch[0], labels=batch[1]) | ||
self.metric.reset() | ||
return {'metrics': metric} | ||
|
||
def forward(self, ids: paddle.Tensor, labels: paddle.Tensor = None): | ||
|
||
# Shape: (batch_size, num_tokens, embedding_dim) | ||
embedded_text = self.embedder(ids) | ||
|
||
# Shape: (batch_size, embedding_dim) | ||
summed = self.bow_encoder(embedded_text) | ||
summed = self.dropout(summed) | ||
encoded_text = paddle.tanh(summed) | ||
|
||
# Shape: (batch_size, hidden_size) | ||
fc1_out = paddle.tanh(self.fc1(encoded_text)) | ||
# Shape: (batch_size, fc_hidden_size) | ||
fc2_out = paddle.tanh(self.fc2(fc1_out)) | ||
# Shape: (batch_size, num_classes) | ||
logits = self.output_layer(fc2_out) | ||
|
||
probs = F.softmax(logits, axis=1) | ||
if labels is not None: | ||
loss = self.criterion(logits, labels) | ||
correct = self.metric.compute(probs, labels) | ||
acc = self.metric.update(correct) | ||
return probs, loss, {'acc': acc} | ||
else: | ||
return probs | ||
|
||
def _batchify(self, data: List[List[str]], max_seq_len: int, batch_size: int): | ||
examples = [] | ||
for item in data: | ||
ids = self.tokenizer.encode(sentence=item[0]) | ||
|
||
if len(ids) > max_seq_len: | ||
ids = trunc_sequence(ids, max_seq_len) | ||
else: | ||
pad_token = self.tokenizer.vocab.pad_token | ||
pad_token_id = self.tokenizer.vocab.to_indices(pad_token) | ||
ids = pad_sequence(ids, max_seq_len, pad_token_id) | ||
examples.append(ids) | ||
|
||
# Seperates data into some batches. | ||
one_batch = [] | ||
for example in examples: | ||
one_batch.append(example) | ||
if len(one_batch) == batch_size: | ||
yield one_batch | ||
one_batch = [] | ||
if one_batch: | ||
# The last batch whose size is less than the config batch_size setting. | ||
yield one_batch | ||
|
||
def predict( | ||
self, | ||
data: List[List[str]], | ||
max_seq_len: int = 128, | ||
batch_size: int = 1, | ||
use_gpu: bool = False, | ||
return_result: bool = True, | ||
): | ||
paddle.set_device('gpu') if use_gpu else paddle.set_device('cpu') | ||
|
||
batches = self._batchify(data, max_seq_len, batch_size) | ||
results = [] | ||
self.eval() | ||
for batch in batches: | ||
ids = paddle.to_tensor(batch) | ||
probs = self(ids) | ||
idx = paddle.argmax(probs, axis=1).numpy() | ||
|
||
if return_result: | ||
idx = idx.tolist() | ||
labels = [self.label_map[i] for i in idx] | ||
results.extend(labels) | ||
else: | ||
results.extend(probs.numpy()) | ||
|
||
return results |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License" | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import paddlehub as hub | ||
from paddlenlp.data import JiebaTokenizer | ||
from model import BoWModel | ||
|
||
import ast | ||
import argparse | ||
|
||
|
||
parser = argparse.ArgumentParser(__doc__) | ||
parser.add_argument("--hub_embedding_name", type=str, default='w2v_baidu_encyclopedia_target_word-word_dim300', help="") | ||
parser.add_argument("--max_seq_len", type=int, default=128, help="Number of words of the longest seqence.") | ||
parser.add_argument("--batch_size", type=int, default=64, help="Total examples' number in batch for training.") | ||
parser.add_argument("--checkpoint", type=str, default='./checkpoint/best_model/model.pdparams', help="Model checkpoint") | ||
parser.add_argument("--use_gpu", type=ast.literal_eval, default=True, help="Whether use GPU for fine-tuning, input should be True or False") | ||
|
||
args = parser.parse_args() | ||
|
||
|
||
if __name__ == '__main__': | ||
# Data to be prdicted | ||
data = [ | ||
["这个宾馆比较陈旧了,特价的房间也很一般。总体来说一般"], | ||
["交通方便;环境很好;服务态度很好 房间较小"], | ||
["还稍微重了点,可能是硬盘大的原故,还要再轻半斤就好了。其他要进一步验证。贴的几种膜气泡较多,用不了多久就要更换了,屏幕膜稍好点,但比没有要强多了。建议配赠几张膜让用用户自己贴。"], | ||
["前台接待太差,酒店有A B楼之分,本人check-in后,前台未告诉B楼在何处,并且B楼无明显指示;房间太小,根本不像4星级设施,下次不会再选择入住此店啦"], | ||
["19天硬盘就罢工了~~~算上运来的一周都没用上15天~~~可就是不能换了~~~唉~~~~你说这算什么事呀~~~"], | ||
] | ||
|
||
label_map = {0: 'negative', 1: 'positive'} | ||
|
||
embedder = hub.Module(name=args.hub_embedding_name) | ||
tokenizer = embedder.get_tokenizer() | ||
model = BoWModel( | ||
embedder=embedder, | ||
tokenizer=tokenizer, | ||
load_checkpoint=args.checkpoint, | ||
label_map=label_map) | ||
|
||
results = model.predict(data, max_seq_len=args.max_seq_len, batch_size=args.batch_size, use_gpu=args.use_gpu, return_result=False) | ||
for idx, text in enumerate(data): | ||
print('Data: {} \t Lable: {}'.format(text[0], results[idx])) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License" | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import paddle | ||
import paddlehub as hub | ||
from paddlehub.datasets import ChnSentiCorp | ||
from paddlenlp.data import JiebaTokenizer | ||
from model import BoWModel | ||
|
||
import ast | ||
import argparse | ||
|
||
|
||
parser = argparse.ArgumentParser(__doc__) | ||
parser.add_argument("--hub_embedding_name", type=str, default='w2v_baidu_encyclopedia_target_word-word_dim300', help="") | ||
parser.add_argument("--num_epoch", type=int, default=10, help="Number of epoches for fine-tuning.") | ||
parser.add_argument("--learning_rate", type=float, default=5e-4, help="Learning rate used to train with warmup.") | ||
parser.add_argument("--max_seq_len", type=int, default=128, help="Number of words of the longest seqence.") | ||
parser.add_argument("--batch_size", type=int, default=64, help="Total examples' number in batch for training.") | ||
parser.add_argument("--checkpoint_dir", type=str, default='./checkpoint', help="Directory to model checkpoint") | ||
parser.add_argument("--save_interval", type=int, default=5, help="Save checkpoint every n epoch.") | ||
parser.add_argument("--use_gpu", type=ast.literal_eval, default=True, help="Whether use GPU for fine-tuning, input should be True or False") | ||
|
||
args = parser.parse_args() | ||
|
||
|
||
if __name__ == '__main__': | ||
embedder = hub.Module(name=args.hub_embedding_name) | ||
tokenizer = embedder.get_tokenizer() | ||
|
||
train_dataset = ChnSentiCorp(tokenizer=tokenizer, max_seq_len=args.max_seq_len, mode='train') | ||
dev_dataset = ChnSentiCorp(tokenizer=tokenizer, max_seq_len=args.max_seq_len, mode='dev') | ||
test_dataset = ChnSentiCorp(tokenizer=tokenizer, max_seq_len=args.max_seq_len, mode='test') | ||
|
||
model = BoWModel(embedder=embedder) | ||
optimizer = paddle.optimizer.AdamW( | ||
learning_rate=args.learning_rate, parameters=model.parameters()) | ||
trainer = hub.Trainer(model, optimizer, checkpoint_dir=args.checkpoint_dir, use_gpu=args.use_gpu) | ||
trainer.train( | ||
train_dataset, | ||
epochs=args.num_epoch, | ||
batch_size=args.batch_size, | ||
eval_dataset=dev_dataset, | ||
save_interval=args.save_interval, | ||
) | ||
trainer.evaluate(test_dataset, batch_size=args.batch_size) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.