Skip to content

Commit

Permalink
Improve documentation and add paper bibtex
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 581523955
  • Loading branch information
samihaija authored and mangpo committed Nov 15, 2023
1 parent e6cf82f commit 60c794b
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 40 deletions.
107 changes: 67 additions & 40 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,17 @@

TpuGraphs is a performance prediction dataset on full tensor programs, represented as computational graphs, running on Tensor Processing Units (TPUs). Each graph in the dataset represents the main computation of a machine learning workload, e.g., a training epoch or an inference step. Each data sample contains a computational graph, a compilation configuration, and the execution time of the graph when compiled with the configuration. The graphs in the dataset are collected from open-source machine learning programs, featuring popular model architectures (e.g., ResNet, EfficientNet, Mask R-CNN, and Transformer).

Please refer to our [paper](https://arxiv.org/abs/2308.13490) for more details about the importance and challenges of the dataset, how the dataset is generated, the model baselines, and the experimental results. Please cite the paper when using this dataset.
Please refer to our [paper](https://arxiv.org/abs/2308.13490) for more details about the importance and challenges of the dataset, how the dataset is generated, the model baselines, and the experimental results. If you find this dataset useful in your research, please cite our paper as:

```
@inproceedings{tpugraphs,
title={TpuGraphs: A Performance Prediction Dataset on Large Tensor Computational Graphs},
author={Phitchaya Mangpo Phothilimthana and Sami Abu-El-Haija and Kaidi Cao and Bahare Fatemi and Michael Burrows and Charith Mendis and Bryan Perozzi},
booktitle={Thirty-seventh Conference on Neural Information Processing Systems Datasets and Benchmarks Track},
year={2023},
url={https://openreview.net/forum?id=plAix1NxhU}
}
```

*This is not an officially supported Google product.*

Expand All @@ -24,31 +34,48 @@ You can use `wget` or `curl` command to download files.
- {search}: `default` or `random`
- {split}: `train`, `valid`, or `test`

To download all files, you may run (from a clone of this directory):
To download all files, please follow **one** of these options:

```sh
python3 echo_download_commands.py | bash
```
1. Download http://download.tensorflow.org/data/tpu_graphs/v0/npz_all.tar, preferably, to `~/data/tpugraphs` (as our training pipelines read from there),
then untar as `tar xvf npz_all.tar`. This can be done with bash commands:

Removing the last pipe (`| bash`) shows the commands for downloading the dataset
(a few `curl` commands followed by `tar xvf`).
```
mkdir -p ~/data/tpugraphs
cd ~/data/tpugraphs
curl http://download.tensorflow.org/data/tpu_graphs/v0/npz_all.tar > npz_all.tar
tar xvf npz_all.tar
```

To copy data for a specific collection, e.g. the layout:xla:random collection, run:
2. Download from
[Kaggle](https://www.kaggle.com/competitions/predict-ai-model-runtime/data).

```sh
mkdir -p ~/data/tpugraphs
cd ~/data/tpugraphs
3. Use our helper script `echo_download_command.py`.
From a clone of this directory:

curl http://download.tensorflow.org/data/tpu_graphs/v0/npz_layout_xla_random_train.tar > npz_layout_xla_random_train.tar
curl http://download.tensorflow.org/data/tpu_graphs/v0/npz_layout_xla_random_valid.tar > npz_layout_xla_random_valid.tar
curl http://download.tensorflow.org/data/tpu_graphs/v0/npz_layout_xla_random_test.tar > npz_layout_xla_random_test.tar
tar xvf npz_layout_xla_random_train.tar
tar xvf npz_layout_xla_random_valid.tar
tar xvf npz_layout_xla_random_test.tar
```
```sh
python3 echo_download_commands.py | bash
```

Removing the last pipe (`| bash`) shows the commands for downloading the dataset
(a few `curl` commands followed by `tar xvf`).

To download {train, test, validation} data for layout collection, e.g.,
for `layout:xla:random`, run:

For a description of these files, please scroll towards the end of this page
("Dataset File Description").
```sh
mkdir -p ~/data/tpugraphs
cd ~/data/tpugraphs

curl http://download.tensorflow.org/data/tpu_graphs/v0/npz_layout_xla_random_train.tar > npz_layout_xla_random_train.tar
curl http://download.tensorflow.org/data/tpu_graphs/v0/npz_layout_xla_random_valid.tar > npz_layout_xla_random_valid.tar
curl http://download.tensorflow.org/data/tpu_graphs/v0/npz_layout_xla_random_test.tar > npz_layout_xla_random_test.tar
tar xvf npz_layout_xla_random_train.tar
tar xvf npz_layout_xla_random_valid.tar
tar xvf npz_layout_xla_random_test.tar
```

For a description of these files, you may scroll down to
["Dataset File Description"](#dataset-file-description).

## Running Baseline Models

Expand Down Expand Up @@ -136,11 +163,11 @@ To run the pipeline on Google Cloud, please follow [this instruction](https://cl


#### Evaluate model
Once the training is done, the training output directory specified with
`--out_dir` (~/out/tpugraphs_tiles by default) will contain a model directory,
whose name starts with the prefix `model_`.

To evaluate a model(s), run:
To evaluate models trained on the `tile` collection, look for the model
directory (the training pipeline defaults flag `--out_dir` to
`~/out/tpugraphs_tiles`) which should start with prefix `model_`.
To evaluate model(s), run:
```
python tiles_evaluate.py --dirs <comma-separated list of model dirs>
```
Expand All @@ -161,17 +188,8 @@ Currently, the evaluation script does not produce the ranking `.csv` file.

### Model on `layout:{xla|nlp}:{random|default}` collections

You may run the GST model (used in the paper), which is available at:
https://github.com/kaidic/GST.
The GST model is built on top of [GraphGPS](https://github.com/rampasek/GraphGPS)
framework, implemented using PyTorch. It can be trained either on a CPU or GPU.
The current code does not output the ranking `.csv` file.

We also provide another baseline for the layout collections
(implemented after the paper was written) in this repo. It is similar to
the GST model described in the paper, but implemented using TF-GNN and
running on a CPU.
You can train this baseline model by invoking:
We provide baseline for the layout collections in this repo.
You can train the layout baseline model by invoking:

```sh
# As a test.
Expand All @@ -190,11 +208,14 @@ python layout_train.py --source nlp --search random --epochs 10 --max_configs 10
python layout_train.py --source nlp --search default --epochs 10 --max_configs 1000
```

NOTE: For running the NLP models, since the data is large, our trainer script
cannot fit the data into memory. The flag `--max_configs 1000` allows us to run,
by sampling only this many configurations per graph. However, you may write your
own scalable implementation, or modify ours, or run
GST: https://github.com/kaidic/GST.
NOTE: For training on the NLP collections, since the data is large, our trainer
script cannot fit the data into memory. The flag `--max_configs 1000` allows us
to run. It samples only this many configurations per graph. However, you may
write your own scalable implementation, or modify ours.

For an alternative implementation (that uses PyTorch), you may view our
collaborators' Graph Segmented Training implementation (GST) at
https://github.com/kaidic/GST.


Each (complete) invocation of `python layout_train.py` should train the model,
Expand All @@ -217,6 +238,12 @@ the training pipelines (i.e. `~/out/tpugraphs_layout` for `layout_train.py`, and
`~/out/tpugraphs_tiles/` for `tiles_train.py`).


#### Evaluate model

To evaluate models on the validation set of layout collections, please refer to
[tpu_graphs/evals](https://github.com/google-research-datasets/tpu_graphs/tree/main/tpu_graphs/evals).


## Dataset File Description

### Tiles Collection `.npz` files
Expand Down
70 changes: 70 additions & 0 deletions tpu_graphs/evals/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Evaluation

This page contains guidelines on the evaluation procedure for TpuGraphs.

This page assumes that you have already [downloaded the dataset](https://github.com/google-research-datasets/tpu_graphs#dataset).

## Tile collection

Since the tile graphs are small (e.g., tens of nodes).
Further, the configuration features are at the graph level, and the number of
available configurations is relatively small (e.g., a few, to hundreds), we
report *slowdown metrics* on all configurations. Please refer to
[our paper](https://openreview.net/forum?id=plAix1NxhU) or
[tiles_evaluate.py](https://github.com/google-research-datasets/tpu_graphs/blob/main/tiles_evaluate.py)

## Layout collections

On the other hand, the layout graphs are larger (up to hundreds of thousands of
nodes). Further, the configuration features are at the node level, and the
number of available configurations is relatively large (e.g., up to hundreds of
thousands of configurations). Therefore, we choose **only** a 1000
configurations to score and report metrics on, in our main paper. This should
decrease the burden for the academic community for training and evaluating
models, especially for reporting experimental metrics.

For every **validation** graph in every subcollection {xla|nlp}:{default|random}
we pre-compute indices of configuration features and their corresponding
runtimes.
Specifically, the indices for each graph is available in json format, at:

https://github.com/google-research-datasets/tpu_graphs/tree/main/tpu_graphs/baselines/layout/eval_indices


The following code snippet reads-in a validation graph and restricts to the
validation indices.

```py
import os
import json
import numpy as np

# Assume that you did `git clone` inside of `~/code`:
_JSON_ROOT_DIR = os.path.expanduser(
'~/code/tpu_graphs/tpu_graphs/baselines/layout/eval_indices')
# Assume data was downloaded per
# https://github.com/google-research-datasets/tpu_graphs#dataset:
_LAYOUT_DATA_ROOT = os.path.expanduser('~/data/tpugraphs/npz/layout')

_JSON_DATA = {
('nlp', 'default'): json.load(open(f'{_JSON_ROOT_DIR}/nlp_default.json')),
('nlp', 'random'): json.load(open(f'{_JSON_ROOT_DIR}/nlp_random.json')),
('xla', 'default'): json.load(open(f'{_JSON_ROOT_DIR}/xla_default.json')),
('xla', 'random'): json.load(open(f'{_JSON_ROOT_DIR}/xla_random.json')),
}


def read_validation_graph(source, search, graph_name):
npz_path = os.path.join(
_LAYOUT_DATA_ROOT, source, search, 'valid', graph_name+'.npz')
npz_data = dict(np.load(npz_path))
ids = _JSON_DATA[(source, search)][graph_name]
np.random.shuffle(ids)
npz_data['config_runtime'] = npz_data['config_runtime'][ids]
npz_data['node_config_feat'] = npz_data['node_config_feat'][ids]
return npz_data


print(read_validation_graph('xla', 'random', 'resnet50.4x4.fp16'))

```

0 comments on commit 60c794b

Please sign in to comment.