Skip to content

Commit

Permalink
get jit
Browse files Browse the repository at this point in the history
  • Loading branch information
WT-MM committed Dec 30, 2024
1 parent 6e4988d commit 983e010
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions sim/play.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,13 @@
def export_policy_as_jit(actor_critic: Any, path: Union[str, os.PathLike]) -> None:
os.makedirs(path, exist_ok=True)
path = os.path.join(path, "policy_1.pt")
model = get_actor_jit(actor_critic)
model.save(path)

def get_actor_jit(actor_critic: Any) -> Any:
model = copy.deepcopy(actor_critic.actor).to("cpu")
traced_script_module = torch.jit.script(model)
traced_script_module.save(path)

return traced_script_module

def play(args: argparse.Namespace) -> None:
logger.info("Configuring environment and training settings...")
Expand Down Expand Up @@ -186,8 +189,9 @@ def play(args: argparse.Namespace) -> None:
)

if args.export_onnx:
jit_policy = get_actor_jit(policy)
kinfer_policy = export_model(
model=policy,
model=jit_policy,
schema=model_schema,
)
kinfer_policy.save_model("policy.kinfer")
Expand Down

0 comments on commit 983e010

Please sign in to comment.