Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow augment to be used on newdata for mlogit #1158

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 29 additions & 6 deletions R/mlogit-tidiers.R
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@
#' augment(m)
#' glance(m)
#'
#' # augment with newdata
#' Fish2 <- Fish
#' Fish2$price <- ifelse(Fish2$income < 3000, Fish2$price * 0.7, Fish2$price )
#' augment(m, newdata = Fish2)
#'
#' @aliases mlogit_tidiers
#' @export
#' @family mlogit tidiers
Expand Down Expand Up @@ -56,23 +61,37 @@ tidy.mlogit <- function(x, conf.int = FALSE, conf.level = 0.95, ...) {
#'
#' @inherit tidy.mlogit params examples
#' @param data Not currently used
#' @param newdata Data frame on which to predict utility values. See `details`.
#'
#' @details At the moment this only works on the estimation dataset. Need to set
#' it up to predict on another dataset.
#' @details Augmenting a new data frame requires that the data be
#' a `dfidx` data frame with ID and alternative information identified.
#'
#' @export
#' @seealso [augment()]
#' @family mlogit tidiers
#'
#'
augment.mlogit <- function(x, data = x$model, ...) {
augment.mlogit <- function(x, data = x$model, newdata = NULL, ...) {
check_ellipses("newdata", "augment", "mlogit", ...)

# So, the way mlogit handles prediction is kind of silly, because
# the developers have chosen to not implement a model.matrix method.
# Rather, mlogit uses update to create a NEW model object but without
# running a new maximum likelihood estimation, it constrains the parameters
# to their previously estimated values.
# This does unfortunately mean that the data to be predicted has to be
# in a dfidx format.
if (!is.null(newdata)) {
x <- update(x, start = coef(x, fixed = TRUE), data = newdata, iterlim = 0,
print.level = 0)
}

# the ID variables are really messed up, so we're going to do some
# retrofitting because this ends up being a pretty important element of
# what we want to do with the results.
idx <- x$model$idx

# augment
reg <- x$model %>%
as_augment_tibble() %>%
dplyr::select(-idx) %>%
Expand All @@ -85,11 +104,15 @@ augment.mlogit <- function(x, data = x$model, ...) {
# reappend the id columns
dplyr::mutate(
id = idx$id1,
alternative = idx$id2,
.resid = as.vector(x$residuals)
alternative = idx$id2
) %>%
dplyr::select(id, alternative, chosen, everything())


# residuals don't make sense for newdata
if(is.null(newdata)){
reg$.resid = as.vector(x$residuals)
}

reg
}

Expand Down