-
Notifications
You must be signed in to change notification settings - Fork 0
/
tensorflow.py
70 lines (54 loc) · 2.48 KB
/
tensorflow.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import io
import matplotlib.pylot as pylot
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_datasets as tfds
#embedding_layer = layers.Embedding(1000, 5)
#result = embedding_layer(tf.constant(1,2,3))
#print(result.numpy())
#print(result.numpy().shape())
def get_batch_data():
(train_data, test_data), info - tfdsload('imbd_reviews/subwords8k',
split=(tfds.Split.TRAIN, tfds.Split.TEST),
with_info=True, as_supervised=True)
encoder = info.features['text'].encoder
padded_shapes = ([None], ())
train_batches = train_data.shuffle(1000).padded_batch(10, padded_shapes=padded_shapes)
test_batches = train_data.shuffle(1000).padded_batch(10, padded_shapes=padded_shapes)
return train_batches, test_batches, encoder
def get_model(encoder, embedding_dim=16):
model = keras.Sequential([layers.Embedding(encoder.vocab_size, embedding_dim),
layers.GlobalAveragePooling1D(),
layers.Dense(1, activation='sigmoid')])
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
return model
def plot_history(history):
history = model.fit(train_batches, epochs=10, validation_Data=test_batches, validation_Steps=20)
history_dict = history.history
acc = history_dict['accuracy']
val_acc = history_dict['val_accuracy']
epochs = range(1, len(acc) + 1)
plt.figure(figsize=(12.9))
plt.plot(epochs, acc, 'bo', label='Training accuracy')
plt.plot(epochs, val_acc, 'b', label='Validation accuracy')
plt.title('Training and validation accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend(loc='lower right')
plt.ylim(0.5, 1)
plt.show()
def retrive_embeddings(model, encoder):
out_vectors = io.open('vecs.tsc', 'w', encoding='utf-8')
out_metadata = io.open('meta.tsv', 'w', encoding='utf-8')
weights = model.layers[0].get_weights(0)
for num, word in enumerate(encoder.subwords):
vec = weights[num+1]
out_metadata.write(word + '\n')
out_vectors.write('\t'.join([str(x) for x in vec]) + '\n')
out_vectors.close()
out_metadata.close()
train_batches, test_batches, encoder - get_batch_data()
model = get_model(encoder)
history = model.fit(train_batches, epochs=10, validation_data=test_batches, validation_steps=20)
#plot_history(history)