Skip to content

Commit

Permalink
update configuration
Browse files Browse the repository at this point in the history
  • Loading branch information
benmalef committed Feb 10, 2025
1 parent 007a967 commit 61acfdc
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 7 deletions.
4 changes: 3 additions & 1 deletion GANDLF/Configuration/Parameters/user_defined_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ class UserDefinedParameters(DefaultParameters):
data_postprocessing_after_reverse_one_hot_encoding: dict = Field(
description="data_postprocessing_after_reverse_one_hot_encoding.", default={}
)
differential_privacy: dict = Field(description="Differential privacy.", default={})
differential_privacy: dict = Field(
description="Differential privacy.", default=None
)
# TODO: It should be defined with a better way (using a BaseModel class)
data_preprocessing: Annotated[
dict,
Expand Down
12 changes: 7 additions & 5 deletions GANDLF/Configuration/Parameters/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,11 +410,13 @@ def validate_data_augmentation(value, patch_size) -> dict:


def validate_differential_privacy(value, batch_size):
# if not isinstance(value, dict):
# print(
# "WARNING: Non dictionary value for the key: 'differential_privacy' was used, replacing with default valued dictionary."
# )
# value = {}
if value is None:
return value
if not isinstance(value, dict):
print(
"WARNING: Non dictionary value for the key: 'differential_privacy' was used, replacing with default valued dictionary."
)
value = {}
# these are some defaults
value = initialize_key(value, "noise_multiplier", 10.0)
value = initialize_key(value, "max_grad_norm", 1.0)
Expand Down
2 changes: 1 addition & 1 deletion GANDLF/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def ConfigManager(
try:
parameters = Parameters(
**_parseConfig(config_file_path, version_check_flag)
).model_dump()
).model_dump(exclude_none=True)
return parameters
# except Exception as e:
# ## todo: ensure logging captures assertion errors
Expand Down

0 comments on commit 61acfdc

Please sign in to comment.