From 9f5b2448201e467e69484dee48f76b959ed2ecb6 Mon Sep 17 00:00:00 2001 From: clement-pages Date: Tue, 28 May 2024 14:59:10 +0200 Subject: [PATCH 1/3] Add a warning when task parameters differ from those of the cache in use --- pyannote/audio/core/task.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/pyannote/audio/core/task.py b/pyannote/audio/core/task.py index 04c73ab51..8aec4b0cb 100644 --- a/pyannote/audio/core/task.py +++ b/pyannote/audio/core/task.py @@ -327,6 +327,7 @@ def prepare_data(self): 'metadata-values': dict of lists of values for subset, scope and database 'metadata-`database-name`-labels': array of `database-name` labels. Each database with "database" scope labels has it own array. 'metadata-labels': array of global scope labels + 'task-parameters': hyper-parameters used for the task } """ @@ -595,6 +596,23 @@ def prepare_data(self): prepared_data["metadata-labels"] = np.array(unique_labels, dtype=np.str_) unique_labels.clear() + # keep track of task parameters + parameters = [] + dtype = [] + for param_name, param_value in self.__dict__.items(): + # only keep public parameters with native type + if param_name[0] == "_": + continue + if isinstance(param_value, (bool, float, int, str)): + parameters.append(param_value) + dtype.append((param_name, type(param_value))) + + prepared_data["task-parameters"] = np.array( + tuple(parameters), dtype=np.dtype(dtype) + ) + parameters.clear() + dtype.clear() + if self.has_validation: self.prepare_validation(prepared_data) @@ -646,6 +664,19 @@ def setup(self, stage=None): f"does not correspond to the cached one ({self.prepared_data['protocol']})" ) + # checks that the task current hyperparameters matches the cached ones + for param_name, param_value in self.__dict__.items(): + if param_name not in self.prepared_data["task-parameters"].dtype.names: + continue + cached_value = self.prepared_data["task-parameters"][param_name] + if param_value != cached_value: + warnings.warn( + f"Value specified for {param_name} of the task differs from the one in the cached data." + f"Current one = {param_value}, cached one = {cached_value}." + "You may need to create a new cache for this task with" + " the new value for this hyperparameter.", + ) + @property def specifications(self) -> Union[Specifications, Tuple[Specifications]]: # setup metadata on-demand the first time specifications are requested and missing From e95f3c37b6d17d58d172c79640df6dbe85864e0d Mon Sep 17 00:00:00 2001 From: clement-pages Date: Tue, 4 Jun 2024 15:28:51 +0200 Subject: [PATCH 2/3] use `inspect.signature` instead `__dict__` --- pyannote/audio/core/task.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/pyannote/audio/core/task.py b/pyannote/audio/core/task.py index 8aec4b0cb..c10de3ac8 100644 --- a/pyannote/audio/core/task.py +++ b/pyannote/audio/core/task.py @@ -23,6 +23,7 @@ from __future__ import annotations +import inspect import itertools import multiprocessing import sys @@ -599,11 +600,15 @@ def prepare_data(self): # keep track of task parameters parameters = [] dtype = [] - for param_name, param_value in self.__dict__.items(): - # only keep public parameters with native type - if param_name[0] == "_": + for param_name in inspect.signature(self.__init__).parameters: + try: + param_value = getattr(self, param_name) + # skip specification-dependent parameters and non-attributed parameters + # (for instance because they were deprecated) + except (AttributeError, UnknownSpecificationsError): + print(param_name) continue - if isinstance(param_value, (bool, float, int, str)): + if isinstance(param_value, (bool, float, int, str, type(None))): parameters.append(param_value) dtype.append((param_name, type(param_value))) @@ -665,11 +670,19 @@ def setup(self, stage=None): ) # checks that the task current hyperparameters matches the cached ones - for param_name, param_value in self.__dict__.items(): + for param_name in inspect.signature(self.__init__).parameters: + try: + param_value = getattr(self, param_name) + # skip specification-dependent parameters and non-attributed parameters + # (for instance because they were deprecated) + except (AttributeError, UnknownSpecificationsError): + continue + if param_name not in self.prepared_data["task-parameters"].dtype.names: continue cached_value = self.prepared_data["task-parameters"][param_name] if param_value != cached_value: + print("passing here") warnings.warn( f"Value specified for {param_name} of the task differs from the one in the cached data." f"Current one = {param_value}, cached one = {cached_value}." From 4bccb40e5e389e0b0ed837a6539562fe555dc942 Mon Sep 17 00:00:00 2001 From: clement-pages Date: Tue, 4 Jun 2024 15:30:35 +0200 Subject: [PATCH 3/3] clear the code --- pyannote/audio/core/task.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pyannote/audio/core/task.py b/pyannote/audio/core/task.py index c10de3ac8..4f4670272 100644 --- a/pyannote/audio/core/task.py +++ b/pyannote/audio/core/task.py @@ -606,7 +606,6 @@ def prepare_data(self): # skip specification-dependent parameters and non-attributed parameters # (for instance because they were deprecated) except (AttributeError, UnknownSpecificationsError): - print(param_name) continue if isinstance(param_value, (bool, float, int, str, type(None))): parameters.append(param_value) @@ -682,7 +681,6 @@ def setup(self, stage=None): continue cached_value = self.prepared_data["task-parameters"][param_name] if param_value != cached_value: - print("passing here") warnings.warn( f"Value specified for {param_name} of the task differs from the one in the cached data." f"Current one = {param_value}, cached one = {cached_value}."