Skip to content

Commit

Permalink
Update teaal to Python 3.12
Browse files Browse the repository at this point in the history
  • Loading branch information
nandeeka committed Sep 24, 2024
1 parent 600ef0a commit b5d5c53
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 38 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:

strategy:
matrix:
python-version: [3.8]
python-version: [3.12]

# Steps represent a sequence of tasks that will be executed as part of the job
steps:
Expand Down
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ name = "teaal"
description = "A compiler from a YAML description to HiFiber code"
version = "0.0.1"
readme = "README.md"
requires-python = ">=3.8"
requires-python = ">=3.12"
license = {file = "LICENSE"}
authors = [
{"name" = "Nandeeka Nayak", "email" = "[email protected]"}
Expand All @@ -17,8 +17,8 @@ dependencies = [
"pytest-cov",
"autopep8",
"mypy",
"lark-parser",
"ruamel-yaml",
"lark",
"ruamel-yaml<0.18.0",
"networkx",
"matplotlib",
"sympy"
Expand Down
10 changes: 0 additions & 10 deletions requirements.txt

This file was deleted.

7 changes: 6 additions & 1 deletion teaal/parse/equation.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class EquationParser:
?num: NUMBER -> pos
| "-" NUMBER -> neg
?output: NAME "[" ranks "]"
?output: NAME "[" ranks "]" -> output
?tensor: NAME "[" ranks "]"
Expand All @@ -47,6 +47,11 @@ class EquationParser:
def parse(equation: str) -> Tree:
tree = EquationParser.parser.parse(equation)

# Remove None due to empty ranks
for ranks in tree.find_data("ranks"):
if ranks.children == [None]:
ranks.children = []

# Parse both positive and negative numbers
for itimes in tree.find_data("itimes"):
num = itimes.children[0]
Expand Down
4 changes: 3 additions & 1 deletion teaal/trans/partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,9 @@ def unpartition(self, tensor: Tensor) -> Statement:
swizzled_ranks, part_ir.get_all_parts(), False)
for part in valid_parts:
trans.append((part, tensor.get_ranks()))
self.program.apply_partitioning(tensor, part)
tensor.update_ranks(
part_ir.partition_ranks(
tensor.get_ranks(), {part}, True, False))

new_ranks = tensor.get_ranks()

Expand Down
38 changes: 16 additions & 22 deletions tests/integration/demo.yaml
Original file line number Diff line number Diff line change
@@ -1,24 +1,18 @@
einsum:
declaration:
A: [K, M]
B: [K, N]
T: [K, M, N]
Z: [M, N]
expressions:
- T[k, m, n] = A[k, m] * B[k, n]
declaration:
A: [K, M]
B: [K, N]
Z: [M, N]
expressions:
- Z[m, n] = A[k, m] * B[k, n]
mapping:
rank-order:
A: [K, M]
B: [K, N]
T: [M, K, N]
Z: [M, N]
partitioning:
T:
K: [ uniform_shape(16) ]
(K0, M): [ flatten() ]
K0M: [ uniform_occupancy(A.64) ]
Z:
K: [ uniform_occupancy(T.64) ]
loop-order:
T: [K1, K0M1, K0M0, N]
Z: [M, K1, K0, N]
rank-order:
A: [K, M]
B: [K, N]
Z: [M, N]
partitioning:
Z:
M: [uniform_shape(M2), uniform_occupancy(A.M1), uniform_occupancy(A.M0)]
N: [uniform_shape(N2), uniform_occupancy(B.N1), uniform_occupancy(B.N0)]
loop-order:
Z: [M3, N3, K, M2, N2, M1, N1, M0, N0]

0 comments on commit b5d5c53

Please sign in to comment.