Skip to content

Commit

Permalink
[torch-frontend] update torch to 2.4.1
Browse files Browse the repository at this point in the history
  • Loading branch information
qingyunqu committed Oct 25, 2024
1 parent 8215fe2 commit 72821a0
Show file tree
Hide file tree
Showing 9 changed files with 59 additions and 314 deletions.
2 changes: 1 addition & 1 deletion frontends/torch-frontend/scripts/build_and_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ cmake --build ./build --target all
if [[ $TORCH_FRONTEND_TEST == "ON" ]]; then
python3 -m pip install -r test-requirements.txt
install_mhlo_tools
PYTHONPATH=build/python_packages/:build/torch_mlir_build/python_packages/torch_mlir TORCH_DISABLE_NATIVE_FUNCOL=1 python3 -m pytest torch-frontend/python/test
PYTHONPATH=build/python_packages/:build/torch_mlir_build/python_packages/torch_mlir TORCH_DISABLE_NATIVE_FUNCOL=1 python3 -m pytest -m "not attention_rewriter" torch-frontend/python/test
fi

popd
4 changes: 2 additions & 2 deletions frontends/torch-frontend/scripts/envsetup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ function prepare_for_build_with_prebuilt() {
pushd ${PROJ_DIR}
# install requirements
python3 -m pip install -r build-requirements.txt
# python3 -m pip install --no-cache-dir torch==2.1.0+cu118 torchvision==0.16.0+cu118 -f https://download.pytorch.org/whl/torch_stable.html

# initialize submodule
git submodule update --init -f $TORCH_MLIR_ROOT
Expand All @@ -49,10 +48,11 @@ function prepare_for_build() {
pushd ${PROJ_DIR}
# install requirements
python3 -m pip install -r build-requirements.txt
# python3 -m pip install --no-cache-dir torch==2.1.0+cu118 torchvision==0.16.0+cu118 -f https://download.pytorch.org/whl/torch_stable.html

# initialize submodule
git submodule update --init --recursive -f $TORCH_MLIR_ROOT

apply_patches
}

# python3 -m pip install --no-cache-dir torch==2.1.0+cu118 torchvision==0.16.0+cu118 -f https://download.pytorch.org/whl/torch_stable.html
4 changes: 2 additions & 2 deletions frontends/torch-frontend/torch-cpu-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
--extra-index-url https://download.pytorch.org/whl/cpu
--pre
torch==2.1.0+cpu
torchvision==0.16.0+cpu
torch==2.4.1+cpu
torchvision==0.19.1+cpu
4 changes: 2 additions & 2 deletions frontends/torch-frontend/torch-cuda-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
--extra-index-url https://download.pytorch.org/whl/cu118
--pre
torch==2.1.0+cu118
torchvision==0.16.0+cu118
torch==2.4.1+cu118
torchvision==0.19.1+cu118

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def forward(self, x):
def test_nonzero():
inputs = (torch.tensor([1, 0, 0, 1, 1]),)
prog = torch.export.export(AtenNonZeroModule(), inputs)
print(prog)
module = compile_dynamo_model(prog, "raw")
print(module.operation.get_asm())

Expand All @@ -38,9 +39,10 @@ def forward(self, x):

def test_view_dtype():
inputs = (torch.rand(4, 5),)
# note: torch==2.1.0's export has bug on view.dtype
prog = torch.export.export(ViewDtypeModule(), inputs)
print(prog)
module = compile_dynamo_model(prog, "raw") # note: torch2.1 export has bug on view.dtype
module = compile_dynamo_model(prog, "stablehlo")
print(module.operation.get_asm())

# ==============================================================================
Expand All @@ -60,5 +62,8 @@ def test_mlp():
print(mlir_str)
assert "dense_resource" not in mlir_str

# ==============================================================================

if __name__ == "__main__":
test_nonzero()
test_view_dtype()

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,51 @@ def AttnReplacement5(q, k, v, attn_mask, inv_scale):
)


# LLaMA aten attention op pattern
def LLaMAAttnPattern(query, key, value, attn_mask, min_val, inv_scale, batch, num_head, fused_batch, seq_len, head_dim):
transpose_3 = torch.ops.aten.transpose.int(key, 2, 3)
expand_2 = torch.ops.aten.expand.default(query, [batch, num_head, seq_len, head_dim])
clone = torch.ops.aten.clone.default(expand_2, memory_format = torch.contiguous_format)
_unsafe_view_3 = torch.ops.aten._unsafe_view.default(clone, [fused_batch, seq_len, head_dim])
expand_3 = torch.ops.aten.expand.default(transpose_3, [batch, num_head, head_dim, seq_len])
clone_1 = torch.ops.aten.clone.default(expand_3, memory_format = torch.contiguous_format)
_unsafe_view_4 = torch.ops.aten._unsafe_view.default(clone_1, [fused_batch, head_dim, seq_len])
bmm = torch.ops.aten.bmm.default(_unsafe_view_3, _unsafe_view_4)
_unsafe_view_5 = torch.ops.aten._unsafe_view.default(bmm, [batch, num_head, seq_len, seq_len])
div = torch.ops.aten.div.Tensor(_unsafe_view_5, inv_scale)
add_5 = torch.ops.aten.add.Tensor(div, attn_mask)
maximum = torch.ops.aten.maximum.default(add_5, min_val)
_softmax = torch.ops.aten._softmax.default(maximum, -1, False)
_to_copy_10 = torch.ops.aten._to_copy.default(_softmax, dtype = torch.float16)
expand_4 = torch.ops.aten.expand.default(_to_copy_10, [batch, num_head, seq_len, seq_len])
view_8 = torch.ops.aten.view.default(expand_4, [fused_batch, seq_len, seq_len]); expand_4 = None
expand_5 = torch.ops.aten.expand.default(value, [batch, num_head, seq_len, head_dim])
clone_2 = torch.ops.aten.clone.default(expand_5, memory_format = torch.contiguous_format)
_unsafe_view_6 = torch.ops.aten._unsafe_view.default(clone_2, [fused_batch, seq_len, head_dim])
bmm_1 = torch.ops.aten.bmm.default(view_8, _unsafe_view_6)
_unsafe_view_5 = torch.ops.aten._unsafe_view.default(bmm_1, [batch, num_head, seq_len, head_dim])
return _softmax, _unsafe_view_5


def LLaMAAttnReplacement(query, key, value, attn_mask, min_val, inv_scale, batch, num_head, fused_batch, seq_len, head_dim):
# q, k, v needs to be transposed for flash attn v2
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
out, q_pad, k_pad, v_pad, out_pad, softmax_lse, S_dmask, rng_state = torch.ops.byteir.flash_attn_fwd(
query,
key,
value,
0.0,
1.0/inv_scale,
True,
False
)
# output also needs to be transposed
out = out.transpose(1, 2)
return out, out


def canonicalize_graph_before_replacement(gm):
for n in gm.graph.nodes:
if n.op == "call_module":
Expand Down Expand Up @@ -243,4 +288,5 @@ def fx_replace_attn_pattern(gm: torch.fx.GraphModule):
torch.fx.replace_pattern(gm, AttnPattern3, AttnReplacement3)
torch.fx.replace_pattern(gm, AttnPattern4, AttnReplacement4)
torch.fx.replace_pattern(gm, AttnPattern5, AttnReplacement5)
torch.fx.replace_pattern(gm, LLaMAAttnPattern, LLaMAAttnReplacement)
return gm
Loading

0 comments on commit 72821a0

Please sign in to comment.