diff --git a/R/RcppExports.R b/R/RcppExports.R index 54fefca412..dcbeecfb0a 100644 --- a/R/RcppExports.R +++ b/R/RcppExports.R @@ -15345,10 +15345,6 @@ cpp_jit_script_module_add_method <- function(self, method) { invisible(.Call(`_torch_cpp_jit_script_module_add_method`, self, method)) } -cpp_jit_script_module_add_forward <- function(self, list_output) { - invisible(.Call(`_torch_cpp_jit_script_module_add_forward`, self, list_output)) -} - cpp_jit_script_module_find_constant <- function(self, name) { .Call(`_torch_cpp_jit_script_module_find_constant`, self, name) } diff --git a/R/script_module.R b/R/script_module.R index 6f0b0beb40..a7dec7c651 100644 --- a/R/script_module.R +++ b/R/script_module.R @@ -126,6 +126,19 @@ nn_ScriptModule <- R6::R6Class( env = private ) }, + + forward = function(...) { + inputs <- list(...) + + if (is.null(private$find_method("forward"))) { + runtime_error("Forward is not defined. Methods from submodules of traced modules are not traced. Are you trying to call from a submodule?") + } + + out <- cpp_call_jit_script(private$ptr, inputs) + # calling the traced function always returns a stack + # with a single element. + out[[1]] + }, register_parameter = function(name, param) { private$ptr$register_parameter(name, param) }, @@ -139,7 +152,14 @@ nn_ScriptModule <- R6::R6Class( private$ptr$add_constant(name, value) }, graph_for = function(...) { - self$forward$graph_for(...) + if (!private$respects_mode) { + return(private$find_method("forward")$graph_for(...)) + } + if (self$training) { + private$find_method("trainforward")$graph_for(...) + } else { + private$find_method("evalforward")$graph_for(...) + } }, ..ptr.. = function() { private$ptr @@ -148,11 +168,34 @@ nn_ScriptModule <- R6::R6Class( private = list( find_method = function(name) { private$ptr$find_method(name) + }, + respects_mode = FALSE, + update_forward_to_respect_mode = function() { + private$respects_mode = TRUE + unlockBinding("forward", self) + self$forward = function(...) { + inputs <- list(...) + + if (self$training) { + private$find_method("trainforward")(...) + } else { + private$find_method("evalforward")(...) + } + } + lockBinding("forward", self) } ), active = list( graph = function() { - self$forward$graph + if (!private$respects_mode) { + return(private$find_method("forward")$graph) + } + if (self$training) { + private$find_method("trainforward")$graph + } else { + private$find_method("evalforward")$graph + + } }, training = function() { self$is_training() @@ -175,20 +218,16 @@ nn_ScriptModule <- R6::R6Class( } new_script_module <- function(ptr) { + module <- nn_ScriptModule$new(ptr = ptr) f <- function(...) { - inputs <- list(...) - - if (is.null(ptr$find_method("forward"))) { - runtime_error("Forward is not defined. Methods from submodules of traced modules are not traced. Are you trying to call from a submodule?") - } - - out <- cpp_call_jit_script(ptr, inputs) - # calling the traced function always returns a stack - # with a single element. - out[[1]] + module$forward(...) + } + if (!is.null(ptr$find_method("trainforward")) && !is.null(ptr$find_method("evalforward")) && + is.null(ptr$find_method("forward"))) { + module$.__enclos_env__$private$update_forward_to_respect_mode() } class(f) <- c("script_module", "nn_module") - attr(f, "module") <- nn_ScriptModule$new(ptr = ptr) + attr(f, "module") <- module f } diff --git a/R/trace.R b/R/trace.R index a67d228dbc..ac3fa60a13 100644 --- a/R/trace.R +++ b/R/trace.R @@ -30,8 +30,8 @@ #' choice. If you trace such models, you may silently get incorrect results on #' subsequent invocations of the model. The tracer will try to emit warnings when #' doing something that may cause an incorrect trace to be produced. +#' For scripting, see [`jit_compile`]. #' -#' @note Scripting is not yet supported in R. #' #' @param func An R function that will be run with `example_inputs`. func arguments #' and return values must be tensors or (possibly nested) lists that contain tensors. @@ -49,10 +49,13 @@ #' your problem is a constant structure and does not get used as control flow #' (`if`, `for`) conditions. #' @param respect_mode (`logical(1)`)\cr -#' Whether the forward method of the resulting module should respect the mode ('train' or 'eval'). -#' If `TRUE` (default), both passes will be jitted and be available as methods `trainforward` and `evalforward`. -#' The `forward` method will then select the appropriate method based on the mode of the module. -#' If `FALSE`, only the current mode of the module will be jitted. +#' Whether both modes ('train' or 'eval') should be traced. If `TRUE` (default), +#' the underlying C++ ScriptModule will have two methods `trainforward()` and +#' `evalforward()`. +#' The `$forward()` method of the R torch module will then select either based +#' on the mode. +#' If `FALSE`, only the current mode of the module will be jitted and hence only +#' one `forward()` method exists. #' #' @returns An `script_function` if `func` is a function and `script_module` if #' `func` is a `nn_module()`. @@ -215,6 +218,7 @@ new_script_function <- function(ptr) { # calling the traced function always returns a stack # with a single element. out[[1]] + } class(f) <- "script_function" attr(f, "ScriptFunction") <- ScriptFunction$new(ptr = ptr) @@ -335,6 +339,7 @@ jit_trace_module <- function(mod, ..., strict = TRUE, respect_mode = TRUE) { module <- create_script_module(mod) + if ("evalforward" %in% names(inputs) || "trainforward" %in% names(inputs)) { value_error("Methods `evalforward` and `trainforward` are reserved.") } @@ -375,7 +380,6 @@ jit_trace_module <- function(mod, ..., strict = TRUE, respect_mode = TRUE) { cpp_jit_script_module_add_method(module$..ptr..(), ptr_eval) cpp_jit_script_module_add_method(module$..ptr..(), ptr_train) list_output = is.list(with_no_grad(do.call(mod[[name]], inp))) - cpp_jit_script_module_add_forward(module$..ptr..(), list_output) } else { mod$train(was_training) tr_fn <- make_traceable_fn(mod[[name]]) @@ -393,6 +397,10 @@ jit_trace_module <- function(mod, ..., strict = TRUE, respect_mode = TRUE) { } } + if (respect_mode) { + module$.__enclos_env__$private$update_forward_to_respect_mode() + } + module$train(was_training) module diff --git a/inst/include/lantern/lantern.h b/inst/include/lantern/lantern.h index 8aa1d2245e..284002c5bd 100644 --- a/inst/include/lantern/lantern.h +++ b/inst/include/lantern/lantern.h @@ -2080,14 +2080,6 @@ HOST_API void lantern_ScriptModule_add_method (void* self, void* method) _lantern_ScriptModule_add_method(self, method); LANTERN_HOST_HANDLER; -} -LANTERN_API void (LANTERN_PTR _lantern_ScriptModule_add_forward) (void* self, bool list_output); -HOST_API void lantern_ScriptModule_add_forward (void* self, bool list_output) -{ - LANTERN_CHECK_LOADED - _lantern_ScriptModule_add_forward(self, list_output); - LANTERN_HOST_HANDLER; - } LANTERN_API void (LANTERN_PTR _lantern_ScriptModule_save) (void* self, void* path); @@ -10885,7 +10877,6 @@ LOAD_SYMBOL(_lantern_ScriptModule_new); LOAD_SYMBOL(_lantern_ScriptModule_add_constant); LOAD_SYMBOL(_lantern_ScriptModule_find_constant); LOAD_SYMBOL(_lantern_ScriptModule_add_method); -LOAD_SYMBOL(_lantern_ScriptModule_add_forward); LOAD_SYMBOL(_lantern_ScriptModule_save); LOAD_SYMBOL(_lantern_ScriptModule_serialize); LOAD_SYMBOL(_lantern_ScriptModule_unserialize); diff --git a/man/jit_trace.Rd b/man/jit_trace.Rd index 251b4f1a5b..35237dcf4b 100644 --- a/man/jit_trace.Rd +++ b/man/jit_trace.Rd @@ -26,10 +26,13 @@ your problem is a constant structure and does not get used as control flow (\code{if}, \code{for}) conditions.} \item{respect_mode}{(\code{logical(1)})\cr -Whether the forward method of the resulting module should respect the mode ('train' or 'eval'). -If \code{TRUE} (default), both passes will be jitted and be available as methods \code{trainforward} and \code{evalforward}. -The \code{forward} method will then select the appropriate method based on the mode of the module. -If \code{FALSE}, only the current mode of the module will be jitted.} +Whether both modes ('train' or 'eval') should be traced. If \code{TRUE} (default), +the underlying C++ ScriptModule will have two methods \code{trainforward()} and +\code{evalforward()}. +The \verb{$forward()} method of the R torch module will then select either based +on the mode. +If \code{FALSE}, only the current mode of the module will be jitted and hence only +one \code{forward()} method exists.} } \value{ An \code{script_function} if \code{func} is a function and \code{script_module} if @@ -43,9 +46,6 @@ recording the operations performed on all the tensors. \details{ The resulting recording of a standalone function produces a \code{script_function}. } -\note{ -Scripting is not yet supported in R. -} \section{Warning}{ @@ -72,6 +72,7 @@ In cases like these, tracing would not be appropriate and scripting is a better choice. If you trace such models, you may silently get incorrect results on subsequent invocations of the model. The tracer will try to emit warnings when doing something that may cause an incorrect trace to be produced. +For scripting, see \code{\link{jit_compile}}. } \examples{ diff --git a/man/jit_trace_module.Rd b/man/jit_trace_module.Rd index ab11e6bccc..1d270f146f 100644 --- a/man/jit_trace_module.Rd +++ b/man/jit_trace_module.Rd @@ -21,10 +21,13 @@ your problem is a constant structure and does not get used as control flow (\code{if}, \code{for}) conditions.} \item{respect_mode}{(\code{logical(1)})\cr -Whether the forward method of the resulting module should respect the mode ('train' or 'eval'). -If \code{TRUE} (default), both passes will be jitted and be available as methods \code{trainforward} and \code{evalforward}. -The \code{forward} method will then select the appropriate method based on the mode of the module. -If \code{FALSE}, only the current mode of the module will be jitted.} +Whether both modes ('train' or 'eval') should be traced. If \code{TRUE} (default), +the underlying C++ ScriptModule will have two methods \code{trainforward()} and +\code{evalforward()}. +The \verb{$forward()} method of the R torch module will then select either based +on the mode. +If \code{FALSE}, only the current mode of the module will be jitted and hence only +one \code{forward()} method exists.} } \description{ Trace a module and return an executable ScriptModule that will be optimized diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index 6f71b83bb0..2a5fcba343 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -47471,17 +47471,6 @@ BEGIN_RCPP return R_NilValue; END_RCPP } -// cpp_jit_script_module_add_forward -void cpp_jit_script_module_add_forward(XPtrTorchScriptModule self, bool list_output); -RcppExport SEXP _torch_cpp_jit_script_module_add_forward(SEXP selfSEXP, SEXP list_outputSEXP) { -BEGIN_RCPP - Rcpp::RNGScope rcpp_rngScope_gen; - Rcpp::traits::input_parameter< XPtrTorchScriptModule >::type self(selfSEXP); - Rcpp::traits::input_parameter< bool >::type list_output(list_outputSEXP); - cpp_jit_script_module_add_forward(self, list_output); - return R_NilValue; -END_RCPP -} // cpp_jit_script_module_find_constant SEXP cpp_jit_script_module_find_constant(XPtrTorchScriptModule self, XPtrTorchstring name); RcppExport SEXP _torch_cpp_jit_script_module_find_constant(SEXP selfSEXP, SEXP nameSEXP) { @@ -51973,7 +51962,6 @@ static const R_CallMethodDef CallEntries[] = { {"_torch_cpp_jit_script_module_new", (DL_FUNC) &_torch_cpp_jit_script_module_new, 2}, {"_torch_cpp_jit_script_module_add_constant", (DL_FUNC) &_torch_cpp_jit_script_module_add_constant, 3}, {"_torch_cpp_jit_script_module_add_method", (DL_FUNC) &_torch_cpp_jit_script_module_add_method, 2}, - {"_torch_cpp_jit_script_module_add_forward", (DL_FUNC) &_torch_cpp_jit_script_module_add_forward, 2}, {"_torch_cpp_jit_script_module_find_constant", (DL_FUNC) &_torch_cpp_jit_script_module_find_constant, 2}, {"_torch_cpp_jit_script_module_save", (DL_FUNC) &_torch_cpp_jit_script_module_save, 2}, {"_torch_cpp_jit_script_module_serialize", (DL_FUNC) &_torch_cpp_jit_script_module_serialize, 1}, diff --git a/src/lantern/include/lantern/lantern.h b/src/lantern/include/lantern/lantern.h index 8aa1d2245e..284002c5bd 100644 --- a/src/lantern/include/lantern/lantern.h +++ b/src/lantern/include/lantern/lantern.h @@ -2080,14 +2080,6 @@ HOST_API void lantern_ScriptModule_add_method (void* self, void* method) _lantern_ScriptModule_add_method(self, method); LANTERN_HOST_HANDLER; -} -LANTERN_API void (LANTERN_PTR _lantern_ScriptModule_add_forward) (void* self, bool list_output); -HOST_API void lantern_ScriptModule_add_forward (void* self, bool list_output) -{ - LANTERN_CHECK_LOADED - _lantern_ScriptModule_add_forward(self, list_output); - LANTERN_HOST_HANDLER; - } LANTERN_API void (LANTERN_PTR _lantern_ScriptModule_save) (void* self, void* path); @@ -10885,7 +10877,6 @@ LOAD_SYMBOL(_lantern_ScriptModule_new); LOAD_SYMBOL(_lantern_ScriptModule_add_constant); LOAD_SYMBOL(_lantern_ScriptModule_find_constant); LOAD_SYMBOL(_lantern_ScriptModule_add_method); -LOAD_SYMBOL(_lantern_ScriptModule_add_forward); LOAD_SYMBOL(_lantern_ScriptModule_save); LOAD_SYMBOL(_lantern_ScriptModule_serialize); LOAD_SYMBOL(_lantern_ScriptModule_unserialize); diff --git a/src/lantern/src/ScriptModule.cpp b/src/lantern/src/ScriptModule.cpp index e547f17405..ccd3d94849 100644 --- a/src/lantern/src/ScriptModule.cpp +++ b/src/lantern/src/ScriptModule.cpp @@ -174,30 +174,6 @@ void _lantern_ScriptModule_add_method(void* self, void* method) { LANTERN_FUNCTION_END_VOID } -void _lantern_ScriptModule_add_forward(void* self, bool list_output) { - LANTERN_FUNCTION_START - auto self_ = reinterpret_cast(self); - if (list_output) { - self_->define(R"( - def forward(self, x) -> List[Tensor]: - if self.training: - return self.trainforward(x) - else: - return self.evalforward(x) - )"); - } else { - self_->define(R"( - def forward(self, x) -> Tensor: - if self.training: - return self.trainforward(x) - else: - return self.evalforward(x) - )"); - - } - LANTERN_FUNCTION_END_VOID -} - void _lantern_ScriptModule_add_constant(void* self, void* name, void* value) { LANTERN_FUNCTION_START auto self_ = reinterpret_cast(self); diff --git a/src/script_module.cpp b/src/script_module.cpp index 4cfd4a9b4a..fd1b3b1add 100644 --- a/src/script_module.cpp +++ b/src/script_module.cpp @@ -119,11 +119,6 @@ void cpp_jit_script_module_add_method(XPtrTorchScriptModule self, lantern_ScriptModule_add_method(self.get(), method->get()); } -// [[Rcpp::export]] -void cpp_jit_script_module_add_forward(XPtrTorchScriptModule self, bool list_output) { - lantern_ScriptModule_add_forward(self.get(), list_output); -} - // [[Rcpp::export]] SEXP cpp_jit_script_module_find_constant(XPtrTorchScriptModule self, XPtrTorchstring name) { diff --git a/tests/testthat/_snaps/script_module.md b/tests/testthat/_snaps/script_module.md index aa6d3e80a2..ef6591fbf1 100644 --- a/tests/testthat/_snaps/script_module.md +++ b/tests/testthat/_snaps/script_module.md @@ -1,74 +1,23 @@ # can print the graph - graph(%self.1 : __torch__.nn_linear, - %x.1 : Tensor): - %training : bool = prim::GetAttr[name="training"](%self.1) - %19 : Tensor = prim::If(%training) # :3:8 - block0(): - %7 : Tensor = prim::CallMethod[name="trainforward"](%self.1, %x.1) # :4:17 - -> (%7) - block1(): - %10 : Tensor = prim::CallMethod[name="evalforward"](%self.1, %x.1) # :6:17 - -> (%10) - return (%19) - ---- - - graph(%self.1 : __torch__.nn_linear, - %x.1 : Tensor): - %training : bool = prim::GetAttr[name="training"](%self.1) - %19 : Tensor = prim::If(%training) # :3:8 - block0(): - %7 : Tensor = prim::CallMethod[name="trainforward"](%self.1, %x.1) # :4:17 - -> (%7) - block1(): - %10 : Tensor = prim::CallMethod[name="evalforward"](%self.1, %x.1) # :6:17 - -> (%10) - return (%19) + graph(%self : __torch__.nn_linear, + %4 : Float(10, 10, strides=[10, 1], requires_grad=0, device=cpu)): + %bias : Tensor = prim::GetAttr[name="bias"](%self) + %weight : Tensor = prim::GetAttr[name="weight"](%self) + %5 : Float(10, 10, strides=[10, 1], requires_grad=1, device=cpu) = aten::linear(%4, %weight, %bias) + return (%5) # graph_for - graph(%self.1 : __torch__.nn_linear1, - %x.1 : Tensor): - %training : bool = prim::GetAttr[name="training"](%self.1) - %3 : Tensor = prim::If(%training) # :3:8 - block0(): - %bias.1 : Tensor = prim::GetAttr[name="bias"](%self.1) - %weight.1 : Tensor = prim::GetAttr[name="weight"](%self.1) - %10 : Tensor = prim::profile[profiled_type=Float(10, 10, strides=[10, 1], requires_grad=0, device=cpu), seen_none=0](%x.1) - %11 : Tensor = prim::profile[profiled_type=Float(10, 10, strides=[10, 1], requires_grad=1, device=cpu), seen_none=0](%weight.1) - %12 : Tensor = prim::profile[profiled_type=Float(10, strides=[1], requires_grad=1, device=cpu), seen_none=0](%bias.1) - %6 : Tensor = aten::linear(%10, %11, %12) - %13 : Tensor = prim::profile[profiled_type=Float(10, 10, strides=[10, 1], requires_grad=1, device=cpu), seen_none=0](%6) - -> (%13) - block1(): - %bias : Tensor = prim::GetAttr[name="bias"](%self.1) - %weight : Tensor = prim::GetAttr[name="weight"](%self.1) - %14 : Tensor = prim::profile[profiled_type=Tensor, seen_none=0](%x.1) - %15 : Tensor = prim::profile[profiled_type=Tensor, seen_none=0](%weight) - %16 : Tensor = prim::profile[profiled_type=Tensor, seen_none=0](%bias) - %9 : Tensor = aten::linear(%14, %15, %16) - %17 : Tensor = prim::profile[profiled_type=Tensor, seen_none=0](%9) - -> (%17) - %18 : Tensor = prim::profile[profiled_type=Float(10, 10, strides=[10, 1], requires_grad=1, device=cpu), seen_none=0](%3) + graph(%self : __torch__.nn_linear1, + %1 : Tensor): + %bias : Tensor = prim::GetAttr[name="bias"](%self) + %weight : Tensor = prim::GetAttr[name="weight"](%self) + %5 : Tensor = prim::profile[profiled_type=Float(10, 10, strides=[10, 1], requires_grad=0, device=cpu), seen_none=0](%1) + %6 : Tensor = prim::profile[profiled_type=Float(10, 10, strides=[10, 1], requires_grad=1, device=cpu), seen_none=0](%weight) + %7 : Tensor = prim::profile[profiled_type=Float(10, strides=[1], requires_grad=1, device=cpu), seen_none=0](%bias) + %4 : Tensor = aten::linear(%5, %6, %7) + %8 : Tensor = prim::profile[profiled_type=Float(10, 10, strides=[10, 1], requires_grad=1, device=cpu), seen_none=0](%4) = prim::profile() - return (%18) - ---- - - graph(%self.1 : __torch__.nn_linear1, - %x.1 : Tensor): - %training : bool = prim::GetAttr[name="training"](%self.1) - %3 : Tensor = prim::If(%training) # :3:8 - block0(): - %bias.1 : Tensor = prim::GetAttr[name="bias"](%self.1) - %weight.1 : Tensor = prim::GetAttr[name="weight"](%self.1) - %42 : Tensor = aten::linear(%x.1, %weight.1, %bias.1) - -> (%42) - block1(): - %bias : Tensor = prim::GetAttr[name="bias"](%self.1) - %weight : Tensor = prim::GetAttr[name="weight"](%self.1) - %47 : Tensor = aten::linear(%x.1, %weight, %bias) - -> (%47) - return (%3) + return (%8) diff --git a/tests/testthat/test-script_module.R b/tests/testthat/test-script_module.R index 09de36bb7f..9656408fdb 100644 --- a/tests/testthat/test-script_module.R +++ b/tests/testthat/test-script_module.R @@ -112,10 +112,6 @@ test_that("can print the graph", { set.seed(1) traced <- jit_trace(nn_linear(10, 10), torch_randn(10, 10)) - expect_snapshot_output({ - print(traced$forward$graph) - }) - expect_snapshot_output({ print(traced$graph) }) @@ -125,10 +121,6 @@ test_that("graph_for", { testthat::local_edition(3) traced <- jit_trace(nn_linear(10, 10), torch_randn(10, 10)) - expect_snapshot_output({ - traced$forward$graph_for(torch_randn(10, 10)) - }) - expect_snapshot_output({ traced$graph_for(torch_randn(10, 10)) }) diff --git a/tests/testthat/test-trace.R b/tests/testthat/test-trace.R index 00e594d467..51967d4939 100644 --- a/tests/testthat/test-trace.R +++ b/tests/testthat/test-trace.R @@ -108,7 +108,7 @@ test_that("can output a list of tensors", { list(x, x + 1) } x <- torch_tensor(1) - tr_fn <- jit_trace(fn, x) + tr_fn <- jit_trace(fn, x, strict = FALSE) expect_equal_to_tensor(fn(x)[[1]], tr_fn(x)[[1]]) expect_equal_to_tensor(fn(x)[[2]], tr_fn(x)[[2]]) }) @@ -121,7 +121,7 @@ test_that("fn can take more than 1 argument", { x <- torch_tensor(1) y <- torch_tensor(2) - tr_fn <- jit_trace(fn, x, y) + tr_fn <- jit_trace(fn, x, y, strict = FALSE) expect_equal_to_tensor(fn(x, y)[[1]], tr_fn(x, y)[[1]]) expect_equal_to_tensor(fn(x, y)[[2]], tr_fn(x, y)[[2]])