From 83525eedf33ae331dbea7480bc91188250545ba3 Mon Sep 17 00:00:00 2001 From: Ondrej Cifka Date: Wed, 20 Jun 2018 23:19:34 +0200 Subject: [PATCH 1/5] Add dump_sentences.py --- examples/dump_sentences.py | 71 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 71 insertions(+) create mode 100755 examples/dump_sentences.py diff --git a/examples/dump_sentences.py b/examples/dump_sentences.py new file mode 100755 index 00000000..298ade7a --- /dev/null +++ b/examples/dump_sentences.py @@ -0,0 +1,71 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +""" +A script to dump all sentences (tokenized) to standard output. +""" + +from __future__ import absolute_import, division, unicode_literals + +import argparse +import logging +import sys + +# Set PATHs +PATH_TO_SENTEVAL = '../' +PATH_TO_DATA = '../data' + +sys.path.insert(0, PATH_TO_SENTEVAL) +import senteval + + +def main(): + logging.basicConfig(format='%(asctime)s : %(message)s', + level=logging.DEBUG) + + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("-t", "--tasks", + help="a comma-separated list of tasks") + args = parser.parse_args() + + def prepare(params, samples): + for sent in samples: + if sys.version_info < (3, 0): + sent = [w.decode('utf-8') if isinstance(w, str) else w for w in sent] + print(' '.join(sent).encode('utf-8')) + else: + sent = [w.decode('utf-8') if isinstance(w, bytes) else w for w in sent] + print(' '.join(sent)) + + def batcher(params, batch): + # Block evaluation and continue with the next task. + raise Done + + params_senteval = { + 'task_path': PATH_TO_DATA + } + + se = senteval.engine.SE(params_senteval, batcher, prepare) + if args.tasks is not None: + transfer_tasks = args.tasks.split(',') + else: + transfer_tasks = se.list_tasks + + for task in transfer_tasks: + try: + se.eval([task]) + raise RuntimeError(task + " not completed") + except Done: + pass + + +class Done(Exception): + pass + + +if __name__ == "__main__": + main() From 2549fc7db9f976b4d5677d67fc1f1e2c675103d6 Mon Sep 17 00:00:00 2001 From: Ondrej Cifka Date: Thu, 21 Jun 2018 00:34:03 +0200 Subject: [PATCH 2/5] Add eval_saved.py --- examples/eval_saved.py | 91 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 91 insertions(+) create mode 100755 examples/eval_saved.py diff --git a/examples/eval_saved.py b/examples/eval_saved.py new file mode 100755 index 00000000..195e817c --- /dev/null +++ b/examples/eval_saved.py @@ -0,0 +1,91 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +""" +A script to run SentEval on pre-computed embeddings from a file. +""" + +from __future__ import absolute_import, division, unicode_literals + +import argparse +import json +import logging +import sys + +import numpy as np + +# Set PATHs +PATH_TO_SENTEVAL = '../' +PATH_TO_DATA = '../data' + +sys.path.insert(0, PATH_TO_SENTEVAL) +import senteval + + +def main(): + logging.basicConfig(format='%(asctime)s : %(message)s', + level=logging.DEBUG) + + parser = argparse.ArgumentParser() + parser.add_argument('sentences', + help='a text file containing all SentEval sentences') + parser.add_argument('embeddings', + help='a NumPy binary file containing the corresponding embeddings') + parser.add_argument('-t', '--tasks', + help='a comma-separated list of tasks') + args = parser.parse_args() + + sent2emb = {} + + def join_sentence(sent): + if sys.version_info < (3, 0): + sent = [w.decode('utf-8') if isinstance(w, str) else w for w in sent] + else: + sent = [w.decode('utf-8') if isinstance(w, bytes) else w for w in sent] + return ' '.join(sent) + + def prepare(params, samples): + # Build the mapping from sentences to embeddings + sent2emb.clear() + samples_set = set(join_sentence(sent) for sent in samples) + all_embeddings = np.load(args.embeddings, mmap_mode='r') + with open(args.sentences) as f_sent: + for i, sent in enumerate(f_sent): + if sys.version_info < (3, 0): + sent = sent.decode('utf-8') + sent = sent.rstrip('\n') + if sent in samples_set: + sent2emb[sent] = all_embeddings[i] + + def batcher(params, batch): + embeddings = np.stack( + [sent2emb[join_sentence(sent)] for sent in batch]) + if len(embeddings.shape) != 2: + embeddings = embeddings.reshape(len(embeddings), -1) + assert len(embeddings.shape) == 2 + return embeddings + + params_senteval = { + 'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 10 + } + params_senteval['classifier'] = { + 'nhid': 0, 'optim': 'adam', 'batch_size': 64, 'tenacity': 5, + 'epoch_size': 4 + } + + se = senteval.engine.SE(params_senteval, batcher, prepare) + if args.tasks is not None: + transfer_tasks = args.tasks.split(',') + else: + transfer_tasks = se.list_tasks + + results = se.eval(transfer_tasks) + json.dump(results, sys.stdout, skipkeys=True) + + +if __name__ == '__main__': + main() From 20815ad7cc3055ae8ebe94be5251c2db50eae820 Mon Sep 17 00:00:00 2001 From: Ondrej Cifka Date: Tue, 28 Aug 2018 19:18:26 +0200 Subject: [PATCH 3/5] Path relative to script file --- examples/dump_sentences.py | 5 +++-- examples/eval_saved.py | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/examples/dump_sentences.py b/examples/dump_sentences.py index 298ade7a..cecdc582 100755 --- a/examples/dump_sentences.py +++ b/examples/dump_sentences.py @@ -13,11 +13,12 @@ import argparse import logging +import os import sys # Set PATHs -PATH_TO_SENTEVAL = '../' -PATH_TO_DATA = '../data' +PATH_TO_SENTEVAL = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +PATH_TO_DATA = os.path.join(PATH_TO_SENTEVAL, 'data') sys.path.insert(0, PATH_TO_SENTEVAL) import senteval diff --git a/examples/eval_saved.py b/examples/eval_saved.py index 195e817c..0a0f5f57 100755 --- a/examples/eval_saved.py +++ b/examples/eval_saved.py @@ -14,13 +14,14 @@ import argparse import json import logging +import os import sys import numpy as np # Set PATHs -PATH_TO_SENTEVAL = '../' -PATH_TO_DATA = '../data' +PATH_TO_SENTEVAL = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +PATH_TO_DATA = os.path.join(PATH_TO_SENTEVAL, 'data') sys.path.insert(0, PATH_TO_SENTEVAL) import senteval From 53e4a0f637b4feacf8a7144e4daddf018af16c63 Mon Sep 17 00:00:00 2001 From: Ondrej Cifka Date: Sun, 12 Aug 2018 14:08:43 +0200 Subject: [PATCH 4/5] Add --no-gpu flag --- examples/eval_saved.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/eval_saved.py b/examples/eval_saved.py index 0a0f5f57..d9f52cd3 100755 --- a/examples/eval_saved.py +++ b/examples/eval_saved.py @@ -38,6 +38,8 @@ def main(): help='a NumPy binary file containing the corresponding embeddings') parser.add_argument('-t', '--tasks', help='a comma-separated list of tasks') + parser.add_argument('--no-gpu', action='store_true', + help='do not use GPU (turn off PyTorch)') args = parser.parse_args() sent2emb = {} @@ -71,7 +73,7 @@ def batcher(params, batch): return embeddings params_senteval = { - 'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 10 + 'task_path': PATH_TO_DATA, 'usepytorch': not args.no_gpu, 'kfold': 10 } params_senteval['classifier'] = { 'nhid': 0, 'optim': 'adam', 'batch_size': 64, 'tenacity': 5, From ee22ac37f7db0f8c69278bcc01b9fc7c2cbe5305 Mon Sep 17 00:00:00 2001 From: Ondrej Cifka Date: Sat, 1 Sep 2018 11:22:54 +0200 Subject: [PATCH 5/5] End JSON with newline --- examples/eval_saved.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/eval_saved.py b/examples/eval_saved.py index d9f52cd3..fbe21358 100755 --- a/examples/eval_saved.py +++ b/examples/eval_saved.py @@ -88,6 +88,7 @@ def batcher(params, batch): results = se.eval(transfer_tasks) json.dump(results, sys.stdout, skipkeys=True) + sys.stdout.write('\n') if __name__ == '__main__':