Skip to content

Commit

Permalink
Merge pull request #73 from AutoResearch/58-comparison-main-pip-CW
Browse files Browse the repository at this point in the history
58 comparison main pip cw
  • Loading branch information
chadcwilliams authored Feb 9, 2024
2 parents 876463d + 6c4cb40 commit 72accb1
Show file tree
Hide file tree
Showing 19 changed files with 5,208 additions and 744 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ __pycache__
trained_ae
trained_models
plots
generated_samples
generated_samples
data
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ Launching the desktop will take you to a virtual desktop. Open terminal and navi

## Load modules
```
module load python/3.9.0
module load gcc/10.2
module load cuda/11.7.1
module load cudnn/8.2.0
Module load cuda/11.8.0-lpttyok
Module load cudnn/8.7.0.84-11.8-lg2dpd5
Module load gcc/10.1.0-mojgbnp
```

## Create and activate virtual environment
Expand Down
123 changes: 80 additions & 43 deletions autoencoder_training_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ def main():
'save_name': default_args['save_name'],
'target': default_args['target'],
'sample_interval': default_args['sample_interval'],
# 'conditions': default_args['conditions'],
'channel_label': default_args['channel_label'],
'channels_out': default_args['channels_out'],
'timeseries_out': default_args['timeseries_out'],
Expand All @@ -50,7 +49,6 @@ def main():
'num_layers': default_args['num_layers'],
'ddp': default_args['ddp'],
'ddp_backend': default_args['ddp_backend'],
# 'n_conditions': len(default_args['conditions']) if default_args['conditions'][0] != '' else 0,
'norm_data': True,
'std_data': False,
'diff_data': False,
Expand All @@ -64,18 +62,11 @@ def main():
# ----------------------------------------------------------------------------------------------------------------------
# Load, process, and split data
# ----------------------------------------------------------------------------------------------------------------------

# Scale function -> Not necessary; already in dataloader -> param: norm_data=True
# def scale(dataset):
# x_min, x_max = dataset.min(), dataset.max()
# return (dataset-x_min)/(x_max-x_min)

data = Dataloader(path=opt['path_dataset'],
channel_label=opt['channel_label'], kw_timestep=opt['kw_timestep'],
norm_data=opt['norm_data'], std_data=opt['std_data'], diff_data=opt['diff_data'],)
dataset = data.get_data()
# dataset = dataset[:, opt['n_conditions']:, :].to(opt['device']) #Remove labels
# dataset = scale(dataset)

# Split data function
def split_data(dataset, train_size=.8):
Expand All @@ -92,6 +83,8 @@ def split_data(dataset, train_size=.8):
# Determine n_channels, output_dim, and seq_length
opt['n_channels'] = dataset.shape[-1]
opt['sequence_length'] = dataset.shape[1]
opt['channels_in'] = opt['n_channels']
opt['timeseries_in'] = opt['sequence_length']

# Split dataset and convert to pytorch dataloader class
test_dataset, train_dataset = split_data(dataset, opt['train_ratio'])
Expand All @@ -106,7 +99,6 @@ def split_data(dataset, train_size=.8):
model_dict = None
if default_args['load_checkpoint'] and os.path.isfile(opt['path_checkpoint']):
model_dict = torch.load(opt['path_checkpoint'])
# model_state = model_dict['state_dict']

target_old = opt['target']
channels_out_old = opt['channels_out']
Expand All @@ -123,45 +115,54 @@ def split_data(dataset, train_size=.8):
print(f"channels_out:\t{channels_out_old} -> {opt['channels_out']}")
print(f"timeseries_out:\t{timeseries_out_old} -> {opt['timeseries_out']}")
print('-----------------------------------\n')
# print(f"Target: {opt['target']}")
# if (opt['target'] == 'channels') | (opt['target'] == 'full'):
# print(f"channels_out: {opt['channels_out']}")
# if (opt['target'] == 'timeseries') | (opt['target'] == 'full'):
# print(f"timeseries_out: {opt['timeseries_out']}")
# print('-----------------------------------\n')

elif default_args['load_checkpoint'] and not os.path.isfile(opt['path_checkpoint']):
raise FileNotFoundError(f"Checkpoint file {opt['path_checkpoint']} not found.")

# Add parameters for tracking
opt['input_dim'] = opt['n_channels'] if opt['target'] in ['channels', 'full'] else opt['sequence_length']
opt['output_dim'] = opt['channels_out'] if opt['target'] in ['channels', 'full'] else opt['n_channels']
opt['output_dim_2'] = opt['sequence_length'] if opt['target'] in ['channels'] else opt['timeseries_out']
opt['output_dim'] = opt['channels_out'] if opt['target'] in ['channels', 'full'] else opt['timeseries_out']
opt['output_dim_2'] = opt['sequence_length'] if opt['target'] in ['channels'] else opt['n_channels']

