Skip to content

Commit

Permalink
Merge pull request #20 from AutoResearch/17-path_checkpoint-does-not-…
Browse files Browse the repository at this point in the history
…function-pip-release

17 path checkpoint does not function pip release
  • Loading branch information
chadcwilliams authored Jul 25, 2023
2 parents 60d6fe5 + 49ab841 commit 2660923
Show file tree
Hide file tree
Showing 8 changed files with 36 additions and 28 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@

.vscode/launch.json
6 changes: 3 additions & 3 deletions eeggan/Generate_Samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import pandas as pd
import torch

from eeggan.helpers import system_inputs
from eeggan.helpers.trainer import Trainer
from eeggan.nn_architecture.models import TtsGenerator, TtsGeneratorFiltered
from helpers import system_inputs
from helpers.trainer import Trainer
from nn_architecture.models import TtsGenerator, TtsGeneratorFiltered

def generate_samples(argv = []):

Expand Down
23 changes: 13 additions & 10 deletions eeggan/Train_Gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
import torch
import torch.multiprocessing as mp

from eeggan.helpers.trainer import Trainer
from eeggan.helpers.get_master import find_free_port
from eeggan.helpers.ddp_training import run, DDPTrainer
from eeggan.nn_architecture.models import TtsDiscriminator, TtsGenerator, TtsGeneratorFiltered
from eeggan.helpers.dataloader import Dataloader
from eeggan.helpers import system_inputs
from helpers.trainer import Trainer
from helpers.get_master import find_free_port
from helpers.ddp_training import run, DDPTrainer
from nn_architecture.models import TtsDiscriminator, TtsGenerator, TtsGeneratorFiltered
from helpers.dataloader import Dataloader
from helpers import system_inputs

"""Implementation of the training process of a GAN for the generation of synthetic sequential data.
Expand Down Expand Up @@ -207,10 +207,13 @@ def train_gan(argv = []):
gen_samples = trainer.training(dataset)

# save final models, optimizer states, generated samples, losses and configuration as final result
path = 'trained_models'
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
filename = f'gan_{trainer.epochs}ep_' + timestamp + '.pt'
trainer.save_checkpoint(path_checkpoint=os.path.join(path, filename), generated_samples=gen_samples)
if default_args['path_checkpoint']:
trainer.save_checkpoint(path_checkpoint=default_args['path_checkpoint'], generated_samples=gen_samples)
else:
path = 'trained_models'
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
filename = f'gan_{trainer.epochs}ep_' + timestamp + '.pt'
trainer.save_checkpoint(path_checkpoint=os.path.join(path, filename), generated_samples=gen_samples)

print("GAN training finished.")
print("Generated samples saved to file.")
Expand Down
10 changes: 5 additions & 5 deletions eeggan/Visualize_Gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
import numpy as np
import torch

from eeggan.helpers import system_inputs
from eeggan.nn_architecture import models
from eeggan.helpers.dataloader import Dataloader
from eeggan.helpers.visualize_pca import visualization_dim_reduction
from eeggan.helpers.visualize_spectogram import plot_spectogram, plot_fft_hist
from helpers import system_inputs
from nn_architecture import models
from helpers.dataloader import Dataloader
from helpers.visualize_pca import visualization_dim_reduction
from helpers.visualize_spectogram import plot_spectogram, plot_fft_hist

class PlotterGanTraining:
"""This class is used to read samples from a csv-file and plot them.
Expand Down
15 changes: 9 additions & 6 deletions eeggan/helpers/ddp_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

import eeggan.helpers.trainer as trainer
from eeggan.helpers.dataloader import Dataloader
import helpers.trainer as trainer
from helpers.dataloader import Dataloader


class DDPTrainer(trainer.Trainer):
Expand Down Expand Up @@ -127,10 +127,13 @@ def _ddp_training(training: DDPTrainer, opt):

# save checkpoint
if training.rank == 0:
path = 'trained_models'
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
filename = f'gan_ddp_{training.epochs}ep_' + timestamp + '.pt'
training.save_checkpoint(path_checkpoint=os.path.join(path, filename), generated_samples=gen_samples)
if opt['path_checkpoint']:
training.save_checkpoint(path_checkpoint=opt['path_checkpoint'], generated_samples=gen_samples)
else:
path = 'trained_models'
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
filename = f'gan_ddp_{training.epochs}ep_' + timestamp + '.pt'
training.save_checkpoint(path_checkpoint=os.path.join(path, filename), generated_samples=gen_samples)

print("GAN training finished.")
print("Model states and generated samples saved to file.")
4 changes: 2 additions & 2 deletions eeggan/helpers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import torch
import numpy as np

from eeggan.nn_architecture import losses, models
from eeggan.nn_architecture.losses import WassersteinGradientPenaltyLoss as Loss
from nn_architecture import losses, models
from nn_architecture.losses import WassersteinGradientPenaltyLoss as Loss

# https://machinelearningmastery.com/how-to-implement-wasserstein-loss-for-generative-adversarial-networks/
# For implementation of Wasserstein-GAN see link above
Expand Down
2 changes: 1 addition & 1 deletion eeggan/helpers/visualize_pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA

from eeggan.helpers.dataloader import Dataloader
from helpers.dataloader import Dataloader


def visualization_dim_reduction(ori_data, generated_data, analysis, save, save_name=None, perplexity=40, iterations=1000, return_result=False):
Expand Down
2 changes: 1 addition & 1 deletion eeggan/nn_architecture/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from scipy import signal
import numpy as np

from eeggan.nn_architecture.ttsgan_components import *
from nn_architecture.ttsgan_components import *

# insert here all different kinds of generators and discriminators

Expand Down

0 comments on commit 2660923

Please sign in to comment.