Skip to content

Commit

Permalink
• pytorchshowgraph.py
Browse files Browse the repository at this point in the history
  - show_graph(): support show_learning=PNL
  • Loading branch information
jdcpni committed Nov 22, 2024
1 parent 1c1470b commit 7665b47
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
3 changes: 2 additions & 1 deletion psyneulink/core/globals/keywords.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@
'PROCESS_EXECUTE', 'PROCESS_INIT', 'PROCESSES', 'PROCESSES_DIM', 'PROCESSING', 'PROCESSING_MECHANISM',
'PROCESSING_PATHWAY', 'PRODUCT', 'PROGRESS_BAR_CHAR', 'PROJECTION', 'PROJECTION_DIRECTION', 'PROJECTION_PARAMS',
'PROJECTION_RECEIVER', 'PROJECTION_SENDER', 'PROJECTION_TYPE', 'PROJECTIONS', 'PROJECTION_COMPONENT_CATEGORY',
'PNL',
'QUOTIENT', 'RANDOM', 'RANDOM_CONNECTIVITY_MATRIX', 'RATE', 'RATIO', 'REARRANGE_FUNCTION', 'RECEIVER',
'RECEIVER_ARG', 'RECURRENT_TRANSFER_MECHANISM', 'REDUCE_FUNCTION', 'REFERENCE_VALUE', 'RESET',
'RESET_STATEFUL_FUNCTION_WHEN', 'RELU_FUNCTION', 'REST', 'RESULT', 'RESULT', 'ROLES', 'RL_FUNCTION', 'RUN',
Expand All @@ -143,7 +144,6 @@

from psyneulink._typing import Literal


#region ----------------------------------------- MATRICES -----------------------------------------------------------

class MatrixKeywords:
Expand Down Expand Up @@ -462,6 +462,7 @@ class Loss(Enum):
#region --------------------------------------------- GENERAL ----------------------------------------------------
# General

PNL = 'psyneulink'

ON = True
OFF = False
Expand Down
12 changes: 9 additions & 3 deletions psyneulink/library/compositions/pytorchshowgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,17 @@

from psyneulink._typing import Optional, Union, Literal

from psyneulink.core.globals.context import ContextFlags, handle_external_context
from psyneulink.core.compositions import NodeRole
from psyneulink.core.compositions.showgraph import ShowGraph, SHOW_JUST_LEARNING_PROJECTIONS
from psyneulink.core.compositions.showgraph import ShowGraph, SHOW_JUST_LEARNING_PROJECTIONS, SHOW_LEARNING
from psyneulink.core.components.mechanisms.mechanism import Mechanism
from psyneulink.core.components.mechanisms.processing.compositioninterfacemechanism import CompositionInterfaceMechanism
from psyneulink.core.components.mechanisms.modulatory.control.controlmechanism import ControlMechanism
from psyneulink.core.components.projections.projection import Projection
from psyneulink.core.components.projections.pathway.mappingprojection import MappingProjection
from psyneulink.core.components.projections.modulatory.controlprojection import ControlProjection
from psyneulink.core.globals.keywords import BOLD, NESTED, INSET
from psyneulink.core.llvm import ExecutionMode
from psyneulink.core.globals.context import ContextFlags, handle_external_context
from psyneulink.core.globals.keywords import BOLD, INSET, NESTED, PNL

__all__ = ['SHOW_PYTORCH']

Expand Down Expand Up @@ -56,6 +57,11 @@ def __init__(self, *args, **kwargs):
@handle_external_context(source=ContextFlags.COMPOSITION)
def show_graph(self, *args, **kwargs):
"""Override of show_graph to check if show_pytorch==True and if so build pytorch rep of autofiffcomposition"""
# if kwargs.pop(SHOW_PNL, None):
if SHOW_LEARNING in kwargs and kwargs[SHOW_LEARNING] == PNL:
self.composition.infer_backpropagation_learning_pathways(ExecutionMode.Python)
kwargs[SHOW_LEARNING] = True
return super().show_graph(*args, **kwargs)
self.show_pytorch = kwargs.pop(SHOW_PYTORCH, self.show_pytorch)
context = kwargs.get('context')
if self.show_pytorch:
Expand Down

0 comments on commit 7665b47

Please sign in to comment.