From 26903441ef21c9063dfa42e5e661f0d3a00c9415 Mon Sep 17 00:00:00 2001 From: pao214 Date: Thu, 15 Dec 2022 20:14:48 -0600 Subject: [PATCH] Marius Script Compiler - Add marius_mpic tool to compile marius script - Add test cases to handle error scenarios - Generate code in mpic_gen directory - Add github workflow to test See test/mpic/examples for example usage --- .flake8 | 12 + .github/workflows/mpic_test.yml | 27 ++ examples/mpic/basic_layer.py | 21 + setup.cfg | 5 +- src/python/tools/mpic/__init__.py | 0 src/python/tools/mpic/astpp.py | 60 +++ src/python/tools/mpic/attrs.py | 187 ++++++++ src/python/tools/mpic/builtins.py | 121 +++++ src/python/tools/mpic/codegen.py | 273 +++++++++++ src/python/tools/mpic/compiler.py | 24 + src/python/tools/mpic/features.py | 153 ++++++ src/python/tools/mpic/globalpass.py | 72 +++ src/python/tools/mpic/layerpass.py | 78 +++ src/python/tools/mpic/marius_mpic.py | 38 ++ src/python/tools/mpic/render.py | 42 ++ src/python/tools/mpic/symtab.py | 33 ++ .../tools/mpic/templates/gnn_layer.jinja.cpp | 165 +++++++ .../tools/mpic/templates/gnn_layer.jinja.h | 40 ++ src/python/tools/mpic/typechecker.py | 443 ++++++++++++++++++ src/python/tools/mpic/utils.py | 54 +++ test/mpic/errors/call_invalid_args.py | 9 + test/mpic/errors/disable_assert.py | 10 + test/mpic/errors/disable_format.py | 10 + test/mpic/errors/disable_generators.py | 12 + test/mpic/errors/disable_globals.py | 12 + test/mpic/errors/disable_imports.py | 12 + test/mpic/errors/disable_lambdas.py | 10 + test/mpic/errors/disable_multiassign.py | 9 + test/mpic/errors/disable_nesting.py | 12 + test/mpic/errors/disable_non_layer_classes.py | 12 + test/mpic/errors/disable_print.py | 10 + test/mpic/errors/duplicate_class.py | 20 + test/mpic/errors/duplicate_fn.py | 12 + test/mpic/errors/invalid_return.py | 9 + test/mpic/errors/require_argtypes.py | 9 + test/mpic/examples/basic_layer.py | 21 + test/mpic/test_codegen.py | 20 + test/mpic/test_error_handling.py | 84 ++++ tox.ini | 13 + 39 files changed, 2153 insertions(+), 1 deletion(-) create mode 100644 .github/workflows/mpic_test.yml create mode 100644 examples/mpic/basic_layer.py create mode 100644 src/python/tools/mpic/__init__.py create mode 100644 src/python/tools/mpic/astpp.py create mode 100644 src/python/tools/mpic/attrs.py create mode 100644 src/python/tools/mpic/builtins.py create mode 100644 src/python/tools/mpic/codegen.py create mode 100644 src/python/tools/mpic/compiler.py create mode 100644 src/python/tools/mpic/features.py create mode 100644 src/python/tools/mpic/globalpass.py create mode 100644 src/python/tools/mpic/layerpass.py create mode 100644 src/python/tools/mpic/marius_mpic.py create mode 100644 src/python/tools/mpic/render.py create mode 100644 src/python/tools/mpic/symtab.py create mode 100644 src/python/tools/mpic/templates/gnn_layer.jinja.cpp create mode 100644 src/python/tools/mpic/templates/gnn_layer.jinja.h create mode 100644 src/python/tools/mpic/typechecker.py create mode 100644 src/python/tools/mpic/utils.py create mode 100644 test/mpic/errors/call_invalid_args.py create mode 100644 test/mpic/errors/disable_assert.py create mode 100644 test/mpic/errors/disable_format.py create mode 100644 test/mpic/errors/disable_generators.py create mode 100644 test/mpic/errors/disable_globals.py create mode 100644 test/mpic/errors/disable_imports.py create mode 100644 test/mpic/errors/disable_lambdas.py create mode 100644 test/mpic/errors/disable_multiassign.py create mode 100644 test/mpic/errors/disable_nesting.py create mode 100644 test/mpic/errors/disable_non_layer_classes.py create mode 100644 test/mpic/errors/disable_print.py create mode 100644 test/mpic/errors/duplicate_class.py create mode 100644 test/mpic/errors/duplicate_fn.py create mode 100644 test/mpic/errors/invalid_return.py create mode 100644 test/mpic/errors/require_argtypes.py create mode 100644 test/mpic/examples/basic_layer.py create mode 100644 test/mpic/test_codegen.py create mode 100644 test/mpic/test_error_handling.py diff --git a/.flake8 b/.flake8 index f7adaf26..d681a215 100644 --- a/.flake8 +++ b/.flake8 @@ -10,6 +10,15 @@ ignore = W503 # whitespace before ':' E203 + # function name should be lowercase + N802 +per-file-ignores: + src/python/tools/mpic/features.py:N802 + src/python/tools/mpic/typechecker.py:N802 + src/python/tools/mpic/attrs.py:N802 + src/python/tools/mpic/codegen.py:N802 + src/python/tools/mpic/globalpass.py:N802 + src/python/tools/mpic/layerpass.py:N802 exclude = .tox .git, @@ -18,6 +27,9 @@ exclude = *.pyc, *third_party*, scripts + test/mpic/errors + test/mpic/examples + examples/mpic max-line-length = 120 max-complexity = 25 import-order-style = pycharm diff --git a/.github/workflows/mpic_test.yml b/.github/workflows/mpic_test.yml new file mode 100644 index 00000000..2a220d96 --- /dev/null +++ b/.github/workflows/mpic_test.yml @@ -0,0 +1,27 @@ +name: Testing Marius Script Compiler +on: + push: + branches: + - main + pull_request: + branches: + - main + +jobs: + mpic: + runs-on: ubuntu-latest + container: ${{ matrix.python_container }} + strategy: + matrix: + python_container: ["python:3.10"] + + steps: + # Downloads a copy of the code in your repository before running CI tests + - name: Check out repository code + uses: actions/checkout@v3 + + - name: Installing dependencies + run: MARIUS_NO_BINDINGS=1 python3 -m pip install .[tests,mpic] + + - name: Running pytest + run: MARIUS_NO_BINDINGS=1 pytest -s test/mpic/test_codegen.py -s test/mpic/test_error_handling.py diff --git a/examples/mpic/basic_layer.py b/examples/mpic/basic_layer.py new file mode 100644 index 00000000..1bcbe0e7 --- /dev/null +++ b/examples/mpic/basic_layer.py @@ -0,0 +1,21 @@ +""" +See https://docs.dgl.ai/tutorials/blitz/3_message_passing.html +""" + + +class BasicLayer(mpi.Module): + def __init__(self, input_dim: int, output_dim: int): + super(mpi.Module, self).__init__(input_dim, output_dim) + self.linear = mpi.Linear(input_dim * 2, output_dim) + self.reset_parameters() + + def reset_parameters(self): + self.linear.reset_parameters() + + def forward(self, graph: mpi.DENSEGraph, h: mpi.Tensor) -> mpi.Tensor: + with graph.local_scope(): + graph.ndata["h"] = h + graph.update_all(message_func=mpi.copy_u("h", "m"), reduce_func=mpi.mean("m", "h_N")) + h_N = graph.ndata["h_N"] + h_total = mpi.cat(h, h_N, dim=1) + return self.linear(h_total) diff --git a/setup.cfg b/setup.cfg index 5a76bbc8..56281faa 100644 --- a/setup.cfg +++ b/setup.cfg @@ -31,6 +31,8 @@ docs = db2graph = psycopg2-binary mysql-connector-python +mpic = + Jinja2 [options] install_requires = @@ -65,4 +67,5 @@ console_scripts = marius_config_generator = marius.tools.marius_config_generator:main marius_predict = marius.tools.marius_predict:main marius_env_info = marius.distribution.marius_env_info:main - marius_db2graph = marius.tools.db2graph.marius_db2graph:main \ No newline at end of file + marius_db2graph = marius.tools.db2graph.marius_db2graph:main + marius_mpic = marius.tools.mpic.marius_mpic:main diff --git a/src/python/tools/mpic/__init__.py b/src/python/tools/mpic/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/python/tools/mpic/astpp.py b/src/python/tools/mpic/astpp.py new file mode 100644 index 00000000..04ffff0d --- /dev/null +++ b/src/python/tools/mpic/astpp.py @@ -0,0 +1,60 @@ +""" +A pretty-printing dump function for the ast module. The code was copied from +the ast.dump function and modified slightly to pretty-print. + +Alex Leone (acleone ~AT~ gmail.com), 2010-01-30 +""" + +from ast import AST, iter_fields, parse + + +def dump(node, annotate_fields=True, include_attributes=False, indent=" "): + """ + Return a formatted dump of the tree in *node*. This is mainly useful for + debugging purposes. The returned string will show the names and the values + for fields. This makes the code impossible to evaluate, so if evaluation is + wanted *annotate_fields* must be set to False. Attributes such as line + numbers and column offsets are not dumped by default. If this is wanted, + *include_attributes* can be set to True. + """ + + def _format(node, level=0): + if isinstance(node, AST): + fields = [(a, _format(b, level)) for a, b in iter_fields(node)] + if include_attributes and node._attributes: + fields.extend([(a, _format(getattr(node, a), level)) for a in node._attributes]) + return "".join( + [ + node.__class__.__name__, + "(", + ", ".join(("%s=%s" % field for field in fields) if annotate_fields else (b for a, b in fields)), + ")", + ] + ) + elif isinstance(node, list): + lines = ["["] + lines.extend((indent * (level + 2) + _format(x, level + 2) + "," for x in node)) + if len(lines) > 1: + lines.append(indent * (level + 1) + "]") + else: + lines[-1] += "]" + return "\n".join(lines) + return repr(node) + + if not isinstance(node, AST): + raise TypeError("expected AST, got %r" % node.__class__.__name__) + return _format(node) + + +if __name__ == "__main__": + import sys + + for filename in sys.argv[1:]: + print("=" * 50) + print("AST tree for", filename) + print("=" * 50) + f = open(filename, "r") + fstr = f.read() + f.close() + print(dump(parse(fstr, filename=filename), include_attributes=True)) + print() diff --git a/src/python/tools/mpic/attrs.py b/src/python/tools/mpic/attrs.py new file mode 100644 index 00000000..968198b2 --- /dev/null +++ b/src/python/tools/mpic/attrs.py @@ -0,0 +1,187 @@ +import ast +from contextlib import contextmanager + +from marius.tools.mpic.builtins import IntAttrs, NoneAttrs, get_builtin_classes +from marius.tools.mpic.symtab import SymbolTable +from marius.tools.mpic import astpp # noqa: F401 +from marius.tools.mpic.utils import Arg, Attrs, Callable, ClassType, SemError, SynError + + +class GlobalAttrsPass(ast.NodeVisitor): + """ + Collect all class names in global scope + Ensures + - classes at the top level + """ + + def __init__(self): + self.global_attrs = Attrs() + self.classes = get_builtin_classes() + + def visit_ClassDef(self, classdef): + if classdef.name in self.global_attrs: + raise SemError(f"Class {classdef.name} is multiply defined! lineno:{classdef.lineno}") + + class_type = ClassType(classdef.name) + self.global_attrs[classdef.name] = class_type + # XXX: Add input_dim and output_dim as instance variables + # TODO: add support for const int to prevent modification + self.classes[class_type] = Attrs({"_mpic_class": class_type, "input_dim": IntAttrs, "output_dim": IntAttrs}) + + def visit_Module(self, module): + print(__file__, 32) + for child in module.body: + if isinstance(child, ast.ClassDef): + self.visit(child) + elif not isinstance(child, ast.Expr): + # Allow comments (expressions as standalone statements) + raise SemError(f"Only mpi.Module classes allowed at the top level! lineno:{child.lineno}") + + def generic_visit(self, node): + raise RuntimeError(f"Internal error!\n{astpp.dump(node)}") + + +class InstanceAttrsPass(ast.NodeVisitor): + """ + Generate a mapping from class names to the class attributes + The attribute map has instance variables and instance methods + Class variables and class methods are disallowed at the moment + Instance variables map to its attribute map + Instance methods map to its `Callable` object + Expects a symbol table populated with symbols in the builtin and global scopes + """ + + def __init__(self, symbol_table: SymbolTable, classes: dict[ClassType, Attrs]): + self.symbol_table = symbol_table + self.classes = classes + self.instance_attrs = None + + def visit_Name(self, name): + attrs = self.symbol_table.find_symbol(name.id) + if attrs is None: + raise SemError(f"Could not find symbol {name.id}! lineno:{name.lineno}") + return attrs + + def visit_Attribute(self, attr): + value_attrs = self.visit(attr.value) + if attr.attr not in value_attrs: + raise SemError(f"Could not find symbol {attr.attr}! lineno:{attr.lineno}") + return value_attrs[attr.attr] + + def visit_AnnAssign(self, assign): + """ + Collect instance variables + Ensures + - valid type annotations + - unique names + """ + lno = assign.lineno + + if not isinstance(assign.target, ast.Name): + raise SynError(f"Definition is not supported! lineno:{lno}") + + if not assign.annotation: + raise SynError(f"Require type annotation for instance variables! lineno:{lno}") + + var_name = assign.target.id + var_attrs = self.classes[self.visit(assign.annotation)] + + if var_name in self.instance_attrs: + raise SemError(f"Duplicate attr definition! lineno:{assign.lineno}") + + self.instance_attrs[var_name] = var_attrs + + def visit_FunctionDef(self, func): + """ + Collects instance functions + Ensures + - unique names + - valid type annotations for all arguments and return + - self first argument + Disables + - operator overriding + - decorator lists + - type comments + - position only arguments + - *args and **kwargs + """ + lno = func.lineno + + if func.name in self.instance_attrs: + raise SemError(f"Duplicate attr definition! func: lineno:{lno}") + + if ( + func.name != "__init__" + and func.name != "__call__" + and func.name.startswith("__") + and func.name.endswith("__") + ): + raise SynError(f"Operator overriding is not supported! lineno:{lno}") + + if func.name == "reset": + raise SynError(f"Please use reset_parameters to reset torch params! lineno:{lno}") + + if func.decorator_list: + raise SynError(f"Decorator lists are not supported! lineno:{lno}") + + if func.type_comment: + raise SynError(f"Type comments are not supported! lineno:{lno}") + + if func.args.vararg: + raise SynError(f"*args is not supported! lineno:{lno}") + + if func.args.kwarg: + raise SynError(f"**kwargs is not supported! func: lineno:{lno}") + + if func.args.posonlyargs or func.args.kwonlyargs: + raise SynError(f"pos only or kw only args are not supported! lineno:{lno}") + + if func.args.kw_defaults or func.args.defaults: + raise SynError(f"Defaults for arguments are not supported! lineno:{lno}") + + if func.args.args[0].arg != "self": + raise SynError(f"Only instance methods are supported! lineno:{lno}") + + args = [] + for arg in func.args.args[1:]: + if not arg.annotation: + raise SynError(f"Type annotations are required! lineno:{lno}") + arg_attrs = self.classes[self.visit(arg.annotation)] + args.append(Arg(arg.arg, arg_attrs)) + + if func.returns: + return_attrs = self.classes[self.visit(func.returns)] + else: + return_attrs = NoneAttrs + self.instance_attrs[func.name] = Callable(args, return_attrs) + + def visit_ClassDef(self, classdef): + """ + Populates class attributes + """ + if classdef.decorator_list: + raise SemError(f"Decorator lists are not supported! lineno:{classdef.lineno}") + + if classdef.keywords: + raise SemError(f"Class arguments are not supported! lineno:{classdef.lineno}") + + @contextmanager + def managed_instance_attrs(): + self.instance_attrs = self.classes[ClassType(classdef.name)] + yield + self.instance_attrs = None + + with managed_instance_attrs(): + for child in classdef.body: + self.visit(child) + + def visit_Module(self, module): + for child in module.body: + if isinstance(child, ast.ClassDef): + self.visit(child) + elif not isinstance(child, ast.Expr): + # Allow comments + raise SemError(f"Only mpi.Module classes allowed at the top level! lineno:{child.lineno}") + + def generic_visit(self, node): + raise RuntimeError(f"Internal error!\n{astpp.dumps(node)}") diff --git a/src/python/tools/mpic/builtins.py b/src/python/tools/mpic/builtins.py new file mode 100644 index 00000000..05065beb --- /dev/null +++ b/src/python/tools/mpic/builtins.py @@ -0,0 +1,121 @@ +from marius.tools.mpic.utils import Arg, Attrs, Callable, ClassType + +IntType = ClassType("int") +StrType = ClassType("str") +VoidType = ClassType("None") + +IntAttrs = Attrs({"_mpic_class": IntType}) +StrAttrs = Attrs({"_mpic_class": StrType}) +NoneAttrs = Attrs({"_mpic_class": VoidType}) + + +def is_numeric(value_attrs: Attrs) -> bool: + return value_attrs is IntAttrs + + +def is_consistent_with(src_attrs: Attrs, dst_attrs: Attrs) -> bool: + """ + Return True iff objects of type src can be converted to objects of type dst + """ + if src_attrs is NodeDataAttrs or src_attrs is EdgeDataAttrs: + src_attrs = TensorAttrs + return src_attrs is dst_attrs + + +# XXX: assumption: single graph in the entire program +node_data_prefix = "_mpic_ndata_" +edge_data_prefix = "_mpic_edata_" + +Tensor = ClassType("_mpic_Tensor") # torch Tensor +Linear = ClassType("_mpic_Linear") # torch Linear +DENSEGraph = ClassType("_mpic_DENSEGraph") # Marius DENSEGraph +Layer = ClassType("_mpic_Module") # Marius base layer + +MessageFuncAttrs = Attrs({"_mpic_class": "_mpic_MessageFunc"}) +ReduceFuncAttrs = Attrs({"_mpic_class": "_mpic_ReduceFunc"}) +NodeDataAttrs = Attrs({"_mpic_class": Tensor}) +EdgeDataAttrs = Attrs({"_mpic_class": Tensor}) +GraphLocalScopeAttrs = Attrs({"_mpic_class": "_mpic_GraphLocalScope"}) + +LayerAttrs = Attrs( + { + "__init__": Callable( + [Arg("_mpic_input_dim", IntAttrs), Arg("_mpic_output_dim", IntAttrs)], + NoneAttrs, + ), + "_mpic_class": Layer, + } +) +TensorAttrs = Attrs({"_mpic_class": Tensor}) +LinearAttrs = Attrs( + { + "__init__": Callable( + [ + Arg("in_features", IntAttrs), + Arg("out_features", IntAttrs), + ], + NoneAttrs, + ), + "reset_parameters": Callable([], NoneAttrs), + "__call__": Callable([Arg("inputs", TensorAttrs)], TensorAttrs), + "_mpic_class": Linear, + } +) +DENSEGraphAttrs = Attrs( + { + "update_all": Callable( + [ + Arg("message_func", MessageFuncAttrs), + Arg("reduce_func", ReduceFuncAttrs), + ], + NoneAttrs, + ), + "local_scope": Callable([], GraphLocalScopeAttrs), + "ndata": NodeDataAttrs, + "edata": EdgeDataAttrs, + "_mpic_class": DENSEGraph, + } +) + + +def get_builtin_classes(): + # XXX: operator overriding is not supported! + return { + IntType: IntAttrs, + StrType: StrAttrs, + VoidType: NoneAttrs, + Tensor: TensorAttrs, + Linear: LinearAttrs, + DENSEGraph: DENSEGraphAttrs, + Layer: LayerAttrs, + } + + +builtin_attrs = Attrs({"int": IntType, "str": StrType}) +builtin_attrs["mpi"] = { + # XXX: no inheritance supported => self is current type + "Tensor": Tensor, + "Linear": Linear, + "DENSEGraph": DENSEGraph, + "Module": Layer, + "copy_u": Callable([Arg("u", StrAttrs), Arg("out", StrAttrs)], MessageFuncAttrs), + "mean": Callable([Arg("msg", StrAttrs), Arg("out", StrAttrs)], ReduceFuncAttrs), + "cat": Callable( + [ + Arg("tensor0", TensorAttrs), + Arg("tensor1", TensorAttrs), + Arg("dim", IntAttrs), + ], + TensorAttrs, + ), +} + + +def get_builtin_typemap() -> dict[Attrs, str]: + return { + IntType: "int", + StrType: "string", + Tensor: "torch::Tensor", + Linear: "torch::Tensor", + DENSEGraph: "DENSEGraph", + } diff --git a/src/python/tools/mpic/codegen.py b/src/python/tools/mpic/codegen.py new file mode 100644 index 00000000..dbc1d661 --- /dev/null +++ b/src/python/tools/mpic/codegen.py @@ -0,0 +1,273 @@ +import ast + +from marius.tools.mpic import astpp # noqa: F401 +from marius.tools.mpic.builtins import DENSEGraphAttrs, LinearAttrs, builtin_attrs, node_data_prefix +from marius.tools.mpic.utils import Attrs, Callable, ClassType + +instance_var_suffix = "_" + + +def extract_options(classdef: ast.ClassDef, class_attrs: Attrs, typemap: dict[str, str]) -> (dict[str, str], str, str): + """ + Add all options except input_dim and output_dim + input_dim and output_dim are inferred from call to super + """ + # extract __init__ + init_func = None + for child in classdef.body: + if isinstance(child, ast.FunctionDef) and child.name == "__init__": + init_func = child + + input_dim = None + output_dim = None + for child in init_func.body: + if ast.unparse(child).startswith("super"): + args = child.value.args + input_dim = args[0].id + output_dim = args[1].id + + options = dict() + exclude = {input_dim, output_dim} + init_sig = class_attrs["__init__"] + for arg in init_sig.args: + if arg.name not in exclude: + options[arg.name] = typemap[arg.attrs["_mpic_class"]] + + return options, input_dim, output_dim + + +def extract_member_variables(class_attrs: Attrs, typemap: dict[str, str]) -> dict[str, str]: + member_vars = dict() + for name, attrs in class_attrs.items(): + if name != "_mpic_class" and name not in {"input_dim", "output_dim"} and isinstance(attrs, Attrs): + member_vars[name + instance_var_suffix] = typemap[attrs["_mpic_class"]] + return member_vars + + +def extract_func_decls(class_attrs: Attrs, typemap: dict[str, str]) -> list[dict]: + func_decls = [] + for name, attrs in class_attrs.items(): + if ( + "__mpic_" not in name + and name not in {"__init__", "reset_parameters", "forward"} + and isinstance(attrs, Callable) + ): + returns = typemap[attrs.return_attrs["_mpic_class"]] + arg_types = [typemap(arg.attrs["_mpic_class"]) for arg in attrs.args] + arg_names = [arg.name for arg in attrs.args] + args = [f"{arg_type} {name}" for arg_type, name in zip(arg_types, arg_names)] + func_decls.append({"returns": returns, "name": name, "args": args}) + return func_decls + + +def extract_local_vars(func_vars: dict[str, dict[str, Attrs]], typemap: dict[str, str]) -> dict[str, dict[str, str]]: + return { + func_name: {name: typemap[attrs["_mpic_class"]] for name, attrs in local_vars.items()} + for func_name, local_vars in func_vars.items() + } + + +class FunctionGenerator(ast.NodeVisitor): + """ + visit_astNode + - returns the equaivalent cpp source for expressions + - writes out statements with indentation + local_vars stores all the local variables + TODO: Add support for splitting long lines mirroring python code + """ + + def __init__(self, classes, class_attrs, expr_attrs, input_dim, output_dim): + self.classes = classes + self.class_attrs = class_attrs + self.expr_attrs = expr_attrs + self.statements = [] + self.indentation = 0 + self.input_dim = input_dim + self.output_dim = output_dim + + def write(self, stmt: str): + self.statements.append(" " * self.indentation + stmt) + + ########## + # Literals + ########## + def visit_Constant(self, constant): + return str(constant.value) + + ########### + # Variables + ########### + def visit_Name(self, name): + if name.id == self.input_dim: + return "input_dim_" + if name.id == self.output_dim: + return "output_dim_" + if name.id == "self": + return "(*this)" + return name.id + + ############# + # Expressions + ############# + def visit_Expr(self, expr): + if not isinstance(expr.value, ast.Constant): + # Ignore function comments + self.write(self.visit(expr.value) + ";") + + def visit_BinOp(self, bin_op): + # TODO: assign priorities to operators and bracket them accordingly + # TODO: Support more operators + left_expr = self.visit(bin_op.left) + right_expr = self.visit(bin_op.right) + return f"{left_expr} * {right_expr}" + + def visit_Call(self, call): + func_expr = self.visit(call.func) + + func_attrs = self.expr_attrs[call.func] + if isinstance(func_attrs, ClassType): + func_attrs = self.classes[func_attrs]["__init__"] + elif isinstance(func_attrs, Attrs): + func_attrs = func_attrs["__call__"] + pos_args = [self.visit(arg) for arg in call.args] + keywords = {keyword.arg: self.visit(keyword.value) for keyword in call.keywords} + keyword_args = [keywords[arg.name] for arg in func_attrs.args[len(pos_args) :]] # noqa: E203 + args = pos_args + keyword_args + + if func_attrs is DENSEGraphAttrs["update_all"]: + # XXX: must be DENSEGraph::update_all + # TODO: support user defined functions + # XXX: assume copy_u and mean + message_func = None + if len(call.args) > 0: + message_func = call.args[0] + reduce_func = None + if len(call.args) > 1: + reduce_func = call.args[1] + for keyword in call.keywords: + if keyword.arg == "message_func": + message_func = keyword.value + elif keyword.arg == "reduce_func": + reduce_func = keyword.value + copysrc = message_func.args[0].value + meandst = reduce_func.args[1].value + inputs = node_data_prefix + copysrc + outputs = node_data_prefix + meandst + graph = self.visit(call.func.value) + return f"{outputs} = update_all({graph}, {inputs});" + elif func_attrs is builtin_attrs["mpi"]["cat"]: + return f"torch::cat({{{args[0]}, {args[1]}}}, {args[2]})" + elif func_attrs is LinearAttrs["reset_parameters"]: + target = self.visit(call.func.value) + linear_input_dim = f"_mpic_{target}input_dim_" + linear_output_dim = f"_mpic_{target}output_dim_" + dims = f"{{{linear_output_dim}, {linear_input_dim}}}" + init_tensor = f"initialize_tensor(config_->init, {dims}, tensor_options)" + autograd = f"{init_tensor}.set_requires_grad(true)" + register = f'register_parameter("{target}", {autograd})' + return f"{target} = {register};" + elif func_attrs is LinearAttrs["__call__"]: + target = self.visit(call.func) + inputs = self.visit(call.args[0]) + return f"torch::matmul({target}, {inputs}.transpose(0, -1)).transpose(0, -1)" + else: + args_expr = ", ".join(args) + return f"{func_expr}({args_expr})" + + def visit_Attribute(self, attr): + if ast.unparse(attr.value) == "self": + if isinstance(self.class_attrs[attr.attr], Attrs): + # Instance variables are suffixed using an underscore + return attr.attr + instance_var_suffix + else: + return attr.attr if attr.attr != "reset_parameters" else "reset" + value_expr = self.visit(attr.value) + return f"{value_expr}.{attr.attr}" + + ############## + # Subscripting + ############## + def visit_Subscript(self, subscript): + return f"_mpic_{subscript.value.attr}_{subscript.slice.value}" + + ############ + # Statements + ############ + def visit_Assign(self, assign): + target = assign.targets[0] # only 1 target supported in the current grammar! + target_expr = self.visit(target) + + if self.expr_attrs[assign.value] is LinearAttrs: + linear_input_dim = self.visit(assign.value.args[0]) + linear_output_dim = self.visit(assign.value.args[1]) + self.write(f"_mpic_{target_expr}input_dim_ = {linear_input_dim};") + self.write(f"_mpic_{target_expr}output_dim_ = {linear_output_dim};") + else: + value_expr = self.visit(assign.value) + self.write(f"{target_expr} = {value_expr};") + + def visit_Pass(self, _): + pass + + ############## + # Control Flow + ############## + def visit_With(self, with_node): + for stmt in with_node.body: + self.visit(stmt) + + ################################ + # Function and Class Definitions + ################################ + def visit_FunctionDef(self, func): + for child in func.body: + # XXX: cleaner to throw and handle an exception + stmt = ast.unparse(child) + if stmt.startswith("super(mpi.Module, self).__init__("): + continue + self.visit(child) + + def visit_Return(self, returns): + return_expr = self.visit(returns.value) + self.write(f"return {return_expr};") + + def generic_visit(self, node): + raise RuntimeError(f"Internal error!: {astpp.dump(node)}") + + +def generate_func_body( + classes: dict[str, Attrs], + classdef: ast.ClassDef, + expr_attrs, + input_dim: str, + output_dim: str, +) -> dict[str, str]: + func_body = dict() + for child in classdef.body: + if isinstance(child, ast.FunctionDef): + conf_input_dim = None + conf_output_dim = None + if child.name == "__init__": + conf_input_dim = input_dim + conf_output_dim = output_dim + func_generator = FunctionGenerator( + classes, + classes[classdef.name], + expr_attrs, + conf_input_dim, + conf_output_dim, + ) + func_generator.visit(child) + func_body[child.name] = func_generator.statements + return func_body + + +def extract_helper_funcs(func_decls, local_vars, func_body): + return [ + func_decl + | { + "local_vars": local_vars[func_decl["name"]], + "body": func_body[func_decl["name"]], + } + for func_decl in func_decls + ] diff --git a/src/python/tools/mpic/compiler.py b/src/python/tools/mpic/compiler.py new file mode 100644 index 00000000..35f2f938 --- /dev/null +++ b/src/python/tools/mpic/compiler.py @@ -0,0 +1,24 @@ +""" +Compiles the input marius script file into a header and a source file + +Tools +- ast: used for parsing marius script +- Jinja2: used as for the code skeleton +""" + +import ast +import logging + +from marius.tools.mpic.globalpass import run_global_pass +from marius.tools.mpic.utils import SynError + + +def run_compiler(filename: str): + with open(filename, "r") as modf: + logging.info(f"Generating AST from file {filename} ...") + contents = modf.read() + if "_mpic_" in contents: + # Reserve _mpic_ to not appear anywhere in the program + raise SynError("Pattern _mpic_ is reserved and cannot appear anywhere in the program!") + tree = ast.parse(contents) + run_global_pass(tree) diff --git a/src/python/tools/mpic/features.py b/src/python/tools/mpic/features.py new file mode 100644 index 00000000..bb8db12f --- /dev/null +++ b/src/python/tools/mpic/features.py @@ -0,0 +1,153 @@ +import ast + +from marius.tools.mpic import astpp # noqa: F401 +from marius.tools.mpic.utils import SynError + + +class FeatureFilterPass(ast.NodeVisitor): + """ + The compiler does not support the complete python grammar + Detect and filter syntax errors before semantic analysis + + Disables + - comprehensions + - print + - assert + - lambdas + - generators + - coroutines + - interactive and eval modes + - starred variables + - formatting strings + """ + + ########## + # Literals + ########## + + def visit_FormattedValue(self, node): + raise SynError(f"Format spec is not supported! lineno:{node.lineno}") + + def visit_JoinedStr(self, node): + raise SynError(f"Python style formatting is not supported! lineno:{node.lineno}") + + ########### + # Variables + ########### + def visit_Starred(self, node): + raise SynError(f"Starred variables are not supported! lineno:{node.lineno}") + + ############# + # Expressions + ############# + def visit_NamedExpr(self, node): + raise SynError(f"Named expressions are not supported! lineno:{node.lineno}") + + ############## + # Subscripting + ############## + def visit_Index(self, index): + if not isinstance(index.value, str): + raise SynError(f"Only string indices are supported! lineno:{index.lineno}") + + def visit_Slice(self, index): + raise SynError(f"Index slices are not supported! lineno:{index.lineno}") + + def visit_ExtSlice(self, extslice): + raise SynError(f"Ext slices are not supported! lineno:{extslice.lineno}") + + ################ + # Comprehensions + ################ + def visit_ListComp(self, node): + raise SynError(f"List comprehensions are not supported! lineno:{node.lineno}") + + def visit_SetComp(self, node): + raise SynError(f"Set comprehensions are not supported! lineno:{node.lineno}") + + def visit_GeneratorExp(self, node): + raise SynError(f"Generator expresions are not supported! lineno:{node.lineno}") + + def visit_DictComp(self, node): + raise SynError(f"Dict comprehensions are not supported! lineno:{node.lineno}") + + ############ + # Statements + ############ + def AugAssign(self, node): + raise SynError(f"Augmented Assign is not supported! lineno:{node.lineno}") + + def visit_Print(self, node): + raise SynError(f"Print is not supported! lineno:{node.lineno}") + + def visit_Assert(self, node): + raise SynError(f"Assert statements are not supported! lineno:{node.lineno}") + + def visit_Delete(self, node): + raise SynError(f"Delete statements are not supported! lineno:{node.lineno}") + + ######### + # Imports + ######### + def visit_Import(self, node): + raise SynError(f"Imports are not supported! lineno:{node.lineno}") + + def visit_ImportFrom(self, node): + raise SynError(f"Imports are not supported! lineno:{node.lineno}") + + ############## + # Control Flow + ############## + def visit_For(self, node): + raise SynError(f"Loops are not supported! lineno:{node.lineno}") + + def visit_While(self, node): + raise SynError(f"Loops are not supported! lineno:{node.lineno}") + + def visit_Break(self, node): + raise SynError(f"Loops are not supported! lineno:{node.lineno}") + + def visit_Continue(self, node): + raise SynError(f"Loops are not supported! lineno:{node.lineno}") + + def visit_Try(self, node): + raise SynError(f"Exception handling is not supported! lineno:{node.lineno}") + + def visit_Finally(self, node): + raise SynError(f"Exception handling is not supported! lineno:{node.lineno}") + + def visit_Except(self, node): + raise SynError(f"Exception handling is not supported! lineno:{node.lineno}") + + ################################ + # Function and Class Definitions + ################################ + def visit_Lambda(self, node): + raise SynError(f"No lambda definitions allowed! lineno:{node.lineno}") + + def visit_Yield(self, node): + raise SynError(f"Generators not supported! lineno:{node.lineno}") + + def visit_YieldFrom(self, node): + raise SynError(f"Generators not supported! lineno:{node.lineno}") + + def Global(self, node): + raise SynError(f"Global not supported! lineno:{node.lineno}") + + def NonLocal(self, node): + raise SynError(f"NonLocal not supported! lineno:{node.lineno}") + + ################# + # Async and Await + ################# + def visit_AsyncFunctionDef(self, node): + raise SynError(f"Coroutines not supported! lineno:{node.lineno}") + + def visit_Await(self, node): + raise SynError(f"Coroutines not supported! lineno:{node.lineno}") + + def visit_AsyncFor(self, node): + raise SynError(f"Coroutines not supported! lineno:{node.lineno}") + + def visit_AsyncWith(self, node): + raise SynError(f"Coroutines not supported! lineno:{node.lineno}") diff --git a/src/python/tools/mpic/globalpass.py b/src/python/tools/mpic/globalpass.py new file mode 100644 index 00000000..d2a4b607 --- /dev/null +++ b/src/python/tools/mpic/globalpass.py @@ -0,0 +1,72 @@ +import ast +import logging + +from marius.tools.mpic import astpp # noqa: F401 +from marius.tools.mpic.attrs import GlobalAttrsPass, InstanceAttrsPass +from marius.tools.mpic.builtins import get_builtin_typemap +from marius.tools.mpic.features import FeatureFilterPass +from marius.tools.mpic.layerpass import run_layer_pass +from marius.tools.mpic.render import CodeRenderer +from marius.tools.mpic.symtab import SymbolTable +from marius.tools.mpic.utils import SemError, camel_to_snake + + +class LayerExtractorPass(ast.NodeVisitor): + """ + Runs semantic checks for each GNN layer module + Expects both a symbol table and class instance attributes + Ensures + - class inherits from mpi.Module and nothing else + """ + + def __init__(self, symbol_table, classes, typemap): + self.symbol_table = symbol_table + self.classes = classes + self.typemap = typemap + self.renderer = CodeRenderer() + + def visit_ClassDef(self, classdef): + # TODO: Support inheritance + # TODO: Support multiple classes in the same file + if len(classdef.bases) != 1 or ast.unparse(classdef.bases[0]) != "mpi.Module": + raise SemError(f"All classes must inherit from mpi.Module! lineno:{classdef.lineno}") + + params = run_layer_pass(self.symbol_table, self.classes, classdef, self.typemap) + filename = camel_to_snake(classdef.name) + self.renderer.render_header(params["header_params"], filename + ".h") + self.renderer.render_source(params["source_params"], filename + ".cpp") + + def visit_Module(self, module): + # XXX: cleaner to infer member variables in a separate pass + for child in module.body: + self.visit(child) + + +def populate_typemap(tree: ast.AST) -> dict[str, str]: + typemap = get_builtin_typemap() + for child in tree.body: + if isinstance(child, ast.ClassDef): + typemap[child.name] = child.name + return typemap + + +def run_global_pass(tree: ast.AST): + logging.info("Running global pass ...") + + # Pass1: Detect and disable "fancy" features + FeatureFilterPass().visit(tree) + + # Pass2: Global attribute pass to collect all classes + global_attrs_pass = GlobalAttrsPass() + global_attrs_pass.visit(tree) + global_attrs = global_attrs_pass.global_attrs + symbol_table = SymbolTable(global_attrs) + classes = global_attrs_pass.classes + + # Pass3: Instance attrs pass to generate object schemas + instance_attrs_pass = InstanceAttrsPass(symbol_table, classes) + instance_attrs_pass.visit(tree) + + # Pass4: Type check each class + typemap = populate_typemap(tree) + LayerExtractorPass(symbol_table, classes, typemap).visit(tree) diff --git a/src/python/tools/mpic/layerpass.py b/src/python/tools/mpic/layerpass.py new file mode 100644 index 00000000..13de191c --- /dev/null +++ b/src/python/tools/mpic/layerpass.py @@ -0,0 +1,78 @@ +import ast +import logging + +from marius.tools.mpic import astpp # noqa: F401 +from marius.tools.mpic.codegen import ( + extract_func_decls, + extract_helper_funcs, + extract_local_vars, + extract_member_variables, + extract_options, + generate_func_body, +) +from marius.tools.mpic.symtab import SymbolTable +from marius.tools.mpic.typechecker import TypeCheckerPass +from marius.tools.mpic.utils import ClassType, camel_to_snake + + +def run_layer_pass( + symbol_table: SymbolTable, + classes: dict[ClassType], + classdef: ast.ClassDef, + typemap: dict[str, str], +) -> dict: + logging.info(f"Running layer pass for {classdef.name} ...") + + # Pass1: Check semantic types + typechecker = TypeCheckerPass(symbol_table, classes) + typechecker.visit(classdef) + class_attrs = classes[classdef.name] + expr_attrs = typechecker.expr_attrs + func_vars = typechecker.func_vars + + # TODO: Optimization: Inline and value propagate linear + + # Pass2: Extract options + options, input_dim, output_dim = extract_options(classdef, class_attrs, typemap) + + # Pass3: Extract instance variables + member_vars = extract_member_variables(class_attrs, typemap) + + # Pass4: Extract member function declarations + func_decls = extract_func_decls(class_attrs, typemap) + + # Pass5: Extract local variables + local_vars = extract_local_vars(func_vars, typemap) + + # Pass6: Generate function body + func_body = generate_func_body(classes, classdef, expr_attrs, input_dim, output_dim) + + init_func = {"local_vars": local_vars["__init__"], "body": func_body["__init__"]} + reset_func = { + "local_vars": local_vars["reset_parameters"], + "body": func_body["reset_parameters"], + } + forward_func = { + "local_vars": local_vars["forward"], + "body": func_body["forward"], + "inputs": class_attrs["forward"].args[1].name, + "graph": class_attrs["forward"].args[0].name, + } + helper_funcs = extract_helper_funcs(func_decls, local_vars, func_body) + + return { + "header_params": { + "LayerClassName": classdef.name, + "options": options, + "member_vars": member_vars, + "member_fns": func_decls, + }, + "source_params": { + "layer_class_name": camel_to_snake(classdef.name), + "LayerClassName": classdef.name, + "init": init_func, + "reset": reset_func, + "forward": forward_func, + "member_fns": helper_funcs, + }, + } diff --git a/src/python/tools/mpic/marius_mpic.py b/src/python/tools/mpic/marius_mpic.py new file mode 100644 index 00000000..1c48fc90 --- /dev/null +++ b/src/python/tools/mpic/marius_mpic.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python + +import argparse +import logging + +from marius.tools.mpic.compiler import run_compiler + + +def parse_args(): + """ + positional args + - pyfiles: list of files to compile + optional args + - -l/--log: log level (see https://docs.python.org/3/library/logging.html#logging-levels) + """ + parser = argparse.ArgumentParser(prog="mpic", description="MariusGNN MPI compiler") + parser.add_argument("pyfiles", nargs="+") + parser.add_argument("-l", "--log", default="WARNING") + return parser.parse_args() + + +def main(): + args = parse_args() + + # Setup logging + log_level = getattr(logging, args.log.upper(), None) + if not isinstance(log_level, int): + raise ValueError(f"Invalid log level: {args.log}") + + logging.basicConfig(level=log_level, format="[%(levelname)s]\t%(message)s") + + # Compile each file + for filename in args.pyfiles: + run_compiler(filename) + + +if __name__ == "__main__": + main() diff --git a/src/python/tools/mpic/render.py b/src/python/tools/mpic/render.py new file mode 100644 index 00000000..f2dad535 --- /dev/null +++ b/src/python/tools/mpic/render.py @@ -0,0 +1,42 @@ +""" +Generate cpp code from the IR +""" + +import logging +import os +from pathlib import Path + +from jinja2 import Environment, FileSystemLoader + + +class CodeRenderer: + def __init__(self): + output_dir = Path(os.getcwd(), "build/mpic_gen").resolve() + + self.env = Environment( + loader=FileSystemLoader(Path(Path(__file__).parent, "templates").resolve()), + keep_trailing_newline=True, + trim_blocks=True, + lstrip_blocks=True, + line_comment_prefix="//*", + line_statement_prefix="->", + ) + self.output_dir = output_dir + + if not os.path.exists(output_dir): + logging.info(f"Making missing directory {output_dir}") + os.makedirs(output_dir) + + def render_header(self, params: dict, output_file: str): + logging.info(f"Rendering header file {output_file}") + self.render(params, "gnn_layer.jinja.h", output_file) + + def render_source(self, params: str, output_file: str): + logging.info(f"Rendering source file {output_file}") + self.render(params, "gnn_layer.jinja.cpp", output_file) + + def render(self, params: dict, template: str, output_file: str): + output_file = os.path.join(self.output_dir, output_file) + with open(output_file, "w") as outfile: + output = self.env.get_template(template).render(**params) + outfile.write(output) diff --git a/src/python/tools/mpic/symtab.py b/src/python/tools/mpic/symtab.py new file mode 100644 index 00000000..1eac53a9 --- /dev/null +++ b/src/python/tools/mpic/symtab.py @@ -0,0 +1,33 @@ +from contextlib import contextmanager + +from marius.tools.mpic.builtins import builtin_attrs +from marius.tools.mpic.utils import Attrs + + +class SymbolTable: + def __init__(self, global_attrs): + self.scopes = [builtin_attrs, global_attrs] + self.graph_local_scope = 0 + + def enter_scope(self): + self.scopes.append(Attrs()) + + def exit_scope(self): + return self.scopes.pop() + + @contextmanager + def managed_scope(self): + self.enter_scope() + yield + self.exit_scope() + + def find_symbol(self, symbol: str): + for scope in reversed(self.scopes): + if symbol in scope: + return scope[symbol] + + return None + + def add_symbol(self, symbol: str, value_attrs: Attrs): + local_scope = self.scopes[-1] + local_scope[symbol] = value_attrs diff --git a/src/python/tools/mpic/templates/gnn_layer.jinja.cpp b/src/python/tools/mpic/templates/gnn_layer.jinja.cpp new file mode 100644 index 00000000..8f0740df --- /dev/null +++ b/src/python/tools/mpic/templates/gnn_layer.jinja.cpp @@ -0,0 +1,165 @@ +// +// Autogenerated file! +// + +#include "nn/layers/gnn/{{layer_class_name}}.h" + +#include "nn/layers/gnn/layer_helpers.h" +#include "reporting/logger.h" +#include "nn/initialization.h" + +// XXX: GCC warns on always_inline +#pragma GCC diagnostic ignored "-Wattributes" + +// XXX: always_inline to mirror original code (unsure if more efficient) +#define FUNCGEN __attribute__((always_inline)) + +namespace { + +struct SumFunc { + static constexpr bool useNumNbrs = false; + + FUNCGEN static torch::Tensor segmented_reduce(torch::Tensor const& embeds, Indices const& offsets) { + return segmented_sum_with_offsets(embeds, offsets); + } +}; + +struct MaxFunc { + static constexpr bool useNumNbrs = false; + + FUNCGEN static torch::Tensor segmented_reduce(torch::Tensor const& embeds, Indices const& offsets) { + return segmented_max_with_offsets(embeds, offsets); + } +}; + +struct MeanFunc { + static constexpr bool useNumNbrs = true; + + FUNCGEN static torch::Tensor segmented_reduce(torch::Tensor const& embeds, Indices const& offsets) { + return segmented_sum_with_offsets(embeds, offsets); + } + + FUNCGEN static torch::Tensor applyNumNbrs(torch::Tensor const& a_i, torch::Tensor const& num_nbrs) { + torch::Tensor denominator = torch::where(torch::not_equal( + num_nbrs, 0), num_nbrs, 1).to(a_i.dtype()).unsqueeze(-1); + return a_i / denominator; + } +}; + +template +FUNCGEN torch::Tensor update_all(DENSEGraph& dense_graph, torch::Tensor const& u) { + constexpr bool useNumNbrs = ReduceFunc::useNumNbrs; + torch::Tensor a_i; + [[maybe_unused]] torch::Tensor total_num_neighbors; + + if (dense_graph.out_neighbors_mapping_.defined()) { + Indices outgoing_neighbors = dense_graph.getNeighborIDs(false, false); + Indices outgoing_neighbor_offsets = dense_graph.getNeighborOffsets(false); + torch::Tensor outgoing_num = dense_graph.getNumNeighbors(false); + + torch::Tensor outgoing_embeddings = u.index_select(0, outgoing_neighbors); + a_i = ReduceFunc::segmented_reduce(outgoing_embeddings, outgoing_neighbor_offsets); + + // often, aggregation functions require the number of neighbors + if constexpr(useNumNbrs) { + total_num_neighbors = outgoing_num; + } + } + + if (dense_graph.in_neighbors_mapping_.defined()) { + Indices incoming_neighbors = dense_graph.getNeighborIDs(true, false); + Indices incoming_neighbor_offsets = dense_graph.getNeighborOffsets(true); + torch::Tensor incoming_num = dense_graph.getNumNeighbors(true); + + torch::Tensor incoming_embeddings = u.index_select(0, incoming_neighbors); + + if (a_i.defined()) { + a_i = a_i + segmented_sum_with_offsets(incoming_embeddings, incoming_neighbor_offsets); + } else { + a_i = segmented_sum_with_offsets(incoming_embeddings, incoming_neighbor_offsets); + } + + // often, aggregation functions require the number of neighbors + if constexpr(useNumNbrs) { + if (total_num_neighbors.defined()) { + total_num_neighbors = total_num_neighbors + incoming_num; + } else { + total_num_neighbors = incoming_num; + } + } + } + + if constexpr(useNumNbrs) { + return ReduceFunc::applyNumNbrs(a_i, total_num_neighbors); + } else { + return a_i; + } +} + +} // anonymous namespace + +{{LayerClassName}}::{{LayerClassName}}(shared_ptr layer_config, torch::Device device) { + config_ = layer_config; + options_ = std::dynamic_pointer_cast<{{LayerClassName}}Options>(config_->options); + input_dim_ = config_->input_dim; + output_dim_ = config_->output_dim; + device_ = device; + {% if init.local_vars %} + + {% endif %} + {% for name, var_type in init.local_vars.items() %} + {{var_type}} {{name}}; + {% endfor %} + {% if init.body %} + + {% endif %} + {% for stmt in init.body %} + {{stmt}} + {% endfor %} +} + +void {{LayerClassName}}::reset() { + [[maybe_unused]] auto tensor_options = torch::TensorOptions().dtype(torch::kFloat32).device(device_); + + {% for name, var_type in reset.local_vars.items() %} + {{var_type}} {{name}}; + {% endfor %} + {% if reset.local_vars %} + + {% endif %} + {% for stmt in reset.body %} + {{stmt}} + {% endfor %} + {% if reset.body %} + + {% endif %} + if (config_->bias) { + init_bias(); + } +} + +torch::Tensor {{LayerClassName}}::forward(torch::Tensor {{forward.inputs}}, DENSEGraph {{forward.graph}}, bool train) { + {% for name, var_type in forward.local_vars.items() %} + {{var_type}} {{name}}; + {% endfor %} + {% if forward.local_vars and forward.body %} + + {% endif %} + {% for stmt in forward.body %} + {{stmt}} + {% endfor %} +} +{%- for fn in member_fns -%} + +{{fn.returns}} {{LayerClassName}}::{{fn.name}}({{fn.args|join(', ')}}) { + {%- for name, var_type in fn.local_vars.items() -%} + {{var_type}} {{name}}; + {%- endfor -%} + {% if fn.local_vars and fn.body %} + + {% endif %} + {% for stmt in fn.body %} + {{stmt}} + {% endfor %} +} +{%- endfor -%} diff --git a/src/python/tools/mpic/templates/gnn_layer.jinja.h b/src/python/tools/mpic/templates/gnn_layer.jinja.h new file mode 100644 index 00000000..690f7d22 --- /dev/null +++ b/src/python/tools/mpic/templates/gnn_layer.jinja.h @@ -0,0 +1,40 @@ +// +// Autogenerated file! +// + +#pragma once + +#include "common/datatypes.h" +#include "configuration/options.h" +#include "configuration/config.h" +#include "data/graph.h" +#include "nn/initialization.h" +#include "gnn_layer.h" +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#include "torch/torch.h" +#pragma GCC diagnostic pop + +struct {{LayerClassName}}Options : GNNLayerOptions { + {% for name, option_type in options.items() %} + {{option_type}} {{name}}; + {% endfor %} +}; + +class {{LayerClassName}} : public GNNLayer { + public: + shared_ptr<{{LayerClassName}}Options> options_; + {% for name, var_type in member_vars.items() %} + {{var_type}} {{name}}; + {% endfor %} + + {{LayerClassName}}(shared_ptr layer_config, torch::Device device); + + void reset() override; + + torch::Tensor forward(torch::Tensor inputs, DENSEGraph dense_graph, bool train = true) override; + {% for fn in member_fns %} + + {{fn.returns}} {{fn.name}}({{fn.args|join(', ')}}); + {% endfor %} +}; diff --git a/src/python/tools/mpic/typechecker.py b/src/python/tools/mpic/typechecker.py new file mode 100644 index 00000000..3abff3c9 --- /dev/null +++ b/src/python/tools/mpic/typechecker.py @@ -0,0 +1,443 @@ +import ast +from contextlib import contextmanager +from functools import wraps + +from marius.tools.mpic import astpp # noqa: F401 +from marius.tools.mpic.builtins import ( + DENSEGraphAttrs, + EdgeDataAttrs, + GraphLocalScopeAttrs, + IntAttrs, + LayerAttrs, + LinearAttrs, + NodeDataAttrs, + NoneAttrs, + StrAttrs, + TensorAttrs, + edge_data_prefix, + is_consistent_with, + is_numeric, + node_data_prefix, +) +from marius.tools.mpic.symtab import SymbolTable +from marius.tools.mpic.utils import Attrs, Callable, ClassType, SemError, SynError + + +def update_expr_attrs(visit_expr): + @wraps(visit_expr) + def wrapper(typechecker, node): + expr_attrs = visit_expr(typechecker, node) + typechecker.expr_attrs[node] = expr_attrs + return expr_attrs + + return wrapper + + +class TypeCheckerPass(ast.NodeVisitor): + """ + visit_className returns object attributes of the expression, None otherwise + Runs semantic checks + Registers new instance variables introduced in __init__ using type inference + Ensures + - functions are called with valid arguments + - called symbols are callable + - new instance variables introduced in __init__ are mapped + - assignments are on consistent types + - expressions are properly typed + - operators are called with compatible types + Disables + - named expressions + - nested functions or classes + - registering new variables with non-enclosing scopes + - registering new variables in non __init__ functions + - generic subscripting + Trampolines + - constructors using __init__ + - callable objects using __call__ + """ + + def __init__(self, symbol_table: SymbolTable, classes: dict[str, Attrs]): + self.symbol_table = symbol_table + self.classes = classes + self.expr_attrs = dict() + self.self_attrs = None + self.init_mode = False + self.return_attrs = None + self.func_vars = dict() + self.current_locals = None + self.graph_local_scope = 0 + + ########## + # Literals + ########## + @update_expr_attrs + def visit_Constant(self, constant): + if isinstance(constant.value, int): + return IntAttrs + elif isinstance(constant.value, str): + return StrAttrs + elif isinstance(constant.value, Ellipsis): + return None + else: + raise SemError(f"Literal is not supported! lineno:{constant.lineno}") + + @update_expr_attrs + def visit_List(self, node): + raise SynError(f"Lists are not supported! lineno:{node.lineno}") + + @update_expr_attrs + def visit_Tuple(self, node): + raise SynError(f"Tuples are not supported! lineno:{node.lineno}") + + @update_expr_attrs + def visit_Set(self, node): + raise SynError(f"Sets are not supported! lineno:{node.lineno}") + + @update_expr_attrs + def visit_Dict(self, node): + raise SynError(f"Dicts are not supported! lineno:{node.lineno}") + + ########### + # Variables + ########### + @update_expr_attrs + def visit_Name(self, name): + return self.symbol_table.find_symbol(name.id) + + ############# + # Expressions + ############# + def visit_Expr(self, expr): + self.visit(expr.value) + + @update_expr_attrs + def visit_UnaryOp(self, node): + raise SynError(f"Unary operations are not supported! lineno:{node.lineno}") + + @update_expr_attrs + def visit_BinOp(self, bin_op): + if not isinstance(bin_op.op, ast.Mult): + raise SynError(f"Unexpected operation! lineno:{bin_op.lineno}") + + if not is_numeric(self.visit(bin_op.left)): + raise SemError(f"Unexpected operand! lineno:{bin_op.lineno}") + + if not is_numeric(self.visit(bin_op.right)): + raise SemError(f"Unexpected operand! lineno:{bin_op.lineno}") + + return IntAttrs + + @update_expr_attrs + def visit_BoolOp(self, node): + raise SynError(f"bool ops are not supported! lineno:{node.lineno}") + + @update_expr_attrs + def visit_Compare(self, node): + raise SynError(f"Compare ops not supported! lineno:{node.lineno}") + + @update_expr_attrs + def visit_Call(self, call): + call_attrs = self.visit(call.func) + fqname = ast.unparse(call.func) + lno = call.lineno + return_attrs = None + + # Extract the type of callable + if fqname == "super": + # FIXME: semantic checking for call to super + return LayerAttrs + elif isinstance(call_attrs, ClassType): + # Constructor + class_type = call_attrs + class_attrs = self.classes[class_type] + func_sig = class_attrs["__init__"] + return_attrs = class_attrs + elif isinstance(call_attrs, Attrs): + # callable object + if "__call__" not in call_attrs: + raise SemError(f"Expression {fqname} is not callable! lineno:{lno}") + func_sig = call_attrs["__call__"] + else: + func_sig = call_attrs + + if not isinstance(func_sig, Callable): + raise SemError(f"Expression {fqname} is not callable! lineno:{lno}") + + # Verify function call arguments + formal_args = [arg for arg in reversed(func_sig.args)] # reversed to pop args from behind + if len(formal_args) < len(call.args): + raise SemError(f"Function called with invalid number of arguments! lineno:{lno}") + for actual_arg in call.args: + # Verify positional arguments + actual_attrs = self.visit(actual_arg) + if not is_consistent_with(actual_attrs, formal_args[-1].attrs): + aname = ast.unparse(actual_arg) + raise SemError(f"Function called with invalid arg {aname}! lineno:{lno}") + + formal_args.pop() + + # Expect the rest of the arguments to be called using kw args + if len(formal_args) != len(call.keywords): + raise SemError(f"Function called with invalid number of arguments! lineno:{lno}") + formal_args = {arg.name: arg.attrs for arg in formal_args} + for keyword in call.keywords: + aname = keyword.arg + actual_attrs = self.visit(keyword.value) + if aname not in formal_args: + raise SemError(f"keyword not available in remaining arguments! lineno:{lno}") + + if not is_consistent_with(actual_attrs, formal_args[aname]): + raise SemError(f"Function called with invalid arg {aname}! lineno:{lno}") + + del formal_args[aname] + + if call_attrs == DENSEGraphAttrs["update_all"]: + # XXX: assume copy_u and mean + message_func = None + if len(call.args) > 0: + message_func = call.args[0] + reduce_func = None + if len(call.args) > 1: + reduce_func = call.args[1] + for keyword in call.keywords: + if keyword.arg == "message_func": + message_func = keyword.value + elif keyword.arg == "reduce_func": + reduce_func = keyword.value + copysrc, copydst = message_func.args[0], message_func.args[1] + meansrc, meandst = reduce_func.args[0], reduce_func.args[1] + if ( + not isinstance(copysrc, ast.Constant) + or not isinstance(copysrc.value, str) + or not isinstance(copydst, ast.Constant) + or not isinstance(copydst.value, str) + or not isinstance(meansrc, ast.Constant) + or not isinstance(meansrc.value, str) + or not isinstance(meandst, ast.Constant) + or not isinstance(meandst.value, str) + or copydst.value != meansrc.value + ): + raise SemError(f"Invalid update_all invocation! lineno:{lno}") + # Add meandst to the list of local variables + var_name = node_data_prefix + meandst.value + self.symbol_table.add_symbol(var_name, NodeDataAttrs) + self.current_locals[var_name] = NodeDataAttrs + + if call_attrs is LayerAttrs["__init__"]: + if not isinstance(call.args[0], ast.Name) or not isinstance(call.args[1], ast.Name): + raise SynError(f"Must pass a name! lineno{lno}") + + # Type of the call expression is the return type of the call expression + # XXX: return_attrs handles the special case of a constructor + return return_attrs if return_attrs else func_sig.return_attrs + + @update_expr_attrs + def visit_IfExp(self, node): + raise SynError(f"if-else is not supported! lineno:{node.lineno}") + + @update_expr_attrs + def visit_Attribute(self, attr): + value_attrs = self.visit(attr.value) + if attr.attr in value_attrs: + return value_attrs[attr.attr] + + return None + + ############## + # Subscripting + ############## + def get_graph_local_name(self, subscript): + if not self.graph_local_scope: + raise SemError(f"Not in graph local scope! lineno:{subscript.lineno}") + graph_attrs = self.visit(subscript.value) + if graph_attrs is NodeDataAttrs: + return node_data_prefix + subscript.slice.value + elif graph_attrs is EdgeDataAttrs: + return edge_data_prefix + subscript.slice.value + else: + raise SemError(f"Generic subscripts are not supported! lineno:{subscript.lineno}") + + @update_expr_attrs + def visit_Subscript(self, subscript): + return self.symbol_table.find_symbol(self.get_graph_local_name(subscript)) + + ############ + # Statements + ############ + def visit_Assign(self, assign): + """ + Implements core logic of automatically inferring new local and instance variables + """ + lno = assign.lineno + if len(assign.targets) != 1: + raise SynError(f"Multi-assignments are not supported! {lno}") + + target = assign.targets[0] + value_attrs = self.visit(assign.value) + + formal_attrs = self.visit(target) + if formal_attrs is not None: + # Definition already exists! + if not is_consistent_with(value_attrs, formal_attrs): + raise SemError(f"Type mismatch: cannot assign to variable! lineno:{lno}") + elif isinstance(target, ast.Name): + # register a new local variable + if target.id.endswith("_"): + raise SynError(f"Variables cannot have a trailing underscore! lineno:{lno}") + self.symbol_table.add_symbol(target.id, value_attrs) + self.current_locals[target.id] = value_attrs + elif isinstance(target, ast.Attribute): + # Only allow self. = to register new varaibles + if ast.unparse(target.value) != "self": + raise SemError(f"Cannot add symbols to non-owning scopes! lineno:{lno}") + if not self.init_mode: + raise SemError(f"Cannot register new instance variables outside of __init__! lineno:{lno}") + if target.attr in {"input_dim", "output_dim"}: + raise SynError(f"Cannot modify input_dim and output_dim! lineno:{lno}") + if target.attr.endswith("_"): + raise SynError(f"Variables cannot have a trailing underscore! lineno:{lno}") + self.self_attrs[target.attr] = value_attrs + if value_attrs is LinearAttrs: + self.self_attrs[f"_mpic_{target.attr}_input_dim"] = IntAttrs + self.self_attrs[f"_mpic_{target.attr}_output_dim"] = IntAttrs + elif isinstance(target, ast.Subscript): + # graph.ndata['h'] = ... + local_name = self.get_graph_local_name(target) + if local_name.endswith("_"): + raise SynError(f"Variables cannot have a trailing underscore! lineno:{lno}") + self.symbol_table.add_symbol(local_name, value_attrs) + self.current_locals[local_name] = value_attrs + + def visit_AnnAssign(self, assign): + if self.return_attrs: + raise SemError(f"Annotated assignments are not supported in func body! lineno:{assign.lineno}") + + def visit_Raise(self, node): + raise SynError(f"Cannot raise exceptions! lineno:{node.lineno}") + + def visit_Pass(self, _): + pass + + ############## + # Control Flow + ############## + def visit_If(self, node): + raise SemError(f"Control flow is not supported! lineno:{node.lineno}") + + ################################ + # Function and Class Definitions + ################################ + def visit_FunctionDef(self, func): + if self.return_attrs: + raise SemError(f"Nested functions are not supported! lineno:{func.lineno}") + + func_sig = self.self_attrs[func.name] + + @contextmanager + def managed_returns(): + self.return_attrs = func_sig.return_attrs + yield + self.return_attrs = None + + @contextmanager + def managed_args(): + # add self + self.symbol_table.add_symbol("self", self.self_attrs) + # add other args + for arg in func_sig.args: + self.symbol_table.add_symbol(arg.name, arg.attrs) + yield + # XXX: symbols removed automatically through managed scope + ... + + @contextmanager + def managed_locals(): + self.current_locals = dict() + yield + self.func_vars[func.name] = self.current_locals + self.current_locals = None + + with managed_returns(), self.symbol_table.managed_scope(), managed_args(), managed_locals(): + for child in func.body: + self.visit(child) + + def visit_With(self, with_node): + lno = with_node.lineno + if len(with_node.items) != 1: + raise SemError(f"General case `with` is not supported! lineno:{lno}") + + with_item = with_node.items[0] + if with_item.optional_vars: + raise SemError(f"General case `with` is not supported! lineno:{lno}") + + if self.visit(with_item.context_expr) != GraphLocalScopeAttrs: + raise SemError(f"General case `with` is not supported! lineno:{lno}") + + @contextmanager + def managed_graph_local_scope(): + self.graph_local_scope += 1 + yield + self.graph_local_scope -= 1 + + with managed_graph_local_scope(): + for stmt in with_node.body: + self.visit(stmt) + + def visit_Return(self, return_node): + return_attrs = self.visit(return_node.value) if return_node.value else NoneAttrs + if not is_consistent_with(return_attrs, self.return_attrs): + raise SemError(f"Return type does not match! lineno:{return_node.lineno}") + + def visit_ClassDef(self, classdef): + lno = classdef.lineno + + if self.self_attrs or self.return_attrs: + raise SemError(f"Nested classes are not supported! lineno:{lno}") + + self_attrs = self.classes[ClassType(classdef.name)] + + # verify the presence of mandatory functions + if "__init__" not in self_attrs or "reset_parameters" not in self_attrs or "forward" not in self_attrs: + raise SemError(f"Must define __init__, reset_parameters, and forward! lineno:{lno}") + reset_func = self_attrs["reset_parameters"] + forward_func = self_attrs["forward"] + + # verify the signature of the reset_parameters function + if len(reset_func.args) != 0 or not is_consistent_with(reset_func.return_attrs, NoneAttrs): + raise SemError(f"Invalid signature of reset_parameters! lineno:{lno}") + + # verify the signature of the forward function + if ( + len(forward_func.args) != 2 + or not is_consistent_with(forward_func.args[0].attrs, DENSEGraphAttrs) + or not is_consistent_with(forward_func.args[1].attrs, TensorAttrs) + or not is_consistent_with(forward_func.return_attrs, TensorAttrs) + ): + raise SemError(f"Invalid signature of forward! lineno:{lno}") + + @contextmanager + def managed_self_attrs(): + self.self_attrs = self_attrs + yield + self.self_attrs = None + + @contextmanager + def managed_init(): + self.init_mode = True + yield + self.init_mode = False + + with managed_self_attrs(): + # Pass1: process __init__ + with managed_init(): + for child in classdef.body: + if child.name == "__init__": + self.visit(child) + + # Pass2: process other nodes + for child in classdef.body: + if child.name != "__init__": + self.visit(child) + + def generic_visit(self, node): + raise RuntimeError(f"Internal error!\n{astpp.dump(node)}") diff --git a/src/python/tools/mpic/utils.py b/src/python/tools/mpic/utils.py new file mode 100644 index 00000000..34c823c9 --- /dev/null +++ b/src/python/tools/mpic/utils.py @@ -0,0 +1,54 @@ +import re +from dataclasses import dataclass + +ClassType = str +Attrs = dict + + +@dataclass +class Arg: + name: str + attrs: Attrs + + +@dataclass +class Callable: + args: list[Arg] + return_attrs: Attrs + + +# TODO: Should we inherit from TypeError? +# TODO: Show context of compiler error instead of stack trace! +class CompileError(Exception): + """ + Base class for errors in marius script + """ + + def __init__(self, msg): + self.msg = msg + + def __str__(self): + return self.msg + + +class SynError(CompileError): + """ + Syntax errors for scripts that do not follow the Grammar + """ + + pass + + +class SemError(CompileError): + """ + Semantic errors for scripts that fail type checking + """ + + pass + + +def camel_to_snake(name): + """ + See https://stackoverflow.com/a/1176023/12160191 + """ + return re.sub(r"(? torch.Tensor: + return h diff --git a/test/mpic/errors/disable_assert.py b/test/mpic/errors/disable_assert.py new file mode 100644 index 00000000..5a3fae13 --- /dev/null +++ b/test/mpic/errors/disable_assert.py @@ -0,0 +1,10 @@ +class BasicLayer(mpi.Module): + def __init__(self, input_dim: int, output_dim: int): + assert True + self.reset_parameters() + + def reset_parameters(self): + pass + + def forward(self, graph: mpi.DENSEGraph, h: mpi.Tensor) -> mpi.Tensor: + return h diff --git a/test/mpic/errors/disable_format.py b/test/mpic/errors/disable_format.py new file mode 100644 index 00000000..9e82f258 --- /dev/null +++ b/test/mpic/errors/disable_format.py @@ -0,0 +1,10 @@ +class BasicLayer(mpi.Module): + def __init__(self, input_dim: int, output_dim: int): + f"""""" + self.reset_parameters() + + def reset_parameters(self): + pass + + def forward(self, graph: mpi.DENSEGraph, h: mpi.Tensor) -> mpi.Tensor: + return h diff --git a/test/mpic/errors/disable_generators.py b/test/mpic/errors/disable_generators.py new file mode 100644 index 00000000..45980204 --- /dev/null +++ b/test/mpic/errors/disable_generators.py @@ -0,0 +1,12 @@ +class BasicLayer(mpi.Module): + def __init__(self, input_dim: int, output_dim: int): + self.reset_parameters() + + def reset_parameters(self): + pass + + def forward(self, graph: mpi.DENSEGraph, h: mpi.Tensor) -> mpi.Tensor: + return h + + def standalone(self): + yield diff --git a/test/mpic/errors/disable_globals.py b/test/mpic/errors/disable_globals.py new file mode 100644 index 00000000..7aa8eafe --- /dev/null +++ b/test/mpic/errors/disable_globals.py @@ -0,0 +1,12 @@ +name = "Hello World!" + + +class ErroneousLayer(nn.Module): + def __init__(self, input_dim: int, output_dim: int): + pass + + def reset_parameters(self): + pass + + def forward(self, graph: mpi.DENSEGraph, h: torch.Tensor) -> torch.Tensor: + return h diff --git a/test/mpic/errors/disable_imports.py b/test/mpic/errors/disable_imports.py new file mode 100644 index 00000000..f68a42b5 --- /dev/null +++ b/test/mpic/errors/disable_imports.py @@ -0,0 +1,12 @@ +from torch import nn + + +class ErroneousLayer(nn.Module): + def __init__(self, input_dim: int, output_dim: int): + pass + + def reset_parameters(self): + pass + + def forward(self, graph: mpi.DENSEGraph, h: torch.Tensor) -> torch.Tensor: + return h diff --git a/test/mpic/errors/disable_lambdas.py b/test/mpic/errors/disable_lambdas.py new file mode 100644 index 00000000..d9b339cd --- /dev/null +++ b/test/mpic/errors/disable_lambdas.py @@ -0,0 +1,10 @@ +class BasicLayer(mpi.Module): + def __init__(self, input_dim: int, output_dim: int): + lambda x: x + self.reset_parameters() + + def reset_parameters(self): + pass + + def forward(self, graph: mpi.DENSEGraph, h: mpi.Tensor) -> mpi.Tensor: + return h diff --git a/test/mpic/errors/disable_multiassign.py b/test/mpic/errors/disable_multiassign.py new file mode 100644 index 00000000..ecd2148a --- /dev/null +++ b/test/mpic/errors/disable_multiassign.py @@ -0,0 +1,9 @@ +class ErroneousLayer(nn.Module): + def __init__(self, input_dim: int, output_dim: int): + output_dim, input_dim = input_dim, output_dim + + def reset_parameters(self): + pass + + def forward(self, graph: mpi.DENSEGraph, h: torch.Tensor) -> torch.Tensor: + return h diff --git a/test/mpic/errors/disable_nesting.py b/test/mpic/errors/disable_nesting.py new file mode 100644 index 00000000..ddecafdf --- /dev/null +++ b/test/mpic/errors/disable_nesting.py @@ -0,0 +1,12 @@ +class ErroneousLayer(nn.Module): + def __init__(self, input_dim: int, output_dim: int): + def init(): + pass + + init() + + def reset_parameters(self): + pass + + def forward(self, graph: mpi.DENSEGraph, h: torch.Tensor) -> torch.Tensor: + return h diff --git a/test/mpic/errors/disable_non_layer_classes.py b/test/mpic/errors/disable_non_layer_classes.py new file mode 100644 index 00000000..0691fa86 --- /dev/null +++ b/test/mpic/errors/disable_non_layer_classes.py @@ -0,0 +1,12 @@ +class ErroneousClass: + def __init__(self, input_dim: int, output_dim: int): + def init(): + pass + + init() + + def reset_parameters(self): + pass + + def forward(self, graph: mpi.DENSEGraph, h: torch.Tensor) -> torch.Tensor: + return h diff --git a/test/mpic/errors/disable_print.py b/test/mpic/errors/disable_print.py new file mode 100644 index 00000000..fc70ef3e --- /dev/null +++ b/test/mpic/errors/disable_print.py @@ -0,0 +1,10 @@ +class BasicLayer(mpi.Module): + def __init__(self, input_dim: int, output_dim: int): + print() + self.reset_parameters() + + def reset_parameters(self): + pass + + def forward(self, graph: mpi.DENSEGraph, h: mpi.Tensor) -> mpi.Tensor: + return h diff --git a/test/mpic/errors/duplicate_class.py b/test/mpic/errors/duplicate_class.py new file mode 100644 index 00000000..547ddb43 --- /dev/null +++ b/test/mpic/errors/duplicate_class.py @@ -0,0 +1,20 @@ +class DuplicateModule(nn.Module): + def __init__(self, input_dim: int, output_dim: int): + pass + + def reset_parameters(self): + pass + + def forward(self, graph: mpi.DENSEGraph, h: torch.Tensor) -> torch.Tensor: + return h + + +class DuplicateModule(nn.Module): + def __init__(self, input_dim: int, output_dim: int): + pass + + def reset_parameters(self): + pass + + def forward(self, graph: mpi.DENSEGraph, h: torch.Tensor) -> torch.Tensor: + return h diff --git a/test/mpic/errors/duplicate_fn.py b/test/mpic/errors/duplicate_fn.py new file mode 100644 index 00000000..ddf9fc5b --- /dev/null +++ b/test/mpic/errors/duplicate_fn.py @@ -0,0 +1,12 @@ +class DuplicateModule(nn.Module): + def __init__(self, input_dim: int, output_dim: int): + pass + + def reset_parameters(self): + pass + + def reset_parameters(self): + pass + + def forward(self, graph: mpi.DENSEGraph, h: torch.Tensor) -> torch.Tensor: + return h diff --git a/test/mpic/errors/invalid_return.py b/test/mpic/errors/invalid_return.py new file mode 100644 index 00000000..5f0fd481 --- /dev/null +++ b/test/mpic/errors/invalid_return.py @@ -0,0 +1,9 @@ +class DuplicateModule(nn.Module): + def __init__(self, input_dim: int, output_dim: int): + pass + + def reset_parameters(self): + pass + + def forward(self, graph: mpi.DENSEGraph, h: torch.Tensor): + pass diff --git a/test/mpic/errors/require_argtypes.py b/test/mpic/errors/require_argtypes.py new file mode 100644 index 00000000..2dd7c497 --- /dev/null +++ b/test/mpic/errors/require_argtypes.py @@ -0,0 +1,9 @@ +class ErroneousLayer(nn.Module): + def __init__(self, input_dim, output_dim): + pass + + def reset_parameters(self): + pass + + def forward(self, graph, h): + return h diff --git a/test/mpic/examples/basic_layer.py b/test/mpic/examples/basic_layer.py new file mode 100644 index 00000000..1bcbe0e7 --- /dev/null +++ b/test/mpic/examples/basic_layer.py @@ -0,0 +1,21 @@ +""" +See https://docs.dgl.ai/tutorials/blitz/3_message_passing.html +""" + + +class BasicLayer(mpi.Module): + def __init__(self, input_dim: int, output_dim: int): + super(mpi.Module, self).__init__(input_dim, output_dim) + self.linear = mpi.Linear(input_dim * 2, output_dim) + self.reset_parameters() + + def reset_parameters(self): + self.linear.reset_parameters() + + def forward(self, graph: mpi.DENSEGraph, h: mpi.Tensor) -> mpi.Tensor: + with graph.local_scope(): + graph.ndata["h"] = h + graph.update_all(message_func=mpi.copy_u("h", "m"), reduce_func=mpi.mean("m", "h_N")) + h_N = graph.ndata["h_N"] + h_total = mpi.cat(h, h_N, dim=1) + return self.linear(h_total) diff --git a/test/mpic/test_codegen.py b/test/mpic/test_codegen.py new file mode 100644 index 00000000..ac5ba1d9 --- /dev/null +++ b/test/mpic/test_codegen.py @@ -0,0 +1,20 @@ +import os +import unittest +from pathlib import Path + +from marius.tools.mpic.compiler import run_compiler + +examples_dir = Path(Path(__file__).parent, "examples").resolve() +TEST_GEN_DIR = Path(os.getcwd(), "build/mpic_gen").resolve() + + +def run_codegen_test(filename): + run_compiler(os.path.join(examples_dir, filename)) + return os.path.exists(os.path.join(TEST_GEN_DIR, Path(filename).with_suffix(".h"))) and os.path.exists( + os.path.join(TEST_GEN_DIR, Path(filename).with_suffix(".cpp")) + ) + + +class TestCodeGen(unittest.TestCase): + def test_basic(self): + self.assertTrue(run_codegen_test("basic_layer.py")) diff --git a/test/mpic/test_error_handling.py b/test/mpic/test_error_handling.py new file mode 100644 index 00000000..8de61f0c --- /dev/null +++ b/test/mpic/test_error_handling.py @@ -0,0 +1,84 @@ +import os +import unittest +from pathlib import Path + +import pytest + +from marius.tools.mpic.compiler import run_compiler +from marius.tools.mpic.utils import CompileError + +errors_dir = Path(Path(__file__).parent, "errors").resolve() + + +def run_error_test(filename): + run_compiler(os.path.join(errors_dir, filename)) + + +class TestModuleErrors(unittest.TestCase): + def test_disable_imports(self): + with self.assertRaises(CompileError): + run_error_test("disable_imports.py") + + def test_disable_globals(self): + with self.assertRaises(CompileError): + run_error_test("disable_globals.py") + + def test_disable_non_layer_classes(self): + with self.assertRaises(CompileError): + run_error_test("disable_non_layer_classes.py") + + def test_duplicate_classes(self): + with self.assertRaises(CompileError): + run_error_test("duplicate_class.py") + + +class TestClassErrors(unittest.TestCase): + def test_duplciate_fn(self): + with self.assertRaises(CompileError): + run_error_test("duplicate_fn.py") + + +class TestFunctionErrors(unittest.TestCase): + def test_disable_nesting(self): + with self.assertRaises(CompileError): + run_error_test("disable_nesting.py") + + def test_require_argtypes(self): + with self.assertRaises(CompileError): + run_error_test("require_argtypes.py") + + def test_invalid_return(self): + with self.assertRaises(CompileError): + run_error_test("invalid_return.py") + + def test_disable_lambdas(self): + with self.assertRaises(CompileError): + run_error_test("disable_lambdas.py") + + def test_disable_generators(self): + with self.assertRaises(CompileError): + run_error_test("disable_generators.py") + + +class TestExpressionErrors(unittest.TestCase): + def test_multiassign(self): + with self.assertRaises(CompileError): + run_error_test("disable_multiassign.py") + + def test_call_invalid_args(self): + with self.assertRaises(CompileError): + run_error_test("call_invalid_args.py") + + +class TestStatementErrors(unittest.TestCase): + def test_disable_format(self): + with self.assertRaises(CompileError): + run_error_test("disable_format.py") + + def test_disable_print(self): + with self.assertRaises(CompileError): + run_error_test("disable_print.py") + + def test_disable_assert(self): + with self.assertRaises(CompileError): + run_error_test("disable_assert.py") diff --git a/tox.ini b/tox.ini index 751fe5f4..300d3210 100644 --- a/tox.ini +++ b/tox.ini @@ -83,3 +83,16 @@ whitelist_externals = /usr/bin/clang-format /usr/local/bin/clang-format /usr/local/bin/bash + +[testenv:test_mpic] +deps = + pytest +commands = + bash -ec 'pip install .[mpic,tests]' + bash -ec 'pytest -s test/mpic/test_error_handling.py -x' + bash -ec 'pytest -s test/mpic/test_codegen.py -x' +setenv = + MARIUS_NO_BINDINGS = 1 +allowlist_externals = + /bin/bash + /usr/bin/bash