Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trace-Jitting fails when combined with jit-compiled module #1252

Open
sebffischer opened this issue Jan 22, 2025 · 1 comment
Open

Trace-Jitting fails when combined with jit-compiled module #1252

sebffischer opened this issue Jan 22, 2025 · 1 comment

Comments

@sebffischer
Copy link
Collaborator

Below, the input type of the trace-jitted module is inferred as Tensor even though it received a list[Tensor] during tracing.

library(torch)
sum_jit = jit_compile("
def sum_jit(xs: List[Tensor]):
  output = torch.zeros_like(xs[0])
  for x in xs:
    output = output + x
  return output
")$sum_jit


net2 = nn_module("nn_mix",
  initialize = function() {
    self$linear = nn_linear(10, 1)
  },
  forward = function(x) {
    self$linear(sum_jit(x))
  }
)()

net2_jit = jit_trace(net2, list(torch_randn(10, 10), torch_randn(10, 10)))

net2_jit(list(torch_tensor(1), torch_tensor(2), torch_tensor(3)))
#> Error in cpp_call_jit_script(ptr, inputs): forward() Expected a value of type 'Tensor (inferred)' for argument 'x' but instead found type 'List[Tensor]'.
#> Inferred 'x' to be of type 'Tensor' because it was not annotated with an explicit type.
#> Position: 1
#> Declaration: forward(__torch__.nn_mix self, Tensor x) -> Tensor
#> Exception raised from checkArg at /Users/runner/work/libtorch-mac-m1/libtorch-mac-m1/pytorch/aten/src/ATen/core/function_schema_inl.h:20 (most recent call first):
#> frame #0: c10::Error::Error(c10::SourceLocation, std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>>) + 52 (0x10375011c in libc10.dylib)
#> frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>> const&) + 140 (0x10374cd6c in libc10.dylib)
#> frame #2: void c10::FunctionSchema::checkArg<c10::Type>(c10::IValue const&, c10::Argument const&, std::__1::optional<unsigned long>) const + 312 (0x14b376d74 in libtorch_cpu.dylib)
#> frame #3: void c10::FunctionSchema::checkAndNormalizeInputs<c10::Type>(std::__1::vector<c10::IValue, std::__1::allocator<c10::IValue>>&, std::__1::unordered_map<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>>, c10::IValue, std::__1::hash<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>>>, std::__1::equal_to<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>>>, std::__1::allocator<std::__1::pair<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>> const, c10::IValue>>> const&) const + 148 (0x14b3767b8 in libtorch_cpu.dylib)
#> frame #4: torch::jit::Method::operator()(std::__1::vector<c10::IValue, std::__1::allocator<c10::IValue>>, std::__1::unordered_map<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>>, c10::IValue, std::__1::hash<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>>>, std::__1::equal_to<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>>>, std::__1::allocator<std::__1::pair<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>> const, c10::IValue>>> const&) const + 568 (0x14f14f498 in libtorch_cpu.dylib)
#> frame #5: torch::jit::Module::forward(std::__1::vector<c10::IValue, std::__1::allocator<c10::IValue>>, std::__1::unordered_map<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>>, c10::IValue, std::__1::hash<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>>>, std::__1::equal_to<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>>>, std::__1::allocator<std::__1::pair<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>> const, c10::IValue>>> const&) + 120 (0x122feb41c in liblantern.dylib)
#> frame #6: _lantern_call_jit_script + 112 (0x122feace0 in liblantern.dylib)
#> frame #7: lantern_call_jit_script + 48 (0x110f258f4 in torchpkg.so)
#> frame #8: cpp_call_jit_script(Rcpp::XPtr<XPtrTorchJITModule, Rcpp::PreserveStorage, &void Rcpp::standard_delete_finalizer<XPtrTorchJITModule>(XPtrTorchJITModule*), false>, XPtrTorchStack) + 96 (0x110f25884 in torchpkg.so)
#> frame #9: _torch_cpp_call_jit_script + 240 (0x110cb99e0 in torchpkg.so)
#> frame #10: R_doDotCall + 268 (0x100bac48c in libR.dylib)
#> frame #11: bcEval_loop + 128060 (0x100c08cfc in libR.dylib)
#> frame #12: bcEval + 684 (0x100bdbeec in libR.dylib)
#> frame #13: Rf_eval + 556 (0x100bdb5ec in libR.dylib)
#> frame #14: R_execClosure + 812 (0x100bde1ac in libR.dylib)
#> frame #15: applyClosure_core + 164 (0x100bdd2a4 in libR.dylib)
#> frame #16: Rf_eval + 1224 (0x100bdb888 in libR.dylib)
#> frame #17: do_eval + 1352 (0x100be2ac8 in libR.dylib)
#> frame #18: bcEval_loop + 40164 (0x100bf35a4 in libR.dylib)
#> frame #19: bcEval + 684 (0x100bdbeec in libR.dylib)
#> frame #20: Rf_eval + 556 (0x100bdb5ec in libR.dylib)
#> frame #21: forcePromise + 232 (0x100bdc128 in libR.dylib)
#> frame #22: Rf_eval + 660 (0x100bdb654 in libR.dylib)
#> frame #23: do_withVisible + 64 (0x100be2e00 in libR.dylib)
#> frame #24: do_internal + 400 (0x100c4bbd0 in libR.dylib)
#> frame #25: bcEval_loop + 40724 (0x100bf37d4 in libR.dylib)
#> frame #26: bcEval + 684 (0x100bdbeec in libR.dylib)
#> frame #27: Rf_eval + 556 (0x100bdb5ec in libR.dylib)
#> frame #28: R_execClosure + 812 (0x100bde1ac in libR.dylib)
#> frame #29: applyClosure_core + 164 (0x100bdd2a4 in libR.dylib)
#> frame #30: Rf_eval + 1224 (0x100bdb888 in libR.dylib)
#> frame #31: R_execClosure + 812 (0x100bde1ac in libR.dylib)
#> frame #32: applyClosure_core + 164 (0x100bdd2a4 in libR.dylib)
#> frame #33: bcEval_loop + 37296 (0x100bf2a70 in libR.dylib)
#> frame #34: bcEval + 684 (0x100bdbeec in libR.dylib)
#> frame #35: Rf_eval + 556 (0x100bdb5ec in libR.dylib)
#> frame #36: R_execClosure + 812 (0x100bde1ac in libR.dylib)
#> frame #37: applyClosure_core + 164 (0x100bdd2a4 in libR.dylib)
#> frame #38: Rf_eval + 1224 (0x100bdb888 in libR.dylib)
#> frame #39: do_begin + 396 (0x100be0a4c in libR.dylib)
#> frame #40: Rf_eval + 1012 (0x100bdb7b4 in libR.dylib)
#> frame #41: R_execClosure + 812 (0x100bde1ac in libR.dylib)
#> frame #42: applyClosure_core + 164 (0x100bdd2a4 in libR.dylib)
#> frame #43: Rf_eval + 1224 (0x100bdb888 in libR.dylib)
#> frame #44: do_docall + 644 (0x100b79f04 in libR.dylib)
#> frame #45: bcEval_loop + 40164 (0x100bf35a4 in libR.dylib)
#> frame #46: bcEval + 684 (0x100bdbeec in libR.dylib)
#> frame #47: Rf_eval + 556 (0x100bdb5ec in libR.dylib)
#> frame #48: R_execClosure + 812 (0x100bde1ac in libR.dylib)
#> frame #49: applyClosure_core + 164 (0x100bdd2a4 in libR.dylib)
#> frame #50: Rf_eval + 1224 (0x100bdb888 in libR.dylib)
#> frame #51: do_docall + 644 (0x100b79f04 in libR.dylib)
#> frame #52: bcEval_loop + 40164 (0x100bf35a4 in libR.dylib)
#> frame #53: bcEval + 684 (0x100bdbeec in libR.dylib)
#> frame #54: Rf_eval + 556 (0x100bdb5ec in libR.dylib)
#> frame #55: R_execClosure + 812 (0x100bde1ac in libR.dylib)
#> frame #56: applyClosure_core + 164 (0x100bdd2a4 in libR.dylib)
#> frame #57: Rf_eval + 1224 (0x100bdb888 in libR.dylib)
#> frame #58: forcePromise + 232 (0x100bdc128 in libR.dylib)
#> frame #59: bcEval_loop + 19716 (0x100bee5c4 in libR.dylib)
#> frame #60: bcEval + 684 (0x100bdbeec in libR.dylib)
#> frame #61: Rf_eval + 556 (0x100bdb5ec in libR.dylib)
#> frame #62: R_execClosure + 812 (0x100bde1ac in libR.dylib)

Created on 2025-01-22 with reprex v2.1.1

@sebffischer
Copy link
Collaborator Author

@dfalbel I introduced this bug in 9b8722e.

I will try to fix this asap

sebffischer added a commit to sebffischer/torch that referenced this issue Jan 23, 2025
I realized that because TorchScript is statically typed
we would need to infer the types for the forward method
that selects trainforward or evalforward depending on the mode.

We can do this later, but for now I think this fix is also okay.

Resolves Issue mlverse#1252
dfalbel pushed a commit that referenced this issue Jan 28, 2025
I realized that because TorchScript is statically typed
we would need to infer the types for the forward method
that selects trainforward or evalforward depending on the mode.

We can do this later, but for now I think this fix is also okay.

Resolves Issue #1252
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant