Skip to content

Commit

Permalink
Refactor special tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
Stonesjtu committed Jul 10, 2018
1 parent 9e21cd9 commit 0a3bc40
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 8 deletions.
12 changes: 6 additions & 6 deletions data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm

from vocab import get_vocab
from vocab import get_vocab, BOS, EOS

def zero_padding(sentences, length):
"""
Expand Down Expand Up @@ -59,7 +59,7 @@ def tokenize(self, path):
def __getitem__(self, index):
raw_sentence = self.data[index]
# truncate the sequence length to maximum of BPTT
sentence = ['<s>'] + raw_sentence[:self.bptt] + ['</s>']
sentence = [BOS] + raw_sentence[:self.bptt] + [EOS]
return [self.vocab.word2idx[word] for word in sentence]

def __len__(self):
Expand All @@ -80,12 +80,12 @@ class ContLMDataset(LMDataset):
def tokenize(self, path):
"""Tokenizes a text file."""
assert os.path.exists(path)
# add the end of sentence token
EOS = ['</s>']
# add the start of sentence token
sentence_sep = [BOS]
with open(path, 'r') as f:
sentences = []
sentences = [BOS]
for sentence in tqdm(f, desc='Processing file: {}'.format(path)):
sentences += sentence.split() + EOS
sentences += sentence.split() + sentence_sep
# split into list of tokens
self.data = sentences

Expand Down
7 changes: 5 additions & 2 deletions vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
from tqdm import tqdm
logger = logging.getLogger(__name__)

UNK = '<unk>' # unknown word
BOS = '<s>' # sentence start
EOS = '</s>' # sentence end

def _default_unk_index():
return 0
Expand Down Expand Up @@ -64,7 +67,7 @@ def __init__(self, counter, max_size=None, min_freq=1):
self.freqs = counter
self.max_size = max_size
self.min_freq = min_freq
self.specials = ['<unk>', '<s>', '</s>']
self.specials = [UNK, BOS, EOS]
self.build()


Expand Down Expand Up @@ -179,7 +182,7 @@ def get_vocab(base_path, file_list, min_freq=1, force_recount=False, vocab_file=
full_path = os.path.join(base_path, filename)
for line in tqdm(open(full_path, 'r'), desc='Building vocabulary: '):
counter.update(line.split())
counter.update(['<s>', '</s>'])
counter.update([BOS, EOS])
vocab = Vocab(counter, min_freq=min_freq)
logger.debug('Refreshing vocabulary finished')

Expand Down

0 comments on commit 0a3bc40

Please sign in to comment.