From b8bd22206086422f23ebc3006da91c78020fe6d8 Mon Sep 17 00:00:00 2001 From: shibing624 Date: Mon, 15 Jan 2024 12:09:47 +0800 Subject: [PATCH] update macbert model saver. --- examples/macbert/train.py | 27 ++++++---------- pycorrector/macbert/macbert_corrector.py | 4 +-- pycorrector/macbert/reader.py | 41 ++++++++++++++---------- 3 files changed, 35 insertions(+), 37 deletions(-) diff --git a/examples/macbert/train.py b/examples/macbert/train.py index e70eb0e8..82264f58 100644 --- a/examples/macbert/train.py +++ b/examples/macbert/train.py @@ -6,13 +6,12 @@ import argparse import os import sys -from collections import OrderedDict import pytorch_lightning as pl import torch from loguru import logger from pytorch_lightning.callbacks import ModelCheckpoint -from transformers import BertTokenizerFast, BertForMaskedLM +from transformers import BertTokenizer sys.path.append('../..') @@ -55,7 +54,7 @@ def args_parse(config_file=''): def main(): cfg = args_parse() logger.info(f'load model, model arch: {cfg.MODEL.NAME}') - tokenizer = BertTokenizerFast.from_pretrained(cfg.MODEL.BERT_CKPT) + tokenizer = BertTokenizer.from_pretrained(cfg.MODEL.BERT_CKPT) collator = DataCollator(tokenizer=tokenizer) # 加载数据 train_loader, valid_loader, test_loader = make_loaders( @@ -106,27 +105,19 @@ def main(): # 模型转为transformers可加载 if ckpt_callback and len(ckpt_callback.best_model_path) > 0: ckpt_path = ckpt_callback.best_model_path - elif cfg.MODEL.WEIGHTS and os.path.exists(cfg.MODEL.WEIGHTS): - ckpt_path = cfg.MODEL.WEIGHTS else: ckpt_path = '' logger.info(f'ckpt_path: {ckpt_path}') if ckpt_path and os.path.exists(ckpt_path): - model.load_state_dict(torch.load(ckpt_path)['state_dict']) - # 先保存原始transformer bert model tokenizer.save_pretrained(cfg.OUTPUT_DIR) - bert = BertForMaskedLM.from_pretrained(cfg.MODEL.BERT_CKPT) - bert.save_pretrained(cfg.OUTPUT_DIR) - state_dict = torch.load(ckpt_path)['state_dict'] - new_state_dict = OrderedDict() - if cfg.MODEL.NAME in ['macbert4csc']: - for k, v in state_dict.items(): - if k.startswith('bert.'): - new_state_dict[k[5:]] = v + if cfg.MODEL.NAME == 'softmaskedbert4csc': + m = SoftMaskedBert4Csc.load_from_checkpoint(ckpt_path) else: - new_state_dict = state_dict - # 再保存finetune训练后的模型文件,替换原始的pytorch_model.bin - torch.save(new_state_dict, os.path.join(cfg.OUTPUT_DIR, 'pytorch_model.bin')) + m = MacBert4Csc.load_from_checkpoint(ckpt_path) + # 保存finetune训练后的模型文件pytorch_model.bin + pt_file = os.path.join(cfg.OUTPUT_DIR, 'pytorch_model.bin') + m.bert.save_pretrained(pt_file) + del m # 进行测试的逻辑同训练 if 'test' in cfg.MODE and test_loader and len(test_loader) > 0: trainer.test(model, test_loader) diff --git a/pycorrector/macbert/macbert_corrector.py b/pycorrector/macbert/macbert_corrector.py index 437c1ec0..b9d76814 100644 --- a/pycorrector/macbert/macbert_corrector.py +++ b/pycorrector/macbert/macbert_corrector.py @@ -11,7 +11,7 @@ import torch from loguru import logger from tqdm import tqdm -from transformers import BertTokenizerFast, BertForMaskedLM +from transformers import BertTokenizer, BertForMaskedLM sys.path.append('../..') from pycorrector.utils.tokenizer import split_text_into_sentences_by_length @@ -24,7 +24,7 @@ class MacBertCorrector: def __init__(self, model_name_or_path="shibing624/macbert4csc-base-chinese"): t1 = time.time() - self.tokenizer = BertTokenizerFast.from_pretrained(model_name_or_path) + self.tokenizer = BertTokenizer.from_pretrained(model_name_or_path) self.model = BertForMaskedLM.from_pretrained(model_name_or_path) self.model.to(device) logger.debug("Use device: {}".format(device)) diff --git a/pycorrector/macbert/reader.py b/pycorrector/macbert/reader.py index 5e47d8d8..edb38899 100644 --- a/pycorrector/macbert/reader.py +++ b/pycorrector/macbert/reader.py @@ -3,16 +3,17 @@ @author:XuMing(xuming624@qq.com), Abtion(abtion@outlook.com) @description: """ -import os import json +import os + import torch -from torch.utils.data import Dataset -from transformers import BertTokenizerFast from torch.utils.data import DataLoader +from torch.utils.data import Dataset +from transformers import BertTokenizer class DataCollator: - def __init__(self, tokenizer: BertTokenizerFast): + def __init__(self, tokenizer: BertTokenizer): self.tokenizer = tokenizer def __call__(self, data): @@ -48,21 +49,27 @@ def make_loaders(collate_fn, train_path='', valid_path='', test_path='', batch_size=32, num_workers=4): train_loader = None if train_path and os.path.exists(train_path): - train_loader = DataLoader(CscDataset(train_path), - batch_size=batch_size, - shuffle=False, - num_workers=num_workers, - collate_fn=collate_fn) + train_loader = DataLoader( + CscDataset(train_path), + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + collate_fn=collate_fn + ) valid_loader = None if valid_path and os.path.exists(valid_path): - valid_loader = DataLoader(CscDataset(valid_path), - batch_size=batch_size, - num_workers=num_workers, - collate_fn=collate_fn) + valid_loader = DataLoader( + CscDataset(valid_path), + batch_size=batch_size, + num_workers=num_workers, + collate_fn=collate_fn + ) test_loader = None if test_path and os.path.exists(test_path): - test_loader = DataLoader(CscDataset(test_path), - batch_size=batch_size, - num_workers=num_workers, - collate_fn=collate_fn) + test_loader = DataLoader( + CscDataset(test_path), + batch_size=batch_size, + num_workers=num_workers, + collate_fn=collate_fn + ) return train_loader, valid_loader, test_loader