From 788d46703a99ea9746943122418a5a29bc641ffd Mon Sep 17 00:00:00 2001 From: Alexis Plaquet <14005967+FrenchKrab@users.noreply.github.com> Date: Tue, 12 Nov 2024 10:55:24 +0100 Subject: [PATCH] Fix balance: 1. with the new prepare_data 2. when the balanced key has different set of values for train and development protocols. --- pyannote/audio/core/task.py | 1 + pyannote/audio/tasks/segmentation/mixins.py | 19 +++++++++++++++---- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/pyannote/audio/core/task.py b/pyannote/audio/core/task.py index 04c73ab51..de39054f8 100644 --- a/pyannote/audio/core/task.py +++ b/pyannote/audio/core/task.py @@ -632,6 +632,7 @@ def setup(self, stage=None): try: with open(self.cache, "rb") as cache_file: self.prepared_data = dict(np.load(cache_file, allow_pickle=True)) + self.prepared_data["metadata-values"] = self.prepared_data["metadata-values"].item() except FileNotFoundError: print( "Cached data for protocol not found. Ensure that prepare_data() was called", diff --git a/pyannote/audio/tasks/segmentation/mixins.py b/pyannote/audio/tasks/segmentation/mixins.py index cf6e3004a..c709b54e4 100644 --- a/pyannote/audio/tasks/segmentation/mixins.py +++ b/pyannote/audio/tasks/segmentation/mixins.py @@ -107,14 +107,20 @@ def train__iter__helper(self, rng: random.Random, **filters): ) for key, value in filters.items(): training &= self.prepared_data["audio-metadata"][key] == self.prepared_data[ - "metadata" + "metadata-values" ][key].index(value) file_ids = np.where(training)[0] # turn annotated duration into a probability distribution annotated_duration = self.prepared_data["audio-annotated"][file_ids] + annotated_duration_total = np.sum(annotated_duration) + # Exit early (this will discard this worker/helper). + # This should happen only when the filter contains no data. + # (which happens when balance is used since we use all combinations of balance filters) + if annotated_duration_total == 0: + yield None cum_prob_annotated_duration = np.cumsum( - annotated_duration / np.sum(annotated_duration) + annotated_duration / annotated_duration_total ) duration = self.duration @@ -184,13 +190,18 @@ def train__iter__(self): # create a subchunk generator for each combination of "balance" keys subchunks = dict() for product in itertools.product( - *[self.prepared_data["metadata"][key] for key in balance] + *[self.prepared_data["metadata-values"][key] for key in balance] ): # we iterate on the cartesian product of the values in metadata_unique_values # eg: for balance=["database", "split"], with 2 databases and 2 splits: # ("DIHARD", "A"), ("DIHARD", "B"), ("REPERE", "A"), ("REPERE", "B") filters = {key: value for key, value in zip(balance, product)} - subchunks[product] = self.train__iter__helper(rng, **filters) + product_iterator = self.train__iter__helper(rng, **filters) + + # This specific product may not exist. For example, if balance=['database'] + # and there is a database that's not present in the training set. + if next(product_iterator) is not None: + subchunks[product] = product_iterator while True: # select one subchunk generator at random (with uniform probability)