diff --git a/temporary_testing_suite/TESTNG-AE-GAN.docx b/temporary_testing_suite/TESTNG-AE-GAN.docx deleted file mode 100644 index 1bf1c55..0000000 Binary files a/temporary_testing_suite/TESTNG-AE-GAN.docx and /dev/null differ diff --git a/temporary_testing_suite/ae_visualize_main.py b/temporary_testing_suite/ae_visualize_main.py deleted file mode 100644 index 691faec..0000000 --- a/temporary_testing_suite/ae_visualize_main.py +++ /dev/null @@ -1,57 +0,0 @@ - -import numpy as np -import matplotlib.pyplot as plt -import torch -from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present -from nn_architecture.ae_networks import TransformerAutoencoder, TransformerDoubleAutoencoder, TransformerFlattenAutoencoder -from helpers.dataloader import Dataloader - -# another comment - -#User input -data_checkpoint = 'data/ganTrialElectrodeERP_p500_e18_len100.csv' -ae_checkpoint = 'trained_ae/ae_ddp_4000ep_20230824_145643.pt' - -#Load -ae_dict = torch.load(ae_checkpoint, map_location=torch.device('cuda')) -dataloader = Dataloader(data_checkpoint, col_label='Condition', channel_label='Electrode') -dataset = dataloader.get_data() -sequence_length = dataset.shape[1] - dataloader.labels.shape[1] - -#Initiate -if ae_dict['configuration']['model_class'] == 'TransformerAutoencoder': - autoencoder = TransformerAutoencoder(**ae_dict['configuration'], sequence_length=sequence_length) -elif ae_dict['configuration']['model_class'] == 'TransformerDoubleAutoencoder': - autoencoder = TransformerDoubleAutoencoder(**ae_dict['configuration'], sequence_length=sequence_length) -elif ae_dict['configuration']['model_class'] == 'TransformerFlattenAutoencoder': - autoencoder = TransformerFlattenAutoencoder(**ae_dict['configuration'], sequence_length=sequence_length) -else: - raise ValueError(f"Autoencoder class {ae_dict['configuration']['model_class']} not recognized.") -consume_prefix_in_state_dict_if_present(ae_dict['model'],'module.') -autoencoder.load_state_dict(ae_dict['model']) -autoencoder.device = torch.device('cpu') -print(ae_dict["configuration"]["history"]) - -#Test -plt.figure() -plt.plot(ae_dict['train_loss'], label='Train Loss') -plt.plot(ae_dict['test_loss'], label = 'Test Loss') -plt.title('Losses') -plt.xlabel('Epoch') -plt.legend() -plt.show() - -def norm(data): - return (data-np.min(data)) / (np.max(data) - np.min(data)) - -dataset = norm(dataset.detach().numpy()) - -fig, axs = plt.subplots(5,1) -for i in range(5): - sample = np.random.choice(len(dataset), 1) - data = dataset[sample,1:,:] - axs[i].plot(data[0,:,0], label='Original') - axs[i].plot(autoencoder.decode(autoencoder.encode(torch.from_numpy(data)))[0,:,0].detach().numpy(), label='Reconstructed') - axs[i].legend() -plt.show() - diff --git a/temporary_testing_suite/gan_visualize_erps_main.py b/temporary_testing_suite/gan_visualize_erps_main.py deleted file mode 100644 index ac5a7bf..0000000 --- a/temporary_testing_suite/gan_visualize_erps_main.py +++ /dev/null @@ -1,43 +0,0 @@ - -import pandas as pd -import matplotlib.pyplot as plt - -#### Define parameters #### -gan_type = 'gan' #gan or aegan -participants = 100 #500 or 100 -epochs = 100 #100, 1000, or 4000 -electrodes = 2 #1, 2, 8 - -#### Load generated data #### -filename = f'{gan_type}_p{participants}_ep{epochs}_e{electrodes}' -c0_syn = pd.read_csv(f'generated_samples/{filename}_c0.csv') -c1_syn = pd.read_csv(f'generated_samples/{filename}_c1.csv') -gen_data_index = 2 - -#### Load empirical data #### -filename_emp = filename.replace('aegan','ganTrialElectrodeERP').replace('gan','ganTrialElectrodeERP').replace('_ep1000','').replace('_ep100','').replace('_ep4000','')+'_len100.csv' -c_emp = pd.read_csv(f'data/{filename_emp}') -c0_emp = c_emp[c_emp['Condition']==0] -c1_emp = c_emp[c_emp['Condition']==1] -emp_data_index = 4 - -#### Plot data #### -fig, ax = plt.subplots(2,len(c0_syn['Electrode'].unique())) - -for electrode in c0_syn['Electrode'].unique(): - #Empirical - ax[0,int(electrode-1)].plot(c0_emp[c0_emp['Electrode']==electrode].mean()[emp_data_index:], label='C0') - ax[0,int(electrode-1)].plot(c1_emp[c1_emp['Electrode']==electrode].mean()[emp_data_index:], label='C1') - ax[0,int(electrode-1)].set_title(f'Empirical (E: {int(electrode)})') - ax[0,int(electrode-1)].get_xaxis().set_visible(False) - ax[0,int(electrode-1)].get_yaxis().set_visible(False) - ax[0,int(electrode-1)].spines[['right', 'top']].set_visible(False) - - #Synthetic - ax[1,int(electrode-1)].plot(c0_syn[c0_syn['Electrode']==electrode].mean()[emp_data_index:], label='C0') - ax[1,int(electrode-1)].plot(c1_syn[c1_syn['Electrode']==electrode].mean()[emp_data_index:], label='C1') - ax[1,int(electrode-1)].set_title(f'Synthetic (E: {int(electrode)})') - ax[1,int(electrode-1)].get_xaxis().set_visible(False) - ax[1,int(electrode-1)].get_yaxis().set_visible(False) - ax[1,int(electrode-1)].spines[['right', 'top']].set_visible(False) -plt.show() \ No newline at end of file