From e844a32b1fade8a64e6ccf229a5d307f85b5a9db Mon Sep 17 00:00:00 2001 From: Berke-Ates Date: Mon, 9 Jan 2023 15:32:59 +0100 Subject: [PATCH] added pytorch with jit --- Dockerfile | 3 - benchmarks/pytorch/mish/pytorch_jit.py | 78 ++++++++++++++++++++++++++ scripts/pytorch/pytorch_jit.sh | 76 +++++++++++++++++++++++++ scripts/pytorch/run_all.sh | 3 +- 4 files changed, 156 insertions(+), 4 deletions(-) create mode 100644 benchmarks/pytorch/mish/pytorch_jit.py create mode 100755 scripts/pytorch/pytorch_jit.sh diff --git a/Dockerfile b/Dockerfile index 82a54f1..304c332 100644 --- a/Dockerfile +++ b/Dockerfile @@ -45,9 +45,6 @@ RUN cmake -G Ninja ../llvm \ -DLLVM_INSTALL_UTILS=ON && \ ninja -# Add binaries to PATH -# ENV PATH=$HOME/llvm-dcir/usr/local/bin:$PATH - # Build mlir-dace WORKDIR $HOME/mlir-dace/build diff --git a/benchmarks/pytorch/mish/pytorch_jit.py b/benchmarks/pytorch/mish/pytorch_jit.py new file mode 100644 index 0000000..c58e529 --- /dev/null +++ b/benchmarks/pytorch/mish/pytorch_jit.py @@ -0,0 +1,78 @@ +#!/usr/bin/python3 + +# Desc: Runs the Mish benchmark using Pytorch with JIT +# Usage: python3 pytorch_jit.py + +import sys +import numpy as np +import torch +from torch import nn +import time +import torch_mlir + +from torch_mlir_e2e_test.mhlo_backends.linalg_on_tensors import LinalgOnTensorsMhloBackend + +if len(sys.argv) != 3: + print("PyTorch(JIT) Mish Benchmarking Tool") + print("Arguments:") + print(" Repetitions: How many times to run the benchmark") + print(" Test Output (T/F): If 'T', tests the output against Torch-MLIR") + exit(1) + +repetitions = int(sys.argv[1]) +test_output = sys.argv[2] == 'T' + + +# Load model +class Mish(nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, x): + x = torch.log(1 + torch.exp(x)) + return x + + +data = torch.rand(8, 32, 224, 224) +model = torch.jit.trace(Mish(), data).to(torch.device('cpu')) +model.eval() + +# Benchmark +for i in range(repetitions): + data = torch.rand(8, 32, 224, 224) + start = time.time() + model.forward(data) + runtime = time.time() - start + print(runtime * 1000) + +# Test output +if test_output: + data = torch.zeros(8, 32, 224, 224) + + for i in range(8): + for j in range(32): + for k in range(224): + for l in range(224): + data[i, j, k, l] = (i + j + k + l) / (8 + 32 + 224 + 224) + + prediction_pytorch = model.forward(data).numpy() + + # Generate MHLO + backend = LinalgOnTensorsMhloBackend() + + model_mlir = nn.Sequential(Mish()).to(torch.device('cpu')) + model_mlir.eval() + module = torch_mlir.compile(model_mlir, + data, + output_type=torch_mlir.OutputType.MHLO) + + compiled = backend.compile(module) + jit_module = backend.load(compiled) + + prediction_torch_mlir = jit_module.forward(data.numpy()) + + # Compare + if not np.allclose( + prediction_pytorch, prediction_torch_mlir, rtol=1e-5, atol=1e-8): + exit(1) diff --git a/scripts/pytorch/pytorch_jit.sh b/scripts/pytorch/pytorch_jit.sh new file mode 100755 index 0000000..5008ceb --- /dev/null +++ b/scripts/pytorch/pytorch_jit.sh @@ -0,0 +1,76 @@ +#!/bin/bash + +# Desc: Runs a pytorch benchmark using Pytorch with JIT. The output contains any +# intermediate results and the times in the CSV format +# Usage: ./pytorch_jit.sh + +# Be safe +set -e # Fail script when subcommand fails +set -u # Disallow using undefined variables +set -o pipefail # Prevent errors from being masked + +# Check args +if [ $# -ne 3 ]; then + echo "Usage: ./pytorch_jit.sh " + exit 1 +fi + +# Read args +input_file=$1 +output_dir=$2 +repetitions=$3 + +# Check tools +check_tool() { + if ! command -v "$1" &>/dev/null; then + echo "$1 could not be found" + exit 1 + fi +} + +check_tool python3 + +# Create output directory +if [ ! -d "$output_dir" ]; then + mkdir -p "$output_dir" +fi + +# Silence Python warnings +export PYTHONWARNINGS="ignore" + +# Helpers +input_dir=$(dirname "$input_file") +input_name=$(basename "$input_dir") +input_file=$input_dir/pytorch_jit.py +timings_file=$output_dir/${input_name}_timings.csv +touch "$timings_file" + +# Adds a value to the timings file, jumps to the next row after a write +csv_line=1 +add_csv() { + while [[ $(grep -c ^ "$timings_file") -lt $csv_line ]]; do + echo '' >>"$timings_file" + done + + if [ -n "$(sed "${csv_line}q;d" "$timings_file")" ]; then + sed -i "${csv_line}s/$/,/" "$timings_file" + fi + + sed -i "${csv_line}s/$/$1/" "$timings_file" + csv_line=$((csv_line + 1)) +} + +# Check output +if ! python3 "$input_file" 0 T; then + echo "Output incorrect!" + exit 1 +fi + +# Running the benchmark +runtimes=$(OMP_NUM_THREADS=1 taskset -c 0 python3 "$input_file" "$repetitions" F) + +add_csv "PyTorch (JIT)" + +for i in $runtimes; do + add_csv "$i" +done diff --git a/scripts/pytorch/run_all.sh b/scripts/pytorch/run_all.sh index ae47d86..37b694f 100755 --- a/scripts/pytorch/run_all.sh +++ b/scripts/pytorch/run_all.sh @@ -37,7 +37,8 @@ benchmarks_dir=$(dirname "$0")/../../benchmarks/pytorch benchmarks=$(find "$benchmarks_dir"/* -name 'pytorch.py') total=$(echo "$benchmarks" | wc -l) -runners="$runners_dir/pytorch.sh $runners_dir/torch-mlir.sh $runners_dir/dcir.sh" +runners="$runners_dir/pytorch.sh $runners_dir/pytorch_jit.sh \ + $runners_dir/torch-mlir.sh $runners_dir/dcir.sh" for runner in $runners; do count=0