diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..dc92603 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ + +.vscode/launch.json diff --git a/eeggan/Generate_Samples.py b/eeggan/Generate_Samples.py index 1f2e655..7d4a34c 100644 --- a/eeggan/Generate_Samples.py +++ b/eeggan/Generate_Samples.py @@ -5,9 +5,9 @@ import pandas as pd import torch -from eeggan.helpers import system_inputs -from eeggan.helpers.trainer import Trainer -from eeggan.nn_architecture.models import TtsGenerator, TtsGeneratorFiltered +from helpers import system_inputs +from helpers.trainer import Trainer +from nn_architecture.models import TtsGenerator, TtsGeneratorFiltered def generate_samples(argv = []): diff --git a/eeggan/Train_Gan.py b/eeggan/Train_Gan.py index dcb8b60..626f33b 100644 --- a/eeggan/Train_Gan.py +++ b/eeggan/Train_Gan.py @@ -5,12 +5,12 @@ import torch import torch.multiprocessing as mp -from eeggan.helpers.trainer import Trainer -from eeggan.helpers.get_master import find_free_port -from eeggan.helpers.ddp_training import run, DDPTrainer -from eeggan.nn_architecture.models import TtsDiscriminator, TtsGenerator, TtsGeneratorFiltered -from eeggan.helpers.dataloader import Dataloader -from eeggan.helpers import system_inputs +from helpers.trainer import Trainer +from helpers.get_master import find_free_port +from helpers.ddp_training import run, DDPTrainer +from nn_architecture.models import TtsDiscriminator, TtsGenerator, TtsGeneratorFiltered +from helpers.dataloader import Dataloader +from helpers import system_inputs """Implementation of the training process of a GAN for the generation of synthetic sequential data. @@ -207,10 +207,13 @@ def train_gan(argv = []): gen_samples = trainer.training(dataset) # save final models, optimizer states, generated samples, losses and configuration as final result - path = 'trained_models' - timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') - filename = f'gan_{trainer.epochs}ep_' + timestamp + '.pt' - trainer.save_checkpoint(path_checkpoint=os.path.join(path, filename), generated_samples=gen_samples) + if default_args['path_checkpoint']: + trainer.save_checkpoint(path_checkpoint=default_args['path_checkpoint'], generated_samples=gen_samples) + else: + path = 'trained_models' + timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') + filename = f'gan_{trainer.epochs}ep_' + timestamp + '.pt' + trainer.save_checkpoint(path_checkpoint=os.path.join(path, filename), generated_samples=gen_samples) print("GAN training finished.") print("Generated samples saved to file.") diff --git a/eeggan/Visualize_Gan.py b/eeggan/Visualize_Gan.py index 07f6643..0a7bee0 100644 --- a/eeggan/Visualize_Gan.py +++ b/eeggan/Visualize_Gan.py @@ -7,11 +7,11 @@ import numpy as np import torch -from eeggan.helpers import system_inputs -from eeggan.nn_architecture import models -from eeggan.helpers.dataloader import Dataloader -from eeggan.helpers.visualize_pca import visualization_dim_reduction -from eeggan.helpers.visualize_spectogram import plot_spectogram, plot_fft_hist +from helpers import system_inputs +from nn_architecture import models +from helpers.dataloader import Dataloader +from helpers.visualize_pca import visualization_dim_reduction +from helpers.visualize_spectogram import plot_spectogram, plot_fft_hist class PlotterGanTraining: """This class is used to read samples from a csv-file and plot them. diff --git a/eeggan/helpers/ddp_training.py b/eeggan/helpers/ddp_training.py index a6dff94..2609a9f 100644 --- a/eeggan/helpers/ddp_training.py +++ b/eeggan/helpers/ddp_training.py @@ -6,8 +6,8 @@ import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP -import eeggan.helpers.trainer as trainer -from eeggan.helpers.dataloader import Dataloader +import helpers.trainer as trainer +from helpers.dataloader import Dataloader class DDPTrainer(trainer.Trainer): @@ -127,10 +127,13 @@ def _ddp_training(training: DDPTrainer, opt): # save checkpoint if training.rank == 0: - path = 'trained_models' - timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') - filename = f'gan_ddp_{training.epochs}ep_' + timestamp + '.pt' - training.save_checkpoint(path_checkpoint=os.path.join(path, filename), generated_samples=gen_samples) + if opt['path_checkpoint']: + training.save_checkpoint(path_checkpoint=opt['path_checkpoint'], generated_samples=gen_samples) + else: + path = 'trained_models' + timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') + filename = f'gan_ddp_{training.epochs}ep_' + timestamp + '.pt' + training.save_checkpoint(path_checkpoint=os.path.join(path, filename), generated_samples=gen_samples) print("GAN training finished.") print("Model states and generated samples saved to file.") diff --git a/eeggan/helpers/trainer.py b/eeggan/helpers/trainer.py index 133ef0d..848c851 100644 --- a/eeggan/helpers/trainer.py +++ b/eeggan/helpers/trainer.py @@ -3,8 +3,8 @@ import torch import numpy as np -from eeggan.nn_architecture import losses, models -from eeggan.nn_architecture.losses import WassersteinGradientPenaltyLoss as Loss +from nn_architecture import losses, models +from nn_architecture.losses import WassersteinGradientPenaltyLoss as Loss # https://machinelearningmastery.com/how-to-implement-wasserstein-loss-for-generative-adversarial-networks/ # For implementation of Wasserstein-GAN see link above diff --git a/eeggan/helpers/visualize_pca.py b/eeggan/helpers/visualize_pca.py index 3a50071..e25e3c0 100644 --- a/eeggan/helpers/visualize_pca.py +++ b/eeggan/helpers/visualize_pca.py @@ -19,7 +19,7 @@ from sklearn.manifold import TSNE from sklearn.decomposition import PCA -from eeggan.helpers.dataloader import Dataloader +from helpers.dataloader import Dataloader def visualization_dim_reduction(ori_data, generated_data, analysis, save, save_name=None, perplexity=40, iterations=1000, return_result=False): diff --git a/eeggan/nn_architecture/models.py b/eeggan/nn_architecture/models.py index 800637d..0504641 100644 --- a/eeggan/nn_architecture/models.py +++ b/eeggan/nn_architecture/models.py @@ -5,7 +5,7 @@ from scipy import signal import numpy as np -from eeggan.nn_architecture.ttsgan_components import * +from nn_architecture.ttsgan_components import * # insert here all different kinds of generators and discriminators