From f7a577375fa1bcdb477b5c8cff2a78fe4a97558a Mon Sep 17 00:00:00 2001 From: Alexey Kamenev Date: Tue, 29 Oct 2024 10:44:01 -0700 Subject: [PATCH] Add kwargs to MGN forward. (#701) --- examples/cfd/aero_graph_net/models.py | 1 + modulus/models/meshgraphnet/meshgraphnet.py | 1 + 2 files changed, 2 insertions(+) diff --git a/examples/cfd/aero_graph_net/models.py b/examples/cfd/aero_graph_net/models.py index 64156249d5..01c483ec01 100644 --- a/examples/cfd/aero_graph_net/models.py +++ b/examples/cfd/aero_graph_net/models.py @@ -88,6 +88,7 @@ def forward( node_features: Tensor, edge_features: Tensor, graph: Union[DGLGraph, list[DGLGraph], "CuGraphCSC"], + **kwargs, ) -> Tensor: edge_features = self.edge_encoder(edge_features) node_features = self.node_encoder(node_features) diff --git a/modulus/models/meshgraphnet/meshgraphnet.py b/modulus/models/meshgraphnet/meshgraphnet.py index 9ff65b7325..72628efb50 100644 --- a/modulus/models/meshgraphnet/meshgraphnet.py +++ b/modulus/models/meshgraphnet/meshgraphnet.py @@ -191,6 +191,7 @@ def forward( node_features: Tensor, edge_features: Tensor, graph: Union[DGLGraph, List[DGLGraph], CuGraphCSC], + **kwargs, ) -> Tensor: edge_features = self.edge_encoder(edge_features) node_features = self.node_encoder(node_features)