Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

doc changes #16

Merged
merged 1 commit into from
Nov 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
282 changes: 8 additions & 274 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,296 +18,30 @@ starting from supplied training data.
[LANs](https://elifesciences.org/articles/65074), although more general in potential scope of applications, were conceived in the context of sequential sampling modeling
to account for cognitive processes giving rise to *choice* and *reaction time* data in *n-alternative forced choice experiments* commonly encountered in the cognitive sciences.

For a basic tutorial on how to use the `LANfactory` package, please refer to the [basic tutorial notebook](docs/basic_tutorial/basic_tutorial.ipynb).

In this quick tutorial we will use the [`ssms`](https://github.com/AlexanderFengler/ssm_simulators) package to generate our training data using such a sequential sampling model (SSM). The use is in no way bound to utilize the `ssms` package.

#### Install

To install the `ssms` package type,

`pip install git+https://github.com/AlexanderFengler/ssm_simulators`
`pip install ssm-simulators`

To install the `LANfactory` package type,

`pip install git+https://github.com/AlexanderFengler/LANfactory`
`pip install lanfactory`

Necessary dependency should be installed automatically in the process.

#### Basic Tutorial


```python
# Load necessary packages
import ssms
import lanfactory
import os
import numpy as np
from copy import deepcopy
import torch
```

#### Generate Training Data
First we need to generate some training data. As mentioned above we will do so using the `ssms` python package, however without delving into a detailed explanation
of this package. Please refer to the [basic ssms tutorial] (https://github.com/AlexanderFengler/ssm_simulators) in case you want to learn more.


```python
# MAKE CONFIGS

# Initialize the generator config (for MLP LANs)
generator_config = deepcopy(ssms.config.data_generator_config['lan']['mlp'])
# Specify generative model (one from the list of included models mentioned above)
generator_config['model'] = 'angle'
# Specify number of parameter sets to simulate
generator_config['n_parameter_sets'] = 100
# Specify how many samples a simulation run should entail
generator_config['n_samples'] = 1000
# Specify folder in which to save generated data
generator_config['output_folder'] = 'data/lan_mlp/'

# Make model config dict
model_config = ssms.config.model_config['angle']
```

```python
# MAKE DATA

my_dataset_generator = ssms.dataset_generators.data_generator(generator_config = generator_config,
model_config = model_config)

training_data = my_dataset_generator.generate_data_training_uniform(save = True)
```

n_cpus used: 6
checking: data/lan_mlp/
simulation round: 1 of 10
simulation round: 2 of 10
simulation round: 3 of 10
simulation round: 4 of 10
simulation round: 5 of 10
simulation round: 6 of 10
simulation round: 7 of 10
simulation round: 8 of 10
simulation round: 9 of 10
simulation round: 10 of 10
Writing to file: data/lan_mlp/training_data_0_nbins_0_n_1000/angle/training_data_angle_ef5b9e0eb76c11eca684acde48001122.pickle


#### Prepare for Training

Next we set up dataloaders for training with pytorch. The `LANfactory` uses custom dataloaders, taking into account particularities of the expected training data.
Specifically, we expect to receive a bunch of training data files (the present example generates only one), where each file hosts a large number of training examples.
So we want to define a dataloader which spits out batches from data with a specific training data file, and keeps checking when to load in a new file.
The way this is implemented here, is via the `DatasetTorch` class in `lanfactory.trainers`, which inherits from `torch.utils.data.Dataset` and prespecifies a `batch_size`. Finally this is supplied to a [`DataLoader`](https://pytorch.org/docs/stable/data.html), for which we keep the `batch_size` argument at 0.

The `DatasetTorch` class is then called as an iterator via the DataLoader and takes care of batching as well as file loading internally.

You may choose your own way of defining the `DataLoader` classes, downstream you are simply expected to supply one.


```python
# MAKE DATALOADERS

# List of datafiles (here only one)
folder_ = 'data/lan_mlp/training_data_0_nbins_0_n_1000/angle/'
file_list_ = [folder_ + file_ for file_ in os.listdir(folder_)]

# Training dataset
torch_training_dataset = lanfactory.trainers.DatasetTorch(file_IDs = file_list_,
batch_size = 128)

torch_training_dataloader = torch.utils.data.DataLoader(torch_training_dataset,
shuffle = True,
batch_size = None,
num_workers = 1,
pin_memory = True)

# Validation dataset
torch_validation_dataset = lanfactory.trainers.DatasetTorch(file_IDs = file_list_,
batch_size = 128)

torch_validation_dataloader = torch.utils.data.DataLoader(torch_validation_dataset,
shuffle = True,
batch_size = None,
num_workers = 1,
pin_memory = True)
```

Now we define two configuration dictionariers,

1. The `network_config` dictionary defines the architecture and properties of the network
2. The `train_config` dictionary defines properties concerning training hyperparameters

Two examples (which we take as provided by the package, but which you can adjust according to your needs) are provided below.


```python
# SPECIFY NETWORK CONFIGS AND TRAINING CONFIGS

network_config = lanfactory.config.network_configs.network_config_mlp

print('Network config: ')
print(network_config)

train_config = lanfactory.config.network_configs.train_config_mlp

print('Train config: ')
print(train_config)
```

Network config:
{'layer_types': ['dense', 'dense', 'dense'], 'layer_sizes': [100, 100, 1], 'activations': ['tanh', 'tanh', 'linear'], 'loss': ['huber'], 'callbacks': ['checkpoint', 'earlystopping', 'reducelr']}
Train config:
{'batch_size': 128, 'n_epochs': 10, 'optimizer': 'adam', 'learning_rate': 0.002, 'loss': 'huber', 'save_history': True, 'metrics': [<keras.losses.MeanSquaredError object at 0x12c403d30>, <keras.losses.Huber object at 0x12c1c78e0>], 'callbacks': ['checkpoint', 'earlystopping', 'reducelr']}


We can now load a network, and save the configuration files for convenience.


```python
# LOAD NETWORK
net = lanfactory.trainers.TorchMLP(network_config = deepcopy(network_config),
input_shape = torch_training_dataset.input_dim,
save_folder = '/data/torch_models/',
generative_model_id = 'angle')

# SAVE CONFIGS
lanfactory.utils.save_configs(model_id = net.model_id + '_torch_',
save_folder = 'data/torch_models/angle/',
network_config = network_config,
train_config = train_config,
allow_abs_path_folder_generation = True)
```

To finally train the network we supply our network, the dataloaders and training config to the `ModelTrainerTorchMLP` class, from `lanfactory.trainers`.


```python
# TRAIN MODEL
model_trainer.train_model(save_history = True,
save_model = True,
verbose = 0)
```

Epoch took 0 / 10, took 11.54538607597351 seconds
epoch 0 / 10, validation_loss: 0.3431
Epoch took 1 / 10, took 13.032279014587402 seconds
epoch 1 / 10, validation_loss: 0.2732
Epoch took 2 / 10, took 12.421074867248535 seconds
epoch 2 / 10, validation_loss: 0.1941
Epoch took 3 / 10, took 12.097641229629517 seconds
epoch 3 / 10, validation_loss: 0.2028
Epoch took 4 / 10, took 12.030233144760132 seconds
epoch 4 / 10, validation_loss: 0.184
Epoch took 5 / 10, took 12.695374011993408 seconds
epoch 5 / 10, validation_loss: 0.1433
Epoch took 6 / 10, took 12.177874326705933 seconds
epoch 6 / 10, validation_loss: 0.1115
Epoch took 7 / 10, took 11.908828258514404 seconds
epoch 7 / 10, validation_loss: 0.1084
Epoch took 8 / 10, took 12.066670179367065 seconds
epoch 8 / 10, validation_loss: 0.0864
Epoch took 9 / 10, took 12.37562108039856 seconds
epoch 9 / 10, validation_loss: 0.07484
Saving training history
Saving model state dict
Training finished successfully...


#### Load Model for Inference and Call

The `LANfactory` provides some convenience functions to use networks for inference after training.
We can load a model using the `LoadTorchMLPInfer` class, which then allows us to run fast inference via either
a direct call, which expects a `torch.tensor` as input, or the `predict_on_batch` method, which expects a `numpy.array`
of `dtype`, `np.float32`.


```python
network_path_list = os.listdir('data/torch_models/angle')
network_file_path = ['data/torch_models/angle/' + file_ for file_ in network_path_list if 'state_dict' in file_][0]

network = lanfactory.trainers.LoadTorchMLPInfer(model_file_path = network_file_path,
network_config = network_config,
input_dim = torch_training_dataset.input_dim)
```


```python

# Two ways to call the network

# Direct call --> need tensor input
direct_out = network(torch.from_numpy(np.array([1, 1.5, 0.5, 1.0, 0.1, 0.65, 1], dtype = np.float32)))
print('direct call out: ', direct_out)

# predict_on_batch method
predict_on_batch_out = network.predict_on_batch(np.array([1, 1.5, 0.5, 1.0, 0.1, 0.65, 1], dtype = np.float32))
print('predict_on_batch out: ', predict_on_batch_out)

```

direct call out: tensor([-16.4997])
predict_on_batch out: [-16.499687]


#### A peek into the first passage distribution computed by the network

We can compare the learned likelihood function in our `network` with simulation data from the underlying generative model.
For this purpose we recruit the [`ssms`](https://github.com/AlexanderFengler/ssm_simulators) package again.


```python
import pandas as pd
import matplotlib.pyplot as plt

data = pd.DataFrame(np.zeros((2000, 7), dtype = np.float32), columns = ['v', 'a', 'z', 't', 'theta', 'rt', 'choice'])
data['v'] = 0.5
data['a'] = 0.75
data['z'] = 0.5
data['t'] = 0.2
data['theta'] = 0.1
data['rt'].iloc[:1000] = np.linspace(5, 0, 1000)
data['rt'].iloc[1000:] = np.linspace(0, 5, 1000)
data['choice'].iloc[:1000] = -1
data['choice'].iloc[1000:] = 1

# Network predictions
predict_on_batch_out = network.predict_on_batch(data.values.astype(np.float32))

# Simulations
from ssms.basic_simulators import simulator
sim_out = simulator(model = 'angle',
theta = data.values[0, :-2],
n_samples = 2000)
```


```python
# Plot network predictions
plt.plot(data['rt'] * data['choice'], np.exp(predict_on_batch_out), color = 'black', label = 'network')

# Plot simulations
plt.hist(sim_out['rts'] * sim_out['choices'], bins = 30, histtype = 'step', label = 'simulations', color = 'blue', density = True)
plt.legend()
plt.title('SSM likelihood')
plt.xlabel('rt')
plt.ylabel('likelihod')
```




Text(0, 0.5, 'likelihod')





![png](basic_tutorial_files/basic_tutorial_22_1.png)

### Basic Tutorial

Check the basic tutorial [here](docs/basic_tutorial/basic_tutorial.ipynb).

### TorchMLP to ONNX Converter

Once you have trained your model, you can convert it to the ONNX format using the `transform_onnx.py` script.

The `transform_onnx.py` script converts a TorchMLP model to the ONNX format. It takes a network configuration file (in pickle format), a state dictionary file (Torch model weights), the size of the input tensor, and the desired output ONNX file path.

### Usage
Expand Down
2 changes: 1 addition & 1 deletion docs/overrides/main.html
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
Navigate the site here!
</span>
<span class="right-margin">
v0.4.4 is out!
v0.4.7 is out!
</span>
<span>
<span class="twemoji">
Expand Down
2 changes: 1 addition & 1 deletion lan_pipeline_env.yml
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ dependencies:
- jupyterlab-server==2.25.0
- jupyterlab-widgets==3.0.9
- kiwisolver==1.4.5
- lanfactory==0.4.1
- lanfactory==0.4.7
- markdown-it-py==3.0.0
- markupsafe==2.1.2
- matplotlib==3.8.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ dependencies:
- jupyterlab-server==2.25.0
- jupyterlab-widgets==3.0.9
- kiwisolver==1.4.5
- lanfactory==6
- lanfactory==0.4.7
- markdown-it-py==3.0.0
- markupsafe==2.1.2
- matplotlib==3.8.0
Expand Down