Skip to content

Commit

Permalink
added pytorch with jit
Browse files Browse the repository at this point in the history
  • Loading branch information
Berke-Ates committed Jan 9, 2023
1 parent 4fcca0f commit e844a32
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 4 deletions.
3 changes: 0 additions & 3 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,6 @@ RUN cmake -G Ninja ../llvm \
-DLLVM_INSTALL_UTILS=ON && \
ninja

# Add binaries to PATH
# ENV PATH=$HOME/llvm-dcir/usr/local/bin:$PATH

# Build mlir-dace
WORKDIR $HOME/mlir-dace/build

Expand Down
78 changes: 78 additions & 0 deletions benchmarks/pytorch/mish/pytorch_jit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
#!/usr/bin/python3

# Desc: Runs the Mish benchmark using Pytorch with JIT
# Usage: python3 pytorch_jit.py <Repetitions> <Test Output (T/F)>

import sys
import numpy as np
import torch
from torch import nn
import time
import torch_mlir

from torch_mlir_e2e_test.mhlo_backends.linalg_on_tensors import LinalgOnTensorsMhloBackend

if len(sys.argv) != 3:
print("PyTorch(JIT) Mish Benchmarking Tool")
print("Arguments:")
print(" Repetitions: How many times to run the benchmark")
print(" Test Output (T/F): If 'T', tests the output against Torch-MLIR")
exit(1)

repetitions = int(sys.argv[1])
test_output = sys.argv[2] == 'T'


# Load model
class Mish(nn.Module):

def __init__(self):
super().__init__()

def forward(self, x):
x = torch.log(1 + torch.exp(x))
return x


data = torch.rand(8, 32, 224, 224)
model = torch.jit.trace(Mish(), data).to(torch.device('cpu'))
model.eval()

# Benchmark
for i in range(repetitions):
data = torch.rand(8, 32, 224, 224)
start = time.time()
model.forward(data)
runtime = time.time() - start
print(runtime * 1000)

# Test output
if test_output:
data = torch.zeros(8, 32, 224, 224)

for i in range(8):
for j in range(32):
for k in range(224):
for l in range(224):
data[i, j, k, l] = (i + j + k + l) / (8 + 32 + 224 + 224)

prediction_pytorch = model.forward(data).numpy()

# Generate MHLO
backend = LinalgOnTensorsMhloBackend()

model_mlir = nn.Sequential(Mish()).to(torch.device('cpu'))
model_mlir.eval()
module = torch_mlir.compile(model_mlir,
data,
output_type=torch_mlir.OutputType.MHLO)

compiled = backend.compile(module)
jit_module = backend.load(compiled)

prediction_torch_mlir = jit_module.forward(data.numpy())

# Compare
if not np.allclose(
prediction_pytorch, prediction_torch_mlir, rtol=1e-5, atol=1e-8):
exit(1)
76 changes: 76 additions & 0 deletions scripts/pytorch/pytorch_jit.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
#!/bin/bash

# Desc: Runs a pytorch benchmark using Pytorch with JIT. The output contains any
# intermediate results and the times in the CSV format
# Usage: ./pytorch_jit.sh <Benchmark File> <Output Dir> <Repetitions>

# Be safe
set -e # Fail script when subcommand fails
set -u # Disallow using undefined variables
set -o pipefail # Prevent errors from being masked

# Check args
if [ $# -ne 3 ]; then
echo "Usage: ./pytorch_jit.sh <Benchmark File> <Output Dir> <Repetitions>"
exit 1
fi

# Read args
input_file=$1
output_dir=$2
repetitions=$3

# Check tools
check_tool() {
if ! command -v "$1" &>/dev/null; then
echo "$1 could not be found"
exit 1
fi
}

check_tool python3

# Create output directory
if [ ! -d "$output_dir" ]; then
mkdir -p "$output_dir"
fi

# Silence Python warnings
export PYTHONWARNINGS="ignore"

# Helpers
input_dir=$(dirname "$input_file")
input_name=$(basename "$input_dir")
input_file=$input_dir/pytorch_jit.py
timings_file=$output_dir/${input_name}_timings.csv
touch "$timings_file"

# Adds a value to the timings file, jumps to the next row after a write
csv_line=1
add_csv() {
while [[ $(grep -c ^ "$timings_file") -lt $csv_line ]]; do
echo '' >>"$timings_file"
done

if [ -n "$(sed "${csv_line}q;d" "$timings_file")" ]; then
sed -i "${csv_line}s/$/,/" "$timings_file"
fi

sed -i "${csv_line}s/$/$1/" "$timings_file"
csv_line=$((csv_line + 1))
}

# Check output
if ! python3 "$input_file" 0 T; then
echo "Output incorrect!"
exit 1
fi

# Running the benchmark
runtimes=$(OMP_NUM_THREADS=1 taskset -c 0 python3 "$input_file" "$repetitions" F)

add_csv "PyTorch (JIT)"

for i in $runtimes; do
add_csv "$i"
done
3 changes: 2 additions & 1 deletion scripts/pytorch/run_all.sh
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ benchmarks_dir=$(dirname "$0")/../../benchmarks/pytorch
benchmarks=$(find "$benchmarks_dir"/* -name 'pytorch.py')
total=$(echo "$benchmarks" | wc -l)

runners="$runners_dir/pytorch.sh $runners_dir/torch-mlir.sh $runners_dir/dcir.sh"
runners="$runners_dir/pytorch.sh $runners_dir/pytorch_jit.sh \
$runners_dir/torch-mlir.sh $runners_dir/dcir.sh"

for runner in $runners; do
count=0
Expand Down

0 comments on commit e844a32

Please sign in to comment.