Skip to content

Commit

Permalink
Add eval_saved.py
Browse files Browse the repository at this point in the history
  • Loading branch information
cifkao committed Aug 12, 2018
1 parent 2e01172 commit e59bc5c
Showing 1 changed file with 91 additions and 0 deletions.
91 changes: 91 additions & 0 deletions examples/eval_saved.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

"""
A script to run SentEval on pre-computed embeddings from a file.
"""

from __future__ import absolute_import, division, unicode_literals

import argparse
import json
import logging
import sys

import numpy as np

# Set PATHs
PATH_TO_SENTEVAL = '../'
PATH_TO_DATA = '../data'

sys.path.insert(0, PATH_TO_SENTEVAL)
import senteval


def main():
logging.basicConfig(format='%(asctime)s : %(message)s',
level=logging.DEBUG)

parser = argparse.ArgumentParser()
parser.add_argument('sentences',
help='a text file containing all SentEval sentences')
parser.add_argument('embeddings',
help='a NumPy binary file containing the corresponding embeddings')
parser.add_argument('-t', '--tasks',
help='a comma-separated list of tasks')
args = parser.parse_args()

sent2emb = {}

def join_sentence(sent):
if sys.version_info < (3, 0):
sent = [w.decode('utf-8') if isinstance(w, str) else w for w in sent]
else:
sent = [w.decode('utf-8') if isinstance(w, bytes) else w for w in sent]
return ' '.join(sent)

def prepare(params, samples):
# Build the mapping from sentences to embeddings
sent2emb.clear()
samples_set = set(join_sentence(sent) for sent in samples)
all_embeddings = np.load(args.embeddings, mmap_mode='r')
with open(args.sentences) as f_sent:
for i, sent in enumerate(f_sent):
if sys.version_info < (3, 0):
sent = sent.decode('utf-8')
sent = sent.rstrip('\n')
if sent in samples_set:
sent2emb[sent] = all_embeddings[i]

def batcher(params, batch):
embeddings = np.stack(
[sent2emb[join_sentence(sent)] for sent in batch])
if len(embeddings.shape) != 2:
embeddings = embeddings.reshape(len(embeddings), -1)
assert len(embeddings.shape) == 2
return embeddings

params_senteval = {
'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 10
}
params_senteval['classifier'] = {
'nhid': 0, 'optim': 'adam', 'batch_size': 64, 'tenacity': 5,
'epoch_size': 4
}

se = senteval.engine.SE(params_senteval, batcher, prepare)
if args.tasks is not None:
transfer_tasks = args.tasks.split(',')
else:
transfer_tasks = se.list_tasks

results = se.eval(transfer_tasks)
json.dump(results, sys.stdout, skipkeys=True)


if __name__ == '__main__':
main()

0 comments on commit e59bc5c

Please sign in to comment.