Skip to content

Commit

Permalink
add unitest, config, fix ernie_gen bugs and add ernie_tiny_couplet (P…
Browse files Browse the repository at this point in the history
  • Loading branch information
kinghuin authored Jul 28, 2020
1 parent 879383e commit e35ff5e
Show file tree
Hide file tree
Showing 13 changed files with 380 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,14 @@ def _initialize(self):
assets_path = os.path.join(self.directory, "assets")
gen_checkpoint_path = os.path.join(assets_path, "ernie_gen_couplet")
ernie_cfg_path = os.path.join(assets_path, 'ernie_config.json')
ernie_cfg = dict(json.loads(open(ernie_cfg_path).read()))
with open(ernie_cfg_path) as ernie_cfg_file:
ernie_cfg = dict(json.loads(ernie_cfg_file.read()))
ernie_vocab_path = os.path.join(assets_path, 'vocab.txt')
ernie_vocab = {
j.strip().split('\t')[0]: i
for i, j in enumerate(open(ernie_vocab_path).readlines())
}
with open(ernie_vocab_path) as ernie_vocab_file:
ernie_vocab = {
j.strip().split('\t')[0]: i
for i, j in enumerate(ernie_vocab_file.readlines())
}

with fluid.dygraph.guard(fluid.CPUPlace()):
with fluid.unique_name.guard():
Expand Down Expand Up @@ -183,5 +185,5 @@ def serving_method(self, texts, use_gpu=False):

if __name__ == "__main__":
module = ErnieGen()
for result in module.generate(['人增福寿年增岁', '风吹云乱天垂泪'], beam_width=5):
for result in module.generate(['上海自来水来自海上', '风吹云乱天垂泪'], beam_width=5):
print(result)
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ ERNIE-GEN 是面向生成任务的预训练-微调框架,首次在预训练阶
## 命令行预测

```shell
$ hub run ernie_gen_poetry --input_text="宝积峰前露术香,使君行旆照晴阳" --use_gpu True --beam_width 5
$ hub run ernie_gen_poetry --input_text="昔年旅南服,始识王荆州" --use_gpu True --beam_width 5
```

## API
Expand Down Expand Up @@ -38,7 +38,7 @@ import paddlehub as hub

module = hub.Module(name="ernie_gen_poetry")

test_texts = ["宝积峰前露术香,使君行旆照晴阳。"]
test_texts = ['昔年旅南服,始识王荆州。', '高名出汉阴,禅阁跨香岑。']
results = module.genrate(texts=test_texts, use_gpu=True, beam_width=5)
for result in results:
print(result)
Expand Down Expand Up @@ -69,7 +69,7 @@ import json

# 发送HTTP请求

data = {'texts':["宝积峰前露术香,使君行旆照晴阳。"],
data = {'texts':['昔年旅南服,始识王荆州。', '高名出汉阴,禅阁跨香岑。'],
'use_gpu':False, 'beam_width':5}
headers = {"Content-type": "application/json"}
url = "http://127.0.0.1:8866/predict/ernie_gen_poetry"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
import paddle.fluid as F
import paddle.fluid.layers as L

from ernie_gen_couplet.model.modeling_ernie import ErnieModel
from ernie_gen_couplet.model.modeling_ernie import _build_linear, _build_ln, append_name
from ernie_gen_poetry.model.modeling_ernie import ErnieModel
from ernie_gen_poetry.model.modeling_ernie import _build_linear, _build_ln, append_name


class ErnieModelForGeneration(ErnieModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,14 @@ def _initialize(self):
assets_path = os.path.join(self.directory, "assets")
gen_checkpoint_path = os.path.join(assets_path, "ernie_gen_poetry")
ernie_cfg_path = os.path.join(assets_path, 'ernie_config.json')
ernie_cfg = dict(json.loads(open(ernie_cfg_path).read()))
with open(ernie_cfg_path) as ernie_cfg_file:
ernie_cfg = dict(json.loads(ernie_cfg_file.read()))
ernie_vocab_path = os.path.join(assets_path, 'vocab.txt')
ernie_vocab = {
j.strip().split('\t')[0]: i
for i, j in enumerate(open(ernie_vocab_path).readlines())
}
with open(ernie_vocab_path) as ernie_vocab_file:
ernie_vocab = {
j.strip().split('\t')[0]: i
for i, j in enumerate(ernie_vocab_file.readlines())
}

with fluid.dygraph.guard(fluid.CPUPlace()):
with fluid.unique_name.guard():
Expand Down Expand Up @@ -183,5 +185,6 @@ def serving_method(self, texts, use_gpu=False):

if __name__ == "__main__":
module = ErnieGen()
for result in module.generate(['宝积峰前露术香,使君行旆照晴阳。'], beam_width=5):
for result in module.generate(['昔年旅南服,始识王荆州。', '高名出汉阴,禅阁跨香岑。'],
beam_width=5):
print(result)
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
```shell
$ hub install ernie_tiny_couplet==1.0.0
```
<p align="center">
<img src="https://paddlehub.bj.bcebos.com/paddlehub-img%2Fernie_tiny_framework.PNG" hspace='10'/> <br />
</p>
本预测module系由TextGenerationTask微调而来,转换方式可以参考[Fine-tune保存的模型如何转化为一个PaddleHub Module](https://github.com/PaddlePaddle/PaddleHub/blob/develop/docs/tutorial/finetuned_model_to_module.md)

## 命令行预测

```shell
$ hub run ernie_tiny_couplet --input_text '风吹云乱天垂泪'
```
命令行预测只支持使用CPU预测,如需使用GPU,请使用API方式预测。

## API
```python
def generate(texts)
```

对联预测接口,输入上联文本,输出下联文本。该接口封装了上联文本使用`hub.BertTokenizer`编码的过程,因此它的调用方式比demo中提供的[predcit接口](https://github.com/PaddlePaddle/PaddleHub/blob/develop/demo/text_generation/predict.py#L83)简单。

**参数**

> texts(list[str]): 上联文本。
**返回**

> result(list[str]): 下联文本。每个上联会对应输出10个下联。
**代码示例**

```python
import paddlehub as hub

# Load ernie pretrained model
module = hub.Module(name="ernie_tiny_couplet")
results = module.generate(["风吹云乱天垂泪", "若有经心风过耳"])
for result in results:
print(result)
```

## 服务部署

PaddleHub Serving 可以部署在线服务。

### 第一步:启动PaddleHub Serving

运行启动命令:
```shell
$ hub serving start -m ernie_tiny_couplet
```

这样就完成了一个服务化API的部署,默认端口号为8866。

**NOTE:** 服务部署只支持使用CPU,如需使用GPU,请使用API方式预测。

### 第二步:发送预测请求

配置好服务端,以下数行代码即可实现发送预测请求,获取预测结果

```python
import requests
import json

# 发送HTTP请求

data = {'texts':["风吹云乱天垂泪", "若有经心风过耳"]}
headers = {"Content-type": "application/json"}
url = "http://127.0.0.1:8866/predict/ernie_tiny_couplet"
r = requests.post(url=url, headers=headers, data=json.dumps(data))

# 保存结果
results = r.json()["results"]
print(results)
```

## 查看代码

https://github.com/PaddlePaddle/PaddleHub/blob/develop/demo/text_generation


## 依赖

paddlepaddle >= 1.8.2

paddlehub >= 1.8.0

## 更新历史

* 1.0.0

初始发布。
Empty file.
144 changes: 144 additions & 0 deletions hub_module/modules/text/text_generation/ernie_tiny_couplet/module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# coding:utf-8
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import ast
import argparse

import paddlehub as hub
from paddlehub.module.module import moduleinfo, serving, runnable
from paddlehub.module.nlp_module import DataFormatError


@moduleinfo(
name="ernie_tiny_couplet",
version="1.0.0",
summary="couplet generation model fine-tuned with ernie_tiny module",
author="paddlehub",
author_email="",
type="nlp/text_generation",
)
class ErnieTinyCouplet(hub.NLPPredictionModule):
def _initialize(self, use_gpu=False):
# Load Paddlehub ERNIE Tiny pretrained model
self.module = hub.Module(name="ernie_tiny")
inputs, outputs, program = self.module.context(
trainable=True, max_seq_len=128)

# Download dataset and get its label list and label num
# If you just want labels information, you can omit its tokenizer parameter to avoid preprocessing the train set.
dataset = hub.dataset.Couplet()
self.label_list = dataset.get_labels()

# Setup RunConfig for PaddleHub Fine-tune API
config = hub.RunConfig(
use_data_parallel=False,
use_cuda=use_gpu,
batch_size=1,
checkpoint_dir=os.path.join(self.directory, "assets", "ckpt"),
strategy=hub.AdamWeightDecayStrategy())

# Construct transfer learning network
# Use "pooled_output" for classification tasks on an entire sentence.
# Use "sequence_output" for token-level output.
pooled_output = outputs["pooled_output"]
sequence_output = outputs["sequence_output"]

# Define a classfication fine-tune task by PaddleHub's API
self.gen_task = hub.TextGenerationTask(
feature=pooled_output,
token_feature=sequence_output,
max_seq_len=128,
num_classes=dataset.num_labels,
config=config,
metrics_choices=["bleu"])

def generate(self, texts):
# Add 0x02 between characters to match the format of training data,
# otherwise the length of prediction results will not match the input string
# if the input string contains non-Chinese characters.
formatted_text_a = list(map("\002".join, texts))

# Use the appropriate tokenizer to preprocess the data
# For ernie_tiny, it use BertTokenizer too.
tokenizer = hub.BertTokenizer(vocab_file=self.module.get_vocab_path())
encoded_data = [
tokenizer.encode(text=text, max_seq_len=128)
for text in formatted_text_a
]
results = self.gen_task.generate(
data=encoded_data,
label_list=self.label_list,
accelerate_mode=False)
results = [["".join(sample_result) for sample_result in sample_results]
for sample_results in results]
return results

def add_module_config_arg(self):
"""
Add the command config options
"""
self.arg_config_group.add_argument(
'--use_gpu',
type=ast.literal_eval,
default=False,
help="whether use GPU for prediction")

@runnable
def run_cmd(self, argvs):
"""
Run as a command
"""
self.parser = argparse.ArgumentParser(
description='Run the %s module.' % self.name,
prog='hub run %s' % self.name,
usage='%(prog)s',
add_help=True)

self.arg_input_group = self.parser.add_argument_group(
title="Input options", description="Input data. Required")
self.arg_config_group = self.parser.add_argument_group(
title="Config options",
description=
"Run configuration for controlling module behavior, not required.")

self.add_module_config_arg()
self.add_module_input_arg()

args = self.parser.parse_args(argvs)

try:
input_data = self.check_input_data(args)
except DataFormatError and RuntimeError:
self.parser.print_help()
return None

results = self.generate(texts=input_data)

return results

@serving
def serving_method(self, texts):
"""
Run as a service.
"""
results = self.generate(texts)
return results


if __name__ == '__main__':
module = ErnieTinyCouplet()
results = module.generate(["风吹云乱天垂泪", "若有经心风过耳"])
for result in results:
print(result)
9 changes: 9 additions & 0 deletions hub_module/scripts/configs/ernie_gen_couplet.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
name: ernie_gen_couplet
dir: "modules/text/text_generation/ernie_gen_couplet"
exclude:
- README.md
resources:
-
url: https://paddlehub.bj.bcebos.com/model/nlp/ernie_gen_couplet/assets.tar.gz
dest: assets
uncompress: True
9 changes: 9 additions & 0 deletions hub_module/scripts/configs/ernie_gen_poetry.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
name: ernie_gen_poetry
dir: "modules/text/text_generation/ernie_gen_poetry"
exclude:
- README.md
resources:
-
url: https://paddlehub.bj.bcebos.com/model/nlp/ernie_gen_poetry/assets.tar.gz
dest: assets
uncompress: True
9 changes: 9 additions & 0 deletions hub_module/scripts/configs/ernie_tiny_couplet.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
name: ernie_tiny_couplet
dir: "modules/text/text_generation/ernie_tiny_couplet"
exclude:
- README.md
resources:
-
url: https://paddlehub.bj.bcebos.com/model/nlp/ernie_tiny_couplet/assets.tar.gz
dest: assets
uncompress: True
32 changes: 32 additions & 0 deletions hub_module/tests/unittests/test_ernie_gen_couplet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import TestCase, main
import paddlehub as hub


class ErnieGenCoupletTestCase(TestCase):
def setUp(self):
self.module = hub.Module(name='ernie_gen_couplet')
self.left = ["风吹云乱天垂泪", "若有经心风过耳"]

def test_predict(self):
rights = self.module.generate(self.left)
self.assertEqual(len(rights), 2)
self.assertEqual(len(rights[0]), 5)
self.assertEqual(len(rights[0][0]), 7)
self.assertEqual(len(rights[1][0]), 7)


if __name__ == '__main__':
main()
Loading

0 comments on commit e35ff5e

Please sign in to comment.