-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
4d5d00a
commit 44a970c
Showing
4 changed files
with
205 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,4 +7,6 @@ pytest-xdist==3.5.0 | |
psutil | ||
fabric | ||
halo | ||
pyyaml | ||
pyyaml | ||
matplotlib | ||
librosa |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
*.wav | ||
*.png | ||
*.pkl |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
import matplotlib.pyplot as plt | ||
from matplotlib import gridspec | ||
import numpy as np | ||
from halo import Halo | ||
from src.fft.demos.utils import spectrogram, plot_spectrogram, numpy_fft, mk_hard_fft | ||
import librosa | ||
from os import path | ||
import multiprocessing as mp | ||
|
||
|
||
def run_spectrogram(f, data, sample_rate, n_samples, n_overlap): | ||
if f[0] == "numpy": | ||
f = numpy_fft | ||
else: | ||
f = mk_hard_fft(*f[1]) | ||
results, bins = spectrogram(f, data, sample_rate, n_samples, n_overlap) | ||
return results, bins | ||
|
||
|
||
if __name__ == "__main__": | ||
# Check if the spectrograms have already been generated | ||
if path.exists(path.join(path.dirname(__file__), "spectrogram_results.pkl")): | ||
|
||
spinner = Halo(text="Generating spectrograms", spinner="dots") | ||
spinner.start() | ||
with open( | ||
path.join(path.dirname(__file__), "spectrogram_results.pkl"), "rb" | ||
) as f: | ||
import pickle | ||
|
||
results = pickle.load(f) | ||
spinner.succeed("Spectrograms loaded") | ||
else: | ||
sample_rate = 44800 | ||
|
||
wav_file = path.join(path.dirname(__file__), "test.wav") | ||
spinner = Halo(text="Loading audio file", spinner="dots") | ||
spinner.start() | ||
data = librosa.load(wav_file, sr=sample_rate, mono=True)[0] | ||
spinner.succeed("Audio file loaded") | ||
|
||
spinner = Halo(text="Generating spectrograms", spinner="dots") | ||
spinner.start() | ||
|
||
# Generate all the spectrograms in parallel | ||
with mp.Pool(16) as pool: | ||
results = pool.starmap( | ||
run_spectrogram, | ||
sum( | ||
[ | ||
[(["numpy", None], data, sample_rate, n_samples, n_samples - 4)] | ||
+ [ | ||
( | ||
["hard", (*n_bits, True)], | ||
data, | ||
sample_rate, | ||
n_samples, | ||
0, | ||
) | ||
for n_bits in [(4, 2), (8, 4), (12, 8), (16, 12)] | ||
] | ||
for n_samples in [8, 16, 32, 64] | ||
], | ||
[], | ||
), | ||
) | ||
|
||
# Pickle the results | ||
with open( | ||
path.join(path.dirname(__file__), "spectrogram_results.pkl"), "wb" | ||
) as f: | ||
import pickle | ||
|
||
pickle.dump(results, f) | ||
|
||
spinner.succeed("Spectrograms generated") | ||
|
||
# Plot the spectrograms | ||
plt.figure() | ||
plt.rcParams.update({"font.size": 8}) | ||
|
||
gs = gridspec.GridSpec( | ||
4, | ||
5, | ||
wspace=0.0, | ||
hspace=0.0, | ||
top=0.95, | ||
bottom=0.05, | ||
left=0.05, | ||
right=0.95, | ||
) | ||
|
||
ylabels = [8, 16, 32, 64] | ||
xlabels = [ | ||
"numpy", | ||
"4 bits, 2 decimal", | ||
"8 bits, 4 decimal", | ||
"12 bits, 8 decimal", | ||
"16 bits, 12 decimal", | ||
] | ||
|
||
for i, (result, bins) in enumerate(results): | ||
print(min(min(x) for x in result), max(max(x) for x in result)) | ||
ax = plt.subplot(gs[i // 5, i % 5]) | ||
plot_spectrogram(ax, sample_rate, result, bins) | ||
if i % 5 == 0: | ||
ax.set_ylabel(f"{ylabels[i//5]} samples") | ||
|
||
if i // 5 == 3: | ||
ax.set_xlabel(xlabels[i % 5]) | ||
|
||
plt.savefig(path.join(path.dirname(__file__), "spectrograms.png"), dpi=300) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
import librosa | ||
import matplotlib as mpl | ||
import matplotlib.pyplot as plt | ||
import os | ||
from os import path | ||
import numpy as np | ||
from src.fft.sim import fft as fft_sim | ||
from fixedpt import CFixed | ||
|
||
|
||
def spectrogram( | ||
fft: callable, | ||
data: list[list[float]], | ||
sample_rate: int, | ||
n_samples: int, | ||
n_overlap: int = None, | ||
): | ||
""" | ||
Generate a spectrogram from a wav file. | ||
Args: | ||
fig (matplotlib.Figure): Figure to plot to. | ||
ax (matplotlib.Axes): Axes to plot to. | ||
fft (callable): FFT function taking a list of floating point | ||
samples and returning a list of floating point results. | ||
wav (str): Path to wav file. | ||
n_samples (int): Number of samples. | ||
n_overlap (int): Number of overlap points between windows (defaults to n_samples // 8) | ||
""" | ||
if n_overlap is None: | ||
n_overlap = n_samples // 8 | ||
|
||
results = [] | ||
|
||
for i in range(0, len(data) - n_samples, n_samples - n_overlap): | ||
sample = data[i : i + n_samples] | ||
if len(sample) < n_samples: | ||
break | ||
results.append(fft(sample)[: n_samples // 2]) | ||
|
||
# Get the frequency bins | ||
bins = np.fft.fftfreq(n_samples, 1 / sample_rate)[: n_samples // 2] | ||
|
||
return results, bins | ||
|
||
|
||
def plot_spectrogram( | ||
ax: mpl.axes.Axes, | ||
sample_rate: InterruptedError, | ||
data: list[list[float]], | ||
bins: list[float], | ||
): | ||
results = np.array(data).T | ||
ax.imshow( | ||
results, | ||
cmap="plasma", | ||
aspect="auto", | ||
origin="lower", | ||
extent=(0, len(data) / sample_rate, bins[0], bins[-1]), | ||
) | ||
|
||
# ax.set_xlabel("Time (s)") | ||
# ax.set_ylabel("Frequency (Hz)") | ||
# Remove axes | ||
ax.set_xticks([]) | ||
ax.set_yticks([]) | ||
|
||
|
||
def numpy_fft(inputs: list[float]) -> list[float]: | ||
return np.abs(np.fft.fft(inputs)) | ||
|
||
|
||
def hard_fft(n: int, d: int, inputs: list[float]) -> list[float]: | ||
n_samples = len(inputs) | ||
# Convert inputs to CFixed | ||
inputs = [CFixed((x, 0), n, d) for x in inputs] | ||
|
||
outputs = fft_sim(inputs, n, d, n_samples) | ||
|
||
# Convert back to floats | ||
outputs = [complex(x) for x in outputs] | ||
|
||
return [abs(x) for x in outputs] | ||
|
||
|
||
def mk_hard_fft(n: int, d: int, sim: bool) -> callable: | ||
return lambda inputs: hard_fft(n, d, inputs) |