From 6b7bc84323221e3473384e63e32c58621bec2c70 Mon Sep 17 00:00:00 2001 From: shibing624 Date: Mon, 14 Oct 2024 12:10:31 +0800 Subject: [PATCH] update chatglm3-csc --- examples/evaluate_models/evaluate_models.py | 4 +++- examples/gpt/training_chatglm_demo.py | 6 ++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/examples/evaluate_models/evaluate_models.py b/examples/evaluate_models/evaluate_models.py index 1dff3ac1..c4b3df56 100644 --- a/examples/evaluate_models/evaluate_models.py +++ b/examples/evaluate_models/evaluate_models.py @@ -84,15 +84,17 @@ def main(args): model_type='chatglm', peft_name="shibing624/chatglm3-6b-csc-chinese-lora") if args.data == 'sighan': - eval_model_batch(m.correct_batch, prompt_template_name='vicuna') + eval_model_batch(m.correct_batch, prefix_prompt="对这个句子语法纠错\n\n", prompt_template_name='vicuna') # Sentence Level: acc:0.5564, precision:0.5574, recall:0.4917, f1:0.5225, cost time:1572.49 s, total num: 1100 # elif args.data == 'ec_law': eval_model_batch(m.correct_batch, input_tsv_file=os.path.join(pwd_path, "../data/ec_law_test.tsv"), + prefix_prompt="对这个句子语法纠错\n\n", prompt_template_name='vicuna') # elif args.data == 'mcsc': eval_model_batch(m.correct_batch, input_tsv_file=os.path.join(pwd_path, "../data/mcsc_test.tsv"), + prefix_prompt="对这个句子语法纠错\n\n", prompt_template_name='vicuna') # elif args.model == 'qwen1.5b': diff --git a/examples/gpt/training_chatglm_demo.py b/examples/gpt/training_chatglm_demo.py index 38585b4a..dda75554 100644 --- a/examples/gpt/training_chatglm_demo.py +++ b/examples/gpt/training_chatglm_demo.py @@ -18,7 +18,7 @@ def main(): parser.add_argument('--train_file', default='../data/grammar/train_sharegpt.jsonl', type=str, help='Train file') parser.add_argument('--test_file', default='../data/grammar/test_sharegpt.jsonl', type=str, help='Test file') parser.add_argument('--model_type', default='chatglm', type=str, help='Transformers model type') - parser.add_argument('--model_name', default='THUDM/chatglm-6b', type=str, + parser.add_argument('--model_name', default='THUDM/chatglm3-6b', type=str, help='Transformers model or path') parser.add_argument('--do_train', action='store_true', help='Whether to run training.') parser.add_argument('--do_predict', action='store_true', help='Whether to run predict.') @@ -68,7 +68,9 @@ def main(): peft_name=args.output_dir, args={'use_peft': True, 'eval_batch_size': args.batch_size, "max_length": args.max_length, } ) - result = m.correct_batch(error_sentences) + result = m.correct_batch(error_sentences, + prefix_prompt="对这个句子语法纠错\n\n", + prompt_template_name=args.prompt_template_name) for res_dict in result: print(res_dict)