diff --git a/mlcube/mlcube/__main__.py b/mlcube/mlcube/__main__.py index c539bdb..cce4b76 100644 --- a/mlcube/mlcube/__main__.py +++ b/mlcube/mlcube/__main__.py @@ -1,6 +1,7 @@ """This requires the MLCube 2.0 that's located somewhere in one of dev branches.""" import logging import os +from pathlib import Path import shutil import sys import typing as t @@ -666,9 +667,21 @@ def create() -> None: default="json", help="Format for reporting results.", ) +@click.option( + "--output-file", + "--output_file", + required=False, + type=str, + default=None, + help="File path to store the MLCube information. Defaults to print to STDOUT", +) @Options.help def inspect( - mlcube: t.Optional[str], platform: str, force: bool = False, format_: str = "json" + mlcube: t.Optional[str], + platform: str, + force: bool = False, + format_: str = "json", + output_file: t.Optional[str] = None, ) -> None: """Return low-level information on MLCube objects.""" runner_cls, mlcube_config = parse_cli_args( @@ -676,6 +689,13 @@ def inspect( unparsed_args=[], resolve=True, ) + if output_file is None: + output_stream = sys.stdout + else: + dir_path = Path(output_file).resolve().parent + dir_path.mkdir(parents=True, exist_ok=True) + output_stream = open(output_file, "w") + try: runner = runner_cls(mlcube_config, task=None) info: t.Dict = runner.inspect(force=force) @@ -683,11 +703,13 @@ def inspect( if format_ == "json": import json - print(json.dumps(info)) + json.dump(info, output_stream) + if output_stream == sys.stdout: + print() # json doesn't print a newline else: import yaml - yaml.dump(info, sys.stdout) + yaml.dump(info, output_stream) except MLCubeError as err: print("MLCube inspect failed") logger.exception(err)