Skip to content

Commit

Permalink
add demo
Browse files Browse the repository at this point in the history
  • Loading branch information
UnsignedByte committed Mar 16, 2024
1 parent 4d5d00a commit 44a970c
Show file tree
Hide file tree
Showing 4 changed files with 205 additions and 1 deletion.
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,6 @@ pytest-xdist==3.5.0
psutil
fabric
halo
pyyaml
pyyaml
matplotlib
librosa
3 changes: 3 additions & 0 deletions src/fft/demos/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
*.wav
*.png
*.pkl
112 changes: 112 additions & 0 deletions src/fft/demos/model_comparison.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import matplotlib.pyplot as plt
from matplotlib import gridspec
import numpy as np
from halo import Halo
from src.fft.demos.utils import spectrogram, plot_spectrogram, numpy_fft, mk_hard_fft
import librosa
from os import path
import multiprocessing as mp


def run_spectrogram(f, data, sample_rate, n_samples, n_overlap):
if f[0] == "numpy":
f = numpy_fft
else:
f = mk_hard_fft(*f[1])
results, bins = spectrogram(f, data, sample_rate, n_samples, n_overlap)
return results, bins


if __name__ == "__main__":
# Check if the spectrograms have already been generated
if path.exists(path.join(path.dirname(__file__), "spectrogram_results.pkl")):

spinner = Halo(text="Generating spectrograms", spinner="dots")
spinner.start()
with open(
path.join(path.dirname(__file__), "spectrogram_results.pkl"), "rb"
) as f:
import pickle

results = pickle.load(f)
spinner.succeed("Spectrograms loaded")
else:
sample_rate = 44800

wav_file = path.join(path.dirname(__file__), "test.wav")
spinner = Halo(text="Loading audio file", spinner="dots")
spinner.start()
data = librosa.load(wav_file, sr=sample_rate, mono=True)[0]
spinner.succeed("Audio file loaded")

spinner = Halo(text="Generating spectrograms", spinner="dots")
spinner.start()

# Generate all the spectrograms in parallel
with mp.Pool(16) as pool:
results = pool.starmap(
run_spectrogram,
sum(
[
[(["numpy", None], data, sample_rate, n_samples, n_samples - 4)]
+ [
(
["hard", (*n_bits, True)],
data,
sample_rate,
n_samples,
0,
)
for n_bits in [(4, 2), (8, 4), (12, 8), (16, 12)]
]
for n_samples in [8, 16, 32, 64]
],
[],
),
)

# Pickle the results
with open(
path.join(path.dirname(__file__), "spectrogram_results.pkl"), "wb"
) as f:
import pickle

pickle.dump(results, f)

spinner.succeed("Spectrograms generated")

# Plot the spectrograms
plt.figure()
plt.rcParams.update({"font.size": 8})

gs = gridspec.GridSpec(
4,
5,
wspace=0.0,
hspace=0.0,
top=0.95,
bottom=0.05,
left=0.05,
right=0.95,
)

ylabels = [8, 16, 32, 64]
xlabels = [
"numpy",
"4 bits, 2 decimal",
"8 bits, 4 decimal",
"12 bits, 8 decimal",
"16 bits, 12 decimal",
]

for i, (result, bins) in enumerate(results):
print(min(min(x) for x in result), max(max(x) for x in result))
ax = plt.subplot(gs[i // 5, i % 5])
plot_spectrogram(ax, sample_rate, result, bins)
if i % 5 == 0:
ax.set_ylabel(f"{ylabels[i//5]} samples")

if i // 5 == 3:
ax.set_xlabel(xlabels[i % 5])

plt.savefig(path.join(path.dirname(__file__), "spectrograms.png"), dpi=300)
87 changes: 87 additions & 0 deletions src/fft/demos/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import librosa
import matplotlib as mpl
import matplotlib.pyplot as plt
import os
from os import path
import numpy as np
from src.fft.sim import fft as fft_sim
from fixedpt import CFixed


def spectrogram(
fft: callable,
data: list[list[float]],
sample_rate: int,
n_samples: int,
n_overlap: int = None,
):
"""
Generate a spectrogram from a wav file.
Args:
fig (matplotlib.Figure): Figure to plot to.
ax (matplotlib.Axes): Axes to plot to.
fft (callable): FFT function taking a list of floating point
samples and returning a list of floating point results.
wav (str): Path to wav file.
n_samples (int): Number of samples.
n_overlap (int): Number of overlap points between windows (defaults to n_samples // 8)
"""
if n_overlap is None:
n_overlap = n_samples // 8

results = []

for i in range(0, len(data) - n_samples, n_samples - n_overlap):
sample = data[i : i + n_samples]
if len(sample) < n_samples:
break
results.append(fft(sample)[: n_samples // 2])

# Get the frequency bins
bins = np.fft.fftfreq(n_samples, 1 / sample_rate)[: n_samples // 2]

return results, bins


def plot_spectrogram(
ax: mpl.axes.Axes,
sample_rate: InterruptedError,
data: list[list[float]],
bins: list[float],
):
results = np.array(data).T
ax.imshow(
results,
cmap="plasma",
aspect="auto",
origin="lower",
extent=(0, len(data) / sample_rate, bins[0], bins[-1]),
)

# ax.set_xlabel("Time (s)")
# ax.set_ylabel("Frequency (Hz)")
# Remove axes
ax.set_xticks([])
ax.set_yticks([])


def numpy_fft(inputs: list[float]) -> list[float]:
return np.abs(np.fft.fft(inputs))


def hard_fft(n: int, d: int, inputs: list[float]) -> list[float]:
n_samples = len(inputs)
# Convert inputs to CFixed
inputs = [CFixed((x, 0), n, d) for x in inputs]

outputs = fft_sim(inputs, n, d, n_samples)

# Convert back to floats
outputs = [complex(x) for x in outputs]

return [abs(x) for x in outputs]


def mk_hard_fft(n: int, d: int, sim: bool) -> callable:
return lambda inputs: hard_fft(n, d, inputs)

0 comments on commit 44a970c

Please sign in to comment.