Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Marius Script Compiler #127

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,15 @@ ignore =
W503
# whitespace before ':'
E203
# function name should be lowercase
N802
per-file-ignores:
src/python/tools/mpic/features.py:N802
src/python/tools/mpic/typechecker.py:N802
src/python/tools/mpic/attrs.py:N802
src/python/tools/mpic/codegen.py:N802
src/python/tools/mpic/globalpass.py:N802
src/python/tools/mpic/layerpass.py:N802
exclude =
.tox
.git,
Expand All @@ -18,6 +27,9 @@ exclude =
*.pyc,
*third_party*,
scripts
test/mpic/errors
test/mpic/examples
examples/mpic
max-line-length = 120
max-complexity = 25
import-order-style = pycharm
Expand Down
27 changes: 27 additions & 0 deletions .github/workflows/mpic_test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
name: Testing Marius Script Compiler
on:
push:
branches:
- main
pull_request:
branches:
- main

jobs:
mpic:
runs-on: ubuntu-latest
container: ${{ matrix.python_container }}
strategy:
matrix:
python_container: ["python:3.10"]

steps:
# Downloads a copy of the code in your repository before running CI tests
- name: Check out repository code
uses: actions/checkout@v3

- name: Installing dependencies
run: MARIUS_NO_BINDINGS=1 python3 -m pip install .[tests,mpic]

- name: Running pytest
run: MARIUS_NO_BINDINGS=1 pytest -s test/mpic/test_codegen.py -s test/mpic/test_error_handling.py
21 changes: 21 additions & 0 deletions examples/mpic/basic_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""
See https://docs.dgl.ai/tutorials/blitz/3_message_passing.html
"""


class BasicLayer(mpi.Module):
def __init__(self, input_dim: int, output_dim: int):
super(mpi.Module, self).__init__(input_dim, output_dim)
self.linear = mpi.Linear(input_dim * 2, output_dim)
self.reset_parameters()

def reset_parameters(self):
self.linear.reset_parameters()

def forward(self, graph: mpi.DENSEGraph, h: mpi.Tensor) -> mpi.Tensor:
with graph.local_scope():
graph.ndata["h"] = h
graph.update_all(message_func=mpi.copy_u("h", "m"), reduce_func=mpi.mean("m", "h_N"))
h_N = graph.ndata["h_N"]
h_total = mpi.cat(h, h_N, dim=1)
return self.linear(h_total)
5 changes: 4 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ docs =
db2graph =
psycopg2-binary
mysql-connector-python
mpic =
Jinja2

[options]
install_requires =
Expand Down Expand Up @@ -65,4 +67,5 @@ console_scripts =
marius_config_generator = marius.tools.marius_config_generator:main
marius_predict = marius.tools.marius_predict:main
marius_env_info = marius.distribution.marius_env_info:main
marius_db2graph = marius.tools.db2graph.marius_db2graph:main
marius_db2graph = marius.tools.db2graph.marius_db2graph:main
marius_mpic = marius.tools.mpic.marius_mpic:main
Empty file.
60 changes: 60 additions & 0 deletions src/python/tools/mpic/astpp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""
A pretty-printing dump function for the ast module. The code was copied from
the ast.dump function and modified slightly to pretty-print.

Alex Leone (acleone ~AT~ gmail.com), 2010-01-30
"""

from ast import AST, iter_fields, parse


def dump(node, annotate_fields=True, include_attributes=False, indent=" "):
"""
Return a formatted dump of the tree in *node*. This is mainly useful for
debugging purposes. The returned string will show the names and the values
for fields. This makes the code impossible to evaluate, so if evaluation is
wanted *annotate_fields* must be set to False. Attributes such as line
numbers and column offsets are not dumped by default. If this is wanted,
*include_attributes* can be set to True.
"""

def _format(node, level=0):
if isinstance(node, AST):
fields = [(a, _format(b, level)) for a, b in iter_fields(node)]
if include_attributes and node._attributes:
fields.extend([(a, _format(getattr(node, a), level)) for a in node._attributes])
return "".join(
[
node.__class__.__name__,
"(",
", ".join(("%s=%s" % field for field in fields) if annotate_fields else (b for a, b in fields)),
")",
]
)
elif isinstance(node, list):
lines = ["["]
lines.extend((indent * (level + 2) + _format(x, level + 2) + "," for x in node))
if len(lines) > 1:
lines.append(indent * (level + 1) + "]")
else:
lines[-1] += "]"
return "\n".join(lines)
return repr(node)

if not isinstance(node, AST):
raise TypeError("expected AST, got %r" % node.__class__.__name__)
return _format(node)


if __name__ == "__main__":
import sys

for filename in sys.argv[1:]:
print("=" * 50)
print("AST tree for", filename)
print("=" * 50)
f = open(filename, "r")
fstr = f.read()
f.close()
print(dump(parse(fstr, filename=filename), include_attributes=True))
print()
187 changes: 187 additions & 0 deletions src/python/tools/mpic/attrs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
import ast
from contextlib import contextmanager

from marius.tools.mpic.builtins import IntAttrs, NoneAttrs, get_builtin_classes
from marius.tools.mpic.symtab import SymbolTable
from marius.tools.mpic import astpp # noqa: F401
from marius.tools.mpic.utils import Arg, Attrs, Callable, ClassType, SemError, SynError


class GlobalAttrsPass(ast.NodeVisitor):
"""
Collect all class names in global scope
Ensures
- classes at the top level
"""

