Skip to content

Commit

Permalink
add code
Browse files Browse the repository at this point in the history
  • Loading branch information
FelipeMarra committed Dec 10, 2024
1 parent 82c4eba commit ebd968f
Show file tree
Hide file tree
Showing 48 changed files with 8,421 additions and 2 deletions.
9 changes: 9 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,12 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

#Pytorch Pickle Numpy
*.pickle
*.pth
*.npy

# Bardo Results
experiments/cotw/results/
experiments/osni/results/
33 changes: 31 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,31 @@
# babel-bardo
Coming Soon
# Babel Bardo
## Introdution
Babel bardo is a system designed to generate music for Tabletop Role-Playing Games (TRPGs) in real-time. The system works in a 30 seconds window, by executing the following procedure:

```
For each 30s of gamplay:
Extract the players dialogs transcriptions with a Speech Recognition (SR) system
Use a LLM to transform the transcription in a music description
Feed a Text-to-music model (TTM) with the music description
Play the generated piece of music
```

A visual representation of the system can be seen in the Figure 1.

![Figure 1. And overview o the Babel Bardo system](/assets/bardo_overview.png)

By prompting the LLM in different ways we obtained different versions of the system. For more details head towards the paper here. The following list presents the nomeclature difference between the systems in the paper and the ones presented in this repository:

* Babel Bardo - Baseline (B): Bardo 1
* Babel Bardo - Emotion (E): Bardo 0
* Babel Bardo - Description (D): Bardo 2
* Babel Bardo - Description Continuation (DC): Bardo 3

## Installation
#TODO collab

## Usage
#TODO

## Techinical Details
#TODO
Binary file added assets/bardo_overview.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
154 changes: 154 additions & 0 deletions experiments/cotw/cotw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
#! python3 -m pip uninstall pt_brdo -y && python3 -m pip install --no-dependencies -e .
import os
import json
import pathlib
import shutil
from pytubefix import Playlist

from babel_bardo.templates import *
from babel_bardo import Bardo, fit_audio_in_video
from babel_bardo.eval_metrics import get_kld, get_kld_for_transitions, get_fad_vggish, get_kld_for_segments_transitions


RPGNAME = 'Call Of The Wild'
PLAYLIST = 'https://www.youtube.com/playlist?list=PLMZlu4rxEyKKl-Ecgca3bbVDMZ89SaQHN'
EXCERPT_LENGTH = 60 * 30
BARDO_ROOT_PATH = pathlib.Path(__file__).parent.joinpath('results').resolve()
SOUNDTRACK_PATH = None # path for FAD background statistics

playlist = Playlist(PLAYLIST)
playlist = [(video.video_id, video.length) for video in playlist.videos]

no_sub_vids = ['d4WEJ2thu4E']

def write_eps_start(template:BardoTemplate, eps_start:dict):
if not os.path.isdir(template.root_path):
os.makedirs(template.root_path)

with open(template.eps_start_file, "w") as json_file:
json.dump(eps_start, json_file, indent=4)

def load_eps_start(template:BardoTemplate, playlist) -> dict:
eps_start = None

if os.path.isfile(template.eps_start_file):
with open(template.eps_start_file, 'r') as json_file:
eps_start = json.load(json_file)
else:
eps_start = {video_id:None for video_id, _ in playlist}

return eps_start

def write_metrics(template:BardoTemplate, metrics:dict, type:str):
file = ''
if type == 'kld':
file = template.kld_file
elif type == 'fad':
file = template.fad_file
elif type == 't_kld':
file = template.transitions_kld_metrics_file
elif type == 's_t_kld':
file = template.segments_transitions_kld_metrics_file

with open(file, "w") as json_file:
json.dump(metrics, json_file, indent=4)

def load_metrics(path_to_file:str) -> dict:
metrics = None

if os.path.isfile(path_to_file):
with open(path_to_file, 'r') as json_file:
metrics = json.load(json_file)
else:
metrics = {}

return metrics

wav_background_path = ''

for idx_v, video in enumerate(playlist):
video_id, video_length = video

bardo_templates:list[BardoTemplate] = [
Bardo1(RPGNAME, BARDO_ROOT_PATH, video_id, translate=False), #Bardo1 commes 1st because it don't use Ollama
Bardo0(RPGNAME, BARDO_ROOT_PATH, video_id),
Bardo2(RPGNAME, BARDO_ROOT_PATH, video_id),
Bardo3(RPGNAME, BARDO_ROOT_PATH, video_id)
]

