Skip to content

Commit

Permalink
Keras 3 semantic similarity
Browse files Browse the repository at this point in the history
  • Loading branch information
mattdangerw committed Nov 23, 2023
1 parent 055dc90 commit e55a9e7
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 31 deletions.
5 changes: 3 additions & 2 deletions examples/nlp/ipynb/semantic_similarity_with_keras_nlp.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@
},
"outputs": [],
"source": [
"!pip install -q keras-nlp"
"!pip install -q --upgrade keras-nlp\n",
"!pip install -q --upgrade keras # Upgrade to Keras 3."
]
},
{
Expand All @@ -69,7 +70,7 @@
"source": [
"import numpy as np\n",
"import tensorflow as tf\n",
"import keras_core as keras\n",
"import keras\n",
"import keras_nlp\n",
"import tensorflow_datasets as tfds"
]
Expand Down
52 changes: 25 additions & 27 deletions examples/nlp/md/semantic_similarity_with_keras_nlp.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,21 +41,19 @@ give us a particularly fast train step below.


```python
!pip install -q keras-nlp
!pip install -q --upgrade keras-nlp
!pip install -q --upgrade keras # Upgrade to Keras 3.
```


```python
import numpy as np
import tensorflow as tf
import keras_core as keras
import keras
import keras_nlp
import tensorflow_datasets as tfds
```

<div class="k-default-codeblock">
```
Using JAX backend.
```
</div>
Expand Down Expand Up @@ -190,9 +188,9 @@ bert_classifier.fit(train_ds, validation_data=val_ds, epochs=1)

<div class="k-default-codeblock">
```
[1m6867/6867[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m49s[0m 7ms/step - loss: 0.8584 - sparse_categorical_accuracy: 0.6049 - val_loss: 0.5857 - val_sparse_categorical_accuracy: 0.7608
6867/6867 ━━━━━━━━━━━━━━━━━━━━ 61s 8ms/step - loss: 0.8732 - sparse_categorical_accuracy: 0.5864 - val_loss: 0.5900 - val_sparse_categorical_accuracy: 0.7602
<keras_core.src.callbacks.history.History at 0x7fc6cbf41ea0>
<keras.src.callbacks.history.History at 0x7f4660171fc0>
```
</div>
Expand All @@ -208,9 +206,9 @@ bert_classifier.evaluate(test_ds)

<div class="k-default-codeblock">
```
[1m614/614[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 3ms/step - loss: 0.5709 - sparse_categorical_accuracy: 0.7742
614/614 ━━━━━━━━━━━━━━━━━━━━ 2s 3ms/step - loss: 0.5815 - sparse_categorical_accuracy: 0.7628
[0.5832399725914001, 0.7678135633468628]
[0.5895748734474182, 0.7618078589439392]
```
</div>
Expand All @@ -235,10 +233,10 @@ bert_classifier.evaluate(test_ds)

<div class="k-default-codeblock">
```
[1m6867/6867[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m49s[0m 7ms/step - accuracy: 0.5944 - loss: 0.8679 - val_accuracy: 0.7645 - val_loss: 0.5811
[1m614/614[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 3ms/step - accuracy: 0.7676 - loss: 0.5742
6867/6867 ━━━━━━━━━━━━━━━━━━━━ 59s 8ms/step - accuracy: 0.6007 - loss: 0.8636 - val_accuracy: 0.7648 - val_loss: 0.5800
614/614 ━━━━━━━━━━━━━━━━━━━━ 2s 3ms/step - accuracy: 0.7700 - loss: 0.5692
[0.5850245356559753, 0.762723982334137]
[0.578984260559082, 0.7686278820037842]
```
</div>
Expand Down Expand Up @@ -297,13 +295,13 @@ bert_classifier.fit(train_ds, validation_data=val_ds, epochs=epochs)
<div class="k-default-codeblock">
```
Epoch 1/3
[1m6867/6867[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m49s[0m 7ms/step - accuracy: 0.5340 - loss: 0.9392 - val_accuracy: 0.7620 - val_loss: 0.5826
6867/6867 ━━━━━━━━━━━━━━━━━━━━ 59s 8ms/step - accuracy: 0.5457 - loss: 0.9317 - val_accuracy: 0.7633 - val_loss: 0.5825
Epoch 2/3
[1m6867/6867[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m46s[0m 7ms/step - accuracy: 0.7314 - loss: 0.6511 - val_accuracy: 0.7871 - val_loss: 0.5338
6867/6867 ━━━━━━━━━━━━━━━━━━━━ 55s 8ms/step - accuracy: 0.7291 - loss: 0.6515 - val_accuracy: 0.7809 - val_loss: 0.5399
Epoch 3/3
[1m6867/6867[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m46s[0m 7ms/step - accuracy: 0.7719 - loss: 0.5683 - val_accuracy: 0.7913 - val_loss: 0.5251
6867/6867 ━━━━━━━━━━━━━━━━━━━━ 55s 8ms/step - accuracy: 0.7708 - loss: 0.5695 - val_accuracy: 0.7918 - val_loss: 0.5214
<keras_core.src.callbacks.history.History at 0x7fc5d069b850>
<keras.src.callbacks.history.History at 0x7f45645b3370>
```
</div>
Expand All @@ -319,9 +317,9 @@ bert_classifier.evaluate(test_ds)

<div class="k-default-codeblock">
```
[1m614/614[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 3ms/step - accuracy: 0.7963 - loss: 0.5189
614/614 ━━━━━━━━━━━━━━━━━━━━ 2s 3ms/step - accuracy: 0.7956 - loss: 0.5128
[0.5268478393554688, 0.791530966758728]
[0.5245093703269958, 0.7890879511833191]
```
</div>
Expand All @@ -346,14 +344,14 @@ restored_model.evaluate(test_ds)

<div class="k-default-codeblock">
```
/home/matt/miniconda3/envs/gpu/lib/python3.10/site-packages/keras_core/src/saving/serialization_lib.py:684: UserWarning: `compile()` was not called as part of model loading because the model's `compile()` method is custom. All subclassed Models that have `compile()` overridden should also override `get_compile_config()` and `compile_from_config(config)`. Alternatively, you can call `compile()` manually after loading.
/home/matt/miniconda3/envs/keras-io/lib/python3.10/site-packages/keras/src/saving/serialization_lib.py:723: UserWarning: `compile()` was not called as part of model loading because the model's `compile()` method is custom. All subclassed Models that have `compile()` overridden should also override `get_compile_config()` and `compile_from_config(config)`. Alternatively, you can call `compile()` manually after loading.
instance.compile_from_config(compile_config)
/home/matt/miniconda3/envs/gpu/lib/python3.10/site-packages/keras_core/src/saving/saving_lib.py:338: UserWarning: Skipping variable loading for optimizer 'adam', because it has 2 variables whereas the saved optimizer has 83 variables.
/home/matt/miniconda3/envs/keras-io/lib/python3.10/site-packages/keras/src/saving/saving_lib.py:355: UserWarning: Skipping variable loading for optimizer 'adam', because it has 2 variables whereas the saved optimizer has 83 variables.
trackable.load_own_variables(weights_store.get(inner_path))
[1m614/614[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 4ms/step - loss: 0.5189 - sparse_categorical_accuracy: 0.7963
614/614 ━━━━━━━━━━━━━━━━━━━━ 2s 3ms/step - loss: 0.5128 - sparse_categorical_accuracy: 0.7956
[0.5268478393554688, 0.791530966758728]
[0.5245093703269958, 0.7890879511833191]
```
</div>
Expand Down Expand Up @@ -406,7 +404,7 @@ predictions = softmax(predictions)

<div class="k-default-codeblock">
```
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 734ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 711ms/step
```
</div>
Expand All @@ -431,10 +429,10 @@ roberta_classifier.evaluate(test_ds)

<div class="k-default-codeblock">
```
[1m6867/6867[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2027s[0m 294ms/step - loss: 0.5688 - sparse_categorical_accuracy: 0.7601 - val_loss: 0.3243 - val_sparse_categorical_accuracy: 0.8820
[1m614/614[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m59s[0m 93ms/step - loss: 0.3250 - sparse_categorical_accuracy: 0.8851
6867/6867 ━━━━━━━━━━━━━━━━━━━━ 2049s 297ms/step - loss: 0.5509 - sparse_categorical_accuracy: 0.7740 - val_loss: 0.3292 - val_sparse_categorical_accuracy: 0.8789
614/614 ━━━━━━━━━━━━━━━━━━━━ 56s 88ms/step - loss: 0.3307 - sparse_categorical_accuracy: 0.8784
[0.3305884897708893, 0.8821254372596741]
[0.33771008253097534, 0.874796450138092]
```
</div>
Expand All @@ -456,7 +454,7 @@ print(tf.math.argmax(predictions, axis=1).numpy())

<div class="k-default-codeblock">
```
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3s/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 4s 4s/step
[0 0 0 0]
```
Expand Down
5 changes: 3 additions & 2 deletions examples/nlp/semantic_similarity_with_keras_nlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,13 @@
"""

"""shell
pip install -q keras-nlp
pip install -q --upgrade keras-nlp
pip install -q --upgrade keras # Upgrade to Keras 3.
"""

import numpy as np
import tensorflow as tf
import keras_core as keras
import keras
import keras_nlp
import tensorflow_datasets as tfds

Expand Down

0 comments on commit e55a9e7

Please sign in to comment.