Skip to content

Commit

Permalink
update chatglm3-csc
Browse files Browse the repository at this point in the history
  • Loading branch information
shibing624 committed Oct 14, 2024
1 parent db61859 commit 6b7bc84
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
4 changes: 3 additions & 1 deletion examples/evaluate_models/evaluate_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,17 @@ def main(args):
model_type='chatglm',
peft_name="shibing624/chatglm3-6b-csc-chinese-lora")
if args.data == 'sighan':
eval_model_batch(m.correct_batch, prompt_template_name='vicuna')
eval_model_batch(m.correct_batch, prefix_prompt="对这个句子语法纠错\n\n", prompt_template_name='vicuna')
# Sentence Level: acc:0.5564, precision:0.5574, recall:0.4917, f1:0.5225, cost time:1572.49 s, total num: 1100
#
elif args.data == 'ec_law':
eval_model_batch(m.correct_batch, input_tsv_file=os.path.join(pwd_path, "../data/ec_law_test.tsv"),
prefix_prompt="对这个句子语法纠错\n\n",
prompt_template_name='vicuna')
#
elif args.data == 'mcsc':
eval_model_batch(m.correct_batch, input_tsv_file=os.path.join(pwd_path, "../data/mcsc_test.tsv"),
prefix_prompt="对这个句子语法纠错\n\n",
prompt_template_name='vicuna')
#
elif args.model == 'qwen1.5b':
Expand Down
6 changes: 4 additions & 2 deletions examples/gpt/training_chatglm_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def main():
parser.add_argument('--train_file', default='../data/grammar/train_sharegpt.jsonl', type=str, help='Train file')
parser.add_argument('--test_file', default='../data/grammar/test_sharegpt.jsonl', type=str, help='Test file')
parser.add_argument('--model_type', default='chatglm', type=str, help='Transformers model type')
parser.add_argument('--model_name', default='THUDM/chatglm-6b', type=str,
parser.add_argument('--model_name', default='THUDM/chatglm3-6b', type=str,
help='Transformers model or path')
parser.add_argument('--do_train', action='store_true', help='Whether to run training.')
parser.add_argument('--do_predict', action='store_true', help='Whether to run predict.')
Expand Down Expand Up @@ -68,7 +68,9 @@ def main():
peft_name=args.output_dir,
args={'use_peft': True, 'eval_batch_size': args.batch_size, "max_length": args.max_length, }
)
result = m.correct_batch(error_sentences)
result = m.correct_batch(error_sentences,
prefix_prompt="对这个句子语法纠错\n\n",
prompt_template_name=args.prompt_template_name)
for res_dict in result:
print(res_dict)

Expand Down

0 comments on commit 6b7bc84

Please sign in to comment.