Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a warning when task parameters differ from those of the cache in use #1719

Draft
wants to merge 8 commits into
base: develop
Choose a base branch
from
42 changes: 42 additions & 0 deletions pyannote/audio/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from __future__ import annotations

import inspect
import itertools
import multiprocessing
import sys
Expand Down Expand Up @@ -327,6 +328,7 @@ def prepare_data(self):
'metadata-values': dict of lists of values for subset, scope and database
'metadata-`database-name`-labels': array of `database-name` labels. Each database with "database" scope labels has it own array.
'metadata-labels': array of global scope labels
'task-parameters': hyper-parameters used for the task
}

"""
Expand Down Expand Up @@ -595,6 +597,26 @@ def prepare_data(self):
prepared_data["metadata-labels"] = np.array(unique_labels, dtype=np.str_)
unique_labels.clear()

# keep track of task parameters
parameters = []
dtype = []
for param_name in inspect.signature(self.__init__).parameters:
try:
param_value = getattr(self, param_name)
# skip specification-dependent parameters and non-attributed parameters
# (for instance because they were deprecated)
except (AttributeError, UnknownSpecificationsError):
continue
if isinstance(param_value, (bool, float, int, str, type(None))):
parameters.append(param_value)
dtype.append((param_name, type(param_value)))

prepared_data["task-parameters"] = np.array(
tuple(parameters), dtype=np.dtype(dtype)
)
parameters.clear()
dtype.clear()

if self.has_validation:
self.prepare_validation(prepared_data)

Expand Down Expand Up @@ -646,6 +668,26 @@ def setup(self, stage=None):
f"does not correspond to the cached one ({self.prepared_data['protocol']})"
)

# checks that the task current hyperparameters matches the cached ones
for param_name in inspect.signature(self.__init__).parameters:
try:
param_value = getattr(self, param_name)
# skip specification-dependent parameters and non-attributed parameters
# (for instance because they were deprecated)
except (AttributeError, UnknownSpecificationsError):
continue

if param_name not in self.prepared_data["task-parameters"].dtype.names:
continue
cached_value = self.prepared_data["task-parameters"][param_name]
if param_value != cached_value:
warnings.warn(
f"Value specified for {param_name} of the task differs from the one in the cached data."
f"Current one = {param_value}, cached one = {cached_value}."
"You may need to create a new cache for this task with"
" the new value for this hyperparameter.",
)

@property
def automatic_optimization(self) -> bool:
return self.model.automatic_optimization
Expand Down
Loading