From cbef31f625e15a20a2b027764cbe2d8818bd75a9 Mon Sep 17 00:00:00 2001 From: Pengzhi Gao Date: Wed, 18 Dec 2019 23:36:59 -0500 Subject: [PATCH 1/2] Add language_model_ptb example --- examples/language_model_ptb/.gitignore | 2 + examples/language_model_ptb/README.md | 25 ++++ examples/language_model_ptb/config_large.py | 67 +++++++++ examples/language_model_ptb/config_medium.py | 67 +++++++++ examples/language_model_ptb/config_small.py | 67 +++++++++ examples/language_model_ptb/lm_ptb.py | 149 +++++++++++++++++++ examples/language_model_ptb/ptb_reader.py | 81 ++++++++++ 7 files changed, 458 insertions(+) create mode 100644 examples/language_model_ptb/.gitignore create mode 100644 examples/language_model_ptb/README.md create mode 100644 examples/language_model_ptb/config_large.py create mode 100644 examples/language_model_ptb/config_medium.py create mode 100644 examples/language_model_ptb/config_small.py create mode 100644 examples/language_model_ptb/lm_ptb.py create mode 100644 examples/language_model_ptb/ptb_reader.py diff --git a/examples/language_model_ptb/.gitignore b/examples/language_model_ptb/.gitignore new file mode 100644 index 000000000..8ae7ce0aa --- /dev/null +++ b/examples/language_model_ptb/.gitignore @@ -0,0 +1,2 @@ +/simple-examples/ +simple-examples.tgz diff --git a/examples/language_model_ptb/README.md b/examples/language_model_ptb/README.md new file mode 100644 index 000000000..ebf737b50 --- /dev/null +++ b/examples/language_model_ptb/README.md @@ -0,0 +1,25 @@ +# Language Model on PTB # + +This example builds an LSTM language model, and trains on PTB data. Model and training are described in +[(Zaremba, et. al.) Recurrent Neural Network Regularization](https://arxiv.org/pdf/1409.2329.pdf). This is a PyTorch implementation of the TensorFlow official PTB example in [tensorflow/models/rnn/ptb](https://github.com/tensorflow/models/tree/master/tutorials/rnn/ptb). + +The example shows: + + * Contruction of simple model, involving the `Embedder` and `RNN Decoder`. + * Use of Texar-PyTorch with external Python data pipeline ([ptb_reader.py](./ptb_reader.py)). + * Specification of various features of `train op`, like *gradient clipping* and *learning rate decay*. + +## Usage ## + +The following command trains a small-size model: + +``` +python lm_ptb.py --config config_small --data-path ./ +``` + +Here: + + * `--config` specifies the configuration file to use. E.g., the above use the configuration defined in [config_small.py](./config_small.py) + * `--data-path` specifies the directory containing PTB raw data (e.g., `ptb.train.txt`). If the data files do not exist, the program will automatically download, extract, and pre-process the data. + +The model will begin training, and will evaluate on the validation data periodically, and evaluate on the test data after the training is done. diff --git a/examples/language_model_ptb/config_large.py b/examples/language_model_ptb/config_large.py new file mode 100644 index 000000000..91de3ebd2 --- /dev/null +++ b/examples/language_model_ptb/config_large.py @@ -0,0 +1,67 @@ +# Copyright 2019 The Texar Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PTB LM large size config. +""" + +init_scale = 0.04 +num_epochs = 55 +hidden_size = 1500 +keep_prob = 0.35 +batch_size = 20 +num_steps = 35 + +cell = { + "type": "LSTMCell", + "kwargs": { + "num_units": hidden_size, + "forget_bias": 0. + }, + "dropout": { + "output_keep_prob": keep_prob + }, + "num_layers": 2 +} + +emb = { + "dim": hidden_size, + "initializer": { + "type": "random_uniform_initializer", + "kwargs": { + "minval": -init_scale, + "maxval": init_scale, + "seed": None + } + }, +} + +opt = { + "optimizer": { + "type": "SGD", + "kwargs": { + "lr": 1.0 + } + }, + "gradient_clip": { + "type": "clip_grad_norm_", + "kwargs": { + "max_norm": 10. + } + }, + "learning_rate_decay": { + "type": "ExponentialLR", + "kwargs": { + "gamma": 1. / 1.15, + }, + } +} diff --git a/examples/language_model_ptb/config_medium.py b/examples/language_model_ptb/config_medium.py new file mode 100644 index 000000000..facdfa7a5 --- /dev/null +++ b/examples/language_model_ptb/config_medium.py @@ -0,0 +1,67 @@ +# Copyright 2019 The Texar Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PTB LM medium size config. +""" + +init_scale = 0.05 +num_epochs = 39 +hidden_size = 650 +keep_prob = 0.5 +batch_size = 20 +num_steps = 35 + +cell = { + "type": "LSTMCell", + "kwargs": { + "num_units": hidden_size, + "forget_bias": 0. + }, + "dropout": { + "output_keep_prob": keep_prob + }, + "num_layers": 2 +} + +emb = { + "dim": hidden_size, + "initializer": { + "type": "random_uniform_initializer", + "kwargs": { + "minval": -init_scale, + "maxval": init_scale, + "seed": None + } + }, +} + +opt = { + "optimizer": { + "type": "SGD", + "kwargs": { + "lr": 1.0 + } + }, + "gradient_clip": { + "type": "clip_grad_norm_", + "kwargs": { + "max_norm": 5. + } + }, + "learning_rate_decay": { + "type": "ExponentialLR", + "kwargs": { + "gamma": 0.8, + }, + } +} diff --git a/examples/language_model_ptb/config_small.py b/examples/language_model_ptb/config_small.py new file mode 100644 index 000000000..4588d34e6 --- /dev/null +++ b/examples/language_model_ptb/config_small.py @@ -0,0 +1,67 @@ +# Copyright 2019 The Texar Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PTB LM small size config. +""" + +init_scale = 0.1 +num_epochs = 13 +hidden_size = 200 +keep_prob = 1.0 +batch_size = 20 +num_steps = 20 + +cell = { + "type": "LSTMCell", + "kwargs": { + "num_units": hidden_size, + "forget_bias": 0. + }, + "dropout": { + "output_keep_prob": keep_prob + }, + "num_layers": 2 +} + +emb = { + "dim": hidden_size, + "dropout_rate": 1 - keep_prob, + "initializer": { + "type": "torch.nn.init.uniform_", + "kwargs": { + "a": -init_scale, + "b": init_scale, + } + }, +} + +opt = { + "optimizer": { + "type": "SGD", + "kwargs": { + "lr": 1.0 + } + }, + "gradient_clip": { + "type": "clip_grad_norm_", + "kwargs": { + "max_norm": 5. + } + }, + "learning_rate_decay": { + "type": "ExponentialLR", + "kwargs": { + "gamma": 0.5, + }, + } +} diff --git a/examples/language_model_ptb/lm_ptb.py b/examples/language_model_ptb/lm_ptb.py new file mode 100644 index 000000000..ab558ac00 --- /dev/null +++ b/examples/language_model_ptb/lm_ptb.py @@ -0,0 +1,149 @@ +# Copyright 2019 The Texar Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Example for building the language model. +""" + +from typing import Any, Dict, Tuple + +import argparse +import importlib +import time + +import numpy as np + +import torch +import torch.nn as nn + +import texar.torch as tx + +from ptb_reader import prepare_data, ptb_iterator + +parser = argparse.ArgumentParser() +parser.add_argument( + '--data-path', type=str, default='./', + help="Directory containing PTB raw data (e.g., ptb.train.txt). " + "E.g., ./simple-examples/data. If not exists, " + "the directory will be created and PTB raw data will be downloaded.") +parser.add_argument( + '--config', type=str, default='config_small', + help='The config to use.') +args = parser.parse_args() + +config: Any = importlib.import_module(args.config) + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +class PTBLanguageModel(nn.Module): + + def __init__(self, vocab_size: int, hparams: Dict[str, Any]): + super().__init__() + + self.embedder = tx.modules.WordEmbedder( + vocab_size=vocab_size, hparams=hparams['embedder']) + self.decoder = tx.modules.BasicRNNDecoder( + token_embedder=self.embedder, + input_size=config.hidden_size, vocab_size=vocab_size, + hparams=hparams['decoder']) + + def forward(self, # type: ignore + inputs: torch.Tensor, targets: torch.Tensor, state) -> \ + Tuple[torch.Tensor, torch.Tensor]: + outputs, final_state, seq_lengths = self.decoder( + decoding_strategy="train_greedy", + impute_finished=False, + inputs=inputs, + sequence_length=torch.tensor([config.num_steps] * config.batch_size), + initial_state=state) + + # import pdb; pdb.set_trace() + + mle_loss = tx.losses.sequence_sparse_softmax_cross_entropy( + labels=targets, + logits=outputs.logits, + sequence_length=seq_lengths) + + return mle_loss, final_state + + +def main() -> None: + # Data + batch_size = config.batch_size + num_steps = config.num_steps + num_epochs = config.num_epochs + + data = prepare_data(args.data_path) + + hparams = { + 'embedder': config.emb, + 'decoder': {"rnn_cell": config.cell}, + } + model = PTBLanguageModel(vocab_size=data["vocab_size"], hparams=hparams) + model.to(device) + train_op = tx.core.get_train_op(params=model.parameters(), + hparams=config.opt) + + def _run_epoch(data_iter, is_train=False, verbose=False): + start_time = time.time() + loss = 0. + iters = 0 + state = None + + if is_train: + epoch_size = (len(data["train_text_id"]) // batch_size - 1) \ + // num_steps + + for step, (x, y) in enumerate(data_iter): + + print("Hello, world!") + + loss_, state_ = model(inputs=torch.tensor(x), + targets=torch.tensor(y), state=state) + if is_train: + loss_.backward(retain_graph=True) + train_op() + loss += loss_ + state = state_ + iters += num_steps + + ppl = torch.exp(loss / iters) + if verbose and is_train and step % (epoch_size // 10) == 10: + print("%.3f perplexity: %.3f speed: %.0f wps" % + ((step + 1) * 1.0 / epoch_size, ppl, + iters * batch_size / (time.time() - start_time))) + + ppl = np.exp(loss / iters) + return ppl + + for epoch in range(num_epochs): + # Train + train_data_iter = ptb_iterator( + data["train_text_id"], batch_size, num_steps) + train_ppl = _run_epoch(train_data_iter, is_train=True, verbose=True) + print("Epoch: %d Train Perplexity: %.3f" % (epoch, train_ppl)) + # Valid + valid_data_iter = ptb_iterator( + data["valid_text_id"], batch_size, num_steps) + valid_ppl = _run_epoch(valid_data_iter) + print("Epoch: %d Valid Perplexity: %.3f" % (epoch, valid_ppl)) + + # Test + test_data_iter = ptb_iterator( + data["test_text_id"], batch_size, num_steps) + test_ppl = _run_epoch(test_data_iter) + print("Test Perplexity: %.3f" % test_ppl) + + +if __name__ == '__main__': + main() diff --git a/examples/language_model_ptb/ptb_reader.py b/examples/language_model_ptb/ptb_reader.py new file mode 100644 index 000000000..cba9f70f4 --- /dev/null +++ b/examples/language_model_ptb/ptb_reader.py @@ -0,0 +1,81 @@ +# Copyright 2019 The Texar Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utilities for pre-processing and iterating over the PTB data. +""" + +import os +import numpy as np + +import texar.torch as tx + + +def ptb_iterator(data, batch_size, num_steps): + r"""Iterates through the ptb data. + """ + data_length = len(data) + batch_length = data_length // batch_size + + data = np.asarray(data[:batch_size * batch_length]) + data = data.reshape([batch_size, batch_length]) + + epoch_size = (batch_length - 1) // num_steps + if epoch_size == 0: + raise ValueError("epoch_size == 0, decrease batch_size or num_steps") + + for i in range(epoch_size): + x = data[:, i * num_steps: (i + 1) * num_steps] + y = data[:, i * num_steps + 1: (i + 1) * num_steps + 1] + yield (x, y) + + +def prepare_data(data_path): + r"""Pre-process PTB data. + """ + train_path = os.path.join(data_path, "ptb.train.txt") + if not os.path.exists(train_path): + url = 'http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz' + tx.data.maybe_download(url, data_path, extract=True) + data_path = os.path.join(data_path, 'simple-examples', 'data') + + train_path = os.path.join(data_path, "ptb.train.txt") + valid_path = os.path.join(data_path, "ptb.valid.txt") + test_path = os.path.join(data_path, "ptb.test.txt") + + word_to_id = tx.data.make_vocab( + train_path, newline_token="", return_type="dict") + assert len(word_to_id) == 10000 + + train_text = tx.data.read_words( + train_path, newline_token="") + train_text_id = [word_to_id[w] for w in train_text if w in word_to_id] + + valid_text = tx.data.read_words( + valid_path, newline_token="") + valid_text_id = [word_to_id[w] for w in valid_text if w in word_to_id] + + test_text = tx.data.read_words( + test_path, newline_token="") + test_text_id = [word_to_id[w] for w in test_text if w in word_to_id] + + data = { + "train_text": train_text, + "valid_text": valid_text, + "test_text": test_text, + "train_text_id": train_text_id, + "valid_text_id": valid_text_id, + "test_text_id": test_text_id, + "vocab": word_to_id, + "vocab_size": len(word_to_id) + } + return data From bfce6d2f672f23bb07f1a6ec029807ec5666b8bc Mon Sep 17 00:00:00 2001 From: Pengzhi Gao Date: Thu, 26 Dec 2019 16:28:55 -0500 Subject: [PATCH 2/2] Updates on language_model_ptb example --- examples/language_model_ptb/config_large.py | 8 ++--- examples/language_model_ptb/config_medium.py | 8 ++--- examples/language_model_ptb/lm_ptb.py | 32 ++++++++++---------- texar/torch/core/optimization.py | 17 ++++++----- 4 files changed, 33 insertions(+), 32 deletions(-) diff --git a/examples/language_model_ptb/config_large.py b/examples/language_model_ptb/config_large.py index 91de3ebd2..4a0b92ea2 100644 --- a/examples/language_model_ptb/config_large.py +++ b/examples/language_model_ptb/config_large.py @@ -35,12 +35,12 @@ emb = { "dim": hidden_size, + "dropout_rate": 1 - keep_prob, "initializer": { - "type": "random_uniform_initializer", + "type": "torch.nn.init.uniform_", "kwargs": { - "minval": -init_scale, - "maxval": init_scale, - "seed": None + "a": -init_scale, + "b": init_scale, } }, } diff --git a/examples/language_model_ptb/config_medium.py b/examples/language_model_ptb/config_medium.py index facdfa7a5..6f56bb270 100644 --- a/examples/language_model_ptb/config_medium.py +++ b/examples/language_model_ptb/config_medium.py @@ -35,12 +35,12 @@ emb = { "dim": hidden_size, + "dropout_rate": 1 - keep_prob, "initializer": { - "type": "random_uniform_initializer", + "type": "torch.nn.init.uniform_", "kwargs": { - "minval": -init_scale, - "maxval": init_scale, - "seed": None + "a": -init_scale, + "b": init_scale, } }, } diff --git a/examples/language_model_ptb/lm_ptb.py b/examples/language_model_ptb/lm_ptb.py index ab558ac00..dc971bdb4 100644 --- a/examples/language_model_ptb/lm_ptb.py +++ b/examples/language_model_ptb/lm_ptb.py @@ -14,14 +14,12 @@ """Example for building the language model. """ -from typing import Any, Dict, Tuple +from typing import Any, Dict, List, Tuple import argparse import importlib import time -import numpy as np - import torch import torch.nn as nn @@ -58,17 +56,16 @@ def __init__(self, vocab_size: int, hparams: Dict[str, Any]): hparams=hparams['decoder']) def forward(self, # type: ignore - inputs: torch.Tensor, targets: torch.Tensor, state) -> \ - Tuple[torch.Tensor, torch.Tensor]: + inputs: torch.Tensor, targets: torch.Tensor, + state: List[Tuple[torch.Tensor]]): outputs, final_state, seq_lengths = self.decoder( decoding_strategy="train_greedy", impute_finished=False, inputs=inputs, - sequence_length=torch.tensor([config.num_steps] * config.batch_size), + sequence_length=torch.tensor( + [config.num_steps] * config.batch_size), initial_state=state) - # import pdb; pdb.set_trace() - mle_loss = tx.losses.sequence_sparse_softmax_cross_entropy( labels=targets, logits=outputs.logits, @@ -101,20 +98,21 @@ def _run_epoch(data_iter, is_train=False, verbose=False): state = None if is_train: - epoch_size = (len(data["train_text_id"]) // batch_size - 1) \ - // num_steps + model.train() + epoch_size = ((len(data["train_text_id"]) // batch_size - 1) // + num_steps) + else: + model.eval() for step, (x, y) in enumerate(data_iter): - - print("Hello, world!") - loss_, state_ = model(inputs=torch.tensor(x), targets=torch.tensor(y), state=state) if is_train: - loss_.backward(retain_graph=True) + loss_.backward() train_op() loss += loss_ - state = state_ + state = [(state_[0][0].detach(), state_[0][1].detach()), + (state_[1][0].detach(), state_[1][1].detach())] iters += num_steps ppl = torch.exp(loss / iters) @@ -123,7 +121,7 @@ def _run_epoch(data_iter, is_train=False, verbose=False): ((step + 1) * 1.0 / epoch_size, ppl, iters * batch_size / (time.time() - start_time))) - ppl = np.exp(loss / iters) + ppl = torch.exp(loss / iters) return ppl for epoch in range(num_epochs): @@ -132,6 +130,8 @@ def _run_epoch(data_iter, is_train=False, verbose=False): data["train_text_id"], batch_size, num_steps) train_ppl = _run_epoch(train_data_iter, is_train=True, verbose=True) print("Epoch: %d Train Perplexity: %.3f" % (epoch, train_ppl)) + train_op(use_scheduler=True) # type: ignore + # Valid valid_data_iter = ptb_iterator( data["valid_text_id"], batch_size, num_steps) diff --git a/texar/torch/core/optimization.py b/texar/torch/core/optimization.py index 427117fb4..a71eba2b5 100644 --- a/texar/torch/core/optimization.py +++ b/texar/torch/core/optimization.py @@ -324,14 +324,15 @@ def get_train_op(params: Optional[Iterable[Union[torch.Tensor, elif isinstance(params, list): params_list += params - def _train_op(): - if grad_clip_fn is not None: - grad_clip_fn(parameters=params_list) - optimizer.step() - # TODO: Ideally, scheduler should be used in the epoch level. - if scheduler is not None: - scheduler.step() - optimizer.zero_grad() + def _train_op(use_scheduler=False): + if not use_scheduler: + if grad_clip_fn is not None: + grad_clip_fn(parameters=params_list) + optimizer.step() + optimizer.zero_grad() + else: + if scheduler is not None: + scheduler.step() return _train_op