Skip to content

Commit

Permalink
Keras 3 imbalanced classification
Browse files Browse the repository at this point in the history
Marked as tf only for some reason, but this is actually all backends
  • Loading branch information
mattdangerw committed Nov 11, 2023
1 parent ea7f24e commit 17632d4
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 83 deletions.
9 changes: 4 additions & 5 deletions examples/structured_data/imbalanced_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,12 @@
## Build a binary classification model
"""

from tensorflow import keras
import keras

model = keras.Sequential(
[
keras.layers.Dense(
256, activation="relu", input_shape=(train_features.shape[-1],)
),
keras.Input(shape=train_features.shape[1:]),
keras.layers.Dense(256, activation="relu"),
keras.layers.Dense(256, activation="relu"),
keras.layers.Dropout(0.3),
keras.layers.Dense(256, activation="relu"),
Expand All @@ -118,7 +117,7 @@
optimizer=keras.optimizers.Adam(1e-2), loss="binary_crossentropy", metrics=metrics
)

callbacks = [keras.callbacks.ModelCheckpoint("fraud_model_at_epoch_{epoch}.h5")]
callbacks = [keras.callbacks.ModelCheckpoint("fraud_model_at_epoch_{epoch}.keras")]
class_weight = {0: weight_for_0, 1: weight_for_1}

model.fit(
Expand Down
23 changes: 11 additions & 12 deletions examples/structured_data/ipynb/imbalanced_classification.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
"features = np.array(all_features, dtype=\"float32\")\n",
"targets = np.array(all_targets, dtype=\"uint8\")\n",
"print(\"features.shape:\", features.shape)\n",
"print(\"targets.shape:\", targets.shape)\n"
"print(\"targets.shape:\", targets.shape)"
]
},
{
Expand All @@ -94,7 +94,7 @@
"val_targets = targets[-num_val_samples:]\n",
"\n",
"print(\"Number of training samples:\", len(train_features))\n",
"print(\"Number of validation samples:\", len(val_features))\n"
"print(\"Number of validation samples:\", len(val_features))"
]
},
{
Expand Down Expand Up @@ -122,7 +122,7 @@
")\n",
"\n",
"weight_for_0 = 1.0 / counts[0]\n",
"weight_for_1 = 1.0 / counts[1]\n"
"weight_for_1 = 1.0 / counts[1]"
]
},
{
Expand All @@ -147,7 +147,7 @@
"val_features -= mean\n",
"std = np.std(train_features, axis=0)\n",
"train_features /= std\n",
"val_features /= std\n"
"val_features /= std"
]
},
{
Expand All @@ -167,21 +167,20 @@
},
"outputs": [],
"source": [
"from tensorflow import keras\n",
"import keras\n",
"\n",
"model = keras.Sequential(\n",
" [\n",
" keras.layers.Dense(\n",
" 256, activation=\"relu\", input_shape=(train_features.shape[-1],)\n",
" ),\n",
" keras.Input(shape=train_features.shape[1:]),\n",
" keras.layers.Dense(256, activation=\"relu\"),\n",
" keras.layers.Dense(256, activation=\"relu\"),\n",
" keras.layers.Dropout(0.3),\n",
" keras.layers.Dense(256, activation=\"relu\"),\n",
" keras.layers.Dropout(0.3),\n",
" keras.layers.Dense(1, activation=\"sigmoid\"),\n",
" ]\n",
")\n",
"model.summary()\n"
"model.summary()"
]
},
{
Expand Down Expand Up @@ -214,7 +213,7 @@
" optimizer=keras.optimizers.Adam(1e-2), loss=\"binary_crossentropy\", metrics=metrics\n",
")\n",
"\n",
"callbacks = [keras.callbacks.ModelCheckpoint(\"fraud_model_at_epoch_{epoch}.h5\")]\n",
"callbacks = [keras.callbacks.ModelCheckpoint(\"fraud_model_at_epoch_{epoch}.keras\")]\n",
"class_weight = {0: weight_for_0, 1: weight_for_1}\n",
"\n",
"model.fit(\n",
Expand All @@ -226,7 +225,7 @@
" callbacks=callbacks,\n",
" validation_data=(val_features, val_targets),\n",
" class_weight=class_weight,\n",
")\n"
")"
]
},
{
Expand All @@ -252,7 +251,7 @@
"\n",
"| Trained Model | Demo |\n",
"| :--: | :--: |\n",
"| [![Generic badge](https://img.shields.io/badge/🤗%20Model-Imbalanced%20Classification-black.svg)](https://huggingface.co/keras-io/imbalanced_classification) | [![Generic badge](https://img.shields.io/badge/🤗%20Spaces-Imbalanced%20Classification-black.svg)](https://huggingface.co/spaces/keras-io/Credit_Card_Fraud_Detection) |\n"
"| [![Generic badge](https://img.shields.io/badge/\ud83e\udd17%20Model-Imbalanced%20Classification-black.svg)](https://huggingface.co/keras-io/imbalanced_classification) | [![Generic badge](https://img.shields.io/badge/\ud83e\udd17%20Spaces-Imbalanced%20Classification-black.svg)](https://huggingface.co/spaces/keras-io/Credit_Card_Fraud_Detection) |"
]
}
],
Expand Down
Loading

0 comments on commit 17632d4

Please sign in to comment.