diff --git a/trax/models/reformer/text_generation.ipynb b/trax/models/reformer/text_generation.ipynb index 5b67721b0..7a882cf55 100644 --- a/trax/models/reformer/text_generation.ipynb +++ b/trax/models/reformer/text_generation.ipynb @@ -289,7 +289,7 @@ " yield (inputs, inputs, mask)\n", "\n", "print(\"(device count, tokens per device) = \",\n", - " next(my_inputs(trax.fastmath.device_count()))[0].shape)" + " next(my_inputs(trax.fastmath.local_device_count()))[0].shape)" ], "execution_count": null, "outputs": [ @@ -385,8 +385,8 @@ "ReformerLM.n_heads = %n_heads\n", "ReformerLM.n_layers = %n_layers\n", "ReformerLM.vocab_size = 320\n", - "ReformerLM.axial_pos_shape = (512, 1024)\n", - "ReformerLM.d_axial_pos_embs= (64, 192)\n", + "ReformerLM.pos_axial_shape = (512, 1024)\n", + "ReformerLM.pos_d_axial_embs= (64, 192)\n", "\"\"\")" ], "execution_count": null, @@ -545,4 +545,4 @@ "outputs": [] } ] -} \ No newline at end of file +}