diff --git a/jetstream/core/implementations/mock/server.py b/jetstream/core/implementations/mock/server.py index 6355469c..7bfc9afd 100644 --- a/jetstream/core/implementations/mock/server.py +++ b/jetstream/core/implementations/mock/server.py @@ -18,9 +18,8 @@ from absl import app from absl import flags - -from jetstream.core.implementations.mock import config as mock_config from jetstream.core import server_lib +from jetstream.core.implementations.mock import config as mock_config _PORT = flags.DEFINE_integer('port', 9000, 'port to listen on') @@ -30,19 +29,30 @@ 'available servers', ) + def main(argv: Sequence[str]): del argv - # No devices for local cpu test. A None for prefill and a None for generate. - devices = server_lib.get_devices() - server_config = mock_config.get_server_config(_CONFIG.value) - # We separate credential from run so that we can unit test it with local credentials. - # TODO: Add grpc credentials for OSS. - jetstream_server = server_lib.run( - port=_PORT.value, - config=server_config, - devices=devices, - ) - jetstream_server.wait_for_termination() + jetstream_server = None + try: + # No devices for local cpu test. A None for prefill and a None for generate. + devices = server_lib.get_devices() + server_config = mock_config.get_server_config(_CONFIG.value) + # We separate credential from run so that we can unit test it with local credentials. + # TODO: Add grpc credentials for OSS. + jetstream_server = server_lib.run( + port=_PORT.value, + config=server_config, + devices=devices, + ) + jetstream_server.wait_for_termination() + except KeyboardInterrupt: + print('Stopping profiler and exiting...') + print( + 'NOTE: DO NOT Interrupt again; the profiler is slowly collecting data' + ' and existing...' + ) + if jetstream_server: + jetstream_server.stop() if __name__ == '__main__': diff --git a/jetstream/core/server_lib.py b/jetstream/core/server_lib.py index 2f0c1093..c775ef69 100644 --- a/jetstream/core/server_lib.py +++ b/jetstream/core/server_lib.py @@ -39,6 +39,8 @@ def __init__(self, driver: orchestrator.Driver, server: grpc.Server): self._server = server def start(self, port, credentials) -> None: + # start jax profiler + jax.profiler.start_trace("/tmp/tensorboard") self._server.add_secure_port(f'{_HOST}:{port}', credentials) self._server.start() @@ -46,6 +48,8 @@ def stop(self) -> None: # Gracefully clean up threads in the orchestrator. self._driver.stop() self._server.stop(0) + # end jax profiler + jax.profiler.stop_trace() def wait_for_termination(self) -> None: self._server.wait_for_termination()