Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(tracing): respect mode in R forward function #1253

Merged
merged 1 commit into from
Jan 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -15345,10 +15345,6 @@ cpp_jit_script_module_add_method <- function(self, method) {
invisible(.Call(`_torch_cpp_jit_script_module_add_method`, self, method))
}

cpp_jit_script_module_add_forward <- function(self, list_output) {
invisible(.Call(`_torch_cpp_jit_script_module_add_forward`, self, list_output))
}

cpp_jit_script_module_find_constant <- function(self, name) {
.Call(`_torch_cpp_jit_script_module_find_constant`, self, name)
}
Expand Down
65 changes: 52 additions & 13 deletions R/script_module.R
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,19 @@ nn_ScriptModule <- R6::R6Class(
env = private
)
},

forward = function(...) {
inputs <- list(...)

if (is.null(private$find_method("forward"))) {
runtime_error("Forward is not defined. Methods from submodules of traced modules are not traced. Are you trying to call from a submodule?")
}

out <- cpp_call_jit_script(private$ptr, inputs)
# calling the traced function always returns a stack
# with a single element.
out[[1]]
},
register_parameter = function(name, param) {
private$ptr$register_parameter(name, param)
},
Expand All @@ -139,7 +152,14 @@ nn_ScriptModule <- R6::R6Class(
private$ptr$add_constant(name, value)
},
graph_for = function(...) {
self$forward$graph_for(...)
if (!private$respects_mode) {
return(private$find_method("forward")$graph_for(...))
}
if (self$training) {
private$find_method("trainforward")$graph_for(...)
} else {
private$find_method("evalforward")$graph_for(...)
}
},
..ptr.. = function() {
private$ptr
Expand All @@ -148,11 +168,34 @@ nn_ScriptModule <- R6::R6Class(
private = list(
find_method = function(name) {
private$ptr$find_method(name)
},
respects_mode = FALSE,
update_forward_to_respect_mode = function() {
private$respects_mode = TRUE
unlockBinding("forward", self)
self$forward = function(...) {
inputs <- list(...)

if (self$training) {
private$find_method("trainforward")(...)
} else {
private$find_method("evalforward")(...)
}
}
lockBinding("forward", self)
}
),
active = list(
graph = function() {
self$forward$graph
if (!private$respects_mode) {
return(private$find_method("forward")$graph)
}
if (self$training) {
private$find_method("trainforward")$graph
} else {
private$find_method("evalforward")$graph

}
},
training = function() {
self$is_training()
Expand All @@ -175,20 +218,16 @@ nn_ScriptModule <- R6::R6Class(
}

new_script_module <- function(ptr) {
module <- nn_ScriptModule$new(ptr = ptr)
f <- function(...) {
inputs <- list(...)

if (is.null(ptr$find_method("forward"))) {
runtime_error("Forward is not defined. Methods from submodules of traced modules are not traced. Are you trying to call from a submodule?")
}

out <- cpp_call_jit_script(ptr, inputs)
# calling the traced function always returns a stack
# with a single element.
out[[1]]
module$forward(...)
}
if (!is.null(ptr$find_method("trainforward")) && !is.null(ptr$find_method("evalforward")) &&
is.null(ptr$find_method("forward"))) {
module$.__enclos_env__$private$update_forward_to_respect_mode()
}
class(f) <- c("script_module", "nn_module")
attr(f, "module") <- nn_ScriptModule$new(ptr = ptr)
attr(f, "module") <- module
f
}

Expand Down
20 changes: 14 additions & 6 deletions R/trace.R
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
#' choice. If you trace such models, you may silently get incorrect results on
#' subsequent invocations of the model. The tracer will try to emit warnings when
#' doing something that may cause an incorrect trace to be produced.
#' For scripting, see [`jit_compile`].
#'
#' @note Scripting is not yet supported in R.
#'
#' @param func An R function that will be run with `example_inputs`. func arguments
#' and return values must be tensors or (possibly nested) lists that contain tensors.
Expand All @@ -49,10 +49,13 @@
#' your problem is a constant structure and does not get used as control flow
#' (`if`, `for`) conditions.
#' @param respect_mode (`logical(1)`)\cr
#' Whether the forward method of the resulting module should respect the mode ('train' or 'eval').
#' If `TRUE` (default), both passes will be jitted and be available as methods `trainforward` and `evalforward`.
#' The `forward` method will then select the appropriate method based on the mode of the module.
#' If `FALSE`, only the current mode of the module will be jitted.
#' Whether both modes ('train' or 'eval') should be traced. If `TRUE` (default),
#' the underlying C++ ScriptModule will have two methods `trainforward()` and
#' `evalforward()`.
#' The `$forward()` method of the R torch module will then select either based
#' on the mode.
#' If `FALSE`, only the current mode of the module will be jitted and hence only
#' one `forward()` method exists.
#'
#' @returns An `script_function` if `func` is a function and `script_module` if
#' `func` is a `nn_module()`.
Expand Down Expand Up @@ -215,6 +218,7 @@ new_script_function <- function(ptr) {
# calling the traced function always returns a stack
# with a single element.
out[[1]]

}
class(f) <- "script_function"
attr(f, "ScriptFunction") <- ScriptFunction$new(ptr = ptr)
Expand Down Expand Up @@ -335,6 +339,7 @@ jit_trace_module <- function(mod, ..., strict = TRUE, respect_mode = TRUE) {

module <- create_script_module(mod)


if ("evalforward" %in% names(inputs) || "trainforward" %in% names(inputs)) {
value_error("Methods `evalforward` and `trainforward` are reserved.")
}
Expand Down Expand Up @@ -375,7 +380,6 @@ jit_trace_module <- function(mod, ..., strict = TRUE, respect_mode = TRUE) {
cpp_jit_script_module_add_method(module$..ptr..(), ptr_eval)
cpp_jit_script_module_add_method(module$..ptr..(), ptr_train)
list_output = is.list(with_no_grad(do.call(mod[[name]], inp)))
cpp_jit_script_module_add_forward(module$..ptr..(), list_output)
} else {
mod$train(was_training)
tr_fn <- make_traceable_fn(mod[[name]])
Expand All @@ -393,6 +397,10 @@ jit_trace_module <- function(mod, ..., strict = TRUE, respect_mode = TRUE) {
}

}
if (respect_mode) {
module$.__enclos_env__$private$update_forward_to_respect_mode()
}

module$train(was_training)

module
Expand Down
9 changes: 0 additions & 9 deletions inst/include/lantern/lantern.h
Original file line number Diff line number Diff line change
Expand Up @@ -2080,14 +2080,6 @@ HOST_API void lantern_ScriptModule_add_method (void* self, void* method)
_lantern_ScriptModule_add_method(self, method);
LANTERN_HOST_HANDLER;

}
LANTERN_API void (LANTERN_PTR _lantern_ScriptModule_add_forward) (void* self, bool list_output);
HOST_API void lantern_ScriptModule_add_forward (void* self, bool list_output)
{
LANTERN_CHECK_LOADED
_lantern_ScriptModule_add_forward(self, list_output);
LANTERN_HOST_HANDLER;

}

LANTERN_API void (LANTERN_PTR _lantern_ScriptModule_save) (void* self, void* path);
Expand Down Expand Up @@ -10885,7 +10877,6 @@ LOAD_SYMBOL(_lantern_ScriptModule_new);
LOAD_SYMBOL(_lantern_ScriptModule_add_constant);
LOAD_SYMBOL(_lantern_ScriptModule_find_constant);
LOAD_SYMBOL(_lantern_ScriptModule_add_method);
LOAD_SYMBOL(_lantern_ScriptModule_add_forward);
LOAD_SYMBOL(_lantern_ScriptModule_save);
LOAD_SYMBOL(_lantern_ScriptModule_serialize);
LOAD_SYMBOL(_lantern_ScriptModule_unserialize);
Expand Down
15 changes: 8 additions & 7 deletions man/jit_trace.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 7 additions & 4 deletions man/jit_trace_module.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 0 additions & 12 deletions src/RcppExports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47471,17 +47471,6 @@ BEGIN_RCPP
return R_NilValue;
END_RCPP
}
// cpp_jit_script_module_add_forward
void cpp_jit_script_module_add_forward(XPtrTorchScriptModule self, bool list_output);
RcppExport SEXP _torch_cpp_jit_script_module_add_forward(SEXP selfSEXP, SEXP list_outputSEXP) {
BEGIN_RCPP
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< XPtrTorchScriptModule >::type self(selfSEXP);
Rcpp::traits::input_parameter< bool >::type list_output(list_outputSEXP);
cpp_jit_script_module_add_forward(self, list_output);
return R_NilValue;
END_RCPP
}
// cpp_jit_script_module_find_constant
SEXP cpp_jit_script_module_find_constant(XPtrTorchScriptModule self, XPtrTorchstring name);
RcppExport SEXP _torch_cpp_jit_script_module_find_constant(SEXP selfSEXP, SEXP nameSEXP) {
Expand Down Expand Up @@ -51973,7 +51962,6 @@ static const R_CallMethodDef CallEntries[] = {
{"_torch_cpp_jit_script_module_new", (DL_FUNC) &_torch_cpp_jit_script_module_new, 2},
{"_torch_cpp_jit_script_module_add_constant", (DL_FUNC) &_torch_cpp_jit_script_module_add_constant, 3},
{"_torch_cpp_jit_script_module_add_method", (DL_FUNC) &_torch_cpp_jit_script_module_add_method, 2},
{"_torch_cpp_jit_script_module_add_forward", (DL_FUNC) &_torch_cpp_jit_script_module_add_forward, 2},
{"_torch_cpp_jit_script_module_find_constant", (DL_FUNC) &_torch_cpp_jit_script_module_find_constant, 2},
{"_torch_cpp_jit_script_module_save", (DL_FUNC) &_torch_cpp_jit_script_module_save, 2},
{"_torch_cpp_jit_script_module_serialize", (DL_FUNC) &_torch_cpp_jit_script_module_serialize, 1},
Expand Down
9 changes: 0 additions & 9 deletions src/lantern/include/lantern/lantern.h
Original file line number Diff line number Diff line change
Expand Up @@ -2080,14 +2080,6 @@ HOST_API void lantern_ScriptModule_add_method (void* self, void* method)
_lantern_ScriptModule_add_method(self, method);
LANTERN_HOST_HANDLER;

}
LANTERN_API void (LANTERN_PTR _lantern_ScriptModule_add_forward) (void* self, bool list_output);
HOST_API void lantern_ScriptModule_add_forward (void* self, bool list_output)
{
LANTERN_CHECK_LOADED
_lantern_ScriptModule_add_forward(self, list_output);
LANTERN_HOST_HANDLER;

}

LANTERN_API void (LANTERN_PTR _lantern_ScriptModule_save) (void* self, void* path);
Expand Down Expand Up @@ -10885,7 +10877,6 @@ LOAD_SYMBOL(_lantern_ScriptModule_new);
LOAD_SYMBOL(_lantern_ScriptModule_add_constant);
LOAD_SYMBOL(_lantern_ScriptModule_find_constant);
LOAD_SYMBOL(_lantern_ScriptModule_add_method);
LOAD_SYMBOL(_lantern_ScriptModule_add_forward);
LOAD_SYMBOL(_lantern_ScriptModule_save);
LOAD_SYMBOL(_lantern_ScriptModule_serialize);
LOAD_SYMBOL(_lantern_ScriptModule_unserialize);
Expand Down
24 changes: 0 additions & 24 deletions src/lantern/src/ScriptModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,30 +174,6 @@ void _lantern_ScriptModule_add_method(void* self, void* method) {
LANTERN_FUNCTION_END_VOID
}

void _lantern_ScriptModule_add_forward(void* self, bool list_output) {
LANTERN_FUNCTION_START
auto self_ = reinterpret_cast<torch::jit::script::Module*>(self);
if (list_output) {
self_->define(R"(
def forward(self, x) -> List[Tensor]:
if self.training:
return self.trainforward(x)
else:
return self.evalforward(x)
)");
} else {
self_->define(R"(
def forward(self, x) -> Tensor:
if self.training:
return self.trainforward(x)
else:
return self.evalforward(x)
)");

}
LANTERN_FUNCTION_END_VOID
}

void _lantern_ScriptModule_add_constant(void* self, void* name, void* value) {
LANTERN_FUNCTION_START
auto self_ = reinterpret_cast<torch::jit::script::Module*>(self);
Expand Down
5 changes: 0 additions & 5 deletions src/script_module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,6 @@ void cpp_jit_script_module_add_method(XPtrTorchScriptModule self,
lantern_ScriptModule_add_method(self.get(), method->get());
}

// [[Rcpp::export]]
void cpp_jit_script_module_add_forward(XPtrTorchScriptModule self, bool list_output) {
lantern_ScriptModule_add_forward(self.get(), list_output);
}

// [[Rcpp::export]]
SEXP cpp_jit_script_module_find_constant(XPtrTorchScriptModule self,
XPtrTorchstring name) {
Expand Down
Loading
Loading