Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
tvhahn committed Nov 3, 2021
1 parent 5328e02 commit 4e089b0
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 19 deletions.
8 changes: 4 additions & 4 deletions bash_scripts/train_model_hpc.sh
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#!/bin/bash
#SBATCH --account=rrg-mechefsk
#SBATCH --gres=gpu:1 # request GPU "generic resource"
#SBATCH --gres=gpu:t4:1 # request GPU "generic resource"
#SBATCH --cpus-per-task=4 # maximum CPU cores per GPU request: 6 on Cedar, 16 on Graham.
#SBATCH --mem=10000M # memory per node
#SBATCH --time=0-00:10 # time (DD-HH:MM)
#SBATCH --mem=14000M # memory per node
#SBATCH --time=0-00:20 # time (DD-HH:MM)
#SBATCH --output=%N-%j.out # %N for node name, %j for jobID
#SBATCH --mail-type=ALL # Type of email notification- BEGIN,END,F$
#SBATCH [email protected] # Email to which notifications will be $
Expand All @@ -27,7 +27,7 @@ cp -r ~/scratch/earth-mantle-surrogate/processed $SLURM_TMPDIR/data
python $PROJECT_DIR/src/models/train_model.py \
--path_data $SLURM_TMPDIR/data/processed \
--proj_dir $PROJECT_DIR \
# --checkpoint 2021_11_03_102524 \
--checkpoint 2021_11_03_102524 \
--batch_size 1 \
--learning_rate 1e-4 \
--critic_iterations 5 \
Expand Down
29 changes: 16 additions & 13 deletions notebooks/scratch/2.0_data-loader-test.ipynb

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions src/models/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def plot_fake_truth(fake, x_truth, x_up, epoch_i, batch_i, time_i):
ax[2, i].pcolormesh(x_truth[bi, vi, ri, :, :].cpu(), cmap=color_scheme)
ax[2, i].get_xaxis().set_visible(False)
ax[2, i].get_yaxis().set_visible(False)
plt.suptitle(f"Epoch {epoch_i}, Batch Index {batch_i}, Time Step {time_i}")
plt.suptitle(f"Epoch {epoch_i}, Batch Index {batch_i}, Time Step {time_i[bi]}")
plt.subplots_adjust(wspace=0, hspace=0)

return fig
Expand Down Expand Up @@ -419,7 +419,7 @@ def train(

if epoch > GEN_PRETRAIN_EPOCHS:
if batch_idx % 10 == 0:

create_tensorboard_fig(
gen,
x_input,
Expand All @@ -445,6 +445,7 @@ def train(
x_up,
epoch,
batch_idx,
time_i,
step,
writer_results,
)
Expand Down

0 comments on commit 4e089b0

Please sign in to comment.