-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathretrieve.py
109 lines (85 loc) · 3.71 KB
/
retrieve.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
# https://www.tensorflow.org/recommenders/examples/quickstart
# https://www.tensorflow.org/recommenders/examples/basic_retrieval
import tensorflow_recommenders as tfrs
from typing import Dict, Text
import tempfile
import os
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
# Env Var: https://cloud.google.com/vertex-ai/docs/training/code-requirements#environment-variables
MODEL_DIR = os.getenv("AIP_MODEL_DIR", tempfile.mkdtemp()) # you can write /gcs/<bucket>/<path> if you want to save the model to GCS
MODEL_VERSION = os.getenv("MODEL_VERSION", "1")
# if MODEL_DIR.startswith("gs://"):
# MODEL_DIR = os.path.join(MODEL_DIR, MODEL_VERSION)
CHECKPOINT_DIR = os.path.join("AIP_CHECKPOINT_DIR", tempfile.mkdtemp())
TENSORBOARD_LOG_DIR = os.path.join("AIP_TENSORBOARD_LOG_DIR", tempfile.mkdtemp())
# Read data
ratings = tfds.load("movielens/100k-ratings", split="train")
# Features of all the available movies.
movies = tfds.load("movielens/100k-movies", split="train")
ratings = ratings.map(lambda x: {
"movie_title": x["movie_title"],
"user_id": x["user_id"]
})
movies = movies.map(lambda x: x["movie_title"]) # MapDataset で各ElementはTensor
user_ids_vocabulary = tf.keras.layers.StringLookup(mask_token=None) # mapping for known user_ids
user_ids_vocabulary.adapt(ratings.map(lambda x: x["user_id"]))
movie_titles_vocabulary = tf.keras.layers.StringLookup(mask_token=None)
movie_titles_vocabulary.adapt(movies)
class MovieLensModel(tfrs.Model):
# We derive from a custom base class to help reduce boilerplate. Under the hood,
# these are still plain Keras Models.
def __init__(
self,
user_model: tf.keras.Model,
movie_model: tf.keras.Model,
task: tfrs.tasks.Retrieval):
super().__init__()
# Set up user and movie representations.
self.user_model = user_model
self.movie_model = movie_model
# Set up a retrieval task.
self.task = task
def compute_loss(self, features: Dict[Text, tf.Tensor], training=False) -> tf.Tensor:
# Define how the loss is computed.
user_embeddings = self.user_model(features["user_id"])
movie_embeddings = self.movie_model(features["movie_title"])
return self.task(user_embeddings, movie_embeddings)
# Define user (user_id) and movie (movie_title) models.
user_model = tf.keras.Sequential([
user_ids_vocabulary, # user_id -> integer に変換
tf.keras.layers.Embedding(user_ids_vocabulary.vocabulary_size(), 64) # integer -> embedding vector に変換
])
movie_model = tf.keras.Sequential([
movie_titles_vocabulary,
tf.keras.layers.Embedding(movie_titles_vocabulary.vocabulary_size(), 64)
])
# Define your objectives.
task = tfrs.tasks.Retrieval(metrics=tfrs.metrics.FactorizedTopK(
movies.batch(128).map(movie_model)
)
)
# Create a retrieval model.
model = MovieLensModel(user_model, movie_model, task)
model.compile(optimizer=tf.keras.optimizers.Adagrad(0.5))
# Train for 3 epochs.
model.fit(ratings.batch(4096), epochs=3)
# !pip install -q scann
is_scann = False
try:
index = tfrs.layers.factorized_top_k.ScaNN(model.user_model)
index.index_from_dataset(
tf.data.Dataset.zip((movies.batch(100), movies.batch(100).map(model.movie_model)))
)
is_scann = True
except: # noqa E722
# Use brute-force search to set up retrieval using the trained representations.
index = tfrs.layers.factorized_top_k.BruteForce(model.user_model)
index.index_from_dataset(
movies.batch(100).map(lambda title: (title, model.movie_model(title))))
# Get recommendations.
_, titles = index(np.array(["42"]))
print(f"Top 3 recommendations for user 42: {titles[0, :3]}")
index.save(MODEL_DIR, options=tf.saved_model.SaveOptions(namespace_whitelist=["Scann"]) if is_scann else None)
print(f"Model saved to {MODEL_DIR}")