Skip to content

Commit

Permalink
update macbert model saver.
Browse files Browse the repository at this point in the history
  • Loading branch information
shibing624 committed Jan 15, 2024
1 parent c2135c4 commit b8bd222
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 37 deletions.
27 changes: 9 additions & 18 deletions examples/macbert/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,12 @@
import argparse
import os
import sys
from collections import OrderedDict

import pytorch_lightning as pl
import torch
from loguru import logger
from pytorch_lightning.callbacks import ModelCheckpoint
from transformers import BertTokenizerFast, BertForMaskedLM
from transformers import BertTokenizer

sys.path.append('../..')

Expand Down Expand Up @@ -55,7 +54,7 @@ def args_parse(config_file=''):
def main():
cfg = args_parse()
logger.info(f'load model, model arch: {cfg.MODEL.NAME}')
tokenizer = BertTokenizerFast.from_pretrained(cfg.MODEL.BERT_CKPT)
tokenizer = BertTokenizer.from_pretrained(cfg.MODEL.BERT_CKPT)
collator = DataCollator(tokenizer=tokenizer)
# 加载数据
train_loader, valid_loader, test_loader = make_loaders(
Expand Down Expand Up @@ -106,27 +105,19 @@ def main():
# 模型转为transformers可加载
if ckpt_callback and len(ckpt_callback.best_model_path) > 0:
ckpt_path = ckpt_callback.best_model_path
elif cfg.MODEL.WEIGHTS and os.path.exists(cfg.MODEL.WEIGHTS):
ckpt_path = cfg.MODEL.WEIGHTS
else:
ckpt_path = ''
logger.info(f'ckpt_path: {ckpt_path}')
if ckpt_path and os.path.exists(ckpt_path):
model.load_state_dict(torch.load(ckpt_path)['state_dict'])
# 先保存原始transformer bert model
tokenizer.save_pretrained(cfg.OUTPUT_DIR)
bert = BertForMaskedLM.from_pretrained(cfg.MODEL.BERT_CKPT)
bert.save_pretrained(cfg.OUTPUT_DIR)
state_dict = torch.load(ckpt_path)['state_dict']
new_state_dict = OrderedDict()
if cfg.MODEL.NAME in ['macbert4csc']:
for k, v in state_dict.items():
if k.startswith('bert.'):
new_state_dict[k[5:]] = v
if cfg.MODEL.NAME == 'softmaskedbert4csc':
m = SoftMaskedBert4Csc.load_from_checkpoint(ckpt_path)
else:
new_state_dict = state_dict
# 再保存finetune训练后的模型文件,替换原始的pytorch_model.bin
torch.save(new_state_dict, os.path.join(cfg.OUTPUT_DIR, 'pytorch_model.bin'))
m = MacBert4Csc.load_from_checkpoint(ckpt_path)
# 保存finetune训练后的模型文件pytorch_model.bin
pt_file = os.path.join(cfg.OUTPUT_DIR, 'pytorch_model.bin')
m.bert.save_pretrained(pt_file)
del m
# 进行测试的逻辑同训练
if 'test' in cfg.MODE and test_loader and len(test_loader) > 0:
trainer.test(model, test_loader)
Expand Down
4 changes: 2 additions & 2 deletions pycorrector/macbert/macbert_corrector.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import torch
from loguru import logger
from tqdm import tqdm
from transformers import BertTokenizerFast, BertForMaskedLM
from transformers import BertTokenizer, BertForMaskedLM

sys.path.append('../..')
from pycorrector.utils.tokenizer import split_text_into_sentences_by_length
Expand All @@ -24,7 +24,7 @@
class MacBertCorrector:
def __init__(self, model_name_or_path="shibing624/macbert4csc-base-chinese"):
t1 = time.time()
self.tokenizer = BertTokenizerFast.from_pretrained(model_name_or_path)
self.tokenizer = BertTokenizer.from_pretrained(model_name_or_path)
self.model = BertForMaskedLM.from_pretrained(model_name_or_path)
self.model.to(device)
logger.debug("Use device: {}".format(device))
Expand Down
41 changes: 24 additions & 17 deletions pycorrector/macbert/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,17 @@
@author:XuMing([email protected]), Abtion([email protected])
@description:
"""
import os
import json
import os

import torch
from torch.utils.data import Dataset
from transformers import BertTokenizerFast
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from transformers import BertTokenizer


class DataCollator:
def __init__(self, tokenizer: BertTokenizerFast):
def __init__(self, tokenizer: BertTokenizer):
self.tokenizer = tokenizer

def __call__(self, data):
Expand Down Expand Up @@ -48,21 +49,27 @@ def make_loaders(collate_fn, train_path='', valid_path='', test_path='',
batch_size=32, num_workers=4):
train_loader = None
if train_path and os.path.exists(train_path):
train_loader = DataLoader(CscDataset(train_path),
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
collate_fn=collate_fn)
train_loader = DataLoader(
CscDataset(train_path),
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
collate_fn=collate_fn
)
valid_loader = None
if valid_path and os.path.exists(valid_path):
valid_loader = DataLoader(CscDataset(valid_path),
batch_size=batch_size,
num_workers=num_workers,
collate_fn=collate_fn)
valid_loader = DataLoader(
CscDataset(valid_path),
batch_size=batch_size,
num_workers=num_workers,
collate_fn=collate_fn
)
test_loader = None
if test_path and os.path.exists(test_path):
test_loader = DataLoader(CscDataset(test_path),
batch_size=batch_size,
num_workers=num_workers,
collate_fn=collate_fn)
test_loader = DataLoader(
CscDataset(test_path),
batch_size=batch_size,
num_workers=num_workers,
collate_fn=collate_fn
)
return train_loader, valid_loader, test_loader

0 comments on commit b8bd222

Please sign in to comment.