Skip to content

Commit

Permalink
Merge pull request #42 from descriptinc/ps/weighted_choose
Browse files Browse the repository at this point in the history
Adding weights to picking from CSV files.
  • Loading branch information
pseeth authored Jun 21, 2022
2 parents 37a11aa + 7691645 commit d68ee19
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 7 deletions.
2 changes: 1 addition & 1 deletion audiotools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.3.6"
__version__ = "0.3.7"
from .core import AudioSignal, STFTParams, Meter, util
from . import metrics
from . import data
Expand Down
4 changes: 2 additions & 2 deletions audiotools/core/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,8 @@ def read_csv(filelists):
return files


def choose_from_list_of_lists(state, list_of_lists):
idx = state.randint(len(list_of_lists))
def choose_from_list_of_lists(state, list_of_lists, p=None):
idx = state.choice(list(range(len(list_of_lists))), p=p)
item_idx = state.randint(len(list_of_lists[idx]))
return list_of_lists[idx][item_idx]

Expand Down
6 changes: 5 additions & 1 deletion audiotools/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def __init__(
n_examples: int = 1000,
duration: float = 0.5,
csv_files: List[str] = None,
csv_weights: List[float] = None,
loudness_cutoff: float = -40,
mono: bool = True,
transform=None,
Expand All @@ -117,13 +118,16 @@ def __init__(
self.audio_lists = util.read_csv(csv_files)
self.loudness_cutoff = loudness_cutoff
self.mono = mono
self.csv_weights = csv_weights

def __getitem__(self, idx):
state = util.random_state(idx)

# Load an audio file randomly from the list of lists,
# seeded by the current index.
audio_info = util.choose_from_list_of_lists(state, self.audio_lists)
audio_info = util.choose_from_list_of_lists(
state, self.audio_lists, p=self.csv_weights
)
signal = AudioSignal.salient_excerpt(
audio_info["path"],
duration=self.duration,
Expand Down
12 changes: 10 additions & 2 deletions audiotools/data/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,7 @@ def __init__(
self,
snr: tuple = ("uniform", 10.0, 30.0),
csv_files: List[str] = None,
csv_weights: List[float] = None,
eq_amount: tuple = ("const", 1.0),
n_bands: int = 3,
name: str = None,
Expand All @@ -334,13 +335,16 @@ def __init__(
self.eq_amount = eq_amount
self.n_bands = n_bands
self.audio_files = util.read_csv(csv_files)
self.csv_weights = csv_weights

def _instantiate(self, state: RandomState, signal: AudioSignal):
eq_amount = util.sample_from_dist(self.eq_amount, state)
eq = -eq_amount * state.rand(self.n_bands)
snr = util.sample_from_dist(self.snr, state)

bg_path = util.choose_from_list_of_lists(state, self.audio_files)["path"]
bg_path = util.choose_from_list_of_lists(
state, self.audio_files, p=self.csv_weights
)["path"]

# Get properties of input signal to use when creating
# background signal.
Expand All @@ -367,6 +371,7 @@ def __init__(
self,
drr: tuple = ("uniform", 0.0, 30.0),
csv_files: List[str] = None,
csv_weights: List[float] = None,
eq_amount: tuple = ("const", 1.0),
n_bands: int = 6,
name: str = None,
Expand All @@ -380,13 +385,16 @@ def __init__(
self.n_bands = n_bands
self.use_original_phase = use_original_phase
self.audio_files = util.read_csv(csv_files)
self.csv_weights = csv_weights

def _instantiate(self, state: RandomState, signal: AudioSignal = None):
eq_amount = util.sample_from_dist(self.eq_amount, state)
eq = -eq_amount * state.rand(self.n_bands)
drr = util.sample_from_dist(self.drr, state)

ir_path = util.choose_from_list_of_lists(state, self.audio_files)["path"]
ir_path = util.choose_from_list_of_lists(
state, self.audio_files, p=self.csv_weights
)["path"]

# Get properties of input signal to use when creating
# background signal.
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

setup(
name="audiotools",
version="0.3.6",
version="0.3.7",
classifiers=[
"Intended Audience :: Developers",
"Intended Audience :: Education",
Expand Down

0 comments on commit d68ee19

Please sign in to comment.