Skip to content

Commit

Permalink
Record the status for a benchmark run. (tensorflow#4402)
Browse files Browse the repository at this point in the history
* Update benchmark logger to update the run status.

This is important for streaming upload to bigquery so that the
dashboard can ignore the 'running' benchmark at the moment since
its not finished yet.

* Move the run status into a separate table.

Also update the run status in the benchmark uploader and
BigqueryBenchmarkLogger.

* Insert instead of update for the benchmark status for file logger.

* Address review comments.

Update the logger to have benchmark context, which will update the
run status accordingly.

* Fix broken tests.

* Move the benchmark logger context to main function.

* Fix tests.

* Update the rest of the models to use the context in main.

* Delint.
  • Loading branch information
qlzh727 authored Jun 1, 2018
1 parent d530ac5 commit 47c5642
Show file tree
Hide file tree
Showing 14 changed files with 205 additions and 39 deletions.
23 changes: 23 additions & 0 deletions official/benchmark/benchmark_uploader.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import json

from google.cloud import bigquery
from google.cloud import exceptions

import tensorflow as tf

Expand Down Expand Up @@ -132,3 +133,25 @@ def _upload_json(self, dataset_name, table_name, json_list):
if errors:
tf.logging.error(
"Failed to upload benchmark info to bigquery: {}".format(errors))

def insert_run_status(self, dataset_name, table_name, run_id, run_status):
"""Insert the run status in to Bigquery run status table."""
query = ("INSERT {ds}.{tb} "
"(run_id, status) "
"VALUES('{rid}', '{status}')").format(
ds=dataset_name, tb=table_name, rid=run_id, status=run_status)
try:
self._bq_client.query(query=query).result()
except exceptions.GoogleCloudError as e:
tf.logging.error("Failed to insert run status: %s", e)

def update_run_status(self, dataset_name, table_name, run_id, run_status):
"""Update the run status in in Bigquery run status table."""
query = ("UPDATE {ds}.{tb} "
"SET status = '{status}' "
"WHERE run_id = '{rid}'").format(
ds=dataset_name, tb=table_name, status=run_status, rid=run_id)
try:
self._bq_client.query(query=query).result()
except exceptions.GoogleCloudError as e:
tf.logging.error("Failed to update run status: %s", e)
4 changes: 4 additions & 0 deletions official/benchmark/benchmark_uploader_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ def main(_):
uploader.upload_metric_file(
flags.FLAGS.bigquery_data_set, flags.FLAGS.bigquery_metric_table, run_id,
metric_json_file)
# Assume the run finished successfully before user invoke the upload script.
uploader.insert_run_status(
flags.FLAGS.bigquery_data_set, flags.FLAGS.bigquery_run_status_table,
run_id, logger.RUN_STATUS_SUCCESS)


if __name__ == "__main__":
Expand Down
58 changes: 37 additions & 21 deletions official/benchmark/benchmark_uploader_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@
benchmark_uploader = None


@unittest.skipIf(bigquery is None, 'Bigquery dependency is not installed.')
@unittest.skipIf(bigquery is None, "Bigquery dependency is not installed.")
class BigQueryUploaderTest(tf.test.TestCase):

@patch.object(bigquery, 'Client')
@patch.object(bigquery, "Client")
def setUp(self, mock_bigquery):
self.mock_client = mock_bigquery.return_value
self.mock_dataset = MagicMock(name="dataset")
Expand All @@ -52,56 +52,72 @@ def setUp(self, mock_bigquery):
self.benchmark_uploader._bq_client = self.mock_client

self.log_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
with open(os.path.join(self.log_dir, 'metric.log'), 'a') as f:
json.dump({'name': 'accuracy', 'value': 1.0}, f)
with open(os.path.join(self.log_dir, "metric.log"), "a") as f:
json.dump({"name": "accuracy", "value": 1.0}, f)
f.write("\n")
json.dump({'name': 'loss', 'value': 0.5}, f)
json.dump({"name": "loss", "value": 0.5}, f)
f.write("\n")
with open(os.path.join(self.log_dir, 'run.log'), 'w') as f:
json.dump({'model_name': 'value'}, f)
with open(os.path.join(self.log_dir, "run.log"), "w") as f:
json.dump({"model_name": "value"}, f)

def tearDown(self):
tf.gfile.DeleteRecursively(self.get_temp_dir())

def test_upload_benchmark_run_json(self):
self.benchmark_uploader.upload_benchmark_run_json(
'dataset', 'table', 'run_id', {'model_name': 'value'})
"dataset", "table", "run_id", {"model_name": "value"})