def __init__(self):
self.global_attrs = Attrs()
self.classes = get_builtin_classes()

def visit_ClassDef(self, classdef):
if classdef.name in self.global_attrs:
raise SemError(f"Class {classdef.name} is multiply defined! lineno:{classdef.lineno}")

class_type = ClassType(classdef.name)
self.global_attrs[classdef.name] = class_type
# XXX: Add input_dim and output_dim as instance variables
# TODO: add support for const int to prevent modification
self.classes[class_type] = Attrs({"_mpic_class": class_type, "input_dim": IntAttrs, "output_dim": IntAttrs})

def visit_Module(self, module):
print(__file__, 32)
for child in module.body:
if isinstance(child, ast.ClassDef):
self.visit(child)
elif not isinstance(child, ast.Expr):
# Allow comments (expressions as standalone statements)
raise SemError(f"Only mpi.Module classes allowed at the top level! lineno:{child.lineno}")

def generic_visit(self, node):
raise RuntimeError(f"Internal error!\n{astpp.dump(node)}")


class InstanceAttrsPass(ast.NodeVisitor):
"""
Generate a mapping from class names to the class attributes
The attribute map has instance variables and instance methods
Class variables and class methods are disallowed at the moment
Instance variables map to its attribute map
Instance methods map to its `Callable` object
Expects a symbol table populated with symbols in the builtin and global scopes
"""

def __init__(self, symbol_table: SymbolTable, classes: dict[ClassType, Attrs]):
self.symbol_table = symbol_table
self.classes = classes
self.instance_attrs = None

def visit_Name(self, name):
attrs = self.symbol_table.find_symbol(name.id)
if attrs is None:
raise SemError(f"Could not find symbol {name.id}! lineno:{name.lineno}")
return attrs

def visit_Attribute(self, attr):
value_attrs = self.visit(attr.value)
if attr.attr not in value_attrs:
raise SemError(f"Could not find symbol {attr.attr}! lineno:{attr.lineno}")
return value_attrs[attr.attr]

def visit_AnnAssign(self, assign):
"""
Collect instance variables
Ensures
- valid type annotations
- unique names
"""
lno = assign.lineno

if not isinstance(assign.target, ast.Name):
raise SynError(f"Definition is not supported! lineno:{lno}")

if not assign.annotation:
raise SynError(f"Require type annotation for instance variables! lineno:{lno}")

var_name = assign.target.id
var_attrs = self.classes[self.visit(assign.annotation)]

if var_name in self.instance_attrs:
raise SemError(f"Duplicate attr definition! lineno:{assign.lineno}")

self.instance_attrs[var_name] = var_attrs

def visit_FunctionDef(self, func):
"""
Collects instance functions
Ensures
- unique names
- valid type annotations for all arguments and return
- self first argument
Disables
- operator overriding
- decorator lists
- type comments
- position only arguments
- *args and **kwargs
"""
lno = func.lineno

if func.name in self.instance_attrs:
raise SemError(f"Duplicate attr definition! func: lineno:{lno}")

if (
func.name != "__init__"
and func.name != "__call__"
and func.name.startswith("__")
and func.name.endswith("__")
):
raise SynError(f"Operator overriding is not supported! lineno:{lno}")

if func.name == "reset":
raise SynError(f"Please use reset_parameters to reset torch params! lineno:{lno}")

if func.decorator_list:
raise SynError(f"Decorator lists are not supported! lineno:{lno}")

if func.type_comment:
raise SynError(f"Type comments are not supported! lineno:{lno}")

if func.args.vararg:
raise SynError(f"*args is not supported! lineno:{lno}")

if func.args.kwarg:
raise SynError(f"**kwargs is not supported! func: lineno:{lno}")

if func.args.posonlyargs or func.args.kwonlyargs:
raise SynError(f"pos only or kw only args are not supported! lineno:{lno}")

if func.args.kw_defaults or func.args.defaults:
raise SynError(f"Defaults for arguments are not supported! lineno:{lno}")

if func.args.args[0].arg != "self":
raise SynError(f"Only instance methods are supported! lineno:{lno}")

args = []
for arg in func.args.args[1:]:
if not arg.annotation:
raise SynError(f"Type annotations are required! lineno:{lno}")
arg_attrs = self.classes[self.visit(arg.annotation)]
args.append(Arg(arg.arg, arg_attrs))

if func.returns:
return_attrs = self.classes[self.visit(func.returns)]
else:
return_attrs = NoneAttrs
self.instance_attrs[func.name] = Callable(args, return_attrs)

def visit_ClassDef(self, classdef):
"""
Populates class attributes
"""
if classdef.decorator_list:
raise SemError(f"Decorator lists are not supported! lineno:{classdef.lineno}")

if classdef.keywords:
raise SemError(f"Class arguments are not supported! lineno:{classdef.lineno}")

@contextmanager
def managed_instance_attrs():
self.instance_attrs = self.classes[ClassType(classdef.name)]
yield
self.instance_attrs = None

with managed_instance_attrs():
for child in classdef.body:
self.visit(child)

def visit_Module(self, module):
for child in module.body:
if isinstance(child, ast.ClassDef):
self.visit(child)
elif not isinstance(child, ast.Expr):
# Allow comments
raise SemError(f"Only mpi.Module classes allowed at the top level! lineno:{child.lineno}")

def generic_visit(self, node):
raise RuntimeError(f"Internal error!\n{astpp.dumps(node)}")
Loading