forked from google-deepmind/deepmind-research
-
Notifications
You must be signed in to change notification settings - Fork 0
/
discriminator_nets.py
121 lines (109 loc) · 4.85 KB
/
discriminator_nets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# Copyright 2019 DeepMind Technologies Limited and Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Discriminator networks for text data."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sonnet as snt
import tensorflow.compat.v1 as tf
from scratchgan import utils
class LSTMEmbedDiscNet(snt.AbstractModule):
"""An LSTM discriminator that operates on word indexes."""
def __init__(self,
feature_sizes,
vocab_size,
use_layer_norm,
trainable_embedding_size,
dropout,
pad_token,
embedding_source=None,
vocab_file=None,
name='LSTMEmbedDiscNet'):
super(LSTMEmbedDiscNet, self).__init__(name=name)
self._feature_sizes = feature_sizes
self._vocab_size = vocab_size
self._use_layer_norm = use_layer_norm
self._trainable_embedding_size = trainable_embedding_size
self._embedding_source = embedding_source
self._vocab_file = vocab_file
self._dropout = dropout
self._pad_token = pad_token
if self._embedding_source:
assert vocab_file
def _build(self, sequence, sequence_length, is_training=True):
"""Connect to the graph.
Args:
sequence: A [batch_size, max_sequence_length] tensor of int. For example
the indices of words as sampled by the generator.
sequence_length: A [batch_size] tensor of int. Length of the sequence.
is_training: Boolean, False to disable dropout.
Returns:
A [batch_size, max_sequence_length, feature_size] tensor of floats. For
each sequence in the batch, the features should (hopefully) allow to
distinguish if the value at each timestep is real or generated.
"""
batch_size, max_sequence_length = sequence.shape.as_list()
keep_prob = (1.0 - self._dropout) if is_training else 1.0
if self._embedding_source:
all_embeddings = utils.make_partially_trainable_embeddings(
self._vocab_file, self._embedding_source, self._vocab_size,
self._trainable_embedding_size)
else:
all_embeddings = tf.get_variable(
'trainable_embedding',
shape=[self._vocab_size, self._trainable_embedding_size],
trainable=True)
_, self._embedding_size = all_embeddings.shape.as_list()
input_embeddings = tf.nn.dropout(all_embeddings, keep_prob=keep_prob)
embeddings = tf.nn.embedding_lookup(input_embeddings, sequence)
embeddings.shape.assert_is_compatible_with(
[batch_size, max_sequence_length, self._embedding_size])
position_dim = 8
embeddings_pos = utils.append_position_signal(embeddings, position_dim)
embeddings_pos = tf.reshape(
embeddings_pos,
[batch_size * max_sequence_length, self._embedding_size + position_dim])
lstm_inputs = snt.Linear(self._feature_sizes[0])(embeddings_pos)
lstm_inputs = tf.reshape(
lstm_inputs, [batch_size, max_sequence_length, self._feature_sizes[0]])
lstm_inputs.shape.assert_is_compatible_with(
[batch_size, max_sequence_length, self._feature_sizes[0]])
encoder_cells = []
for feature_size in self._feature_sizes:
encoder_cells += [
snt.LSTM(feature_size, use_layer_norm=self._use_layer_norm)
]
encoder_cell = snt.DeepRNN(encoder_cells)
initial_state = encoder_cell.initial_state(batch_size)
hidden_states, _ = tf.nn.dynamic_rnn(
cell=encoder_cell,
inputs=lstm_inputs,
sequence_length=sequence_length,
initial_state=initial_state,
swap_memory=True)
hidden_states.shape.assert_is_compatible_with(
[batch_size, max_sequence_length,
sum(self._feature_sizes)])
logits = snt.BatchApply(snt.Linear(1))(hidden_states)
logits.shape.assert_is_compatible_with([batch_size, max_sequence_length, 1])
logits_flat = tf.reshape(logits, [batch_size, max_sequence_length])
# Mask past first PAD symbol
#
# Note that we still rely on tf.nn.bidirectional_dynamic_rnn taking
# into account the sequence_length properly, because otherwise
# the logits at a given timestep will depend on the inputs for all other
# timesteps, including the ones that should be masked.
mask = utils.get_mask_past_symbol(sequence, self._pad_token)
masked_logits_flat = logits_flat * tf.cast(mask, tf.float32)
return masked_logits_flat