模型名称 | CPM_LM |
---|---|
类别 | 文本-文本生成 |
网络 | GPT-2 |
数据集 | 自建数据集 |
是否支持Fine-tuning | 否 |
模型大小 | 5.31G |
最新更新日期 | 2021-02-26 |
数据指标 | - |
-
- CPM-LM 是一个基于 GPT-2 的预训练生成模型,参数规模达 26 亿,预训练中文数据规模 100 GB,使用了 64 块 V100 GPU,训练时间约为 3 周。能够在多种自然语言处理任务上进行零次学习或少次学习,并达到较好的效果。基于给定上文,模型可以续写出一致性高、可读性强的文本,达到现有中文生成模型的领先效果。
-
-
paddlepaddle >= 2.0.0
-
paddlehub >= 2.0.0 | 如何安装PaddleHub
-
sentencepiece==0.1.92
-
注意本模型对sentencepiece版本要求严格,在使用前请确认您所使用的版本是正确的
-
$ pip install sentencepiece==0.1.92
-
-
-
-
$ hub install CPM_LM
-
如您安装时遇到问题,可参考:零基础windows安装 | 零基础Linux安装 | 零基础MacOS安装
-
-
-
import paddlehub as hub model = hub.Module(name='CPM_LM') # 加载模型
-
Note:模型参数转换至官方开源项目,由于模型较大,推荐在GPU环境下运行,并且请确保运行环境的内存大于20G且显卡显存大于12G,否则可能无法正常运行
-
使用 Greedy Search 生成文本:
-
inputs = '''默写古诗: 日照香炉生紫烟,遥看瀑布挂前川。 飞流直下三千尺,''' outputs = model.predict(inputs, max_len=10, end_word='\n') print(inputs+outputs) # 默写古诗: # 日照香炉生紫烟,遥看瀑布挂前川。 # 飞流直下三千尺,疑是银河落九天。
-
inputs = '''问题:西游记是谁写的? 答案:''' outputs = model.predict(inputs, max_len=10, end_word='\n') print(inputs+outputs) # 问题:西游记是谁写的? # 答案:吴承恩。
-
inputs = '''小明决定去吃饭,小红继续写作业 问题:去吃饭的人是谁? 答案:''' outputs = model.predict(inputs, max_len=10, end_word='\n') print(inputs+outputs) # 小明决定去吃饭,小红继续写作业 # 问题:去吃饭的人是谁? # 答案:小明
-
inputs = '''默写英文: 狗:dog 猫:''' outputs = model.predict(inputs, max_len=10, end_word='\n') print(inputs+outputs) # 默写英文: # 狗:dog # 猫:cat
-
-
-
-
def predict(text, max_len=32, end_word=None):
- 预测 API ,根据输入的文字进行文本生成,使用 Greedy Search 进行解码。
- 参数
- text (str) : 输入文本
- max_len (int) : 生成文本的最大长度
- end_word (str or None) : 终止生成的标志词
- 返回
- results (str): 生成的文本
-
def tokenizer.encode(text):
- 编码 API
- 参数
- text (str) : 输入文本
- 返回
- results (listint) : 输出编码
-
def tokenizer.decode(ids):
- 解码 API
- 参数
- ids (listint) : 输入编码
- 返回
- results (str) : 输出文本
-
def model(x, kv_cache=None, use_cache=False):
- 模型前向计算 API
- 参数
- x (tensor) : 输入编码
- kv_cache (tensor) : 输入的缓存
- use_cache (bool) : 是否使用缓存
- 返回
- results (tensor) : 模型输出
-
-
PaddleHub Serving可以部署一个在线文本生成服务。
-
-
运行启动命令:
-
$ hub serving start --modules GPT2_CPM_LM -p 8866
-
这样就完成了一个对话机器人服务化API的部署,默认端口号为8866。
-
NOTE: 如使用GPU预测,则需要在启动服务之前,请设置CUDA_VISIBLE_DEVICES环境变量,否则不用设置。
-
-
-
配置好服务端,以下数行代码即可实现发送预测请求,获取预测结果
-
import requests import json text = "今天是个好日子" data = { "text": text, "mode": "sample", # 'search' or 'sample' # 可以更加需要设置上述 API 中提到的其他参数 } url = "http://127.0.0.1:8866/predict/GPT2_CPM_LM" headers = {"Content-Type": "application/json"} r = requests.post(url=url, headers=headers, data=json.dumps(data))
-
关于PaddleHub Serving更多信息参考服务部署
-
-
1.0.0
初始发布