self.mock_client.insert_rows_json.assert_called_once_with(
self.mock_table, [{'model_name': 'value', 'model_id': 'run_id'}])
self.mock_table, [{"model_name": "value", "model_id": "run_id"}])

def test_upload_benchmark_metric_json(self):
metric_json_list = [
{'name': 'accuracy', 'value': 1.0},
{'name': 'loss', 'value': 0.5}
{"name": "accuracy", "value": 1.0},
{"name": "loss", "value": 0.5}
]
expected_params = [
{'run_id': 'run_id', 'name': 'accuracy', 'value': 1.0},
{'run_id': 'run_id', 'name': 'loss', 'value': 0.5}
{"run_id": "run_id", "name": "accuracy", "value": 1.0},
{"run_id": "run_id", "name": "loss", "value": 0.5}
]
self.benchmark_uploader.upload_benchmark_metric_json(
'dataset', 'table', 'run_id', metric_json_list)
"dataset", "table", "run_id", metric_json_list)
self.mock_client.insert_rows_json.assert_called_once_with(
self.mock_table, expected_params)

def test_upload_benchmark_run_file(self):
self.benchmark_uploader.upload_benchmark_run_file(
'dataset', 'table', 'run_id', os.path.join(self.log_dir, 'run.log'))
"dataset", "table", "run_id", os.path.join(self.log_dir, "run.log"))

self.mock_client.insert_rows_json.assert_called_once_with(
self.mock_table, [{'model_name': 'value', 'model_id': 'run_id'}])
self.mock_table, [{"model_name": "value", "model_id": "run_id"}])

def test_upload_metric_file(self):
self.benchmark_uploader.upload_metric_file(
'dataset', 'table', 'run_id',
os.path.join(self.log_dir, 'metric.log'))
"dataset", "table", "run_id",
os.path.join(self.log_dir, "metric.log"))
expected_params = [
{'run_id': 'run_id', 'name': 'accuracy', 'value': 1.0},
{'run_id': 'run_id', 'name': 'loss', 'value': 0.5}
{"run_id": "run_id", "name": "accuracy", "value": 1.0},
{"run_id": "run_id", "name": "loss", "value": 0.5}
]
self.mock_client.insert_rows_json.assert_called_once_with(
self.mock_table, expected_params)

def test_insert_run_status(self):
self.benchmark_uploader.insert_run_status(
"dataset", "table", "run_id", "status")
expected_query = ("INSERT dataset.table "
"(run_id, status) "
"VALUES('run_id', 'status')")
self.mock_client.query.assert_called_once_with(query=expected_query)

if __name__ == '__main__':
def test_update_run_status(self):
self.benchmark_uploader.update_run_status(
"dataset", "table", "run_id", "status")
expected_query = ("UPDATE dataset.table "
"SET status = 'status' "
"WHERE run_id = 'run_id'")
self.mock_client.query.assert_called_once_with(query=expected_query)


