Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
…Link into devel
  • Loading branch information
jdcpni committed Nov 22, 2024
2 parents 1c1470b + c0f73e2 commit ee8fada
Show file tree
Hide file tree
Showing 13 changed files with 785 additions and 297 deletions.
37 changes: 33 additions & 4 deletions psyneulink/core/components/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,14 +528,14 @@
Context, ContextError, ContextFlags, INITIALIZATION_STATUS_FLAGS, _get_time, handle_external_context
from psyneulink.core.globals.mdf import MDFSerializable
from psyneulink.core.globals.keywords import \
CONTEXT, CONTROL_PROJECTION, DEFERRED_INITIALIZATION, EXECUTE_UNTIL_FINISHED, \
CONTEXT, CONTROL_PROJECTION, DEFERRED_INITIALIZATION, DETERMINISTIC, EXECUTE_UNTIL_FINISHED, \
FUNCTION, FUNCTION_PARAMS, INIT_FULL_EXECUTE_METHOD, INPUT_PORTS, \
LEARNING, LEARNING_PROJECTION, MATRIX, MAX_EXECUTIONS_BEFORE_FINISHED, \
MODEL_SPEC_ID_PSYNEULINK, MODEL_SPEC_ID_METADATA, \
MODEL_SPEC_ID_INPUT_PORTS, MODEL_SPEC_ID_OUTPUT_PORTS, \
MODEL_SPEC_ID_MDF_VARIABLE, \
MODULATORY_SPEC_KEYWORDS, NAME, OUTPUT_PORTS, OWNER, PARAMS, PREFS_ARG, \
RESET_STATEFUL_FUNCTION_WHEN, INPUT_SHAPES, VALUE, VARIABLE, SHARED_COMPONENT_TYPES
RANDOM, RESET_STATEFUL_FUNCTION_WHEN, INPUT_SHAPES, VALUE, VARIABLE, SHARED_COMPONENT_TYPES
from psyneulink.core.globals.log import LogCondition
from psyneulink.core.globals.parameters import \
Defaults, SharedParameter, Parameter, ParameterAlias, ParameterError, ParametersBase, check_user_specified, copy_parameter_value, is_array_like
Expand Down Expand Up @@ -931,6 +931,9 @@ class Component(MDFSerializable, metaclass=ComponentsMeta):
componentType = None

standard_constructor_args = {EXECUTE_UNTIL_FINISHED, FUNCTION_PARAMS, MAX_EXECUTIONS_BEFORE_FINISHED, RESET_STATEFUL_FUNCTION_WHEN, INPUT_SHAPES}
deprecated_constructor_args = {
'size': 'input_shapes',
}

# helper attributes for MDF model spec
_model_spec_id_parameters = 'parameters'
Expand Down Expand Up @@ -1388,6 +1391,9 @@ def _get_compilation_state(self):
if cost_functions.DURATION not in cost_functions:
blacklist.add('duration_cost_fct')

if getattr(self, "mode", None) == DETERMINISTIC and getattr(self, "tie", None) != RANDOM:
whitelist.remove('random_state')

# Drop previous_value from MemoryFunctions
if hasattr(self.parameters, 'duplicate_keys'):
blacklist.add("previous_value")
Expand Down Expand Up @@ -1505,13 +1511,20 @@ def _get_compilation_params(self):
"retain_torch_trained_outputs", "retain_torch_targets", "retain_torch_losses"
"torch_trained_outputs", "torch_targets", "torch_losses",
# should be added to relevant _gen_llvm_function... when aug:
# OneHot:
'abs_val', 'indicator',
# SoftMax:
'mask_threshold', 'adapt_scale', 'adapt_base', 'adapt_entropy_weighting',
# LCAMechanism
"mask"
}

# OneHot:
# * runtime abs_val and indicator are only used in deterministic mode.
# * random_state and seed are only used in RANDOM tie resolution.
if getattr(self, "mode", None) != DETERMINISTIC:
blacklist.update(['abs_val', 'indicator'])
elif getattr(self, "tie", None) != RANDOM:
blacklist.add("seed")

# Mechanism's need few extra entries:
# * matrix -- is never used directly, and is flatened below
# * integration_rate -- shape mismatch with param port input
Expand Down Expand Up @@ -2150,8 +2163,11 @@ def alias_conflicts(alias, passed_name):

conflicting_aliases = []
unused_constructor_args = {}
deprecated_args = {}
for p in self.parameters:
if p.name in illegal_passed_args:
# p must have a constructor_argument, because otherwise
# p.name would not be in illegal_passed_args
assert p.constructor_argument is not None
unused_constructor_args[p.name] = p.constructor_argument

Expand All @@ -2164,13 +2180,26 @@ def alias_conflicts(alias, passed_name):
if alias_conflicts(p, passed_name):
conflicting_aliases.append((p.source.name, passed_name, p.name))

for arg in illegal_passed_args:
try:
deprecated_args[arg] = self.deprecated_constructor_args[arg]
except KeyError:
continue

# raise constructor arg errors
if len(unused_constructor_args) > 0:
raise create_illegal_argument_error([
f"'{arg}': must use '{constr_arg}' instead"
for arg, constr_arg in unused_constructor_args.items()
])

# raise deprecated argument errors
if len(deprecated_args) > 0:
raise create_illegal_argument_error([
f"'{arg}' is deprecated. Use '{new_arg}' instead"
for arg, new_arg in deprecated_args.items()
])

# raise generic illegal argument error
unknown_args = illegal_passed_args.difference(unused_constructor_args)
if len(unknown_args) > 0:
Expand Down
Loading

0 comments on commit ee8fada

Please sign in to comment.