Skip to content

Commit

Permalink
update plan
Browse files Browse the repository at this point in the history
  • Loading branch information
0x00-pl committed Nov 27, 2024
1 parent e0867bf commit 16edd59
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 5 deletions.
22 changes: 17 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,33 @@
开发计划
-------

- [x] 引入pytorch
- [x] 接入torch compile
### 项目结构
- [x] CI script.
- ![CI status](https://github.com/0x00-pl/plai/actions/workflows/ci.yml/badge.svg?branch=master)
- [x] 定义新的graph/node格式
- [x] node定义中添加namespace信息

### 接口相关
- [x] 引入pytorch
- [x] 接入torch compile
- [x] 从函数地址解析函数
- [x] 在训练时也调用自定义的compiler
- [x] 解析出计算图

### 算子定义相关
- [x] 定义新的graph/node格式
- [x] node定义中添加namespace信息
- [x] 添加torch和aten的namespace, 覆盖简单模型.
- [ ] 添加numpy的namespace, 覆盖简单模型, 用于实现运行时.
- [ ] node中添加location信息
- [ ] 添加多输出的支持
- [ ] 添加sub-graph的支持

### 运行时相关
- [ ] 添加numpy的namespace, 覆盖简单模型, 用于实现运行时.

### 其他
- [ ] 编译简单的四则运算
- [ ] 简单的四则运算运行时


commands
--------

Expand Down
2 changes: 2 additions & 0 deletions plai/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,5 @@ def __str__(self):
result += f' {idx}: {node_name_dict[node]} = {node.to_string(node_name_dict)}\n'
result += f' output ({", ".join(node_name_dict[i] for i in self.outputs)})\n'
return result


22 changes: 22 additions & 0 deletions plai/dialect/numpy_dialect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from plai.core import module
from plai.core.location import Location


class NumpyNode(module.Node):
@staticmethod
def build(op_name: str, args: list, attrs: dict, loc: Location = None):
raise ValueError('this is a dialect, should not using Build.')

@classmethod
def get_namespace(cls):
return 'torch.nn'


class Relu(NumpyNode):
def __init__(self, arg: module.Node, loc: Location = None):
super().__init__([arg], {}, loc)

@staticmethod
def build(op_name: str, args: list, attrs: dict, loc: Location = None):
assert op_name == 'relu'
return Relu(args[0], loc)

0 comments on commit 16edd59

Please sign in to comment.