Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

language_model_ptb example #281

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions examples/language_model_ptb/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
/simple-examples/
simple-examples.tgz
25 changes: 25 additions & 0 deletions examples/language_model_ptb/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Language Model on PTB #

This example builds an LSTM language model, and trains on PTB data. Model and training are described in
[(Zaremba, et. al.) Recurrent Neural Network Regularization](https://arxiv.org/pdf/1409.2329.pdf). This is a PyTorch implementation of the TensorFlow official PTB example in [tensorflow/models/rnn/ptb](https://github.com/tensorflow/models/tree/master/tutorials/rnn/ptb).

The example shows:

* Contruction of simple model, involving the `Embedder` and `RNN Decoder`.
* Use of Texar-PyTorch with external Python data pipeline ([ptb_reader.py](./ptb_reader.py)).
* Specification of various features of `train op`, like *gradient clipping* and *learning rate decay*.

## Usage ##

The following command trains a small-size model:

```
python lm_ptb.py --config config_small --data-path ./
```

Here:

* `--config` specifies the configuration file to use. E.g., the above use the configuration defined in [config_small.py](./config_small.py)
* `--data-path` specifies the directory containing PTB raw data (e.g., `ptb.train.txt`). If the data files do not exist, the program will automatically download, extract, and pre-process the data.

The model will begin training, and will evaluate on the validation data periodically, and evaluate on the test data after the training is done.
67 changes: 67 additions & 0 deletions examples/language_model_ptb/config_large.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright 2019 The Texar Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PTB LM large size config.
"""

init_scale = 0.04
num_epochs = 55
hidden_size = 1500
keep_prob = 0.35
batch_size = 20
num_steps = 35

cell = {
"type": "LSTMCell",
"kwargs": {
"num_units": hidden_size,
"forget_bias": 0.
},
"dropout": {
"output_keep_prob": keep_prob
},
"num_layers": 2
}

emb = {
"dim": hidden_size,
"dropout_rate": 1 - keep_prob,
"initializer": {
"type": "torch.nn.init.uniform_",
"kwargs": {
"a": -init_scale,
"b": init_scale,
}
},
}

opt = {
"optimizer": {
"type": "SGD",
"kwargs": {
"lr": 1.0
}
},
"gradient_clip": {
"type": "clip_grad_norm_",
"kwargs": {
"max_norm": 10.
}
},
"learning_rate_decay": {
"type": "ExponentialLR",
"kwargs": {
"gamma": 1. / 1.15,
},
}
}
67 changes: 67 additions & 0 deletions examples/language_model_ptb/config_medium.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright 2019 The Texar Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PTB LM medium size config.
"""

init_scale = 0.05
num_epochs = 39
hidden_size = 650
keep_prob = 0.5
batch_size = 20
num_steps = 35

cell = {
"type": "LSTMCell",
"kwargs": {
"num_units": hidden_size,
"forget_bias": 0.
},
"dropout": {
"output_keep_prob": keep_prob
},
"num_layers": 2
}

emb = {
"dim": hidden_size,
"dropout_rate": 1 - keep_prob,
"initializer": {
"type": "torch.nn.init.uniform_",
"kwargs": {
"a": -init_scale,
"b": init_scale,
}
},
}

opt = {
"optimizer": {
"type": "SGD",
"kwargs": {
"lr": 1.0
}
},
"gradient_clip": {
"type": "clip_grad_norm_",
"kwargs": {
"max_norm": 5.
}
},
"learning_rate_decay": {
"type": "ExponentialLR",
"kwargs": {
"gamma": 0.8,
},
}
}
67 changes: 67 additions & 0 deletions examples/language_model_ptb/config_small.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright 2019 The Texar Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PTB LM small size config.
"""

init_scale = 0.1
num_epochs = 13
hidden_size = 200
keep_prob = 1.0
batch_size = 20
num_steps = 20

cell = {
"type": "LSTMCell",
"kwargs": {
"num_units": hidden_size,
"forget_bias": 0.
},
"dropout": {
"output_keep_prob": keep_prob
},
"num_layers": 2
}

emb = {
"dim": hidden_size,
"dropout_rate": 1 - keep_prob,
"initializer": {
"type": "torch.nn.init.uniform_",
"kwargs": {
"a": -init_scale,
"b": init_scale,
}
},
}

opt = {
"optimizer": {
"type": "SGD",
"kwargs": {
"lr": 1.0
}
},
"gradient_clip": {
"type": "clip_grad_norm_",
"kwargs": {
"max_norm": 5.
}
},
"learning_rate_decay": {
"type": "ExponentialLR",
"kwargs": {
"gamma": 0.5,
},
}
}
149 changes: 149 additions & 0 deletions examples/language_model_ptb/lm_ptb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# Copyright 2019 The Texar Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Example for building the language model.
"""

from typing import Any, Dict, List, Tuple

import argparse
import importlib
import time

import torch
import torch.nn as nn

import texar.torch as tx

from ptb_reader import prepare_data, ptb_iterator

parser = argparse.ArgumentParser()
parser.add_argument(
'--data-path', type=str, default='./',
help="Directory containing PTB raw data (e.g., ptb.train.txt). "
"E.g., ./simple-examples/data. If not exists, "
"the directory will be created and PTB raw data will be downloaded.")
parser.add_argument(
'--config', type=str, default='config_small',
help='The config to use.')
args = parser.parse_args()

config: Any = importlib.import_module(args.config)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class PTBLanguageModel(nn.Module):

def __init__(self, vocab_size: int, hparams: Dict[str, Any]):
super().__init__()

self.embedder = tx.modules.WordEmbedder(
vocab_size=vocab_size, hparams=hparams['embedder'])
self.decoder = tx.modules.BasicRNNDecoder(
token_embedder=self.embedder,
input_size=config.hidden_size, vocab_size=vocab_size,
hparams=hparams['decoder'])

def forward(self, # type: ignore
inputs: torch.Tensor, targets: torch.Tensor,
state: List[Tuple[torch.Tensor]]):
outputs, final_state, seq_lengths = self.decoder(
decoding_strategy="train_greedy",
impute_finished=False,
inputs=inputs,
sequence_length=torch.tensor(
[config.num_steps] * config.batch_size),
initial_state=state)

mle_loss = tx.losses.sequence_sparse_softmax_cross_entropy(
labels=targets,
logits=outputs.logits,
sequence_length=seq_lengths)

return mle_loss, final_state


def main() -> None:
# Data
batch_size = config.batch_size
num_steps = config.num_steps
num_epochs = config.num_epochs

data = prepare_data(args.data_path)

hparams = {
'embedder': config.emb,
'decoder': {"rnn_cell": config.cell},
}
model = PTBLanguageModel(vocab_size=data["vocab_size"], hparams=hparams)
model.to(device)
train_op = tx.core.get_train_op(params=model.parameters(),
hparams=config.opt)

def _run_epoch(data_iter, is_train=False, verbose=False):
start_time = time.time()
loss = 0.
iters = 0
state = None

if is_train:
model.train()
epoch_size = ((len(data["train_text_id"]) // batch_size - 1) //
num_steps)
else:
model.eval()

for step, (x, y) in enumerate(data_iter):
loss_, state_ = model(inputs=torch.tensor(x),
targets=torch.tensor(y), state=state)
if is_train:
loss_.backward()
train_op()
loss += loss_
state = [(state_[0][0].detach(), state_[0][1].detach()),
(state_[1][0].detach(), state_[1][1].detach())]
iters += num_steps

ppl = torch.exp(loss / iters)
if verbose and is_train and step % (epoch_size // 10) == 10:
print("%.3f perplexity: %.3f speed: %.0f wps" %
((step + 1) * 1.0 / epoch_size, ppl,
iters * batch_size / (time.time() - start_time)))

ppl = torch.exp(loss / iters)
return ppl

for epoch in range(num_epochs):
# Train
train_data_iter = ptb_iterator(
data["train_text_id"], batch_size, num_steps)
train_ppl = _run_epoch(train_data_iter, is_train=True, verbose=True)
print("Epoch: %d Train Perplexity: %.3f" % (epoch, train_ppl))
train_op(use_scheduler=True) # type: ignore

# Valid
valid_data_iter = ptb_iterator(
data["valid_text_id"], batch_size, num_steps)
valid_ppl = _run_epoch(valid_data_iter)
print("Epoch: %d Valid Perplexity: %.3f" % (epoch, valid_ppl))

# Test
test_data_iter = ptb_iterator(
data["test_text_id"], batch_size, num_steps)
test_ppl = _run_epoch(test_data_iter)
print("Test Perplexity: %.3f" % test_ppl)


if __name__ == '__main__':
main()
Loading