if __name__ == "__main__":
tf.test.main()
6 changes: 0 additions & 6 deletions official/benchmark/datastore/schema/benchmark_run.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,6 @@
"name": "model_id",
"type": "STRING"
},
{
"description": "The status of the run for the benchmark. Eg, running, failed, success",
"mode": "NULLABLE",
"name": "status",
"type": "STRING"
},
{
"description": "The name of the model, E.g ResNet50, LeNet-5 etc.",
"mode": "REQUIRED",
Expand Down
14 changes: 14 additions & 0 deletions official/benchmark/datastore/schema/benchmark_run_status.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
[
{
"description": "The UUID of the run for the benchmark.",
"mode": "REQUIRED",
"name": "run_id",
"type": "STRING"
},
{
"description": "The status of the run for the benchmark. Eg, running, failed, success",
"mode": "REQUIRED",
"name": "status",
"type": "STRING"
}
]
8 changes: 7 additions & 1 deletion official/recommendation/ncf_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,12 @@ def per_device_batch_size(batch_size, num_gpus):


def main(_):
with logger.benchmark_context(FLAGS):
run_ncf(FLAGS)


def run_ncf(_):
"""Run NCF training and eval loop."""
# Data preprocessing
# The file name of training and test dataset
train_fname = os.path.join(
Expand Down Expand Up @@ -237,7 +243,7 @@ def main(_):
"hr_threshold": FLAGS.hr_threshold,
"train_epochs": FLAGS.train_epochs,
}
benchmark_logger = logger.config_benchmark_logger(FLAGS)
benchmark_logger = logger.get_benchmark_logger()
benchmark_logger.log_run_info(
model_name="recommendation",
dataset_name=FLAGS.dataset,
Expand Down
5 changes: 3 additions & 2 deletions official/resnet/cifar10_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import tensorflow as tf # pylint: disable=g-bad-import-order

from official.utils.flags import core as flags_core
from official.utils.logs import logger
from official.resnet import resnet_model
from official.resnet import resnet_run_loop

Expand Down Expand Up @@ -236,14 +237,14 @@ def run_cifar(flags_obj):
"""
input_function = (flags_obj.use_synthetic_data and get_synth_input_fn()
or input_fn)

resnet_run_loop.resnet_main(
flags_obj, cifar10_model_fn, input_function, DATASET_NAME,
shape=[_HEIGHT, _WIDTH, _NUM_CHANNELS])


def main(_):
run_cifar(flags.FLAGS)
with logger.benchmark_context(flags.FLAGS):
run_cifar(flags.FLAGS)


if __name__ == '__main__':
Expand Down
4 changes: 3 additions & 1 deletion official/resnet/imagenet_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import tensorflow as tf # pylint: disable=g-bad-import-order

from official.utils.flags import core as flags_core
from official.utils.logs import logger
from official.resnet import imagenet_preprocessing
from official.resnet import resnet_model
from official.resnet import resnet_run_loop
Expand Down Expand Up @@ -321,7 +322,8 @@ def run_imagenet(flags_obj):


def main(_):
run_imagenet(flags.FLAGS)
with logger.benchmark_context(flags.FLAGS):
run_imagenet(flags.FLAGS)


if __name__ == '__main__':
Expand Down
3 changes: 1 addition & 2 deletions official/resnet/resnet_run_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ def resnet_main(
'synthetic_data': flags_obj.use_synthetic_data,
'train_epochs': flags_obj.train_epochs,
}
benchmark_logger = logger.config_benchmark_logger(flags_obj)
benchmark_logger = logger.get_benchmark_logger()
benchmark_logger.log_run_info('resnet', dataset_name, run_params)

train_hooks = hooks_helper.get_train_hooks(
Expand All @@ -415,7 +415,6 @@ def input_fn_eval():
batch_size=per_device_batch_size(
flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)),
num_epochs=1)

total_training_cycle = (flags_obj.train_epochs //
flags_obj.epochs_between_evals)
for cycle_index in range(total_training_cycle):
Expand Down
6 changes: 4 additions & 2 deletions official/transformer/transformer_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ def run_transformer(flags_obj):
tensors_to_log=TENSORS_TO_LOG, # used for logging hooks
batch_size=params.batch_size # for ExamplesPerSecondHook
)
benchmark_logger = logger.config_benchmark_logger(flags_obj)
benchmark_logger = logger.get_benchmark_logger()
benchmark_logger.log_run_info(
model_name="transformer",
dataset_name="wmt_translate_ende",
Expand All @@ -445,6 +445,7 @@ def run_transformer(flags_obj):
# Train and evaluate transformer model
estimator = tf.estimator.Estimator(
model_fn=model_fn, model_dir=flags_obj.model_dir, params=params)

train_schedule(
estimator=estimator,
# Training arguments
Expand All @@ -461,7 +462,8 @@ def run_transformer(flags_obj):


def main(_):
run_transformer(flags.FLAGS)
with logger.benchmark_context(flags.FLAGS):
run_transformer(flags.FLAGS)


if __name__ == "__main__":
Expand Down
6 changes: 6 additions & 0 deletions official/utils/flags/_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@ def define_benchmark(benchmark_log_dir=True, bigquery_uploader=True):
help=help_wrap("The Bigquery table name where the benchmark run "
"information will be uploaded."))

flags.DEFINE_string(
name="bigquery_run_status_table", short_name="brst",
default="benchmark_run_status",
help=help_wrap("The Bigquery table name where the benchmark run "
"status information will be uploaded."))

flags.DEFINE_string(
name="bigquery_metric_table", short_name="bmt",
default="benchmark_metric",
Expand Down
Loading

0 comments on commit 47c5642

Please sign in to comment.