if opt['target'] == 'channels':
model = TransformerAutoencoder(input_dim=opt['n_channels'],
output_dim=opt['channels_out'],
output_dim_2=opt['sequence_length'],
model = TransformerAutoencoder(input_dim=opt['input_dim'],
output_dim=opt['output_dim'],
output_dim_2=opt['output_dim_2'],
target=TransformerAutoencoder.TARGET_CHANNELS,
hidden_dim=opt['hidden_dim'],
num_layers=opt['num_layers'],
num_heads=opt['num_heads'],).to(opt['device'])
num_heads=opt['num_heads'],
activation=opt['activation']).to(opt['device'])
elif opt['target'] == 'time':
model = TransformerAutoencoder(input_dim=opt['sequence_length'],
output_dim=opt['timeseries_out'],
output_dim_2=opt['n_channels'],
model = TransformerAutoencoder(input_dim=opt['input_dim'],
output_dim=opt['output_dim'],
output_dim_2=opt['output_dim_2'],
target=TransformerAutoencoder.TARGET_TIMESERIES,
hidden_dim=opt['hidden_dim'],
num_layers=opt['num_layers'],
num_heads=opt['num_heads'],).to(opt['device'])
num_heads=opt['num_heads'],
activation=opt['activation']).to(opt['device'])
elif opt['target'] == 'full':
model = TransformerDoubleAutoencoder(input_dim=opt['n_channels'],
output_dim=opt['output_dim'],
output_dim_2=opt['output_dim_2'],
sequence_length=opt['sequence_length'],
model_1 = TransformerDoubleAutoencoder(channels_in=opt['channels_in'],
timeseries_in=opt['timeseries_in'],
channels_out=opt['channels_out'],
timeseries_out=opt['timeseries_out'],
hidden_dim=opt['hidden_dim'],
num_layers=opt['num_layers'],
num_heads=opt['num_heads'],
activation=opt['activation'],
training_level=1).to(opt['device'])

model_2 = TransformerDoubleAutoencoder(channels_in=opt['channels_in'],
timeseries_in=opt['timeseries_in'],
channels_out=opt['channels_out'],
timeseries_out=opt['timeseries_out'],
hidden_dim=opt['hidden_dim'],
num_layers=opt['num_layers'],
num_heads=opt['num_heads'],).to(opt['device'])
num_heads=opt['num_heads'],
activation=opt['activation'],
training_level=2).to(opt['device'])

else:
raise ValueError(f"Encode target '{opt['target']}' not recognized, options are 'channels', 'time', or 'full'.")

Expand All @@ -179,30 +180,66 @@ def split_data(dataset, train_size=.8):

opt['history'] = history

training_levels = 2 if opt['target'] == 'full' else 1

opt['training_levels'] = training_levels

if opt['ddp']:
trainer = AEDDPTrainer(model, opt)
if default_args['load_checkpoint']:
trainer.load_checkpoint(default_args['path_checkpoint'])
mp.spawn(run, args=(opt['world_size'], find_free_port(), opt['ddp_backend'], trainer, opt),
nprocs=opt['world_size'], join=True)
for training_level in range(1,training_levels+1):
if training_levels == 2 and training_level == 1:
print('Training the first level of the autoencoder...')
model = model_1
elif training_levels == 2 and training_level == 2:
print('Training the second level of the autoencoder...')
model = model_2
trainer = AEDDPTrainer(model, opt)
if default_args['load_checkpoint']:
trainer.load_checkpoint(default_args['path_checkpoint'])
mp.spawn(run, args=(opt['world_size'], find_free_port(), opt['ddp_backend'], trainer, opt),
nprocs=opt['world_size'], join=True)

if training_levels == 2 and training_level == 1:
model_1 = trainer.model
model_2.model_1 = model_1
model_2.model_1.eval()

elif training_levels == 2 and training_level == 2:
model_2 = trainer.model
else:
trainer = AETrainer(model, opt)
if default_args['load_checkpoint']:
trainer.load_checkpoint(default_args['path_checkpoint'])
samples = trainer.training(train_dataloader, test_dataloader)
model = trainer.model
for training_level in range(1,training_levels+1):
opt['training_level'] = training_level

if training_levels == 2 and training_level == 1:
print('Training the first level of the autoencoder...')
model = model_1
elif training_levels == 2 and training_level == 2:
print('Training the second level of the autoencoder...')
model = model_2
trainer = AETrainer(model, opt)
if default_args['load_checkpoint']:
trainer.load_checkpoint(default_args['path_checkpoint'])
samples = trainer.training(train_dataloader, test_dataloader)

if training_levels == 2 and training_level == 1:
model_1 = trainer.model
model_2.model_1 = model_1
model_2.model_1.eval()

elif training_levels == 2 and training_level == 2:
model_2 = trainer.model

model = trainer.model

print("Training finished.")

# ----------------------------------------------------------------------------------------------------------------------
# Save autoencoder
# ----------------------------------------------------------------------------------------------------------------------

# Save model
# model_dict = dict(state_dict=model.state_dict(), config=model.config)
if opt['save_name'] is None:
fn = opt['path_dataset'].split('/')[-1].split('.csv')[0]
opt['save_name'] = os.path.join("trained_ae", f"ae_{fn}_{str(time.time()).split('.')[0]}.pt")
# save(model_dict, save_name)

trainer.save_checkpoint(opt['save_name'], update_history=True, samples=samples)
print(f"Model and configuration saved in {opt['save_name']}")
Expand Down
Loading

0 comments on commit 72accb1

Please sign in to comment.