if video_id in no_sub_vids:
continue

for idx_t, template in enumerate(bardo_templates):
is_last_template = idx_t == len(bardo_templates) -1

# Set the same start for the same episode in set_random_excerpt
eps_start = load_eps_start(template, playlist)

if eps_start[video_id] != None:
print(f"\n LOADED eps_start {eps_start[video_id]} for {video_id} \n")

eps_start[video_id] = template.set_random_excerpt(EXCERPT_LENGTH, video_length, eps_start[video_id])

print(f"\n Video {video_id} is starting at {eps_start[video_id]}")
write_eps_start(template, eps_start)

print(template.log_header)

# Bardo Play
if not os.path.isfile(template.generated_audio_file):
bardo = Bardo(template)
bardo.play()
else:
print("Skipping generating", template.generated_audio_file)

fit_audio_in_video(template, video_id)

# Get Metrics
# KLD
kld_metrics = load_metrics(template.kld_file)

if kld_metrics.get(template.bardo_name) == None:
kld_data = get_kld(template.original_audios_path, template.original_audio_file, template.generated_audio_file)

print(f"\n MEAN KLD for {template.bardo_name}, video {video_id} = {kld_data['mean']} \n")

kld_metrics[template.bardo_name] = {k:str(v) for k,v in kld_data.items()}
write_metrics(template, kld_metrics, 'kld')

# KLD for segments transitions
s_t_kld_metrics = load_metrics(template.segments_transitions_kld_metrics_file)

if s_t_kld_metrics.get(template.bardo_name) == None:
s_t_kld_data = get_kld_for_segments_transitions(template.generated_audio_file)

print(f"\n MEAN SEGMENT TRANSITION KLD for {template.bardo_name}, video {video_id} = {s_t_kld_data['mean']} \n")

s_t_kld_metrics[template.bardo_name] = {k:str(v) for k,v in s_t_kld_data.items()}
write_metrics(template, s_t_kld_metrics, 's_t_kld')

# KLD for transitions
t_kld_metrics = load_metrics(template.transitions_kld_metrics_file)

if t_kld_metrics.get(template.bardo_name) == None:
t_kld_data = get_kld_for_transitions(template.original_audios_path, template.original_audio_file, template.generated_audio_file, template.transitions_eval_audios_path)

print(f"\n MEAN TRANSITION KLD for {template.bardo_name}, video {video_id} = {t_kld_data['mean']} \n")

t_kld_metrics[template.bardo_name] = {k:str(v) for k,v in t_kld_data.items()}
write_metrics(template, t_kld_metrics, 't_kld')

#FAD
if SOUNDTRACK_PATH != None:
fad_metrics = load_metrics(template.fad_file)
if fad_metrics.get(template.bardo_name) == None:
overall_fad, ep_fad, wav_background_path = get_fad_vggish(SOUNDTRACK_PATH, template.generated_audios_path, template.generated_audio_file, False)

fad_metrics[template.bardo_name] = {}
fad_metrics[template.bardo_name]['ep'] = str(ep_fad)
fad_metrics[template.bardo_name]['overall'] = str(overall_fad)

print(f"\n FAD for {template.bardo_name} = ep: {ep_fad}, overall: {overall_fad} \n")
write_metrics(template, fad_metrics, 'fad')

shutil.rmtree(wav_background_path)
8 changes: 8 additions & 0 deletions experiments/eval/download_soudtrack.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from babel_bardo.video_manager import get_audio_playlist_for_fad

DOWNLOAD_PATH = '/home/felipe/Desktop/movies_soundtracks/inter'
PLAYLIST = 'https://www.youtube.com/playlist?list=PLco_u-O9FeQ_cV5gc3VdUHoQYBI73MYkU'

print(DOWNLOAD_PATH)

get_audio_playlist_for_fad(DOWNLOAD_PATH, PLAYLIST)
77 changes: 77 additions & 0 deletions experiments/eval/fad_metrics_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
#%% Imports & Constants
import os
import json

import pandas as pd
import matplotlib.pyplot as plt

RPG = 'cotw'
METRICS_FOLDER = f'/home/felipe/Documents/Github/Pt-Brdo/experiments/{RPG}/results/metrics'
EPS_START_FILE = f'/home/felipe/Documents/Github/Pt-Brdo/experiments/{RPG}/results/eps_start.json'

BARDOS = ['bardo_1', 'bardo_0', 'bardo_2', 'bardo_3']

