forked from minhtriet/gas_market
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_event.py
executable file
·72 lines (63 loc) · 2.53 KB
/
train_event.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import argparse
import numpy as np
import spacy
import yaml
from nltk.corpus import stopwords
from sklearn.feature_extraction.text import CountVectorizer
from models import lstm_event
from util import data_generator, io
nlp = spacy.load('en_core_web_lg')
with open('config.yaml') as stream:
try:
config = yaml.load(stream)
except yaml.YAMLError as exc:
print(exc)
parser = argparse.ArgumentParser(description='data related parameters')
parser.add_argument('--stride', type=int, default=3)
parser.add_argument('--predict_length', type=int, default=3)
args = parser.parse_args()
window = config['window']
news = io.load_news(embed='none')
corpus = news.loc[:'2013-04-11', 0].values # 60%
vectorizer = CountVectorizer(binary=True, stop_words=stopwords.words('english'),
lowercase=True, min_df=3, max_df=0.9, max_features=128000)
x_train_onehot = vectorizer.fit_transform(corpus)
word2idx = {nlp(word)[0].lemma_: idx for idx, word in enumerate(vectorizer.get_feature_names())}
embeddings_index = np.zeros((len(vectorizer.get_feature_names()) + 1, 300))
for word, idx in word2idx.items():
embedding = nlp.vocab[word].vector
embeddings_index[idx] = embedding
word2idx = {word: idx for idx, word in enumerate(vectorizer.get_feature_names())}
x_train, x_test, y_train, y_test = data_generator.generate(window, future=True, news=False, train_percentage=0.6,
stride=5, # args.stride,
predict_length=3, # args.predict_length
)
lstm_event.train(x_train, y_train, x_test, y_test, time_steps=window, layer_shape=[128, 32],
learning_rate=0.0000001, epoch=5000, predict_length=args.predict_length,
embed=embeddings_index, words_per_news=15, wordindex=word2idx)
# >>> np.percentile(news[0].str.len(), 10)
# 38.0
# >>> np.percentile(news[0].str.len(), 20)
# 45.0
# >>> np.percentile(news[0].str.len(), 30)
# 54.0
# >>> np.percentile(news[0].str.len(), 40)
# 65.0
# >>> np.percentile(news[0].str.len(), 50)
# 79.5
# >>> np.percentile(news[0].str.len(), 60)
# 96.0
# >>> np.percentile(news[0].str.len(), 70)
# 121.0
# >>> np.percentile(news[0].str.len(), 80)
# 153.0
# >>> np.percentile(news[0].str.len(), 90)
# 202.0
# >>> np.percentile(news[0].str.len(), 100)
# 730.0
# >>> np.percentile(news[0].str.len(), 95)
# 251.0
# >>> np.percentile(news[0].str.len(), 97)
# 290.0
# >>> np.percentile(news[0].str.len(), 98)
# 315.0