Skip to content

Commit

Permalink
update mucgec model.
Browse files Browse the repository at this point in the history
  • Loading branch information
shibing624 committed Oct 28, 2024
1 parent 598b5cc commit 81f6498
Show file tree
Hide file tree
Showing 12 changed files with 220 additions and 174 deletions.
67 changes: 23 additions & 44 deletions README.md

Large diffs are not rendered by default.

54 changes: 54 additions & 0 deletions examples/mucgec_bart/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# BART文本纠错-中文-通用领域-large


#### 中文文本纠错模型介绍
输入一句中文文本,文本纠错技术对句子中存在拼写、语法、语义等错误进行自动纠正,输出纠正后的文本。主流的方法为seq2seq和seq2edits,常用的中文纠错数据集包括NLPCC18和CGED等,我们最新的工作提供了高质量、多答案的测试集MuCGEC。

我们采用基于transformer的seq2seq方法建模文本纠错任务。模型训练上,我们使用中文BART作为预训练模型,然后在Lang8和HSK训练数据上进行finetune。不引入额外资源的情况下,本模型在NLPCC18测试集上达到了SOTA。

模型效果如下:
输入:这洋的话,下一年的福气来到自己身上。
输出:这样的话,下一年的福气就会来到自己身上。

#### 期望模型使用方式以及适用范围
本模型主要用于对中文文本进行错误诊断,输出符合拼写、语法要求的文本。该纠错模型是一个句子级别的模型,模型效果会受到文本长度、分句粒度的影响,建议是每次输入一句话。具体调用方式请参考代码示例。





## Usage
#### 安装依赖
```shell
pip install pycorrector difflib modelscope==1.16.0 fairseq==0.12.2
```
#### pycorrector快速预测

