-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathtokens.py
39 lines (30 loc) · 1.4 KB
/
tokens.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# import modules & set up logging
import logging
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO)
import numpy as np
from elasticsearch import Elasticsearch
from elasticsearch.helpers import scan
__all__ = ['Tokens']
class Tokens():
es = Elasticsearch([{'host':'localhost','port':9200}])
es_logger = logging.getLogger('elasticsearch')
es_logger.setLevel(logging.WARNING)
def __init__(self, dataSource):
if (dataSource == 'twenty-news'):
self.esIndex = 'twenty-news'
elif (dataSource == 'acl-imdb'):
self.esIndex = 'acl-imdb'
def getTokens(self,tokenType, split=None):
X, y, classNames = [], [], set()
docType = 'article'
if (split):
query = { "query": { "term" : {"split" : split} }, "_source" : [tokenType, 'groupIndex', 'groupName'] }
else:
query = { "query": { "match_all" : {} }, "_source" : [tokenType, 'groupIndex', 'groupName'] }
hits = scan (self.es, query=query, index=self.esIndex, doc_type=docType, request_timeout=120)
for hit in hits:
X.append(hit['_source'][tokenType])
y.append(hit['_source']['groupIndex'])
classNames.add(hit['_source']['groupName'])
X=np.array([np.array(xi) for xi in X]) # rows: Docs. columns: words
return X, np.array(y), sorted(list(classNames))