Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

Reduce memory usage on loading embedding from txt #191

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,12 +266,19 @@ def read_txt_embeddings(params, source, full_vocab):
Reload pretrained embeddings from a text file.
"""
word2id = {}
vectors = []

# load pretrained embeddings
lang = params.src_lang if source else params.tgt_lang
emb_path = params.src_emb if source else params.tgt_emb
_emb_dim_file = params.emb_dim

with io.open(emb_path, 'r', encoding='utf-8', newline='\n', errors='ignore') as f:
vocab_size = sum([1 for _ in f])-1

if not full_vocab and params.max_vocab<vocab_size:
embeddings = np.empty([params.max_vocab, params.emb_dim], dtype=np.float32)
else:
embeddings = np.empty([vocab_size, params.emb_dim], dtype=np.float32)

with io.open(emb_path, 'r', encoding='utf-8', newline='\n', errors='ignore') as f:
for i, line in enumerate(f):
if i == 0:
Expand All @@ -282,7 +289,7 @@ def read_txt_embeddings(params, source, full_vocab):
word, vect = line.rstrip().split(' ', 1)
if not full_vocab:
word = word.lower()
vect = np.fromstring(vect, sep=' ')
vect = np.fromstring(vect, sep=' ', dtype=np.float32)
if np.linalg.norm(vect) == 0: # avoid to have null embeddings
vect[0] = 0.01
if word in word2id:
Expand All @@ -296,17 +303,16 @@ def read_txt_embeddings(params, source, full_vocab):
continue
assert vect.shape == (_emb_dim_file,), i
word2id[word] = len(word2id)
vectors.append(vect[None])
embeddings[i-1] = vect
if params.max_vocab > 0 and len(word2id) >= params.max_vocab and not full_vocab:
break

assert len(word2id) == len(vectors)
logger.info("Loaded %i pre-trained word embeddings." % len(vectors))
assert len(word2id) == len(embeddings)
logger.info("Loaded %i pre-trained word embeddings." % len(embeddings))

# compute new vocabulary / embeddings
id2word = {v: k for k, v in word2id.items()}
dico = Dictionary(id2word, word2id, lang)
embeddings = np.concatenate(vectors, 0)
embeddings = torch.from_numpy(embeddings).float()
embeddings = embeddings.cuda() if (params.cuda and not full_vocab) else embeddings

Expand Down