Skip to content

Commit

Permalink
Fix example
Browse files Browse the repository at this point in the history
  • Loading branch information
mattdangerw committed Nov 23, 2023
1 parent 69d3b41 commit 936dd23
Show file tree
Hide file tree
Showing 5 changed files with 283 additions and 300 deletions.
31 changes: 26 additions & 5 deletions examples/vision/captcha_ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,6 @@
import keras
from keras import layers

# TODO restore traceback filtering.
keras.config.disable_traceback_filtering()

"""
## Load the data: [Captcha Images](https://www.kaggle.com/fournierp/captcha-version-2-images)
Let's download the data.
Expand Down Expand Up @@ -357,9 +354,33 @@ def build_model():
"""


def ctc_decode(y_pred, input_length, greedy=True, beam_width=100, top_paths=1):
input_shape = tf.shape(y_pred)
num_samples, num_steps = input_shape[0], input_shape[1]
y_pred = tf.math.log(tf.transpose(y_pred, perm=[1, 0, 2]) + keras.backend.epsilon())
input_length = tf.cast(input_length, tf.int32)

if greedy:
(decoded, log_prob) = tf.nn.ctc_greedy_decoder(
inputs=y_pred, sequence_length=input_length
)
else:
(decoded, log_prob) = tf.compat.v1.nn.ctc_beam_search_decoder(
inputs=y_pred,
sequence_length=input_length,
beam_width=beam_width,
top_paths=top_paths,
)
decoded_dense = []
for st in decoded:
st = tf.SparseTensor(st.indices, st.values, (num_samples, num_steps))
decoded_dense.append(tf.sparse.to_dense(sp_input=st, default_value=-1))
return (decoded_dense, log_prob)


# Get the prediction model by extracting layers till the output layer
prediction_model = keras.models.Model(
model.get_layer(name="image").input, model.get_layer(name="dense2").output
model.input[0], model.get_layer(name="dense2").output
)
prediction_model.summary()

Expand All @@ -368,7 +389,7 @@ def build_model():
def decode_batch_predictions(pred):
input_len = np.ones(pred.shape[0]) * pred.shape[1]
# Use greedy search. For complex tasks, you can use beam search
results = keras.backend.ctc_decode(pred, input_length=input_len, greedy=True)[0][0][
results = ctc_decode(pred, input_length=input_len, greedy=True)[0][0][
:, :max_length
]
# Iterate over the results and get back the text
Expand Down
Binary file modified examples/vision/img/captcha_ocr/captcha_ocr_13_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
33 changes: 28 additions & 5 deletions examples/vision/ipynb/captcha_ocr.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,7 @@
"\n",
"import tensorflow as tf\n",
"import keras\n",
"from keras import layers\n",
"\n",
"keras.config.disable_traceback_filtering()"
"from keras import layers"
]
},
{
Expand Down Expand Up @@ -450,6 +448,7 @@
"outputs": [],
"source": [
"\n",
"# TODO restore epoch count.\n",
"epochs = 2\n",
"early_stopping_patience = 10\n",
"# Add early stopping\n",
Expand Down Expand Up @@ -487,10 +486,34 @@
},
"outputs": [],
"source": [
"\n",
"def ctc_decode(y_pred, input_length, greedy=True, beam_width=100, top_paths=1):\n",
" input_shape = tf.shape(y_pred)\n",
" num_samples, num_steps = input_shape[0], input_shape[1]\n",
" y_pred = tf.math.log(tf.transpose(y_pred, perm=[1, 0, 2]) + keras.backend.epsilon())\n",
" input_length = tf.cast(input_length, tf.int32)\n",
"\n",
" if greedy:\n",
" (decoded, log_prob) = tf.nn.ctc_greedy_decoder(\n",
" inputs=y_pred, sequence_length=input_length\n",
" )\n",
" else:\n",
" (decoded, log_prob) = tf.compat.v1.nn.ctc_beam_search_decoder(\n",
" inputs=y_pred,\n",
" sequence_length=input_length,\n",
" beam_width=beam_width,\n",
" top_paths=top_paths,\n",
" )\n",
" decoded_dense = []\n",
" for st in decoded:\n",
" st = tf.SparseTensor(st.indices, st.values, (num_samples, num_steps))\n",
" decoded_dense.append(tf.sparse.to_dense(sp_input=st, default_value=-1))\n",
" return (decoded_dense, log_prob)\n",
"\n",
"\n",
"# Get the prediction model by extracting layers till the output layer\n",
"prediction_model = keras.models.Model(\n",
" model.get_layer(name=\"image\").input, model.get_layer(name=\"dense2\").output\n",
" model.input[0], model.get_layer(name=\"dense2\").output\n",
")\n",
"prediction_model.summary()\n",
"\n",
Expand All @@ -499,7 +522,7 @@
"def decode_batch_predictions(pred):\n",
" input_len = np.ones(pred.shape[0]) * pred.shape[1]\n",
" # Use greedy search. For complex tasks, you can use beam search\n",
" results = keras.backend.ctc_decode(pred, input_length=input_len, greedy=True)[0][0][\n",
" results = ctc_decode(pred, input_length=input_len, greedy=True)[0][0][\n",
" :, :max_length\n",
" ]\n",
" # Iterate over the results and get back the text\n",
Expand Down
Loading

0 comments on commit 936dd23

Please sign in to comment.