diff --git a/twang/source_separation/nussl.py b/twang/source_separation/nussl.py index f99590b..2874188 100644 --- a/twang/source_separation/nussl.py +++ b/twang/source_separation/nussl.py @@ -1,7 +1,7 @@ import nussl from twang.source_separation.base import SourceSeparation, SourceSeparationDict -from twang.track import BaseTrack +from twang.track import BaseTrack, LibrosaTrack def _track_to_nussl_audio_signal(track: BaseTrack) -> nussl.AudioSignal: @@ -9,6 +9,10 @@ def _track_to_nussl_audio_signal(track: BaseTrack) -> nussl.AudioSignal: return nussl.AudioSignal(audio_data_array=librosa_track.y, sample_rate=librosa_track.sr) +def _audio_signal_to_librosa_track(audio_signal: nussl.AudioSignal) -> LibrosaTrack: + return LibrosaTrack(audio_signal.audio_data, sr=audio_signal.sample_rate) + + class Repet(SourceSeparation): """https://nussl.github.io/docs/separation.html#foreground-background-via-repet""" @@ -18,7 +22,10 @@ def run(self, track: BaseTrack) -> SourceSeparationDict: audio_signal = _track_to_nussl_audio_signal(track) repet = self.repet_cls(audio_signal) estimates = repet() - return {"background": estimates[0], "foreground": estimates[1]} + return { + "background": _audio_signal_to_librosa_track(estimates[0]), + "foreground": _audio_signal_to_librosa_track(estimates[1]), + } class RepetSim(Repet): diff --git a/twang/track/base.py b/twang/track/base.py index f680b4e..888ca8c 100644 --- a/twang/track/base.py +++ b/twang/track/base.py @@ -121,6 +121,7 @@ def _repr_html_(self) -> str: def __getitem__(self, ms_or_slice: Union[int, slice]): """TODO: test & document""" + self.file_path = None # because the object no longer corresponds to the file on disk self.y = self.y[ms_or_slice] return self