diff --git a/psyneulink/core/globals/keywords.py b/psyneulink/core/globals/keywords.py index 415ed2d7c2..7543fda313 100644 --- a/psyneulink/core/globals/keywords.py +++ b/psyneulink/core/globals/keywords.py @@ -118,6 +118,7 @@ 'PROCESS_EXECUTE', 'PROCESS_INIT', 'PROCESSES', 'PROCESSES_DIM', 'PROCESSING', 'PROCESSING_MECHANISM', 'PROCESSING_PATHWAY', 'PRODUCT', 'PROGRESS_BAR_CHAR', 'PROJECTION', 'PROJECTION_DIRECTION', 'PROJECTION_PARAMS', 'PROJECTION_RECEIVER', 'PROJECTION_SENDER', 'PROJECTION_TYPE', 'PROJECTIONS', 'PROJECTION_COMPONENT_CATEGORY', + 'PNL', 'QUOTIENT', 'RANDOM', 'RANDOM_CONNECTIVITY_MATRIX', 'RATE', 'RATIO', 'REARRANGE_FUNCTION', 'RECEIVER', 'RECEIVER_ARG', 'RECURRENT_TRANSFER_MECHANISM', 'REDUCE_FUNCTION', 'REFERENCE_VALUE', 'RESET', 'RESET_STATEFUL_FUNCTION_WHEN', 'RELU_FUNCTION', 'REST', 'RESULT', 'RESULT', 'ROLES', 'RL_FUNCTION', 'RUN', @@ -143,7 +144,6 @@ from psyneulink._typing import Literal - #region ----------------------------------------- MATRICES ----------------------------------------------------------- class MatrixKeywords: @@ -462,6 +462,7 @@ class Loss(Enum): #region --------------------------------------------- GENERAL ---------------------------------------------------- # General +PNL = 'psyneulink' ON = True OFF = False diff --git a/psyneulink/library/compositions/pytorchshowgraph.py b/psyneulink/library/compositions/pytorchshowgraph.py index 6452ccf9f6..c27ccc85ec 100644 --- a/psyneulink/library/compositions/pytorchshowgraph.py +++ b/psyneulink/library/compositions/pytorchshowgraph.py @@ -12,16 +12,17 @@ from psyneulink._typing import Optional, Union, Literal -from psyneulink.core.globals.context import ContextFlags, handle_external_context from psyneulink.core.compositions import NodeRole -from psyneulink.core.compositions.showgraph import ShowGraph, SHOW_JUST_LEARNING_PROJECTIONS +from psyneulink.core.compositions.showgraph import ShowGraph, SHOW_JUST_LEARNING_PROJECTIONS, SHOW_LEARNING from psyneulink.core.components.mechanisms.mechanism import Mechanism from psyneulink.core.components.mechanisms.processing.compositioninterfacemechanism import CompositionInterfaceMechanism from psyneulink.core.components.mechanisms.modulatory.control.controlmechanism import ControlMechanism from psyneulink.core.components.projections.projection import Projection from psyneulink.core.components.projections.pathway.mappingprojection import MappingProjection from psyneulink.core.components.projections.modulatory.controlprojection import ControlProjection -from psyneulink.core.globals.keywords import BOLD, NESTED, INSET +from psyneulink.core.llvm import ExecutionMode +from psyneulink.core.globals.context import ContextFlags, handle_external_context +from psyneulink.core.globals.keywords import BOLD, INSET, NESTED, PNL __all__ = ['SHOW_PYTORCH'] @@ -56,6 +57,11 @@ def __init__(self, *args, **kwargs): @handle_external_context(source=ContextFlags.COMPOSITION) def show_graph(self, *args, **kwargs): """Override of show_graph to check if show_pytorch==True and if so build pytorch rep of autofiffcomposition""" + # if kwargs.pop(SHOW_PNL, None): + if SHOW_LEARNING in kwargs and kwargs[SHOW_LEARNING] == PNL: + self.composition.infer_backpropagation_learning_pathways(ExecutionMode.Python) + kwargs[SHOW_LEARNING] = True + return super().show_graph(*args, **kwargs) self.show_pytorch = kwargs.pop(SHOW_PYTORCH, self.show_pytorch) context = kwargs.get('context') if self.show_pytorch: