-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
4fcca0f
commit e844a32
Showing
4 changed files
with
156 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
#!/usr/bin/python3 | ||
|
||
# Desc: Runs the Mish benchmark using Pytorch with JIT | ||
# Usage: python3 pytorch_jit.py <Repetitions> <Test Output (T/F)> | ||
|
||
import sys | ||
import numpy as np | ||
import torch | ||
from torch import nn | ||
import time | ||
import torch_mlir | ||
|
||
from torch_mlir_e2e_test.mhlo_backends.linalg_on_tensors import LinalgOnTensorsMhloBackend | ||
|
||
if len(sys.argv) != 3: | ||
print("PyTorch(JIT) Mish Benchmarking Tool") | ||
print("Arguments:") | ||
print(" Repetitions: How many times to run the benchmark") | ||
print(" Test Output (T/F): If 'T', tests the output against Torch-MLIR") | ||
exit(1) | ||
|
||
repetitions = int(sys.argv[1]) | ||
test_output = sys.argv[2] == 'T' | ||
|
||
|
||
# Load model | ||
class Mish(nn.Module): | ||
|
||
def __init__(self): | ||
super().__init__() | ||
|
||
def forward(self, x): | ||
x = torch.log(1 + torch.exp(x)) | ||
return x | ||
|
||
|
||
data = torch.rand(8, 32, 224, 224) | ||
model = torch.jit.trace(Mish(), data).to(torch.device('cpu')) | ||
model.eval() | ||
|
||
# Benchmark | ||
for i in range(repetitions): | ||
data = torch.rand(8, 32, 224, 224) | ||
start = time.time() | ||
model.forward(data) | ||
runtime = time.time() - start | ||
print(runtime * 1000) | ||
|
||
# Test output | ||
if test_output: | ||
data = torch.zeros(8, 32, 224, 224) | ||
|
||
for i in range(8): | ||
for j in range(32): | ||
for k in range(224): | ||
for l in range(224): | ||
data[i, j, k, l] = (i + j + k + l) / (8 + 32 + 224 + 224) | ||
|
||
prediction_pytorch = model.forward(data).numpy() | ||
|
||
# Generate MHLO | ||
backend = LinalgOnTensorsMhloBackend() | ||
|
||
model_mlir = nn.Sequential(Mish()).to(torch.device('cpu')) | ||
model_mlir.eval() | ||
module = torch_mlir.compile(model_mlir, | ||
data, | ||
output_type=torch_mlir.OutputType.MHLO) | ||
|
||
compiled = backend.compile(module) | ||
jit_module = backend.load(compiled) | ||
|
||
prediction_torch_mlir = jit_module.forward(data.numpy()) | ||
|
||
# Compare | ||
if not np.allclose( | ||
prediction_pytorch, prediction_torch_mlir, rtol=1e-5, atol=1e-8): | ||
exit(1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
#!/bin/bash | ||
|
||
# Desc: Runs a pytorch benchmark using Pytorch with JIT. The output contains any | ||
# intermediate results and the times in the CSV format | ||
# Usage: ./pytorch_jit.sh <Benchmark File> <Output Dir> <Repetitions> | ||
|
||
# Be safe | ||
set -e # Fail script when subcommand fails | ||
set -u # Disallow using undefined variables | ||
set -o pipefail # Prevent errors from being masked | ||
|
||
# Check args | ||
if [ $# -ne 3 ]; then | ||
echo "Usage: ./pytorch_jit.sh <Benchmark File> <Output Dir> <Repetitions>" | ||
exit 1 | ||
fi | ||
|
||
# Read args | ||
input_file=$1 | ||
output_dir=$2 | ||
repetitions=$3 | ||
|
||
# Check tools | ||
check_tool() { | ||
if ! command -v "$1" &>/dev/null; then | ||
echo "$1 could not be found" | ||
exit 1 | ||
fi | ||
} | ||
|
||
check_tool python3 | ||
|
||
# Create output directory | ||
if [ ! -d "$output_dir" ]; then | ||
mkdir -p "$output_dir" | ||
fi | ||
|
||
# Silence Python warnings | ||
export PYTHONWARNINGS="ignore" | ||
|
||
# Helpers | ||
input_dir=$(dirname "$input_file") | ||
input_name=$(basename "$input_dir") | ||
input_file=$input_dir/pytorch_jit.py | ||
timings_file=$output_dir/${input_name}_timings.csv | ||
touch "$timings_file" | ||
|
||
# Adds a value to the timings file, jumps to the next row after a write | ||
csv_line=1 | ||
add_csv() { | ||
while [[ $(grep -c ^ "$timings_file") -lt $csv_line ]]; do | ||
echo '' >>"$timings_file" | ||
done | ||
|
||
if [ -n "$(sed "${csv_line}q;d" "$timings_file")" ]; then | ||
sed -i "${csv_line}s/$/,/" "$timings_file" | ||
fi | ||
|
||
sed -i "${csv_line}s/$/$1/" "$timings_file" | ||
csv_line=$((csv_line + 1)) | ||
} | ||
|
||
# Check output | ||
if ! python3 "$input_file" 0 T; then | ||
echo "Output incorrect!" | ||
exit 1 | ||
fi | ||
|
||
# Running the benchmark | ||
runtimes=$(OMP_NUM_THREADS=1 taskset -c 0 python3 "$input_file" "$repetitions" F) | ||
|
||
add_csv "PyTorch (JIT)" | ||
|
||
for i in $runtimes; do | ||
add_csv "$i" | ||
done |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters