Skip to content

Commit

Permalink
fix ernie_gen bug. and plato and ddparser config (PaddlePaddle#817)
Browse files Browse the repository at this point in the history
  • Loading branch information
kinghuin authored Aug 13, 2020
1 parent cc78bd1 commit d9d160a
Show file tree
Hide file tree
Showing 8 changed files with 45 additions and 35 deletions.
24 changes: 13 additions & 11 deletions hub_module/modules/text/syntactic_analysis/DDParser/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,18 @@ $ hub run ddparser --input_text="百度是一家高科技公司"

# API

## parse(texts=[])
## parse(texts=[], return\_visual=False)

依存分析接口,输入文本,输出依存关系。

**参数**

* texts(list[list[str] or list[str]]): 待预测数据。各元素可以是未分词的字符串,也可以是已分词的token列表。
* texts(list\[list\[str\] or list\[str\]]): 待预测数据。各元素可以是未分词的字符串,也可以是已分词的token列表。
* return\_visual(bool): 是否返回依存分析可视化结果。如果为True,返回结果中将包含'visual'字段。

**返回**

* results(list[dict]): 依存分析结果。每个元素都是dict类型,包含以下信息:
* results(list\[dict\]): 依存分析结果。每个元素都是dict类型,包含以下信息:
```python
{
'word': list[str], 分词结果。
Expand All @@ -34,9 +35,9 @@ $ hub run ddparser --input_text="百度是一家高科技公司"

**参数**

* word(list[list[str]): 分词信息。
* head(list[int]): 当前成分其支配者的id。
* deprel(list[str]): 当前成分与支配者的依存关系。
* word(list\[list\[str\]\): 分词信息。
* head(list\[int\]): 当前成分其支配者的id。
* deprel(list\[str\]): 当前成分与支配者的依存关系。

**返回**

Expand All @@ -55,11 +56,12 @@ results = module.parse(texts=test_text)
print(results)

test_tokens = [['百度', '', '一家', '高科技', '公司']]
results = module.parse(texts=test_text)
results = module.parse(texts=test_text, return_visual = True)
print(results)

result = results[0]
data = module.visualize(result['word'],result['head'],result['deprel'])
# or data = result['visual']
cv2.imwrite('test.jpg',data)
```

Expand All @@ -81,7 +83,7 @@ Loading ddparser successful.

这样就完成了服务化API的部署,默认端口号为8866。

**NOTE:** 如使用GPU预测,则需要在启动服务之前,请设置CUDA_VISIBLE_DEVICES环境变量,否则不用设置。
**NOTE:** 如使用GPU预测,则需要在启动服务之前,请设置CUDA\_VISIBLE\_DEVICES环境变量,否则不用设置。

## 第二步:发送预测请求

Expand All @@ -105,12 +107,12 @@ data = {"texts": text, "return_visual": return_visual}
url = "http://0.0.0.0:8866/predict/ddparser"
headers = {"Content-Type": "application/json"}
r = requests.post(url=url, headers=headers, data=json.dumps(data))
results, visuals = r.json()['results']
results = r.json()['results']

for i in range(len(results)):
print(results[i])
print(results[i]['word'])
# 不同于本地调用parse接口,serving返回的图像是list类型的,需要先用numpy加载再显示或保存。
cv2.imwrite('%s.jpg'%i, np.array(visuals[i]))
cv2.imwrite('%s.jpg'%i, np.array(results[i]['visual']))
```

关于PaddleHub Serving更多信息参考[服务部署](https://github.com/PaddlePaddle/PaddleHub/blob/release/v1.6/docs/tutorial/serving.md)
Expand Down
30 changes: 14 additions & 16 deletions hub_module/modules/text/syntactic_analysis/DDParser/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,16 @@ def _initialize(self):
"""
self.ddp = DDParserModel(prob=True, use_pos=True)
self.font = font_manager.FontProperties(
fname=os.path.join(self.directory, "SimHei.ttf"))
fname=os.path.join(self.directory, "SourceHanSans-Regular.ttf"))

@serving
def serving_parse(self, texts=[], return_visual=False):
results, visuals = self.parse(texts, return_visual)
for i, visual in enumerate(visuals):
visuals[i] = visual.tolist()
results = self.parse(texts, return_visual)
if return_visual:
for i, result in enumerate(results):
result['visual'] = result['visual'].tolist()

return results, visuals
return results

def parse(self, texts=[], return_visual=False):
"""
Expand All @@ -57,11 +58,9 @@ def parse(self, texts=[], return_visual=False):
'head': list[int], the head ids.
'deprel': list[str], the dependency relation.
'prob': list[float], the prediction probility of the dependency relation.
'postag': list[str], the POS tag. If the element of the texts is list, the key 'postag' will not be returned.
'postag': list[str], the POS tag. If the element of the texts is list, the key 'postag' will not return.
'visual' : list[numpy.array]: the dependency visualization. Use cv2.imshow to show or cv2.imwrite to save it. If return_visual=False, it will not return.
}
visuals : list[numpy.array]: the dependency visualization. Use cv2.imshow to show or cv2.imwrite to save it. If return_visual=False, it will not be empty.
"""

if not texts:
Expand All @@ -73,13 +72,11 @@ def parse(self, texts=[], return_visual=False):
else:
raise ValueError("All of the elements should be string or list")
results = do_parse(texts)
visuals = []
if return_visual:
for result in results:
visuals.append(
self.visualize(result['word'], result['head'],
result['deprel']))
return results, visuals
result['visual'] = self.visualize(
result['word'], result['head'], result['deprel'])
return results

@runnable
def run_cmd(self, argvs):
Expand Down Expand Up @@ -194,10 +191,11 @@ def visualize(self, word, head, deprel):
results = module.parse(texts=test_text)
print(results)
test_tokens = [['百度', '是', '一家', '高科技', '公司']]
results = module.parse(texts=test_text)
results = module.parse(texts=test_text, return_visual=True)
print(results)
result = results[0]
data = module.visualize(result['word'], result['head'], result['deprel'])
import cv2
import numpy as np
cv2.imwrite('test.jpg', np.array(data))
cv2.imwrite('test1.jpg', data)
cv2.imwrite('test2.jpg', result['visual'])
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,7 @@ paddlehub >= 1.7.0
* 1.0.0

初始发布

* 1.0.1

修复windows中的编码问题
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

@moduleinfo(
name="ernie_gen_couplet",
version="1.0.0",
version="1.0.1",
summary=
"ERNIE-GEN is a multi-flow language generation framework for both pre-training and fine-tuning. This module has fine-tuned for couplet generation task.",
author="baidu-nlp",
Expand All @@ -50,10 +50,10 @@ def _initialize(self):
assets_path = os.path.join(self.directory, "assets")
gen_checkpoint_path = os.path.join(assets_path, "ernie_gen_couplet")
ernie_cfg_path = os.path.join(assets_path, 'ernie_config.json')
with open(ernie_cfg_path) as ernie_cfg_file:
with open(ernie_cfg_path, encoding='utf8') as ernie_cfg_file:
ernie_cfg = dict(json.loads(ernie_cfg_file.read()))
ernie_vocab_path = os.path.join(assets_path, 'vocab.txt')
with open(ernie_vocab_path) as ernie_vocab_file:
with open(ernie_vocab_path, encoding='utf8') as ernie_vocab_file:
ernie_vocab = {
j.strip().split('\t')[0]: i
for i, j in enumerate(ernie_vocab_file.readlines())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,7 @@ paddlehub >= 1.7.0
* 1.0.0

初始发布

* 1.0.1

修复windows中的编码问题
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

@moduleinfo(
name="ernie_gen_poetry",
version="1.0.0",
version="1.0.1",
summary=
"ERNIE-GEN is a multi-flow language generation framework for both pre-training and fine-tuning. This module has fine-tuned for poetry generation task.",
author="baidu-nlp",
Expand All @@ -50,10 +50,10 @@ def _initialize(self):
assets_path = os.path.join(self.directory, "assets")
gen_checkpoint_path = os.path.join(assets_path, "ernie_gen_poetry")
ernie_cfg_path = os.path.join(assets_path, 'ernie_config.json')
with open(ernie_cfg_path) as ernie_cfg_file:
with open(ernie_cfg_path, encoding='utf8') as ernie_cfg_file:
ernie_cfg = dict(json.loads(ernie_cfg_file.read()))
ernie_vocab_path = os.path.join(assets_path, 'vocab.txt')
with open(ernie_vocab_path) as ernie_vocab_file:
with open(ernie_vocab_path, encoding='utf8') as ernie_vocab_file:
ernie_vocab = {
j.strip().split('\t')[0]: i
for i, j in enumerate(ernie_vocab_file.readlines())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ PLATO2是一个超大规模生成式对话系统模型。它承袭了PLATO隐变
## 命令行预测

```shell
$ hub run plato2_en_base --input_text="Hello, how are you" --use_gpu
$ hub run plato2_en_base --input_text="Hello, how are you"
```

## API
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@ PLATO2是一个超大规模生成式对话系统模型。它承袭了PLATO隐变

更多详情参考论文[PLATO-2: Towards Building an Open-Domain Chatbot via Curriculum Learning](https://arxiv.org/abs/2006.16779)

**注:plato2\_en\_large 模型大小12GB,下载时间较长,请耐心等候。运行此模型要求显存至少16GB。**

## 命令行预测

```shell
$ hub run plato2_en_large --input_text="Hello, how are you" --use_gpu
$ hub run plato2_en_large --input_text="Hello, how are you"
```

## API
Expand Down

0 comments on commit d9d160a

Please sign in to comment.