Skip to content

Commit

Permalink
update macbert model saver.
Browse files Browse the repository at this point in the history
  • Loading branch information
shibing624 committed Jan 15, 2024
1 parent 6fea687 commit 3992519
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
4 changes: 2 additions & 2 deletions examples/macbert/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import torch
from loguru import logger
from pytorch_lightning.callbacks import ModelCheckpoint
from transformers import BertTokenizer
from transformers import BertTokenizerFast

sys.path.append('../..')

Expand Down Expand Up @@ -54,7 +54,7 @@ def args_parse(config_file=''):
def main():
cfg = args_parse()
logger.info(f'load model, model arch: {cfg.MODEL.NAME}')
tokenizer = BertTokenizer.from_pretrained(cfg.MODEL.BERT_CKPT)
tokenizer = BertTokenizerFast.from_pretrained(cfg.MODEL.BERT_CKPT)
collator = DataCollator(tokenizer=tokenizer)
# 加载数据
train_loader, valid_loader, test_loader = make_loaders(
Expand Down
4 changes: 2 additions & 2 deletions pycorrector/macbert/macbert_corrector.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import torch
from loguru import logger
from tqdm import tqdm
from transformers import BertTokenizer, BertForMaskedLM
from transformers import BertTokenizerFast, BertForMaskedLM

sys.path.append('../..')
from pycorrector.utils.tokenizer import split_text_into_sentences_by_length
Expand All @@ -24,7 +24,7 @@
class MacBertCorrector:
def __init__(self, model_name_or_path="shibing624/macbert4csc-base-chinese"):
t1 = time.time()
self.tokenizer = BertTokenizer.from_pretrained(model_name_or_path)
self.tokenizer = BertTokenizerFast.from_pretrained(model_name_or_path)
self.model = BertForMaskedLM.from_pretrained(model_name_or_path)
self.model.to(device)
logger.debug("Use device: {}".format(device))
Expand Down
4 changes: 2 additions & 2 deletions pycorrector/macbert/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from transformers import BertTokenizer
from transformers import BertTokenizerFast


class DataCollator:
def __init__(self, tokenizer: BertTokenizer):
def __init__(self, tokenizer: BertTokenizerFast):
self.tokenizer = tokenizer

def __call__(self, data):
Expand Down

0 comments on commit 3992519

Please sign in to comment.