This is the code for the modified version of the RL-Tuner. The goal of this project is to include marginal probabilities provided by the Mini-CPBP solver into the reward function to enforce constraints.
RL-Tuner/preprocessing/bach_dataset_loader.py
is used to created note sequences from the Bach dataset. The path_root
represents the path to the folder containing the midi dataset. This path can be changed to process a different dataset. The quarter_note
represents the number of events in a quarter-note. This depends on the bpm of the midi files. The resolution chosen is sent to load_all_midi_files_in_folder
. Running this python file will create a pickle file containing the note sequences for the whole dataset.
RL-Tuner/preprocessing/bach_note_sequences.py
is used to convert the pickled note sequences in sequence examples used to train the RNN. Three different modes can be specified:
- melodic lines: removes the silences and any duration information (for example, the sequence [4 4 4 5 5] will become [4 5]). All the notes will have a value between 0 and 28.
- note sequences: keeps the silences and the durations are expressed with the hold token (for example, the sequence [4 4 4 5 5] will become [4 1 1 5 1]). All the notes will have a value between 2 and 30. 0 represents silences and 1 is the hold token.
- no hold: keeps the silences and the durations are expressed by repeating the note (for example, the sequence [4 4 4 5 5] will stay the same). All the notes will have a value between 1 and 29. 0 represents silences.
input_file_name
and output_file_name
are used to specify the path of the pickle file and the tfrecord file.
To train the RNN from the tfrecords, refer to https://github.com/magenta/magenta/tree/main/magenta/models/melody_rnn
A few parameters can be specified to train the DQN agent in the file script.py
:
- Seed: The random seed to make results reproducible
- Algorithm: 'q' for Q-Learning
- Reward Scaler: Multiplies the reward from the CP model before adding the RNN reward
- Reward Mode: Refer to the paper or to the function collect_reward to see available reward functions or create your own
- Restrict Domain: If True, filters the domain based on the number of violations before picking an action
- Output Every Nth: Number of iterations before evaluating the agent
- Num steps: The number of training iterations
- Num notes in composition: 32, by default
- Prime with midi: If True, starts the composition with a midi primer
- Checkpoint dir and checkpoint: Where the pretrained RNN checkpoint is