Skip to content

Commit

Permalink
update infer eval.
Browse files Browse the repository at this point in the history
  • Loading branch information
shibing624 committed Oct 14, 2024
1 parent 9d6c2ca commit b949208
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 20 deletions.
20 changes: 10 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,15 +86,15 @@ python examples/macbert/gradio_demo.py
- CTC(CHinese Text Correction): 文本纠错模型,表示模型支持拼写、语法等长度对齐的错误纠正,还可以处理多字、少字等长度不对齐的错误纠正
- GPU:Tesla V100,显存 32 GB

| Model Name | Model Link | Base Model | Avg | SIGHAN-2015 | EC-LAW | MCSC | GPU/CPU | QPS |
|:-----------------|:--------------------------------------------------------------------------------------------------------------------|:---------------------------|:-----------|:----------------|:-------|:-------|:--------|:--------|
| Kenlm-CSC | [shibing624/chinese-kenlm-klm](https://huggingface.co/shibing624/chinese-kenlm-klm) | kenlm | 0.3409 | 0.3147 | 0.3763 | 0.3317 | CPU | 9 |
| Mengzi-T5-CSC | [shibing624/mengzi-t5-base-chinese-correction](https://huggingface.co/shibing624/mengzi-t5-base-chinese-correction) | mengzi-t5-base | 0.3984 | 0.7758 | 0.3156 | 0.1039 | GPU | 214 |
| ERNIE-CSC | [PaddleNLP/ernie-csc](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/legacy/examples/text_correction/ernie-csc) | PaddlePaddle/ernie-1.0-base-zh | 0.4353 | 0.8383 | 0.3357 | 0.1318 | GPU | 114 |
| MacBERT-CSC | [shibing624/macbert4csc-base-chinese](https://huggingface.co/shibing624/macbert4csc-base-chinese) | hfl/chinese-macbert-base | 0.3993 | 0.8314 | 0.1610 | 0.2055 | GPU | **224** |
| ChatGLM3-6B-CSC | [shibing624/chatglm3-6b-csc-chinese-lora](https://huggingface.co/shibing624/chatglm3-6b-csc-chinese-lora) | THUDM/chatglm3-6b | - | 0.5225 | - | - | GPU | 1 |
| Qwen2.5-1.5B-CTC | [shibing624/chinese-text-correction-1.5b](https://huggingface.co/shibing624/chinese-text-correction-1.5b) | Qwen/Qwen2.5-1.5B-Instruct | 0.6802 | 0.3032 | 0.7846 | 0.9529 | GPU | 3 |
| Qwen2.5-7B-CTC | [shibing624/chinese-text-correction-7b](https://huggingface.co/shibing624/chinese-text-correction-7b) | Qwen/Qwen2.5-7B-Instruct | **0.8225** | 0.4917 | 0.9798 | 0.9959 | GPU | 2 |
| Model Name | Model Link | Base Model | Avg | SIGHAN-2015 | EC-LAW | MCSC | GPU/CPU | QPS |
|:-----------------|:------------------------------------------------------------------------------------------------------------------------|:---------------------------|:-----------|:------------|:-------|:-------|:--------|:--------|
| Kenlm-CSC | [shibing624/chinese-kenlm-klm](https://huggingface.co/shibing624/chinese-kenlm-klm) | kenlm | 0.3409 | 0.3147 | 0.3763 | 0.3317 | CPU | 9 |
| Mengzi-T5-CSC | [shibing624/mengzi-t5-base-chinese-correction](https://huggingface.co/shibing624/mengzi-t5-base-chinese-correction) | mengzi-t5-base | 0.3984 | 0.7758 | 0.3156 | 0.1039 | GPU | 214 |
| ERNIE-CSC | [PaddleNLP/ernie-csc](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/legacy/examples/text_correction/ernie-csc) | PaddlePaddle/ernie-1.0-base-zh | 0.4353 | 0.8383 | 0.3357 | 0.1318 | GPU | 114 |
| MacBERT-CSC | [shibing624/macbert4csc-base-chinese](https://huggingface.co/shibing624/macbert4csc-base-chinese) | hfl/chinese-macbert-base | 0.3993 | 0.8314 | 0.1610 | 0.2055 | GPU | **224** |
| ChatGLM3-6B-CSC | [shibing624/chatglm3-6b-csc-chinese-lora](https://huggingface.co/shibing624/chatglm3-6b-csc-chinese-lora) | THUDM/chatglm3-6b | 0.4538 | 0.6572 | 0.4369 | 0.2672 | GPU | 3 |
| Qwen2.5-1.5B-CTC | [shibing624/chinese-text-correction-1.5b](https://huggingface.co/shibing624/chinese-text-correction-1.5b) | Qwen/Qwen2.5-1.5B-Instruct | 0.6802 | 0.3032 | 0.7846 | 0.9529 | GPU | 6 |
| Qwen2.5-7B-CTC | [shibing624/chinese-text-correction-7b](https://huggingface.co/shibing624/chinese-text-correction-7b) | Qwen/Qwen2.5-7B-Instruct | **0.8225** | 0.4917 | 0.9798 | 0.9959 | GPU | 3 |


## Install
Expand Down Expand Up @@ -125,7 +125,7 @@ docker run -it -v ~/.pycorrector:/root/.pycorrector shibing624/pycorrector:0.0.2
## Usage
本项目的初衷之一是比对、调研各种中文文本纠错方法,抛砖引玉。

项目实现了kenlm、macbert、seq2seq、 ernie_csc、T5、deepcontext、LLaMA等模型应用于文本纠错任务,各模型均可基于已经训练好的纠错模型快速预测,也可使用自有数据训练、预测。
项目实现了kenlm、macbert、seq2seq、 ernie_csc、T5、deepcontext、GPT(Qwen/ChatGLM)等模型应用于文本纠错任务,各模型均可基于已经训练好的纠错模型快速预测,也可使用自有数据训练、预测。


### kenlm模型(统计模型)
Expand Down
14 changes: 6 additions & 8 deletions examples/evaluate_models/evaluate_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,19 +84,17 @@ def main(args):
model_type='chatglm',
peft_name="shibing624/chatglm3-6b-csc-chinese-lora")
if args.data == 'sighan':
eval_model_batch(m.correct_batch, prefix_prompt="对这个句子语法纠错\n\n", prompt_template_name='vicuna')
eval_model_batch(m.correct_batch, prefix_prompt="对下面文本纠错:", prompt_template_name='vicuna')
# Sentence Level: acc:0.5564, precision:0.5574, recall:0.4917, f1:0.5225, cost time:1572.49 s, total num: 1100
#
# Sentence Level: acc:0.6591, precision:0.7000, recall:0.6193, f1:0.6572, cost time:273.06 s, total num: 707
elif args.data == 'ec_law':
eval_model_batch(m.correct_batch, input_tsv_file=os.path.join(pwd_path, "../data/ec_law_test.tsv"),
prefix_prompt="对这个句子语法纠错\n\n",
prompt_template_name='vicuna')
#
prefix_prompt="对下面文本纠错:", prompt_template_name='vicuna')
# Sentence Level: acc:0.4870, precision:0.5182, recall:0.3776, f1:0.4369, cost time:372.46 s, total num: 1000
elif args.data == 'mcsc':
eval_model_batch(m.correct_batch, input_tsv_file=os.path.join(pwd_path, "../data/mcsc_test.tsv"),
prefix_prompt="对这个句子语法纠错\n\n",
prompt_template_name='vicuna')
#
prefix_prompt="对下面文本纠错:", prompt_template_name='vicuna')
# Sentence Level: acc:0.4790, precision:0.4185, recall:0.1963, f1:0.2672, cost time:383.76 s, total num: 1000
elif args.model == 'qwen1.5b':
from pycorrector.gpt.gpt_corrector import GptCorrector
m = GptCorrector(model_name_or_path="shibing624/chinese-text-correction-1.5b")
Expand Down
8 changes: 7 additions & 1 deletion examples/gpt/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,13 @@
]
m = GptCorrector("shibing624/chinese-text-correction-1.5b")

batch_res = m.correct_batch(error_sentences, system_prompt="你是一个中文文本纠错助手。请根据用户提供的原始文本,生成纠正后的文本。")
batch_res = m.correct_batch(error_sentences,
system_prompt="你是一个中文文本纠错助手。请根据用户提供的原始文本,生成纠正后的文本。")
for i in batch_res:
print(i)
print()

# batch_res = m.correct_batch(error_sentences, prefix_prompt='文本纠错:\n\n', prompt_template_name='qwen')
# for i in batch_res:
# print(i)
# print()
4 changes: 3 additions & 1 deletion pycorrector/gpt/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,8 @@ def predict(
prompt_len = len(input_ids[0])
generated_sequence = generated_sequence[prompt_len:]
gen_text = self.tokenizer.decode(generated_sequence, skip_special_tokens=True)
# logger.error(f"input_text: {input_text}, gen_text: {gen_text}")
gen_text = gen_text.strip()
# logger.debug(f"input_text: {input_text}, gen_text: {gen_text}")
all_outputs.append(gen_text)

return all_outputs
Expand Down Expand Up @@ -641,6 +642,7 @@ def chat(
)
output_tensor = outputs[0][len(input_ids[0]):] if skip_prompt else outputs[0]
response = self.tokenizer.decode(output_tensor, skip_special_tokens=True)
response = response.strip()
history[-1][1] = response
return response, history

Expand Down

0 comments on commit b949208

Please sign in to comment.