From 005ce17ce0bcd6b7656b67c5aa207f0b307556e0 Mon Sep 17 00:00:00 2001 From: Zijun Zhou Date: Thu, 7 Mar 2024 08:08:36 +0000 Subject: [PATCH 1/2] [DO NOT MERGE] JAX profiling on JetStream server --- jetstream/core/implementations/mock/server.py | 34 ++++++++++++------- jetstream/core/server_lib.py | 4 +++ 2 files changed, 25 insertions(+), 13 deletions(-) diff --git a/jetstream/core/implementations/mock/server.py b/jetstream/core/implementations/mock/server.py index 6355469c..8f312249 100644 --- a/jetstream/core/implementations/mock/server.py +++ b/jetstream/core/implementations/mock/server.py @@ -18,9 +18,8 @@ from absl import app from absl import flags - -from jetstream.core.implementations.mock import config as mock_config from jetstream.core import server_lib +from jetstream.core.implementations.mock import config as mock_config _PORT = flags.DEFINE_integer('port', 9000, 'port to listen on') @@ -30,19 +29,28 @@ 'available servers', ) + def main(argv: Sequence[str]): del argv - # No devices for local cpu test. A None for prefill and a None for generate. - devices = server_lib.get_devices() - server_config = mock_config.get_server_config(_CONFIG.value) - # We separate credential from run so that we can unit test it with local credentials. - # TODO: Add grpc credentials for OSS. - jetstream_server = server_lib.run( - port=_PORT.value, - config=server_config, - devices=devices, - ) - jetstream_server.wait_for_termination() + try: + # No devices for local cpu test. A None for prefill and a None for generate. + devices = server_lib.get_devices() + server_config = mock_config.get_server_config(_CONFIG.value) + # We separate credential from run so that we can unit test it with local credentials. + # TODO: Add grpc credentials for OSS. + jetstream_server = server_lib.run( + port=_PORT.value, + config=server_config, + devices=devices, + ) + jetstream_server.wait_for_termination() + except KeyboardInterrupt: + print('Stopping profiler and exiting...') + print( + 'NOTE: DO NOT Interrupt again; the profiler is slowly collecting data' + ' and existing...' + ) + jetstream_server.stop() if __name__ == '__main__': diff --git a/jetstream/core/server_lib.py b/jetstream/core/server_lib.py index 2f0c1093..c775ef69 100644 --- a/jetstream/core/server_lib.py +++ b/jetstream/core/server_lib.py @@ -39,6 +39,8 @@ def __init__(self, driver: orchestrator.Driver, server: grpc.Server): self._server = server def start(self, port, credentials) -> None: + # start jax profiler + jax.profiler.start_trace("/tmp/tensorboard") self._server.add_secure_port(f'{_HOST}:{port}', credentials) self._server.start() @@ -46,6 +48,8 @@ def stop(self) -> None: # Gracefully clean up threads in the orchestrator. self._driver.stop() self._server.stop(0) + # end jax profiler + jax.profiler.stop_trace() def wait_for_termination(self) -> None: self._server.wait_for_termination() From d5b313f7d722110ca61d2337ee8a55bf08cdb204 Mon Sep 17 00:00:00 2001 From: Zijun Zhou Date: Thu, 7 Mar 2024 08:18:37 +0000 Subject: [PATCH 2/2] fix pytype --- jetstream/core/implementations/mock/server.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/jetstream/core/implementations/mock/server.py b/jetstream/core/implementations/mock/server.py index 8f312249..7bfc9afd 100644 --- a/jetstream/core/implementations/mock/server.py +++ b/jetstream/core/implementations/mock/server.py @@ -32,6 +32,7 @@ def main(argv: Sequence[str]): del argv + jetstream_server = None try: # No devices for local cpu test. A None for prefill and a None for generate. devices = server_lib.get_devices() @@ -50,7 +51,8 @@ def main(argv: Sequence[str]): 'NOTE: DO NOT Interrupt again; the profiler is slowly collecting data' ' and existing...' ) - jetstream_server.stop() + if jetstream_server: + jetstream_server.stop() if __name__ == '__main__':