Skip to content

Commit

Permalink
emit metadata
Browse files Browse the repository at this point in the history
Differential Revision: D61625159

Pull Request resolved: pytorch#4837
  • Loading branch information
JacobSzwejbka authored Aug 22, 2024
1 parent c2044a4 commit 3af50f9
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 1 deletion.
27 changes: 27 additions & 0 deletions exir/emit/_emit_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,29 @@ def _remove_non_user_outputs(exported_program: ExportedProgram) -> torch.fx.Grap
return gm


# For each entry point in the model, determine if its a joint graph,
# and if it is return a map of the indices in the model output that the
# gradient outputs start at and that the parameter outputs start at.
def _get_training_metadata(methods: Dict[str, ExportedProgram]) -> Dict[str, int]:
gradients_method_prefix = "__et_training_gradients_index_"
parameters_method_prefix = "__et_training_parameters_index_"
training_metadata = {}
for name, method in methods.items():
found_grad = False
found_param = False
i = 0
for output_spec in method.graph_signature.output_specs:
if output_spec.kind == OutputKind.GRADIENT_TO_PARAMETER and not found_grad:
training_metadata[gradients_method_prefix + name] = i
found_grad = True
elif output_spec.kind == OutputKind.TOKEN and not found_param:
assert found_grad # Params must come after gradients
training_metadata[parameters_method_prefix + name] = i
found_param = True
i += 1
return training_metadata


def emit_program(
methods: Union[ExportedProgram, Dict[str, ExportedProgram]],
emit_stacktrace: bool = False,
Expand Down Expand Up @@ -143,6 +166,10 @@ def emit_program(
emitter.instr_id_to_delegate_debug_id_map
)

training_metadata = _get_training_metadata(methods)
if len(training_metadata) > 0:
plans.extend(emitter._emit_prim_getters(training_metadata))

# emit any primitive getters
if prim_getters is not None:
plans.extend(emitter._emit_prim_getters(prim_getters))
Expand Down
20 changes: 20 additions & 0 deletions exir/tests/test_joint_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,23 @@ def forward(self, x, y):
self.assertTrue(torch.allclose(m.linear.bias.grad, et_outputs[2]))
self.assertTrue(torch.allclose(m.linear.weight, et_outputs[3]))
self.assertTrue(torch.allclose(m.linear.bias, et_outputs[4]))

self.assertEqual(
len(et.executorch_program.execution_plan), 3
) # forward + 2 training metadata functions

# gradient outputs start at index 1
self.assertEqual(
et.executorch_program.execution_plan[1] # pyre-ignore
.values[0]
.val.int_val,
1,
)

# parameter outputs start at index 3
self.assertEqual(
et.executorch_program.execution_plan[2] # pyre-ignore
.values[0]
.val.int_val,
3,
)
2 changes: 1 addition & 1 deletion extension/training/test/training_loop_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
// @lint-ignore-every CLANGTIDY facebook-hte-CArray

using namespace ::testing;
using namespace torch::executor::training::optimizer;
using namespace executorch::extension::training::optimizer;
using namespace torch::executor::testing;
using exec_aten::ScalarType;
using exec_aten::Tensor;
Expand Down

0 comments on commit 3af50f9

Please sign in to comment.