This is the code for the modified version of the RL-Tuner. The goal of this project is to include marginal probabilities provided by the Mini-CPBP solver into the reward function to enforce constraints.
is used to created note sequences from the Bach dataset. The path_root
represents the path to the folder containing the midi dataset. This path can be changed to process a different dataset. The quarter_note
represents the number of events in a quarter-note. This depends on the bpm of the midi files. The resolution chosen is sent to load_all_midi_files_in_folder
. Running this python file will create a pickle file containing the note sequences for the whole dataset.
is used to convert the pickled note sequences in sequence examples used to train the RNN. Three different modes can be specified:
- melodic lines: removes the silences and any duration information (for example, the sequence [4 4 4 5 5] will become [4 5]). All the notes will have a value between 0 and 28.
- note sequences: keeps the silences and the durations are expressed with the hold token (for example, the sequence [4 4 4 5 5] will become [4 1 1 5 1]). All the notes will have a value between 2 and 30. 0 represents silences and 1 is the hold token.
- no hold: keeps the silences and the durations are expressed by repeating the note (for example, the sequence [4 4 4 5 5] will stay the same). All the notes will have a value between 1 and 29. 0 represents silences.
and output_file_name
are used to specify the path of the pickle file and the tfrecord file.
To train the RNN from the tfrecords, refer to
A few parameters can be specified to train the DQN agent in the file
- Seed: The random seed to make results reproducible
- Algorithm: 'q' for Q-Learning
- Reward Scaler: Multiplies the reward from the CP model before adding the RNN reward
- Reward Mode: Refer to the paper or to the function collect_reward to see available reward functions or create your own
- Restrict Domain: If True, filters the domain based on the number of violations before picking an action
- Output Every Nth: Number of iterations before evaluating the agent
- Num steps: The number of training iterations
- Num notes in composition: 32, by default
- Prime with midi: If True, starts the composition with a midi primer
- Checkpoint dir and checkpoint: Where the pretrained RNN checkpoint is