Skip to content

Commit

Permalink
make builder accessible (pytorch#3743)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#3743

Make the `builder` accessible after calling `_export_llama`. Also add a member method to return most recent saved model path (which guarantee ending with `.pte`, so that user doesn't need to think about complicated logic to check if it's model name or file name).

Reviewed By: cccclai

Differential Revision: D57801568

fbshipit-source-id: 4317d85a3aa8e54e0919e385e20674ddacfbf512
  • Loading branch information
Yanghan Wang authored and facebook-github-bot committed May 28, 2024
1 parent 2b91eba commit 24b37f2
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 6 deletions.
1 change: 1 addition & 0 deletions examples/models/llama2/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ runtime.python_library(
"//bento/...",
"//bento_kernels/...",
"//executorch/examples/...",
"@EXECUTORCH_CLIENTS",
],
deps = [
"//caffe2:torch",
Expand Down
10 changes: 9 additions & 1 deletion examples/models/llama2/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def __init__(
self.edge_manager: Optional[EdgeProgramManager] = None
self.export_program = None
self.output_dir = "."
self._saved_pte_filename = None

def set_metadata(self, metadata: Optional[dict]) -> "LlamaEdgeManager":
"""
Expand Down Expand Up @@ -388,4 +389,11 @@ def save_to_pte(self, output_name: str) -> None:
output_name (Optional[str]): The name of the .pte file.
"""
assert output_name, "Need a valid output name"
save_pte_program(self.export_program, output_name, self.output_dir)
filename = save_pte_program(self.export_program, output_name, self.output_dir)
self._saved_pte_filename = filename

def get_saved_pte_filename(self) -> Optional[str]:
"""
Return the filename of the most recenet saved .pte file. Return None if the model is not saved.
"""
return self._saved_pte_filename
16 changes: 12 additions & 4 deletions examples/models/llama2/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,14 +293,22 @@ def export_llama(modelname, args) -> str:
from executorch.util.python_profiler import CProfilerFlameGraph

with CProfilerFlameGraph(args.profile_path):
return _export_llama(modelname, args)
builder = _export_llama(modelname, args)
assert (
filename := builder.get_saved_pte_filename()
) is not None, "Fail to get file name from builder"
return filename
except ImportError:
print(
"Please run `pip install snakeviz` to install required dependencies for cProfiler flamegraph."
)
return ""
else:
return _export_llama(modelname, args)
builder = _export_llama(modelname, args)
assert (
filename := builder.get_saved_pte_filename()
) is not None, "Fail to get file name from builder"
return filename


def _prepare_for_llama_export(modelname: str, args) -> LlamaEdgeManager:
Expand Down Expand Up @@ -383,7 +391,7 @@ def get_quantizer_and_quant_params(args):
return pt2e_quant_params, quantizers, quant_dtype


def _export_llama(modelname, args) -> str: # noqa: C901
def _export_llama(modelname, args) -> LlamaEdgeManager: # noqa: C901
pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args)

# export_to_edge
Expand Down Expand Up @@ -468,4 +476,4 @@ def _export_llama(modelname, args) -> str: # noqa: C901

builder.save_to_pte(output_file)

return output_file
return builder
4 changes: 3 additions & 1 deletion examples/portable/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def export_to_exec_prog(

def save_pte_program(
prog: ExecutorchProgramManager, model_name: str, output_dir: str = ""
) -> None:
) -> str:
if model_name.endswith(".pte"):
filename = model_name
else:
Expand All @@ -114,3 +114,5 @@ def save_pte_program(
logging.info(f"Saved exported program to {filename}")
except Exception as e:
logging.error(f"Error while saving to {filename}: {e}")

return filename

0 comments on commit 24b37f2

Please sign in to comment.