-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
598b5cc
commit 81f6498
Showing
12 changed files
with
220 additions
and
174 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# BART文本纠错-中文-通用领域-large | ||
|
||
|
||
#### 中文文本纠错模型介绍 | ||
输入一句中文文本,文本纠错技术对句子中存在拼写、语法、语义等错误进行自动纠正,输出纠正后的文本。主流的方法为seq2seq和seq2edits,常用的中文纠错数据集包括NLPCC18和CGED等,我们最新的工作提供了高质量、多答案的测试集MuCGEC。 | ||
|
||
我们采用基于transformer的seq2seq方法建模文本纠错任务。模型训练上,我们使用中文BART作为预训练模型,然后在Lang8和HSK训练数据上进行finetune。不引入额外资源的情况下,本模型在NLPCC18测试集上达到了SOTA。 | ||
|
||
模型效果如下: | ||
输入:这洋的话,下一年的福气来到自己身上。 | ||
输出:这样的话,下一年的福气就会来到自己身上。 | ||
|
||
#### 期望模型使用方式以及适用范围 | ||
本模型主要用于对中文文本进行错误诊断,输出符合拼写、语法要求的文本。该纠错模型是一个句子级别的模型,模型效果会受到文本长度、分句粒度的影响,建议是每次输入一句话。具体调用方式请参考代码示例。 | ||
|
||
|
||
|
||
|
||
|
||
## Usage | ||
#### 安装依赖 | ||
```shell | ||
pip install pycorrector difflib modelscope==1.16.0 fairseq==0.12.2 | ||
``` | ||
#### pycorrector快速预测 | ||
|
||
example: [examples/mucgec_bart/demo.py](https://github.com/shibing624/pycorrector/blob/master/examples/mucgec_bart/demo.py) | ||
```python | ||
from pycorrector.mucgec_bart.mucgec_bart_corrector import MuCGECBartCorrector | ||
|
||
|
||
if __name__ == "__main__": | ||
m = MuCGECBartCorrector() | ||
result = m.correct_batch(['这洋的话,下一年的福气来到自己身上。', | ||
'在拥挤时间,为了让人们尊守交通规律,派至少两个警察或者交通管理者。', | ||
'随着中国经济突飞猛近,建造工业与日俱增', | ||
"北京是中国的都。", | ||
"他说:”我最爱的运动是打蓝球“", | ||
"我每天大约喝5次水左右。", | ||
"今天,我非常开开心。"]) | ||
print(result) | ||
``` | ||
|
||
output: | ||
```shell | ||
[{'source': '今天新情很好', 'target': '今天心情很好', 'errors': [('新', '心', 2)]}, | ||
{'source': '你找到你最喜欢的工作,我也很高心。', 'target': '你找到你最喜欢的工作,我也很高兴。', 'errors': [('心', '兴', 15)]}] | ||
``` | ||
|
||
## Reference | ||
- https://modelscope.cn/models/iic/nlp_bart_text-error-correction_chinese/summary | ||
- 苏大:Tang et al. Chinese grammatical error correction enhanced by data augmentation from word and character levels. 2021. | ||
- 北大 & MSRA & CUHK:Sun et al. A Unified Strategy for Multilingual Grammatical Error Correction with Pre-trained Cross-Lingual Language Model. 2021. | ||
- Ours:Zhang et al. MuCGEC: a Multi-Reference Multi-Source Evaluation Dataset for Chinese Grammatical Error Correction. 2022. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# -*- coding: utf-8 -*- | ||
""" | ||
@author:XuMing([email protected]) | ||
@description: | ||
""" | ||
import sys | ||
|
||
sys.path.append('../..') | ||
|
||
from pycorrector.mucgec_bart.mucgec_bart_corrector import MuCGECBartCorrector | ||
|
||
if __name__ == "__main__": | ||
m = MuCGECBartCorrector() | ||
result = m.correct_batch(['这洋的话,下一年的福气来到自己身上。', | ||
'在拥挤时间,为了让人们尊守交通规律,派至少两个警察或者交通管理者。', | ||
'随着中国经济突飞猛近,建造工业与日俱增', | ||
"北京是中国的都。", | ||
"他说:”我最爱的运动是打蓝球“", | ||
"我每天大约喝5次水左右。", | ||
"今天,我非常开开心。"]) | ||
print(result) | ||
|
||
# [{'source': '这洋的话,下一年的福气来到自己身上。', 'target': '这样的话,下一年的福气就会来到自己身上。', 'errors': [('洋', '样', 1), ('', '就会', 11)]}, | ||
# {'source': '在拥挤时间,为了让人们尊守交通规律,派至少两个警察或者交通管理者。', 'target': '在拥挤时间,为了让人们遵守交通规则,应该派至少两个警察或者交通管理者。', 'errors': [('尊', '遵', 11), ('律', '则', 16), ('', '应该', 18)]}, | ||
# {'source': '随着中国经济突飞猛近,建造工业与日俱增', 'target': '随着中国经济突飞猛进,建造工业与日俱增', 'errors': [('近', '进', 9)]}, | ||
# {'source': '北京是中国的都。', 'target': '北京是中国的首都。', 'errors': [('', '首', 6)]}, | ||
# {'source': '他说:”我最爱的运动是打蓝球“', 'target': '他说:“我最爱的运动是打篮球”', 'errors': [('”', '“', 3), ('蓝', '篮', 12), ('“', '”', 14)]}, | ||
# {'source': '我每天大约喝5次水左右。', 'target': '我每天大约喝5杯水左右。', 'errors': [('次', '杯', 7)]}, | ||
# {'source': '今天,我非常开开心。', 'target': '今天,我非常开心。', 'errors': [('开', '', 7)]}] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,70 +1,68 @@ | ||
from typing import Any, Dict, List | ||
import torch | ||
from modelscope.pipelines import Pipeline | ||
from typing import Any, Dict, List | ||
from modelscope.utils.constant import Frameworks | ||
from modelscope.utils.device import device_placement | ||
|
||
# 批量推理问题 | ||
def _process_batch(self, input: List, batch_size, | ||
**kwargs) -> Dict[str, Any]: | ||
preprocess_params = kwargs.get('preprocess_params') | ||
forward_params = kwargs.get('forward_params') | ||
postprocess_params = kwargs.get('postprocess_params') | ||
|
||
# batch data | ||
output_list = [] | ||
for i in range(0, len(input), batch_size): | ||
end = min(i + batch_size, len(input)) | ||
real_batch_size = end - i | ||
preprocessed_list = [ | ||
self.preprocess(i, **preprocess_params) for i in input[i:end] | ||
] | ||
def _process_batch(self, input: List, batch_size, **kwargs): | ||
preprocess_params = kwargs.get('preprocess_params') | ||
forward_params = kwargs.get('forward_params') | ||
postprocess_params = kwargs.get('postprocess_params') | ||
|
||
with device_placement(self.framework, self.device_name): | ||
if self.framework == Frameworks.torch: | ||
with torch.no_grad(): | ||
batched_out = self._batch(preprocessed_list) | ||
if self._auto_collate: | ||
batched_out = self._collate_fn(batched_out) | ||
batched_out = self.forward(batched_out, | ||
**forward_params) | ||
else: | ||
# batch data | ||
output_list = [] | ||
for i in range(0, len(input), batch_size): | ||
end = min(i + batch_size, len(input)) | ||
real_batch_size = end - i | ||
preprocessed_list = [ | ||
self.preprocess(i, **preprocess_params) for i in input[i:end] | ||
] | ||
|
||
with device_placement(self.framework, self.device_name): | ||
if self.framework == Frameworks.torch: | ||
with torch.no_grad(): | ||
batched_out = self._batch(preprocessed_list) | ||
batched_out = self.forward(batched_out, **forward_params) | ||
model_name = kwargs.get("model_name") | ||
# print("model_name", model_name) | ||
if model_name=="batch_correct": | ||
for batch_idx in range(real_batch_size): | ||
out = {} | ||
for k, element in batched_out.items(): | ||
if element is not None: | ||
if isinstance(element, (tuple, list)): | ||
out[k] = element[batch_idx] | ||
else: | ||
out[k] = element[batch_idx:batch_idx + 1] | ||
out = self.postprocess(out, **postprocess_params) | ||
self._check_output(out) | ||
output_list.append(out) | ||
if self._auto_collate: | ||
batched_out = self._collate_fn(batched_out) | ||
batched_out = self.forward(batched_out, | ||
**forward_params) | ||
else: | ||
for batch_idx in range(real_batch_size): | ||
out = {} | ||
for k, element in batched_out.items(): | ||
if element is not None: | ||
if isinstance(element, (tuple, list)): | ||
if isinstance(element[0], torch.Tensor): | ||
out[k] = type(element)( | ||
e[batch_idx:batch_idx + 1] | ||
for e in element) | ||
else: | ||
# Compatible with traditional pipelines | ||
out[k] = element[batch_idx] | ||
batched_out = self._batch(preprocessed_list) | ||
batched_out = self.forward(batched_out, **forward_params) | ||
model_name = kwargs.get("model_name") | ||
# print("model_name", model_name) | ||
if model_name == "batch_correct": | ||
for batch_idx in range(real_batch_size): | ||
out = {} | ||
for k, element in batched_out.items(): | ||
if element is not None: | ||
if isinstance(element, (tuple, list)): | ||
out[k] = element[batch_idx] | ||
else: | ||
out[k] = element[batch_idx:batch_idx + 1] | ||
out = self.postprocess(out, **postprocess_params) | ||
self._check_output(out) | ||
output_list.append(out) | ||
else: | ||
for batch_idx in range(real_batch_size): | ||
out = {} | ||
for k, element in batched_out.items(): | ||
if element is not None: | ||
if isinstance(element, (tuple, list)): | ||
if isinstance(element[0], torch.Tensor): | ||
out[k] = type(element)( | ||
e[batch_idx:batch_idx + 1] | ||
for e in element) | ||
else: | ||
out[k] = element[batch_idx:batch_idx + 1] | ||
out = self.postprocess(out, **postprocess_params) | ||
self._check_output(out) | ||
output_list.append(out) | ||
|
||
return output_list | ||
# Compatible with traditional pipelines | ||
out[k] = element[batch_idx] | ||
else: | ||
out[k] = element[batch_idx:batch_idx + 1] | ||
out = self.postprocess(out, **postprocess_params) | ||
self._check_output(out) | ||
output_list.append(out) | ||
return output_list | ||
|
||
|
||
Pipeline._process_batch = _process_batch | ||
Pipeline._process_batch = _process_batch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.