example: [examples/mucgec_bart/demo.py](https://github.com/shibing624/pycorrector/blob/master/examples/mucgec_bart/demo.py)
```python
from pycorrector.mucgec_bart.mucgec_bart_corrector import MuCGECBartCorrector


if __name__ == "__main__":
m = MuCGECBartCorrector()
result = m.correct_batch(['这洋的话,下一年的福气来到自己身上。',
'在拥挤时间,为了让人们尊守交通规律,派至少两个警察或者交通管理者。',
'随着中国经济突飞猛近,建造工业与日俱增',
"北京是中国的都。",
"他说:”我最爱的运动是打蓝球“",
"我每天大约喝5次水左右。",
"今天,我非常开开心。"])
print(result)
```

output:
```shell
[{'source': '今天新情很好', 'target': '今天心情很好', 'errors': [('', '', 2)]},
{'source': '你找到你最喜欢的工作,我也很高心。', 'target': '你找到你最喜欢的工作,我也很高兴。', 'errors': [('', '', 15)]}]
```

## Reference
- https://modelscope.cn/models/iic/nlp_bart_text-error-correction_chinese/summary
- 苏大:Tang et al. Chinese grammatical error correction enhanced by data augmentation from word and character levels. 2021.
- 北大 & MSRA & CUHK:Sun et al. A Unified Strategy for Multilingual Grammatical Error Correction with Pre-trained Cross-Lingual Language Model. 2021.
- Ours:Zhang et al. MuCGEC: a Multi-Reference Multi-Source Evaluation Dataset for Chinese Grammatical Error Correction. 2022.
29 changes: 29 additions & 0 deletions examples/mucgec_bart/demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# -*- coding: utf-8 -*-
"""
@author:XuMing([email protected])
@description:
"""
import sys

sys.path.append('../..')

from pycorrector.mucgec_bart.mucgec_bart_corrector import MuCGECBartCorrector

if __name__ == "__main__":
m = MuCGECBartCorrector()
result = m.correct_batch(['这洋的话,下一年的福气来到自己身上。',
'在拥挤时间,为了让人们尊守交通规律,派至少两个警察或者交通管理者。',
'随着中国经济突飞猛近,建造工业与日俱增',
"北京是中国的都。",
"他说:”我最爱的运动是打蓝球“",
"我每天大约喝5次水左右。",
"今天,我非常开开心。"])
print(result)

# [{'source': '这洋的话,下一年的福气来到自己身上。', 'target': '这样的话,下一年的福气就会来到自己身上。', 'errors': [('洋', '样', 1), ('', '就会', 11)]},
# {'source': '在拥挤时间,为了让人们尊守交通规律,派至少两个警察或者交通管理者。', 'target': '在拥挤时间,为了让人们遵守交通规则,应该派至少两个警察或者交通管理者。', 'errors': [('尊', '遵', 11), ('律', '则', 16), ('', '应该', 18)]},
# {'source': '随着中国经济突飞猛近,建造工业与日俱增', 'target': '随着中国经济突飞猛进,建造工业与日俱增', 'errors': [('近', '进', 9)]},
# {'source': '北京是中国的都。', 'target': '北京是中国的首都。', 'errors': [('', '首', 6)]},
# {'source': '他说:”我最爱的运动是打蓝球“', 'target': '他说:“我最爱的运动是打篮球”', 'errors': [('”', '“', 3), ('蓝', '篮', 12), ('“', '”', 14)]},
# {'source': '我每天大约喝5次水左右。', 'target': '我每天大约喝5杯水左右。', 'errors': [('次', '杯', 7)]},
# {'source': '今天,我非常开开心。', 'target': '今天,我非常开心。', 'errors': [('开', '', 7)]}]
2 changes: 0 additions & 2 deletions pycorrector/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
from pycorrector.proper_corrector import ProperCorrector
from pycorrector.seq2seq.conv_seq2seq_corrector import ConvSeq2SeqCorrector
from pycorrector.t5.t5_corrector import T5Corrector
from pycorrector.mucgec_bart.mucgec_bart_corrector import MuCGECBartCorrector
from pycorrector.nasgec_bart.nasgec_bart_corrector import NaSGECBartCorrector
from pycorrector.utils import text_utils, tokenizer, io_utils, math_utils, evaluate_utils
from pycorrector.utils.evaluate_utils import eval_model_batch
from pycorrector.utils.get_file import get_file
Expand Down
114 changes: 56 additions & 58 deletions pycorrector/mucgec_bart/monkey_pack.py
Original file line number Diff line number Diff line change
@@ -1,70 +1,68 @@
from typing import Any, Dict, List
import torch
from modelscope.pipelines import Pipeline
from typing import Any, Dict, List
from modelscope.utils.constant import Frameworks
from modelscope.utils.device import device_placement

# 批量推理问题
def _process_batch(self, input: List, batch_size,
**kwargs) -> Dict[str, Any]:
preprocess_params = kwargs.get('preprocess_params')
forward_params = kwargs.get('forward_params')
postprocess_params = kwargs.get('postprocess_params')

# batch data
output_list = []
for i in range(0, len(input), batch_size):
end = min(i + batch_size, len(input))
real_batch_size = end - i
preprocessed_list = [
self.preprocess(i, **preprocess_params) for i in input[i:end]
]
def _process_batch(self, input: List, batch_size, **kwargs):
preprocess_params = kwargs.get('preprocess_params')
forward_params = kwargs.get('forward_params')
postprocess_params = kwargs.get('postprocess_params')

with device_placement(self.framework, self.device_name):
if self.framework == Frameworks.torch:
with torch.no_grad():
batched_out = self._batch(preprocessed_list)
if self._auto_collate:
batched_out = self._collate_fn(batched_out)
batched_out = self.forward(batched_out,
**forward_params)
else:
# batch data
output_list = []
for i in range(0, len(input), batch_size):
end = min(i + batch_size, len(input))
real_batch_size = end - i
preprocessed_list = [
self.preprocess(i, **preprocess_params) for i in input[i:end]
]

with device_placement(self.framework, self.device_name):
if self.framework == Frameworks.torch:
with torch.no_grad():
batched_out = self._batch(preprocessed_list)
batched_out = self.forward(batched_out, **forward_params)
model_name = kwargs.get("model_name")
# print("model_name", model_name)
if model_name=="batch_correct":
for batch_idx in range(real_batch_size):
out = {}
for k, element in batched_out.items():
if element is not None:
if isinstance(element, (tuple, list)):
out[k] = element[batch_idx]
else:
out[k] = element[batch_idx:batch_idx + 1]
out = self.postprocess(out, **postprocess_params)
self._check_output(out)
output_list.append(out)
if self._auto_collate:
batched_out = self._collate_fn(batched_out)
batched_out = self.forward(batched_out,
**forward_params)
else:
for batch_idx in range(real_batch_size):
out = {}
for k, element in batched_out.items():
if element is not None:
if isinstance(element, (tuple, list)):
if isinstance(element[0], torch.Tensor):
out[k] = type(element)(
e[batch_idx:batch_idx + 1]
for e in element)
else:
# Compatible with traditional pipelines
out[k] = element[batch_idx]
batched_out = self._batch(preprocessed_list)
batched_out = self.forward(batched_out, **forward_params)
model_name = kwargs.get("model_name")
# print("model_name", model_name)
if model_name == "batch_correct":
for batch_idx in range(real_batch_size):
out = {}
for k, element in batched_out.items():
if element is not None:
if isinstance(element, (tuple, list)):
out[k] = element[batch_idx]
else:
out[k] = element[batch_idx:batch_idx + 1]
out = self.postprocess(out, **postprocess_params)
self._check_output(out)
output_list.append(out)
else:
for batch_idx in range(real_batch_size):
out = {}
for k, element in batched_out.items():
if element is not None:
if isinstance(element, (tuple, list)):
if isinstance(element[0], torch.Tensor):
out[k] = type(element)(
e[batch_idx:batch_idx + 1]
for e in element)
else:
out[k] = element[batch_idx:batch_idx + 1]
out = self.postprocess(out, **postprocess_params)
self._check_output(out)
output_list.append(out)

return output_list
# Compatible with traditional pipelines
out[k] = element[batch_idx]
else:
out[k] = element[batch_idx:batch_idx + 1]
out = self.postprocess(out, **postprocess_params)
self._check_output(out)
output_list.append(out)
return output_list


Pipeline._process_batch = _process_batch
Pipeline._process_batch = _process_batch
47 changes: 23 additions & 24 deletions pycorrector/mucgec_bart/mucgec_bart_corrector.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,16 @@

import torch
from loguru import logger
from tqdm import tqdm


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
import sys
sys.path.append('../..')
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from pycorrector.mucgec_bart.monkey_pack import Pipeline
from pycorrector.utils.sentence_utils import long_sentence_split
import difflib

device = torch.device("mps" if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()
else "cuda" if torch.cuda.is_available() else "cpu")
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"


class MuCGECBartCorrector:
def __init__(self, model_name_or_path: str = "damo/nlp_bart_text-error-correction_chinese"):
Expand All @@ -28,9 +25,9 @@ def __init__(self, model_name_or_path: str = "damo/nlp_bart_text-error-correctio

def _predict(self, sentences, batch_size=32, max_length=128, silent=True):
raise NotImplementedError

def correct_batch(self, sentences: List[str], max_length: int = 128, batch_size: int = 32, silent: bool = True, ignore_function=None):

def correct_batch(self, sentences: List[str], max_length: int = 128, batch_size: int = 32, silent: bool = True,
ignore_function=None):
"""
批量句子纠错
:param sentences: list[str], sentence list
Expand All @@ -47,29 +44,37 @@ def correct_batch(self, sentences: List[str], max_length: int = 128, batch_size:
result = [r["output"] for r in result]
for i in range(n):
a, b = sentences[i], result[i]
if len(a)==0 or len(b)==0 or a=="\n":
if len(a) == 0 or len(b) == 0 or a == "\n":
start_idx += len(a)
return
s = difflib.SequenceMatcher(None, a, b)
errors = []
offset = 0
for tag, i1, i2, j1, j2 in s.get_opcodes():
if tag!="equal":
e = [a[i1:i2], b[j1+offset:j2+offset], i1]
if tag != "equal":
e = [a[i1:i2], b[j1 + offset:j2 + offset], i1]
if ignore_function and ignore_function(e):
# 因为不认为是错误, 所以改回原来的偏移值
b = b[:j1] + a[i1:i2] + b[j2:]
offset += i2-i1-j2+j1
offset += i2 - i1 - j2 + j1
continue

errors.append(tuple(e))
data.append({"source": a, "target": b, "errors": errors})
return data


def correct(self, sentence: str, **kwargs):
"""长句改为短句, 可直接调用长文本"""
sentences = long_sentence_split(sentence, max_length=kwargs.pop("max_length", 128), period=kwargs.pop("period", None), comma=kwargs.pop("comma", None))
"""
长句改为短句, 可直接调用长文本
Args:
sentence:
**kwargs:
Returns:
dict
"""
sentences = long_sentence_split(sentence, max_length=kwargs.pop("max_length", 128),
period=kwargs.pop("period", None), comma=kwargs.pop("comma", None))
batch_results = self.correct_batch(sentences, **kwargs)
source, target, errors = "", "", []
for sr in batch_results:
Expand All @@ -83,9 +88,3 @@ def correct(self, sentence: str, **kwargs):
e[2] += ll
errors.append(tuple(e))
return {"source": source, "target": target, "errors": errors, "sentences": batch_results}






Loading

0 comments on commit 81f6498

Please sign in to comment.