Skip to content

Commit

Permalink
fixed example
Browse files Browse the repository at this point in the history
  • Loading branch information
tcoroller committed Dec 13, 2024
1 parent 6ed7d8c commit 179b634
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 23 deletions.
1 change: 1 addition & 0 deletions docs/notebooks/helpers_momentum.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def train_dataloader(self):
num_workers=self.num_workers,
persistent_workers=True,
shuffle=True,
drop_last=True,
)

def val_dataloader(self):
Expand Down
47 changes: 24 additions & 23 deletions docs/notebooks/momentum.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
"source": [
"import matplotlib.pyplot as plt\n",
"import torch\n",
"import copy\n",
"import lightning as L\n",
"from torchvision.models import resnet18\n",
"from torchvision.transforms import v2\n",
Expand Down Expand Up @@ -121,13 +122,13 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"id": "794004c5-588c-4590-ae96-c6d9e52109ff",
"metadata": {},
"outputs": [],
"source": [
"EPOCHS = 2 # number of epochs to train\n",
"FAST_DEV_RUN = 5 # Quick prototype, comment line for full epochs training"
"FAST_DEV_RUN = None # Quick prototype, comment line for full epochs training"
]
},
{
Expand Down Expand Up @@ -278,7 +279,7 @@
"outputs": [],
"source": [
"# Train first model (regular training) using our backbone\n",
"model_regular = LitMNIST(backbone=resnet)"
"model_regular = LitMNIST(backbone=copy.deepcopy(resnet))"
]
},
{
Expand All @@ -295,7 +296,7 @@
"TPU available: False, using: 0 TPU cores\n",
"IPU available: False, using: 0 IPUs\n",
"HPU available: False, using: 0 HPUs\n",
"Running in `fast_dev_run` mode: will run the requested loop using 3 batch(es). Logging and checkpointing is suppressed.\n"
"Running in `fast_dev_run` mode: will run the requested loop using 10 batch(es). Logging and checkpointing is suppressed.\n"
]
}
],
Expand Down Expand Up @@ -337,21 +338,21 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 0: 100%|██████████| 3/3 [00:50<00:00, 0.06it/s, loss_step=274.0, val_loss_step=260.0, cindex_step=0.615, val_loss_epoch=261.0, cindex_epoch=0.618, loss_epoch=300.0]"
"Epoch 0: 100%|██████████| 10/10 [06:00<00:00, 0.03it/s, loss_step=242.0, val_loss_step=260.0, cindex_step=0.585, val_loss_epoch=260.0, cindex_epoch=0.579, loss_epoch=265.0]"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"`Trainer.fit` stopped: `max_steps=3` reached.\n"
"`Trainer.fit` stopped: `max_steps=10` reached.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 0: 100%|██████████| 3/3 [00:50<00:00, 0.06it/s, loss_step=274.0, val_loss_step=260.0, cindex_step=0.615, val_loss_epoch=261.0, cindex_epoch=0.618, loss_epoch=300.0]\n"
"Epoch 0: 100%|██████████| 10/10 [06:00<00:00, 0.03it/s, loss_step=242.0, val_loss_step=260.0, cindex_step=0.585, val_loss_epoch=260.0, cindex_epoch=0.579, loss_epoch=265.0]\n"
]
}
],
Expand All @@ -370,19 +371,19 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Testing DataLoader 0: 100%|██████████| 3/3 [00:20<00:00, 0.15it/s]\n",
"Testing DataLoader 0: 100%|██████████| 10/10 [03:09<00:00, 0.05it/s]\n",
"────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n",
" Test metric DataLoader 0\n",
"────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n",
" cindex_epoch 0.612917959690094\n",
" val_loss_epoch 34.167057037353516\n",
" cindex_epoch 0.5841928124427795\n",
" val_loss_epoch -90.17676544189453\n",
"────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n"
]
},
{
"data": {
"text/plain": [
"[{'val_loss_epoch': 34.167057037353516, 'cindex_epoch': 0.612917959690094}]"
"[{'val_loss_epoch': -90.17676544189453, 'cindex_epoch': 0.5841928124427795}]"
]
},
"execution_count": 15,
Expand Down Expand Up @@ -415,7 +416,7 @@
"outputs": [],
"source": [
"FACTOR = 10 # Number of batch to keep in memory. Increase our training batch size artificially by factor of 10 here\n",
"resnet_momentum = Momentum(resnet, neg_partial_log_likelihood, steps=FACTOR, rate=0.999)\n",
"resnet_momentum = Momentum(copy.deepcopy(resnet), neg_partial_log_likelihood, steps=FACTOR, rate=0.999)\n",
"model_momentum = LitMomentum(backbone=resnet_momentum)\n",
"\n",
"# By using momentum, we can in theory reduce our batch size by factor and still have the same effective sample size\n",
Expand All @@ -438,7 +439,7 @@
"TPU available: False, using: 0 TPU cores\n",
"IPU available: False, using: 0 IPUs\n",
"HPU available: False, using: 0 HPUs\n",
"Running in `fast_dev_run` mode: will run the requested loop using 3 batch(es). Logging and checkpointing is suppressed.\n",
"Running in `fast_dev_run` mode: will run the requested loop using 10 batch(es). Logging and checkpointing is suppressed.\n",
"\n",
" | Name | Type | Params\n",
"-----------------------------------\n",
Expand All @@ -454,21 +455,21 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 0: 100%|██████████| 3/3 [00:15<00:00, 0.19it/s, loss_step=44.70, val_loss_step=74.10, cindex_step=0.614, val_loss_epoch=74.00, cindex_epoch=0.631, loss_epoch=23.30]"
"Epoch 0: 100%|██████████| 10/10 [01:27<00:00, 0.11it/s, loss_step=65.60, val_loss_step=67.00, cindex_step=0.524, val_loss_epoch=66.40, cindex_epoch=0.519, loss_epoch=52.70]"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"`Trainer.fit` stopped: `max_steps=3` reached.\n"
"`Trainer.fit` stopped: `max_steps=10` reached.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 0: 100%|██████████| 3/3 [00:15<00:00, 0.19it/s, loss_step=44.70, val_loss_step=74.10, cindex_step=0.614, val_loss_epoch=74.00, cindex_epoch=0.631, loss_epoch=23.30]\n"
"Epoch 0: 100%|██████████| 10/10 [01:27<00:00, 0.11it/s, loss_step=65.60, val_loss_step=67.00, cindex_step=0.524, val_loss_epoch=66.40, cindex_epoch=0.519, loss_epoch=52.70]\n"
]
}
],
Expand Down Expand Up @@ -497,19 +498,19 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Testing DataLoader 0: 100%|██████████| 3/3 [00:04<00:00, 0.61it/s]\n",
"Testing DataLoader 0: 100%|██████████| 10/10 [00:29<00:00, 0.34it/s]\n",
"────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n",
" Test metric DataLoader 0\n",
"────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n",
" cindex_epoch 0.609957754611969\n",
" val_loss_epoch 65.41207122802734\n",
" cindex_epoch 0.521008312702179\n",
" val_loss_epoch 66.2925796508789\n",
"────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n"
]
},
{
"data": {
"text/plain": [
"[{'val_loss_epoch': 65.41207122802734, 'cindex_epoch': 0.609957754611969}]"
"[{'val_loss_epoch': 66.2925796508789, 'cindex_epoch': 0.521008312702179}]"
]
},
"execution_count": 18,
Expand Down Expand Up @@ -580,9 +581,9 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Cindex (regular) = 0.637220025062561\n",
"Cindex (momentum) = 0.6211331486701965\n",
"Compare (p-value) = 0.9429321885108948\n"
"Cindex (regular) = 0.5925866365432739\n",
"Cindex (momentum) = 0.5304314494132996\n",
"Compare (p-value) = 0.9395550489425659\n"
]
}
],
Expand Down

0 comments on commit 179b634

Please sign in to comment.