Skip to content

Commit

Permalink
Now works with Pytorch 1.8
Browse files Browse the repository at this point in the history
  • Loading branch information
vansky committed Mar 25, 2021
1 parent a91691e commit 088f046
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 11 deletions.
10 changes: 3 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,20 +1,16 @@
# Neural complexity
A neural language model that computes various information-theoretic processing complexity measures (e.g., surprisal) for each word given the preceding context. Also, it can function as an adaptive language model ([van Schijndel and Linzen, 2018](http://aclweb.org/anthology/D18-1499)) which adapts to test domains.

**Note**: Recent updates remove dependencies but break compatibility with pre-2021 models. To use older models, use version 1.1.0: `git checkout tags/v1.1.0`

### Dependencies
Requires the following python packages (available through pip):
* [pytorch](https://pytorch.org/) v1.0.0
* nltk
* [pytorch](https://pytorch.org/)

The following python packages are optional:
* progress
* dill (to handle binarized vocabularies)

Requires the `punkt` nltk module. Install it from within python:

import nltk
nltk.download('punkt')

### Quick Usage
The below all use GPUs. To use CPUs instead, omit the `--cuda` flag.

Expand Down
8 changes: 4 additions & 4 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ def test_evaluate(test_sentences, data_source):
target = targets[word_index].unsqueeze(0)
output, hidden = model(word_input, hidden)
output_flat = output.view(-1, ntokens)
loss = criterion(output_flat, target)
loss = criterion(output_flat, target.long())
total_loss += loss.item()
input_word = corpus.dictionary.idx2word[int(word_input.data)]
targ_word = corpus.dictionary.idx2word[int(target.data)]
Expand Down Expand Up @@ -482,7 +482,7 @@ def test_evaluate(test_sentences, data_source):
except RuntimeError:
print("Vocabulary Error! Most likely there weren't unks in training and unks are now needed for testing")
raise
loss = criterion(output_flat, targets)
loss = criterion(output_flat, targets.long())
total_loss += loss.item()
if args.words:
# output word-level complexity metrics
Expand Down Expand Up @@ -527,7 +527,7 @@ def evaluate(data_source):
data, targets = get_batch(data_source, i)
output, hidden = model(data, hidden)
output_flat = output.view(-1, ntokens)
total_loss += len(data) * criterion(output_flat, targets).item()
total_loss += len(data) * criterion(output_flat, targets.long()).item()
hidden = repackage_hidden(hidden)
return total_loss / len(data_source)

Expand All @@ -546,7 +546,7 @@ def train():
hidden = repackage_hidden(hidden)
model.zero_grad()
output, hidden = model(data, hidden)
loss = criterion(output.view(-1, ntokens), targets)
loss = criterion(output.view(-1, ntokens), targets.long())
loss.backward()

# `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
Expand Down

0 comments on commit 088f046

Please sign in to comment.