forked from tensorflow/models
-
Notifications
You must be signed in to change notification settings - Fork 4
/
movielens_test.py
117 lines (92 loc) · 3.41 KB
/
movielens_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.datasets import movielens
from official.utils.testing import integration
from official.wide_deep import movielens_dataset
from official.wide_deep import movielens_main
from official.wide_deep import wide_deep_run_loop
tf.logging.set_verbosity(tf.logging.ERROR)
TEST_INPUT_VALUES = {
"genres": np.array(
[0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
"user_id": [3],
"item_id": [4],
}
TEST_ITEM_DATA = """item_id,titles,genres
1,Movie_1,Comedy|Romance
2,Movie_2,Adventure|Children's
3,Movie_3,Comedy|Drama
4,Movie_4,Comedy
5,Movie_5,Action|Crime|Thriller
6,Movie_6,Action
7,Movie_7,Action|Adventure|Thriller"""
TEST_RATING_DATA = """user_id,item_id,rating,timestamp
1,2,5,978300760
1,3,3,978302109
1,6,3,978301968
2,1,4,978300275
2,7,5,978824291
3,1,3,978302268
3,4,5,978302039
3,5,5,978300719
"""
class BaseTest(tf.test.TestCase):
"""Tests for Wide Deep model."""
@classmethod
def setUpClass(cls): # pylint: disable=invalid-name
super(BaseTest, cls).setUpClass()
movielens_main.define_movie_flags()
def setUp(self):
# Create temporary CSV file
self.temp_dir = self.get_temp_dir()
tf.gfile.MakeDirs(os.path.join(self.temp_dir, movielens.ML_1M))
self.ratings_csv = os.path.join(
self.temp_dir, movielens.ML_1M, movielens.RATINGS_FILE)
self.item_csv = os.path.join(
self.temp_dir, movielens.ML_1M, movielens.MOVIES_FILE)
with tf.gfile.Open(self.ratings_csv, "w") as f:
f.write(TEST_RATING_DATA)
with tf.gfile.Open(self.item_csv, "w") as f:
f.write(TEST_ITEM_DATA)
def test_input_fn(self):
train_input_fn, _, _ = movielens_dataset.construct_input_fns(
dataset=movielens.ML_1M, data_dir=self.temp_dir, batch_size=8, repeat=1)
dataset = train_input_fn()
features, labels = dataset.make_one_shot_iterator().get_next()
with self.test_session() as sess:
features, labels = sess.run((features, labels))
# Compare the two features dictionaries.
for key in TEST_INPUT_VALUES:
self.assertTrue(key in features)
self.assertAllClose(TEST_INPUT_VALUES[key], features[key][0])
self.assertAllClose(labels[0], [1.0])
def test_end_to_end_deep(self):
integration.run_synthetic(
main=movielens_main.main, tmp_root=self.temp_dir,
extra_flags=[
"--data_dir", self.temp_dir,
"--download_if_missing=false",
"--train_epochs", "1",
"--epochs_between_evals", "1"
],
synth=False, max_train=None)
if __name__ == "__main__":
tf.test.main()