Skip to content

Commit

Permalink
modified: separately lerning onset/velocity
Browse files Browse the repository at this point in the history
  • Loading branch information
naotokui committed Dec 21, 2019
1 parent 99fec53 commit 2c86aab
Show file tree
Hide file tree
Showing 4 changed files with 183 additions and 149 deletions.
2 changes: 1 addition & 1 deletion M4L.MelodyVAE/M4L.MelodyVAE.maxproj
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"name" : "M4L.MelodyVAE",
"version" : 1,
"creationdate" : 3656588145,
"modificationdate" : 3659695542,
"modificationdate" : 3659769440,
"viewrect" : [ 27.0, 79.0, 300.0, 500.0 ],
"autoorganize" : 1,
"hideprojectwindow" : 0,
Expand Down
88 changes: 57 additions & 31 deletions melodyvae.js
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ const vae = require('./src/vae.js');
Max.post(`Loaded the ${path.basename(__filename)} script`);

// Global varibles
var train_data = [];
var train_data_onsets = [];
var train_data_velocities = [];
var train_data_durations = [];
var isGenerating = false;

Expand Down Expand Up @@ -55,7 +56,9 @@ function getNoteIndexAndTimeshift(note, tempo){
function processPianoroll(midiFile, augmentation){
const tempo = getTempo(midiFile);

var pianorolls = [];
// data array
var onsets = [];
var velocities = [];
var durations = [];

midiFile.tracks.forEach(track => {
Expand All @@ -70,15 +73,22 @@ function processPianoroll(midiFile, augmentation){
let duration = timing[2];

// add new array
while (Math.floor(index / LOOP_DURATION) >= pianorolls.length){
pianorolls.push(utils.create2DArray(NUM_MIDI_CLASSES, LOOP_DURATION));
while (Math.floor(index / LOOP_DURATION) >= onsets.length){
onsets.push(utils.create2DArray(NUM_MIDI_CLASSES, LOOP_DURATION));
velocities.push(utils.create2DArray(NUM_MIDI_CLASSES, LOOP_DURATION));
durations.push(utils.create2DArray(NUM_MIDI_CLASSES, LOOP_DURATION));
}

let matrix = pianorolls[Math.floor(index / LOOP_DURATION)];
let note_id = note.midi - MIN_MIDI_NOTE;
matrix[note_id][index % LOOP_DURATION] = note.velocity;

// store onset
let matrix = onsets[Math.floor(index / LOOP_DURATION)];
matrix[note_id][index % LOOP_DURATION] = 1; // 1 for onsets

// store velocity
matrix = velocities[Math.floor(index / LOOP_DURATION)];
matrix[note_id][index % LOOP_DURATION] = note.velocity; // normalized 0 - 1

// store timeshift
matrix = durations[Math.floor(index / LOOP_DURATION)];
matrix[note_id][index % LOOP_DURATION] = duration;
Expand All @@ -89,35 +99,45 @@ function processPianoroll(midiFile, augmentation){

//data augmentation - with all keys
if (augmentation){
aug_pianorolls = [];
aug_onsets = [];
aug_velocities = [];
aug_durations = [];

pianorolls.forEach(function (pianoroll, i){
onsets.forEach(function (onset, i){
let velocity = velocities[i];
let duration = durations[i];
let maxv = utils.getMaxPitch(pianoroll) + MIN_MIDI_NOTE;
let minv = utils.getMinPitch(pianoroll) + MIN_MIDI_NOTE;
let maxv = utils.getMaxPitch(onset) + MIN_MIDI_NOTE;
let minv = utils.getMinPitch(onset) + MIN_MIDI_NOTE;
for (let diff = -12; diff <= 12; diff++){
if (maxv + diff <= MAX_MIDI_NOTE && minv + diff >= MIN_MIDI_NOTE){ // if it's in the transposition range...
let newroll = utils.create2DArray(NUM_MIDI_CLASSES, LOOP_DURATION);
let newonset = utils.create2DArray(NUM_MIDI_CLASSES, LOOP_DURATION);
let newvelocity = utils.create2DArray(NUM_MIDI_CLASSES, LOOP_DURATION);
let newduration = utils.create2DArray(NUM_MIDI_CLASSES, LOOP_DURATION);
for (var i = 0; i < NUM_MIDI_CLASSES; i++){
for (var j =0; j < LOOP_DURATION; j++){
if (i + diff >= 0 && i + diff < NUM_MIDI_CLASSES){
newroll[i + diff][j] = pianoroll[i][j]; // transpose
newduration[i+diff][j] = duration[i][j];
if (onset[i][j] > 0) { // only if there is onset
newonset[i + diff][j] = 1; // transpose
newvelocity[i + diff][j] = velocity[i][j];
newduration[i + diff][j] = duration[i][j];
}
}
}
}
aug_pianorolls.push(newroll);
aug_onsets.push(newonset);
aug_velocities.push(newvelocity);
aug_durations.push(newduration);
}
}
});

pianorolls.push(...aug_pianorolls);
onsets.push(...aug_onsets);
velocities.push(...aug_velocities);
durations.push(...aug_durations);
}

console.assert(onsets.length == velocities.length && velocities.length == durations.length,
"Something wrong with augmentation? array length must be the same.");
// /* for debug - output pianoroll */
// if (durations.length > 0){
// var index = utils.getRandomInt(durations.length);
Expand All @@ -131,8 +151,9 @@ function processPianoroll(midiFile, augmentation){
// }

// 2D array to tf.tensor2d
for (var i=0; i < pianorolls.length; i++){
train_data.push(tf.tensor2d(pianorolls[i], [NUM_MIDI_CLASSES, LOOP_DURATION]));
for (var i=0; i < onsets.length; i++){
train_data_onsets.push(tf.tensor2d(onsets[i], [NUM_MIDI_CLASSES, LOOP_DURATION]));
train_data_velocities.push(tf.tensor2d(velocities[i], [NUM_MIDI_CLASSES, LOOP_DURATION]));
train_data_durations.push(tf.tensor2d(durations[i], [NUM_MIDI_CLASSES, LOOP_DURATION]));
}
}
Expand Down Expand Up @@ -165,18 +186,21 @@ Max.addHandler("midi", (filename, augmentation) => {
glob(filename + '**/*.mid', {}, (err, files)=>{
if (err) console.log(err);
else {
for (var idx in files){
if (processMidiFile(files[idx],augmentation)) count += 1;
for (var idx in files){
try {
if (processMidiFile(files[idx], augmentation)) count += 1;
} catch(error) {
utils.error("failed to process " + files[idx] + " - " + error);
}
}
Max.post("# of midi files added: " + count);
utils.post("# of midi files added: " + count);
reportNumberOfBars();
}
})
} else {
if (processMidiFile(filename,augmentation)) count += 1;
Max.post("# of midi files added: " + count);
reportNumberOfBars();

}
});

Expand All @@ -188,9 +212,9 @@ Max.addHandler("train", ()=>{
}

utils.log_status("Start training...");
console.log("# of bars in training data:", train_data.length * 2);
console.log("# of bars in training data:", train_data_onsets.length * 2);
reportNumberOfBars();
vae.loadAndTrain(train_data, train_data_durations);
vae.loadAndTrain(train_data_onsets, train_data_velocities, train_data_durations);
});

// Generate a rhythm pattern
Expand All @@ -207,24 +231,24 @@ async function generatePattern(z1, z2, threshold){
if (isGenerating) return;

isGenerating = true;
let [pattern, durations] = vae.generatePattern(z1, z2);
let [onsets, velocities, durations] = vae.generatePattern(z1, z2);
Max.outlet("matrix_clear",1); // clear all

// Velocity
// For Grid
for (var i=0; i< NUM_MIDI_CLASSES; i++){
var sequence = [];
// output for matrix view
for (var j=0; j < LOOP_DURATION; j++){
var x = 0.0;
// if (pattern[i * LOOP_DURATION + j] > 0.2) x = 1;
if (pattern[i][j] > threshold){
if (onsets[i][j] > threshold){
x = 1;
Max.outlet("matrix_output", j + 1, i + 1, x); // index for live.grid starts from 1
}
}
}

// Pitch

// live.step has mono-phonic sequences (up to 16 tracks)
for (var k=0; k< 16; k++){ // 16 = number of monophonic sequence in live.step
var pitch_sequence = [];
Expand All @@ -234,10 +258,10 @@ async function generatePattern(z1, z2, threshold){

var count = 0;
for (var i=0; i< NUM_MIDI_CLASSES; i++){
if (pattern[i][j] > threshold) count++;
if (onsets[i][j] > threshold) count++; // if there is an onset
if (count > k) {
pitch_sequence.push(i + MIN_MIDI_NOTE);
velocity_sequence.push(Math.floor(pattern[i][j]*127.));
velocity_sequence.push(Math.floor(velocities[i][j]*127.));
duration_sequence.push(Math.min(Math.floor(durations[i][j]*64.), 127));
break;
}
Expand Down Expand Up @@ -265,7 +289,9 @@ async function generatePattern(z1, z2, threshold){

// Clear training data
Max.addHandler("clear_train", ()=>{
train_data = []; // clear
train_data_onsets = []; // clear
train_data_velocities = [];
train_data_timeshift = [];
reportNumberOfBars();
});

Expand Down Expand Up @@ -296,5 +322,5 @@ Max.addHandler("epochs", (e)=>{
});

function reportNumberOfBars(){
Max.outlet("train_bars", train_data.length * 2); // number of bars for training
Max.outlet("train_bars", train_data_onsets.length * 2); // number of bars for training
}
Loading

0 comments on commit 2c86aab

Please sign in to comment.