#%%
def load_dict(path:str) -> dict:
with open(path, 'r') as json_file:
my_dict = json.load(json_file)

return my_dict

#%%
data_dict = {bardo:[] for bardo in BARDOS}
data_index = []

eps_start = load_dict(EPS_START_FILE)
last_video = list(eps_start.keys())[-1]
videos_ids = list(eps_start.items())
valid_vides_ids = [video_id for video_id, start_time in videos_ids if start_time != None]

latex_table = ""
overall_row = ""
added_overall = False

for v_idx, video_id in enumerate(valid_vides_ids):
metrics_file = os.path.join(METRICS_FOLDER, f"fad_{video_id}.json")

#data_index.append(f"Ep {idx+1} ({video_id}) KLD Mean|Std:")
latex_table += "\multicolumn{1}{c}{\\textbf{"+f"{v_idx+1}"+"}} & "
data_index.append(f"Ep {v_idx+1}")

for b_idx, bardo in enumerate(BARDOS):
video_metrics = load_dict(metrics_file)
fad_ep = round(float(video_metrics[bardo]['ep']), 2)

if b_idx == len(BARDOS)-1:
latex_table += f"{fad_ep} \\\\ \n"
else:
latex_table += f"{fad_ep} & "

data_dict[bardo].append(f"{fad_ep}")

if video_id == last_video:
fad_overall = round(float(video_metrics[bardo]['overall']), 2)
data_dict[bardo].append(fad_overall)

if not added_overall:
overall_row += "\multicolumn{1}{c}{\\textbf{Overall}} &"
added_overall = True

if b_idx == len(BARDOS)-1:
overall_row += f"{fad_overall} \\\\ \n"
else:
overall_row += f"{fad_overall} & "

if video_id == last_video:
data_index.append("Overall:")

latex_table += overall_row

# %%
print(latex_table)

# %%
df = pd.DataFrame(data_dict, index=data_index)

#df
df.style.set_caption("Mean KLD for COTW").highlight_min(color='blue', axis=1)
59 changes: 59 additions & 0 deletions experiments/eval/kld_metrics_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#%% Imports & Constants
import os
import json

import pandas as pd
import matplotlib.pyplot as plt

RPG = 'cotw'
EPS_START_FILE = f'/home/felipe/Documents/Github/Pt-Brdo/experiments/{RPG}/results/eps_start.json'

#METRICS_FOLDER = f'/home/felipe/Documents/Github/Pt-Brdo/experiments/{RPG}/results/metrics'
#METRICS_FOLDER = f'/home/felipe/Documents/Github/Pt-Brdo/experiments/{RPG}/results/transitions_eval/metrics'
METRICS_FOLDER = f'/home/felipe/Documents/Github/Pt-Brdo/experiments/{RPG}/results/segments_transitions_eval'

BARDOS = ['bardo_1', 'bardo_0', 'bardo_2', 'bardo_3']

#%%
def load_dict(path:str) -> dict:
with open(path, 'r') as json_file:
my_dict = json.load(json_file)

return my_dict

#%%
data_dict = {bardo:[] for bardo in BARDOS}
data_index = []

eps_start = load_dict(EPS_START_FILE)
videos_ids = list(eps_start.items())
valid_vides_ids = [video_id for video_id, start_time in videos_ids if start_time != None]

latex_table = ""

for v_idx, video_id in enumerate(valid_vides_ids):
metrics_file = os.path.join(METRICS_FOLDER, f"s_t_kld_{video_id}.json")

#data_index.append(f"Ep {idx+1} ({video_id}) KLD Mean|Std:")
latex_table += "\multicolumn{1}{c}{\\textbf{"+f"{v_idx+1}"+"}} & "
data_index.append(f"Ep {v_idx+1}")

for b_idx, bardo in enumerate(BARDOS):
video_metrics = load_dict(metrics_file)
kld_mean = round(float(video_metrics[bardo]['mean']), 2)
kld_std = round(float(video_metrics[bardo]['std']), 2)

if b_idx == len(BARDOS)-1:
latex_table += f"{kld_mean}$\pm${kld_std} \\\\ \n"
else:
latex_table += f"{kld_mean}$\pm${kld_std} & "

data_dict[bardo].append(f"{kld_mean}$\pm${kld_std}")

# %%
print(latex_table)
# %%
df = pd.DataFrame(data_dict, index=data_index)

#df
df.style.set_caption("Mean KLD for COTW").highlight_min(color='blue', axis=1)
Loading

0 comments on commit ebd968f

Please sign in to comment.