Skip to content

Commit

Permalink
Created using Colaboratory
Browse files Browse the repository at this point in the history
  • Loading branch information
1kaiser committed Dec 31, 2022
1 parent 0441144 commit 992bb5b
Showing 1 changed file with 163 additions and 3 deletions.
166 changes: 163 additions & 3 deletions MLP_Image_Train_Inference_JAX.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"metadata": {
"colab": {
"provenance": [],
"authorship_tag": "ABX9TyO/VV28tLyXQmzWWnHKkzuo",
"authorship_tag": "ABX9TyOKcfrMg1MLg77aTV44y/4b",
"include_colab_link": true
},
"kernelspec": {
Expand Down Expand Up @@ -4494,14 +4494,14 @@
"name": "stderr",
"text": [
"\n",
" 99%|█████████▊| 994/1009 [29:52<00:27, 1.80s/it]\u001b[A"
" 52%|█████▏ | 523/1009 [14:53<13:28, 1.66s/it]\u001b[A"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"loss: 1486.7937 <<< \n"
"loss: 1168.78 <<< \n"
]
}
]
Expand Down Expand Up @@ -4583,6 +4583,166 @@
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"**ensemble**"
],
"metadata": {
"id": "ka3_468K9xRZ"
}
},
{
"cell_type": "code",
"source": [
"#✅\n",
"!python -m pip install -q -U flax\n",
"import optax\n",
"from flax.training import train_state\n",
"import jax.numpy as jnp\n",
"import jax\n",
"\n",
"@functools.partial(jax.pmap, static_broadcasted_argnums=(1, 2))\n",
"def Create_train_state(r_key, shape, learning_rate ):\n",
" print(shape)\n",
" model = MLPModel()\n",
" variables = model.init(r_key, jnp.ones(shape)) \n",
" optimizer = optax.adam(learning_rate) \n",
" return train_state.TrainState.create(\n",
" apply_fn = model.apply,\n",
" tx=optimizer,\n",
" params=variables['params']\n",
" )\n",
"\n",
"learning_rate = 1e-4\n",
"batch_size_no = 64\n",
"\n",
"model = MLPModel() # Instantiate the Model"
],
"metadata": {
"id": "Aaat0R0q9Z7F"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"@functools.partial(jax.pmap, axis_name='ensemble')\n",
"def apply_model(state, batch: jnp.asarray):\n",
" image, label = batch\n",
" def loss_fn(params):\n",
" logits = MLPModel().apply({'params': params}, image)\n",
" loss = image_difference_loss(logits, label);\n",
" return loss, logits\n",
"\n",
" grad_fn = jax.value_and_grad(loss_fn, has_aux=True)\n",
" (loss, logits), grads = grad_fn(state.params)\n",
" return grads, loss\n",
"\n",
"@jax.pmap\n",
"def update_model(state, grads):\n",
" return state.apply_gradients(grads=grads)"
],
"metadata": {
"id": "QyGP4Fmf-q7q"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def train_epoch(state, train_ds, batch_size, rng):\n",
" train_ds_size = len(train_ds['image'])\n",
" steps_per_epoch = train_ds_size // batch_size\n",
"\n",
" perms = jax.random.permutation(rng, len(train_ds['image']))\n",
" perms = perms[:steps_per_epoch * batch_size]\n",
" perms = perms.reshape((steps_per_epoch, batch_size))\n",
"\n",
" epoch_loss = []\n",
"\n",
" for perm in perms:\n",
" batch_images = jax_utils.replicate(train_ds['image'][perm, ...])\n",
" batch_labels = jax_utils.replicate(train_ds['label'][perm, ...])\n",
" grads, loss = apply_model(state, batch_images, batch_labels)\n",
" state = update_model(state, grads)\n",
" epoch_loss.append(jax_utils.unreplicate(loss))\n",
" train_loss = np.mean(epoch_loss)\n",
" return state, train_loss"
],
"metadata": {
"id": "QQUj3Y3LA9A1"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"train_ds, test_ds = get_datasets()\n",
"test_ds = jax_utils.replicate(test_ds)\n",
"rng = jax.random.PRNGKey(0)\n",
"\n",
"rng, init_rng = jax.random.split(rng)\n",
"\n",
"HxW, Channels = next(batches)[0].shape\n",
"state = create_train_state(jax.random.split(init_rng, jax.device_count()),(HxW, Channels),learning_rate)\n",
"\n",
"for epoch in range(1, num_epochs + 1):\n",
" rng, input_rng = jax.random.split(rng)\n",
" state, train_loss = train_epoch(state, train_ds, batch_size, input_rng)\n",
"\n",
" # _, test_loss = jax_utils.unreplicate(apply_model(state, test_ds['image'], test_ds['label']))\n",
"\n",
" logging.info('epoch:% 3d, train_loss: %.4f ' % (epoch, train_loss))"
],
"metadata": {
"id": "X-CttLscBnDQ"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"correct = total = 0\n",
"for batch in ds.as_numpy_iterator():\n",
" preds = flax.jax_utils.pad_shard_unpad(get_preds)(\n",
" vs, batch['image'], min_device_batch=per_device_batch_size)\n",
" total += len(batch['image'])\n",
" correct += (batch['label'] == preds.argmax(axis=-1)).sum()"
],
"metadata": {
"id": "I_orMqbuD3LL"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def eval_step(metrics, variables, batch):\n",
" print('retrigger compilation', {k: v.shape for k, v in batch.items()})\n",
" preds = model.apply(variables, batch['image'])\n",
" correct = (batch['mask'] & (batch['label'] == preds.argmax(axis=-1))).sum()\n",
" total = batch['mask'].sum()\n",
" return dict(\n",
" correct=metrics['correct'] + jax.lax.psum(correct, axis_name='batch'),\n",
" total=metrics['total'] + jax.lax.psum(total, axis_name='batch'),\n",
" )\n",
"\n",
"eval_step = jax.pmap(eval_step, axis_name='batch')\n",
"eval_step = flax.jax_utils.pad_shard_unpad(\n",
" eval_step, static_argnums=(0, 1), static_return=True)"
],
"metadata": {
"id": "RxhJjRZLD5P-"
},
"execution_count": null,
"outputs": []
}
]
}

0 comments on commit 992bb5b

Please sign in to comment.