From 58a3674cc6c4d71aa9aed9617424b3de45c4b9cd Mon Sep 17 00:00:00 2001 From: "Logan C. Brooks" Date: Tue, 9 Jul 2024 15:00:38 -0700 Subject: [PATCH 01/11] Make `layer_predict` forward stored dots_list to `predict()` --- DESCRIPTION | 2 +- NEWS.md | 2 ++ R/epi_workflow.R | 2 +- R/layer_predict.R | 17 ++++++++--- tests/testthat/test-layer_predict.R | 44 +++++++++++++++++++++++++++++ 5 files changed, 61 insertions(+), 6 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index 1126f8304..a219637b6 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: epipredict Title: Basic epidemiology forecasting methods -Version: 0.0.16 +Version: 0.0.17 Authors@R: c( person("Daniel", "McDonald", , "daniel@stat.ubc.ca", role = c("aut", "cre")), person("Ryan", "Tibshirani", , "ryantibs@cmu.edu", role = "aut"), diff --git a/NEWS.md b/NEWS.md index bf3f4d9d5..cce52cb51 100644 --- a/NEWS.md +++ b/NEWS.md @@ -47,3 +47,5 @@ Pre-1.0.0 numbering scheme: 0.x will indicate releases, while 0.0.x will indicat - Revise `compat-purrr` to use the r-lang `standalone-*` version (via `{usethis}`) - `epi_recipe()` will now warn when given non-`epi_df` data +- `layer_predict()` will now appropriately forward `...` args intended for + `predict.workflow()` diff --git a/R/epi_workflow.R b/R/epi_workflow.R index c6f1e43a9..1ad916c19 100644 --- a/R/epi_workflow.R +++ b/R/epi_workflow.R @@ -123,7 +123,7 @@ fit.epi_workflow <- function(object, data, ..., control = workflows::control_wor #' possible. Specifically, the output will have `time_value` and #' `geo_value` columns as well as the prediction. #' -#' @inheritParams parsnip::predict.model_fit +#' @inheritParams workflows::predict.workflow #' #' @param object An epi_workflow that has been fit by #' [workflows::fit.workflow()] diff --git a/R/layer_predict.R b/R/layer_predict.R index b40c24be5..f161f21c3 100644 --- a/R/layer_predict.R +++ b/R/layer_predict.R @@ -45,12 +45,18 @@ layer_predict <- id = rand_id("predict_default")) { arg_is_chr_scalar(id) arg_is_chr_scalar(type, allow_null = TRUE) + dots_list <- rlang::dots_list(..., .homonyms = "error", .check_assign = TRUE) + if (any(rlang::names2(dots_list) == "")) { + cli_abort("All `...` arguments must be named.", + class = "epipredict__layer_predict__unnamed_dot" + ) + } add_layer( frosting, layer_predict_new( type = type, opts = opts, - dots_list = rlang::list2(...), # can't figure how to use this + dots_list = dots_list, id = id ) ) @@ -63,13 +69,16 @@ layer_predict_new <- function(type, opts, dots_list, id) { #' @export slather.layer_predict <- function(object, components, workflow, new_data, ...) { + rlang::check_dots_empty() + the_fit <- workflows::extract_fit_parsnip(workflow) - components$predictions <- predict( + components$predictions <- rlang::inject(predict( the_fit, components$forged$predictors, - type = object$type, opts = object$opts - ) + type = object$type, opts = object$opts, + !!!object$dots_list + )) components$predictions <- dplyr::bind_cols( components$keys, components$predictions ) diff --git a/tests/testthat/test-layer_predict.R b/tests/testthat/test-layer_predict.R index bd10de08c..f3b33f5db 100644 --- a/tests/testthat/test-layer_predict.R +++ b/tests/testthat/test-layer_predict.R @@ -31,3 +31,47 @@ test_that("prediction with interval works", { expect_equal(nrow(p), 108L) expect_named(p, c("geo_value", "time_value", ".pred_lower", ".pred_upper")) }) + +test_that("layer_predict dots validation", { + # We balk at unnamed arguments, though perhaps not with the most helpful error messages: + expect_error( + frosting() %>% layer_predict("pred_int", list(), tibble::tibble(x = 5)), + class = "epipredict__layer_predict__unnamed_dot" + ) + expect_error( + frosting() %>% layer_predict("pred_int", list(), "maybe_meant_to_be_id"), + class = "epipredict__layer_predict__unnamed_dot" + ) + # We allow arguments that might actually work at prediction time: + expect_no_error(frosting() %>% layer_predict(type = "quantile", interval = "confidence")) + + # We don't detect completely-bogus arg names until predict time: + expect_no_error(f_bad_arg <- frosting() %>% layer_predict(bogus_argument = "something")) + wf_bad_arg <- wf %>% add_frosting(f_bad_arg) + expect_error(predict(wf_bad_arg, latest)) + # Some argument names only apply for some prediction `type`s; we don't check for ignored arguments, and neither does workflows: + expect_no_error(frosting() %>% layer_predict(eval_time = "preferably this would error")) + + # ^ (currently with a truly awful error message, due to an extra comma in parsnip::check_pred_type_dots) + # + # Unfortunately, we outright ignore attempts to pass args via `predict.epi_workflow`: + f_predict <- frosting() %>% layer_predict() + wf_predict <- wf %>% add_frosting(f_predict) + expect_no_error(predict(wf_predict, latest, type = "pred_int")) +}) + +test_that("layer_predict dots are forwarded", { + f_lm_int_level <- frosting() %>% + layer_predict(type = "pred_int", level = 0.8) + wf_lm_int_level <- wf %>% add_frosting(f_lm_int_level) + p <- predict(wf, latest) + p_lm_int_level <- predict(wf_lm_int_level, latest) + expect_contains(names(p_lm_int_level), c(".pred_lower", ".pred_upper")) + expect_equal(nrow(na.omit(p)), nrow(na.omit(p_lm_int_level))) + expect_true(cbind(p, p_lm_int_level[c(".pred_lower", ".pred_upper")]) %>% + na.omit() %>% + mutate(sandwiched = .pred_lower <= .pred & .pred <= .pred_upper) %>% + `[[`("sandwiched") %>% + all()) + # There are many possible other valid configurations that aren't tested here. +}) From 1c9b30856340affc0498b6c19478d7ac730512ef Mon Sep 17 00:00:00 2001 From: "Logan C. Brooks" Date: Thu, 18 Jul 2024 11:16:23 -0700 Subject: [PATCH 02/11] Sometimes allow passing type, opts, ... via predict.epi_workflow() --- NAMESPACE | 1 + R/epi_workflow.R | 4 +-- R/epipredict-package.R | 2 +- R/frosting.R | 22 +++++++++++++-- R/layer_add_forecast_date.R | 1 + R/layer_naomit.R | 1 + R/layer_point_from_distn.R | 4 +-- R/layer_population_scaling.R | 2 +- R/layer_predict.R | 15 ++++++---- R/layer_predictive_distn.R | 1 + R/layer_quantile_distn.R | 2 ++ R/layer_residual_quantiles.R | 2 ++ R/layer_threshold_preds.R | 1 + R/layer_unnest.R | 1 + inst/templates/layer.R | 1 + man/apply_frosting.Rd | 2 +- man/predict-epi_workflow.Rd | 12 +++++++- tests/testthat/test-frosting.R | 43 +++++++++++++++++++++++++++++ tests/testthat/test-layer_predict.R | 36 +++++++++++++++++------- 19 files changed, 127 insertions(+), 26 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index 708c91e06..941ea1542 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -208,6 +208,7 @@ import(parsnip) import(recipes) importFrom(checkmate,assert) importFrom(checkmate,assert_character) +importFrom(checkmate,assert_class) importFrom(checkmate,assert_date) importFrom(checkmate,assert_function) importFrom(checkmate,assert_int) diff --git a/R/epi_workflow.R b/R/epi_workflow.R index 1ad916c19..7e1c95f88 100644 --- a/R/epi_workflow.R +++ b/R/epi_workflow.R @@ -152,7 +152,7 @@ fit.epi_workflow <- function(object, data, ..., control = workflows::control_wor #' #' preds <- predict(wf, latest) #' preds -predict.epi_workflow <- function(object, new_data, ...) { +predict.epi_workflow <- function(object, new_data, type = NULL, opts = list(), ...) { if (!workflows::is_trained_workflow(object)) { cli::cli_abort(c( "Can't predict on an untrained epi_workflow.", @@ -168,7 +168,7 @@ predict.epi_workflow <- function(object, new_data, ...) { components$forged, components$mold, new_data ) - components <- apply_frosting(object, components, new_data, ...) + components <- apply_frosting(object, components, new_data, type = type, opts = opts, ...) components$predictions } diff --git a/R/epipredict-package.R b/R/epipredict-package.R index 4bd37c519..7746281ba 100644 --- a/R/epipredict-package.R +++ b/R/epipredict-package.R @@ -6,7 +6,7 @@ #' @importFrom cli cli_abort #' @importFrom checkmate assert assert_character assert_int assert_scalar #' assert_logical assert_numeric assert_number assert_integer -#' assert_integerish assert_date assert_function +#' assert_integerish assert_date assert_function assert_class #' @import epiprocess parsnip ## usethis namespace: end NULL diff --git a/R/frosting.R b/R/frosting.R index f9c5867a4..f293314fb 100644 --- a/R/frosting.R +++ b/R/frosting.R @@ -357,7 +357,7 @@ apply_frosting.default <- function(workflow, components, ...) { #' @importFrom rlang abort #' @export apply_frosting.epi_workflow <- - function(workflow, components, new_data, ...) { + function(workflow, components, new_data, type = NULL, opts = list(), ...) { the_fit <- workflows::extract_fit_parsnip(workflow) if (!has_postprocessor(workflow)) { @@ -397,10 +397,28 @@ apply_frosting.epi_workflow <- layers ) } + if (length(layers) > 1L && + (!is.null(type) || !identical(opts, list()) || rlang::dots_n(...) > 0L)) { + cli_abort(" + Passing `type`, `opts`, or `...` into `predict.epi_workflow()` is not + supported if you have frosting layers other than `layer_predict`. Please + provide these arguments earlier (i.e. while constructing the frosting + object) by passing them into an explicit call to `layer_predict(), and + adjust the remaining layers to account for resulting differences in + output format from these settings. + ", class = "epipredict__apply_frosting__predict_settings_with_unsupported_layers") + } for (l in seq_along(layers)) { la <- layers[[l]] - components <- slather(la, components, workflow, new_data) + if (inherits(la, "layer_predict")) { + components <- slather(la, components, workflow, new_data, type = type, opts = opts, ...) + } else { + # The check above should ensure we have default `type` and `opts` and + # empty `...`; don't forward these default `type` and `opts`, to avoid + # upsetting some slather method validation. + components <- slather(la, components, workflow, new_data) + } } return(components) diff --git a/R/layer_add_forecast_date.R b/R/layer_add_forecast_date.R index 2174b7330..c4bb7d483 100644 --- a/R/layer_add_forecast_date.R +++ b/R/layer_add_forecast_date.R @@ -86,6 +86,7 @@ layer_add_forecast_date_new <- function(forecast_date, id) { #' @export slather.layer_add_forecast_date <- function(object, components, workflow, new_data, ...) { + rlang::check_dots_empty() if (is.null(object$forecast_date)) { max_time_value <- as.Date(max( workflows::extract_preprocessor(workflow)$max_time_value, diff --git a/R/layer_naomit.R b/R/layer_naomit.R index ad6c5606c..85842bfdf 100644 --- a/R/layer_naomit.R +++ b/R/layer_naomit.R @@ -45,6 +45,7 @@ layer_naomit_new <- function(terms, id) { #' @export slather.layer_naomit <- function(object, components, workflow, new_data, ...) { + rlang::check_dots_empty() exprs <- rlang::expr(c(!!!object$terms)) pos <- tidyselect::eval_select(exprs, components$predictions) col_names <- names(pos) diff --git a/R/layer_point_from_distn.R b/R/layer_point_from_distn.R index 52ecef3cc..8f5ed2c33 100644 --- a/R/layer_point_from_distn.R +++ b/R/layer_point_from_distn.R @@ -76,16 +76,16 @@ layer_point_from_distn_new <- function(type, name, id) { #' @export slather.layer_point_from_distn <- function(object, components, workflow, new_data, ...) { - rlang::check_dots_empty() dstn <- components$predictions$.pred if (!inherits(dstn, "distribution")) { rlang::warn( c("`layer_point_from_distn` requires distributional predictions.", i = "These are of class {class(dstn)}. Ignoring this layer." - ) + ) ) return(components) } + rlang::check_dots_empty() dstn <- match.fun(object$type)(dstn) if (is.null(object$name)) { diff --git a/R/layer_population_scaling.R b/R/layer_population_scaling.R index 1d02604e5..33183198d 100644 --- a/R/layer_population_scaling.R +++ b/R/layer_population_scaling.R @@ -128,11 +128,11 @@ layer_population_scaling_new <- #' @export slather.layer_population_scaling <- function(object, components, workflow, new_data, ...) { - rlang::check_dots_empty() stopifnot( "Only one population column allowed for scaling" = length(object$df_pop_col) == 1 ) + rlang::check_dots_empty() if (is.null(object$by)) { object$by <- intersect( diff --git a/R/layer_predict.R b/R/layer_predict.R index f161f21c3..ecc76408a 100644 --- a/R/layer_predict.R +++ b/R/layer_predict.R @@ -45,11 +45,12 @@ layer_predict <- id = rand_id("predict_default")) { arg_is_chr_scalar(id) arg_is_chr_scalar(type, allow_null = TRUE) + assert_class(opts, "list") dots_list <- rlang::dots_list(..., .homonyms = "error", .check_assign = TRUE) if (any(rlang::names2(dots_list) == "")) { cli_abort("All `...` arguments must be named.", - class = "epipredict__layer_predict__unnamed_dot" - ) + class = "epipredict__layer_predict__unnamed_dot" + ) } add_layer( frosting, @@ -68,16 +69,18 @@ layer_predict_new <- function(type, opts, dots_list, id) { } #' @export -slather.layer_predict <- function(object, components, workflow, new_data, ...) { - rlang::check_dots_empty() +slather.layer_predict <- function(object, components, workflow, new_data, type = NULL, opts = list(), ...) { + arg_is_chr_scalar(type, allow_null = TRUE) + assert_class(opts, "list") the_fit <- workflows::extract_fit_parsnip(workflow) components$predictions <- rlang::inject(predict( the_fit, components$forged$predictors, - type = object$type, opts = object$opts, - !!!object$dots_list + type = object$type %||% type, + opts = c(object$opts, opts), + !!!object$dots_list, ... )) components$predictions <- dplyr::bind_cols( components$keys, components$predictions diff --git a/R/layer_predictive_distn.R b/R/layer_predictive_distn.R index 652e42368..9b1a160e1 100644 --- a/R/layer_predictive_distn.R +++ b/R/layer_predictive_distn.R @@ -73,6 +73,7 @@ layer_predictive_distn_new <- function(dist_type, truncate, name, id) { slather.layer_predictive_distn <- function(object, components, workflow, new_data, ...) { the_fit <- workflows::extract_fit_parsnip(workflow) + rlang::check_dots_empty() m <- components$predictions$.pred r <- grab_residuals(the_fit, components) diff --git a/R/layer_quantile_distn.R b/R/layer_quantile_distn.R index d763207a4..734ccec9e 100644 --- a/R/layer_quantile_distn.R +++ b/R/layer_quantile_distn.R @@ -79,6 +79,8 @@ slather.layer_quantile_distn <- "These are of class {.cls {class(dstn)}}." )) } + rlang::check_dots_empty() + dstn <- dist_quantiles( quantile(dstn, object$quantile_levels), object$quantile_levels diff --git a/R/layer_residual_quantiles.R b/R/layer_residual_quantiles.R index 514bddc5f..85c1c6ed0 100644 --- a/R/layer_residual_quantiles.R +++ b/R/layer_residual_quantiles.R @@ -75,6 +75,8 @@ layer_residual_quantiles_new <- function( #' @export slather.layer_residual_quantiles <- function(object, components, workflow, new_data, ...) { + rlang::check_dots_empty() + the_fit <- workflows::extract_fit_parsnip(workflow) if (is.null(object$quantile_levels)) { diff --git a/R/layer_threshold_preds.R b/R/layer_threshold_preds.R index ef1781a3c..8b2b56d1e 100644 --- a/R/layer_threshold_preds.R +++ b/R/layer_threshold_preds.R @@ -98,6 +98,7 @@ snap.dist_quantiles <- function(x, lower, upper, ...) { #' @export slather.layer_threshold <- function(object, components, workflow, new_data, ...) { + rlang::check_dots_empty() exprs <- rlang::expr(c(!!!object$terms)) pos <- tidyselect::eval_select(exprs, components$predictions) col_names <- names(pos) diff --git a/R/layer_unnest.R b/R/layer_unnest.R index 64b17a306..dfc391942 100644 --- a/R/layer_unnest.R +++ b/R/layer_unnest.R @@ -28,6 +28,7 @@ layer_unnest_new <- function(terms, id) { #' @export slather.layer_unnest <- function(object, components, workflow, new_data, ...) { + rlang::check_dots_empty() exprs <- rlang::expr(c(!!!object$terms)) pos <- tidyselect::eval_select(exprs, components$predictions) col_names <- names(pos) diff --git a/inst/templates/layer.R b/inst/templates/layer.R index 3fecb3c33..59556db5f 100644 --- a/inst/templates/layer.R +++ b/inst/templates/layer.R @@ -29,6 +29,7 @@ layer_{{{ name }}}_new <- function(terms, args, more_args, id) { #' @export slather.layer_{{{ name }}} <- function(object, components, workflow, new_data, ...) { + rlang::check_dots_empty() # if layer_ used ... in tidyselect, we need to evaluate it now exprs <- rlang::expr(c(!!!object$terms)) diff --git a/man/apply_frosting.Rd b/man/apply_frosting.Rd index fc01a3461..345f14b19 100644 --- a/man/apply_frosting.Rd +++ b/man/apply_frosting.Rd @@ -11,7 +11,7 @@ apply_frosting(workflow, ...) \method{apply_frosting}{default}(workflow, components, ...) -\method{apply_frosting}{epi_workflow}(workflow, components, new_data, ...) +\method{apply_frosting}{epi_workflow}(workflow, components, new_data, type = NULL, opts = list(), ...) } \arguments{ \item{workflow}{An object of class workflow} diff --git a/man/predict-epi_workflow.Rd b/man/predict-epi_workflow.Rd index d92fd8ca9..130279249 100644 --- a/man/predict-epi_workflow.Rd +++ b/man/predict-epi_workflow.Rd @@ -5,7 +5,7 @@ \alias{predict.epi_workflow} \title{Predict from an epi_workflow} \usage{ -\method{predict}{epi_workflow}(object, new_data, ...) +\method{predict}{epi_workflow}(object, new_data, type = NULL, opts = list(), ...) } \arguments{ \item{object}{An epi_workflow that has been fit by @@ -14,6 +14,16 @@ \item{new_data}{A data frame containing the new predictors to preprocess and predict on} +\item{type}{A single character value or \code{NULL}. Possible values +are \code{"numeric"}, \code{"class"}, \code{"prob"}, \code{"conf_int"}, \code{"pred_int"}, +\code{"quantile"}, \code{"time"}, \code{"hazard"}, \code{"survival"}, or \code{"raw"}. When \code{NULL}, +\code{predict()} will choose an appropriate value based on the model's mode.} + +\item{opts}{A list of optional arguments to the underlying +predict function that will be used when \code{type = "raw"}. The +list should not include options for the model object or the +new data being predicted.} + \item{...}{Additional \code{parsnip}-related options, depending on the value of \code{type}. Arguments to the underlying model's prediction function cannot be passed here (use the \code{opts} argument instead). diff --git a/tests/testthat/test-frosting.R b/tests/testthat/test-frosting.R index 8af9f1c39..9bdce3197 100644 --- a/tests/testthat/test-frosting.R +++ b/tests/testthat/test-frosting.R @@ -86,3 +86,46 @@ test_that("layer_predict is added by default if missing", { expect_equal(forecast(wf1), forecast(wf2)) }) + + +test_that("parsnip settings can be passed through predict.epi_workflow", { + jhu <- case_death_rate_subset %>% + dplyr::filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny")) + + r <- epi_recipe(jhu) %>% + step_epi_lag(death_rate, lag = c(0, 7, 14)) %>% + step_epi_ahead(death_rate, ahead = 7) %>% + step_epi_naomit() + + wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu) + + latest <- get_test_data(r, jhu) + + f1 <- frosting() %>% layer_predict() + f2 <- frosting() %>% layer_predict(type = "pred_int") + f3 <- frosting() %>% layer_predict(type = "pred_int", level = 0.6) + + pred2 <- wf %>% add_frosting(f2) %>% predict(latest) + pred3 <- wf %>% add_frosting(f3) %>% predict(latest) + + pred2_re <- wf %>% add_frosting(f1) %>% predict(latest, type = "pred_int") + pred3_re <- wf %>% add_frosting(f1) %>% predict(latest, type = "pred_int", level = 0.6) + + expect_identical(pred2, pred2_re) + expect_identical(pred3, pred3_re) + + f4 <- frosting() %>% + layer_predict() %>% + layer_threshold(.pred, lower = 0) + + expect_error(wf %>% add_frosting(f4) %>% predict(latest, type = "pred_int"), + class = "epipredict__apply_frosting__predict_settings_with_unsupported_layers") + + # We also refuse to continue when just passing the level, which might not be ideal: + f5 <- frosting() %>% + layer_predict(type = "pred_int") %>% + layer_threshold(.pred_lower, .pred_upper, lower = 0) + + expect_error(wf %>% add_frosting(f5) %>% predict(latest, level = 0.6), + class = "epipredict__apply_frosting__predict_settings_with_unsupported_layers") +}) diff --git a/tests/testthat/test-layer_predict.R b/tests/testthat/test-layer_predict.R index f3b33f5db..f8efe7ea3 100644 --- a/tests/testthat/test-layer_predict.R +++ b/tests/testthat/test-layer_predict.R @@ -61,17 +61,33 @@ test_that("layer_predict dots validation", { }) test_that("layer_predict dots are forwarded", { - f_lm_int_level <- frosting() %>% + f_lm_int_level_95 <- frosting() %>% + layer_predict(type = "pred_int") + f_lm_int_level_80 <- frosting() %>% layer_predict(type = "pred_int", level = 0.8) - wf_lm_int_level <- wf %>% add_frosting(f_lm_int_level) + wf_lm_int_level_95 <- wf %>% add_frosting(f_lm_int_level_95) + wf_lm_int_level_80 <- wf %>% add_frosting(f_lm_int_level_80) p <- predict(wf, latest) - p_lm_int_level <- predict(wf_lm_int_level, latest) - expect_contains(names(p_lm_int_level), c(".pred_lower", ".pred_upper")) - expect_equal(nrow(na.omit(p)), nrow(na.omit(p_lm_int_level))) - expect_true(cbind(p, p_lm_int_level[c(".pred_lower", ".pred_upper")]) %>% - na.omit() %>% - mutate(sandwiched = .pred_lower <= .pred & .pred <= .pred_upper) %>% - `[[`("sandwiched") %>% - all()) + p_lm_int_level_95 <- predict(wf_lm_int_level_95, latest) + p_lm_int_level_80 <- predict(wf_lm_int_level_80, latest) + expect_contains(names(p_lm_int_level_95), c(".pred_lower", ".pred_upper")) + expect_contains(names(p_lm_int_level_80), c(".pred_lower", ".pred_upper")) + expect_equal(nrow(na.omit(p)), nrow(na.omit(p_lm_int_level_95))) + expect_equal(nrow(na.omit(p)), nrow(na.omit(p_lm_int_level_80))) + expect_true( + cbind( + p, + p_lm_int_level_95 %>% dplyr::select(.pred_lower_95 = .pred_lower, .pred_upper_95 = .pred_upper), + p_lm_int_level_80 %>% dplyr::select(.pred_lower_80 = .pred_lower, .pred_upper_80 = .pred_upper) + ) %>% + na.omit() %>% + mutate(sandwiched = + .pred_lower_95 <= .pred_lower_80 & + .pred_lower_80 <= .pred & + .pred <= .pred_upper_80 & + .pred_upper_80 <= .pred_upper_95) %>% + `[[`("sandwiched") %>% + all() + ) # There are many possible other valid configurations that aren't tested here. }) From ecf2c73e3f573b942a5f490795f578f507aab212 Mon Sep 17 00:00:00 2001 From: "Logan C. Brooks" Date: Thu, 18 Jul 2024 11:27:05 -0700 Subject: [PATCH 03/11] Detect conflicting `type` settings in frosting construction&slather --- R/layer_predict.R | 8 ++++++++ tests/testthat/test-frosting.R | 3 +++ 2 files changed, 11 insertions(+) diff --git a/R/layer_predict.R b/R/layer_predict.R index ecc76408a..5e6f0ace1 100644 --- a/R/layer_predict.R +++ b/R/layer_predict.R @@ -71,6 +71,14 @@ layer_predict_new <- function(type, opts, dots_list, id) { #' @export slather.layer_predict <- function(object, components, workflow, new_data, type = NULL, opts = list(), ...) { arg_is_chr_scalar(type, allow_null = TRUE) + if (!is.null(object$type) && !is.null(type) && !identical(object$type, type)) { + cli_abort(" + Conflicting `type` settings were specified during frosting construction + (in call to `layer_predict()`) and while slathering (in call to + `slather()`/ `predict()`/etc.): {object$type} vs. {type}. Please remove + one of these `type` settings. + ", class = "epipredict__layer_predict__conflicting_type_settings") + } assert_class(opts, "list") the_fit <- workflows::extract_fit_parsnip(workflow) diff --git a/tests/testthat/test-frosting.R b/tests/testthat/test-frosting.R index 9bdce3197..ccebbf1f2 100644 --- a/tests/testthat/test-frosting.R +++ b/tests/testthat/test-frosting.R @@ -114,6 +114,9 @@ test_that("parsnip settings can be passed through predict.epi_workflow", { expect_identical(pred2, pred2_re) expect_identical(pred3, pred3_re) + expect_error(wf %>% add_frosting(f2) %>% predict(latest, type = "raw"), + class = "epipredict__layer_predict__conflicting_type_settings") + f4 <- frosting() %>% layer_predict() %>% layer_threshold(.pred, lower = 0) From 482e4a6311d152e9942877ab29c96328c88a990e Mon Sep 17 00:00:00 2001 From: "Logan C. Brooks" Date: Thu, 18 Jul 2024 13:05:50 -0700 Subject: [PATCH 04/11] @inheritParams directly from predict.model_fit again We were missing `type` and `opts` in documentation because they weren't in the signature of `predict.epi_workflow`, not because of an issue with `predict.model_fit` docs, and it seems like the latter method is the one we'd be directly using. --- R/epi_workflow.R | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/R/epi_workflow.R b/R/epi_workflow.R index 7e1c95f88..6a81dfd40 100644 --- a/R/epi_workflow.R +++ b/R/epi_workflow.R @@ -123,14 +123,14 @@ fit.epi_workflow <- function(object, data, ..., control = workflows::control_wor #' possible. Specifically, the output will have `time_value` and #' `geo_value` columns as well as the prediction. #' -#' @inheritParams workflows::predict.workflow -#' #' @param object An epi_workflow that has been fit by #' [workflows::fit.workflow()] #' #' @param new_data A data frame containing the new predictors to preprocess #' and predict on #' +#' @inheritParams parsnip::predict.model_fit +#' #' @return #' A data frame of model predictions, with as many rows as `new_data` has. #' If `new_data` is an `epi_df` or a data frame with `time_value` or From 249954d20dab92f13a813c7fed2472092d3577c0 Mon Sep 17 00:00:00 2001 From: "Logan C. Brooks" Date: Fri, 19 Jul 2024 16:44:48 -0700 Subject: [PATCH 05/11] Also forward type & opts when frosting isn't detected --- R/frosting.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/frosting.R b/R/frosting.R index f293314fb..1548c492e 100644 --- a/R/frosting.R +++ b/R/frosting.R @@ -376,7 +376,7 @@ apply_frosting.epi_workflow <- "Returning unpostprocessed predictions." )) components$predictions <- predict( - the_fit, components$forged$predictors, ... + the_fit, components$forged$predictors, type, opts, ... ) components$predictions <- dplyr::bind_cols( components$keys, components$predictions From bff7c3323df3e9e9d90f8096d8d377d3de788f85 Mon Sep 17 00:00:00 2001 From: "Logan C. Brooks" Date: Fri, 19 Jul 2024 16:48:08 -0700 Subject: [PATCH 06/11] Add missing param docs for apply_frosting.epi_workflow; style, doc --- R/frosting.R | 6 ++++-- R/layer_point_from_distn.R | 2 +- R/layer_predict.R | 4 ++-- man/apply_frosting.Rd | 3 +++ tests/testthat/test-frosting.R | 27 +++++++++++++++++++-------- tests/testthat/test-layer_predict.R | 12 +++++++----- 6 files changed, 36 insertions(+), 18 deletions(-) diff --git a/R/frosting.R b/R/frosting.R index 1548c492e..af72f29f0 100644 --- a/R/frosting.R +++ b/R/frosting.R @@ -355,6 +355,8 @@ apply_frosting.default <- function(workflow, components, ...) { #' @rdname apply_frosting #' @importFrom rlang is_null #' @importFrom rlang abort +#' @param type,opts,... forwarded to [`predict.model_fit()`] and [`slather()`] +#' for supported layers #' @export apply_frosting.epi_workflow <- function(workflow, components, new_data, type = NULL, opts = list(), ...) { @@ -398,7 +400,7 @@ apply_frosting.epi_workflow <- ) } if (length(layers) > 1L && - (!is.null(type) || !identical(opts, list()) || rlang::dots_n(...) > 0L)) { + (!is.null(type) || !identical(opts, list()) || rlang::dots_n(...) > 0L)) { cli_abort(" Passing `type`, `opts`, or `...` into `predict.epi_workflow()` is not supported if you have frosting layers other than `layer_predict`. Please @@ -414,7 +416,7 @@ apply_frosting.epi_workflow <- if (inherits(la, "layer_predict")) { components <- slather(la, components, workflow, new_data, type = type, opts = opts, ...) } else { - # The check above should ensure we have default `type` and `opts` and + # The check above should ensure we have default `type` and `opts`, and # empty `...`; don't forward these default `type` and `opts`, to avoid # upsetting some slather method validation. components <- slather(la, components, workflow, new_data) diff --git a/R/layer_point_from_distn.R b/R/layer_point_from_distn.R index 8f5ed2c33..f415e7bd4 100644 --- a/R/layer_point_from_distn.R +++ b/R/layer_point_from_distn.R @@ -81,7 +81,7 @@ slather.layer_point_from_distn <- rlang::warn( c("`layer_point_from_distn` requires distributional predictions.", i = "These are of class {class(dstn)}. Ignoring this layer." - ) + ) ) return(components) } diff --git a/R/layer_predict.R b/R/layer_predict.R index 5e6f0ace1..46d81be18 100644 --- a/R/layer_predict.R +++ b/R/layer_predict.R @@ -49,8 +49,8 @@ layer_predict <- dots_list <- rlang::dots_list(..., .homonyms = "error", .check_assign = TRUE) if (any(rlang::names2(dots_list) == "")) { cli_abort("All `...` arguments must be named.", - class = "epipredict__layer_predict__unnamed_dot" - ) + class = "epipredict__layer_predict__unnamed_dot" + ) } add_layer( frosting, diff --git a/man/apply_frosting.Rd b/man/apply_frosting.Rd index 345f14b19..a3b627b8a 100644 --- a/man/apply_frosting.Rd +++ b/man/apply_frosting.Rd @@ -34,6 +34,9 @@ here for ease. \item{new_data}{a data frame containing the new predictors to preprocess and predict on} + +\item{type, opts, ...}{forwarded to \code{\link[=predict.model_fit]{predict.model_fit()}} and \code{\link[=slather]{slather()}} +for supported layers} } \description{ This function is intended for internal use. It implements postprocessing diff --git a/tests/testthat/test-frosting.R b/tests/testthat/test-frosting.R index ccebbf1f2..5cab9c494 100644 --- a/tests/testthat/test-frosting.R +++ b/tests/testthat/test-frosting.R @@ -105,24 +105,34 @@ test_that("parsnip settings can be passed through predict.epi_workflow", { f2 <- frosting() %>% layer_predict(type = "pred_int") f3 <- frosting() %>% layer_predict(type = "pred_int", level = 0.6) - pred2 <- wf %>% add_frosting(f2) %>% predict(latest) - pred3 <- wf %>% add_frosting(f3) %>% predict(latest) - - pred2_re <- wf %>% add_frosting(f1) %>% predict(latest, type = "pred_int") - pred3_re <- wf %>% add_frosting(f1) %>% predict(latest, type = "pred_int", level = 0.6) + pred2 <- wf %>% + add_frosting(f2) %>% + predict(latest) + pred3 <- wf %>% + add_frosting(f3) %>% + predict(latest) + + pred2_re <- wf %>% + add_frosting(f1) %>% + predict(latest, type = "pred_int") + pred3_re <- wf %>% + add_frosting(f1) %>% + predict(latest, type = "pred_int", level = 0.6) expect_identical(pred2, pred2_re) expect_identical(pred3, pred3_re) expect_error(wf %>% add_frosting(f2) %>% predict(latest, type = "raw"), - class = "epipredict__layer_predict__conflicting_type_settings") + class = "epipredict__layer_predict__conflicting_type_settings" + ) f4 <- frosting() %>% layer_predict() %>% layer_threshold(.pred, lower = 0) expect_error(wf %>% add_frosting(f4) %>% predict(latest, type = "pred_int"), - class = "epipredict__apply_frosting__predict_settings_with_unsupported_layers") + class = "epipredict__apply_frosting__predict_settings_with_unsupported_layers" + ) # We also refuse to continue when just passing the level, which might not be ideal: f5 <- frosting() %>% @@ -130,5 +140,6 @@ test_that("parsnip settings can be passed through predict.epi_workflow", { layer_threshold(.pred_lower, .pred_upper, lower = 0) expect_error(wf %>% add_frosting(f5) %>% predict(latest, level = 0.6), - class = "epipredict__apply_frosting__predict_settings_with_unsupported_layers") + class = "epipredict__apply_frosting__predict_settings_with_unsupported_layers" + ) }) diff --git a/tests/testthat/test-layer_predict.R b/tests/testthat/test-layer_predict.R index f8efe7ea3..d01544ad1 100644 --- a/tests/testthat/test-layer_predict.R +++ b/tests/testthat/test-layer_predict.R @@ -81,11 +81,13 @@ test_that("layer_predict dots are forwarded", { p_lm_int_level_80 %>% dplyr::select(.pred_lower_80 = .pred_lower, .pred_upper_80 = .pred_upper) ) %>% na.omit() %>% - mutate(sandwiched = - .pred_lower_95 <= .pred_lower_80 & - .pred_lower_80 <= .pred & - .pred <= .pred_upper_80 & - .pred_upper_80 <= .pred_upper_95) %>% + mutate( + sandwiched = + .pred_lower_95 <= .pred_lower_80 & + .pred_lower_80 <= .pred & + .pred <= .pred_upper_80 & + .pred_upper_80 <= .pred_upper_95 + ) %>% `[[`("sandwiched") %>% all() ) From 1cc124789b8b62ca0a67f7e39800804ad1d6e8b0 Mon Sep 17 00:00:00 2001 From: "Logan C. Brooks" Date: Fri, 19 Jul 2024 17:01:44 -0700 Subject: [PATCH 07/11] Use wording tricks to avoid double-doc parm in generic+methods topic --- R/frosting.R | 4 ++-- man/apply_frosting.Rd | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/R/frosting.R b/R/frosting.R index af72f29f0..4fc0caec3 100644 --- a/R/frosting.R +++ b/R/frosting.R @@ -355,8 +355,8 @@ apply_frosting.default <- function(workflow, components, ...) { #' @rdname apply_frosting #' @importFrom rlang is_null #' @importFrom rlang abort -#' @param type,opts,... forwarded to [`predict.model_fit()`] and [`slather()`] -#' for supported layers +#' @param type,opts forwarded (along with `...`) to [`predict.model_fit()`] and +#' [`slather()`] for supported layers #' @export apply_frosting.epi_workflow <- function(workflow, components, new_data, type = NULL, opts = list(), ...) { diff --git a/man/apply_frosting.Rd b/man/apply_frosting.Rd index a3b627b8a..ef18796cc 100644 --- a/man/apply_frosting.Rd +++ b/man/apply_frosting.Rd @@ -35,8 +35,8 @@ here for ease. \item{new_data}{a data frame containing the new predictors to preprocess and predict on} -\item{type, opts, ...}{forwarded to \code{\link[=predict.model_fit]{predict.model_fit()}} and \code{\link[=slather]{slather()}} -for supported layers} +\item{type, opts}{forwarded (along with \code{...}) to \code{\link[=predict.model_fit]{predict.model_fit()}} and +\code{\link[=slather]{slather()}} for supported layers} } \description{ This function is intended for internal use. It implements postprocessing From 2ca6ee80993e86d6c459e25b7c627345e54389e4 Mon Sep 17 00:00:00 2001 From: "Logan C. Brooks" Date: Thu, 25 Jul 2024 16:06:03 -0700 Subject: [PATCH 08/11] Update tests & commentary given predict() arg forwarding --- tests/testthat/test-layer_predict.R | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/tests/testthat/test-layer_predict.R b/tests/testthat/test-layer_predict.R index d01544ad1..041516b29 100644 --- a/tests/testthat/test-layer_predict.R +++ b/tests/testthat/test-layer_predict.R @@ -49,15 +49,13 @@ test_that("layer_predict dots validation", { expect_no_error(f_bad_arg <- frosting() %>% layer_predict(bogus_argument = "something")) wf_bad_arg <- wf %>% add_frosting(f_bad_arg) expect_error(predict(wf_bad_arg, latest)) - # Some argument names only apply for some prediction `type`s; we don't check for ignored arguments, and neither does workflows: - expect_no_error(frosting() %>% layer_predict(eval_time = "preferably this would error")) + # ^ (currently with a awful error message, due to an extra comma in parsnip::check_pred_type_dots) - # ^ (currently with a truly awful error message, due to an extra comma in parsnip::check_pred_type_dots) - # - # Unfortunately, we outright ignore attempts to pass args via `predict.epi_workflow`: - f_predict <- frosting() %>% layer_predict() - wf_predict <- wf %>% add_frosting(f_predict) - expect_no_error(predict(wf_predict, latest, type = "pred_int")) + # Some argument names only apply for some prediction `type`s; we don't check + # for invalid pairings, nor does {parsnip}, so we end up producing a forecast + # that silently ignores some arguments some of the time. ({workflows} doesn't + # check for these either.) + expect_no_error(frosting() %>% layer_predict(eval_time = "preferably this would error")) }) test_that("layer_predict dots are forwarded", { From 49484bf7ed7328fc5363127ff0d81b9b02ff1209 Mon Sep 17 00:00:00 2001 From: "Logan C. Brooks" Date: Thu, 25 Jul 2024 16:09:51 -0700 Subject: [PATCH 09/11] Update&correct NEWS.md with predict.epi_workflow() arg forwarding --- NEWS.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/NEWS.md b/NEWS.md index cce52cb51..9cb78770c 100644 --- a/NEWS.md +++ b/NEWS.md @@ -47,5 +47,5 @@ Pre-1.0.0 numbering scheme: 0.x will indicate releases, while 0.0.x will indicat - Revise `compat-purrr` to use the r-lang `standalone-*` version (via `{usethis}`) - `epi_recipe()` will now warn when given non-`epi_df` data -- `layer_predict()` will now appropriately forward `...` args intended for - `predict.workflow()` +- `layer_predict()` and `predict.epi_workflow()` will now appropriately forward + `...` args intended for `predict.model_fit()` From be15821390e0bd8fa530b7d45620ee302447ecdb Mon Sep 17 00:00:00 2001 From: "Logan C. Brooks" Date: Thu, 25 Jul 2024 18:28:46 -0700 Subject: [PATCH 10/11] Recalc geo&time type after bake, avoid warning spam from as_epi_df Since epiprocess#472, as_epi_df. will ignore the geo and time type args and re-infer them, plus emit a warning if these args were provided. We could `dplyr_reconstruct` to go back to trusting `meta` instead of re-inferring, but since baking could plausibly (but likely rarely) change them, don't. - Pre-epiprocess#472, this should be a bugfix in any such rare cases. - Post-epiprocess#472, this should remove warning spam. --- NEWS.md | 2 ++ R/epi_recipe.R | 8 ++++++-- tests/testthat/test-pad_to_end.R | 2 +- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/NEWS.md b/NEWS.md index 9cb78770c..4e21f8191 100644 --- a/NEWS.md +++ b/NEWS.md @@ -49,3 +49,5 @@ Pre-1.0.0 numbering scheme: 0.x will indicate releases, while 0.0.x will indicat - `epi_recipe()` will now warn when given non-`epi_df` data - `layer_predict()` and `predict.epi_workflow()` will now appropriately forward `...` args intended for `predict.model_fit()` +- `bake.epi_recipe()` will now re-infer the geo and time type in case baking the + steps has changed the appropriate values diff --git a/R/epi_recipe.R b/R/epi_recipe.R index 1a1cd1455..f9c4cb4b2 100644 --- a/R/epi_recipe.R +++ b/R/epi_recipe.R @@ -572,9 +572,13 @@ bake.epi_recipe <- function(object, new_data, ..., composition = "epi_df") { } new_data <- NextMethod("bake") if (!is.null(meta)) { + # Baking should have dropped epi_df-ness and metadata. Re-infer some + # metadata and assume others remain the same as the object/template: new_data <- as_epi_df( - new_data, meta$geo_type, meta$time_type, meta$as_of, - meta$additional_metadata %||% list() + new_data, + as_of = meta$as_of, + # avoid NULL if meta is from saved older epi_df: + additional_metadata = meta$additional_metadata %||% list() ) } new_data diff --git a/tests/testthat/test-pad_to_end.R b/tests/testthat/test-pad_to_end.R index 474b9001b..0ea6244b0 100644 --- a/tests/testthat/test-pad_to_end.R +++ b/tests/testthat/test-pad_to_end.R @@ -32,6 +32,6 @@ test_that("test set padding works", { # make sure it maintains the epi_df dat <- dat %>% dplyr::rename(geo_value = gr1) %>% - as_epi_df(dat) + as_epi_df() expect_s3_class(pad_to_end(dat, "geo_value", 2), "epi_df") }) From 7fd40945b966c1b4e30406d1962b9a4ac0913cfb Mon Sep 17 00:00:00 2001 From: "Logan C. Brooks" Date: Thu, 25 Jul 2024 20:01:41 -0700 Subject: [PATCH 11/11] Fix document() and check() warnings regarding [epiprocess::]epi_df Workaround for epiprocess#493. --- R/arx_classifier.R | 2 +- R/arx_forecaster.R | 2 +- R/cdc_baseline_forecaster.R | 4 ++-- R/data.R | 2 +- R/epi_recipe.R | 2 +- R/epi_workflow.R | 2 +- R/flatline_forecaster.R | 4 ++-- R/get_test_data.R | 2 +- man/get_test_data.Rd | 2 +- 9 files changed, 11 insertions(+), 11 deletions(-) diff --git a/R/arx_classifier.R b/R/arx_classifier.R index de730826c..44acb9b30 100644 --- a/R/arx_classifier.R +++ b/R/arx_classifier.R @@ -1,7 +1,7 @@ #' Direct autoregressive classifier with covariates #' #' This is an autoregressive classification model for -#' [epiprocess::epi_df] data. It does "direct" forecasting, meaning +#' [epiprocess::epi_df][epiprocess::as_epi_df] data. It does "direct" forecasting, meaning #' that it estimates a class at a particular target horizon. #' #' @inheritParams arx_forecaster diff --git a/R/arx_forecaster.R b/R/arx_forecaster.R index 10b2d2bce..1b9e3d503 100644 --- a/R/arx_forecaster.R +++ b/R/arx_forecaster.R @@ -1,7 +1,7 @@ #' Direct autoregressive forecaster with covariates #' #' This is an autoregressive forecasting model for -#' [epiprocess::epi_df] data. It does "direct" forecasting, meaning +#' [epiprocess::epi_df][epiprocess::as_epi_df] data. It does "direct" forecasting, meaning #' that it estimates a model for a particular target horizon. #' #' diff --git a/R/cdc_baseline_forecaster.R b/R/cdc_baseline_forecaster.R index 4af6d6f3f..d5b74a9c3 100644 --- a/R/cdc_baseline_forecaster.R +++ b/R/cdc_baseline_forecaster.R @@ -1,7 +1,7 @@ #' Predict the future with the most recent value #' #' This is a simple forecasting model for -#' [epiprocess::epi_df] data. It uses the most recent observation as the +#' [epiprocess::epi_df][epiprocess::as_epi_df] data. It uses the most recent observation as the #' forecast for any future date, and produces intervals by shuffling the quantiles #' of the residuals of such a "flatline" forecast and incrementing these #' forward over all available training data. @@ -12,7 +12,7 @@ #' This forecaster is meant to produce exactly the CDC Baseline used for #' [COVID19ForecastHub](https://covid19forecasthub.org) #' -#' @param epi_data An [`epiprocess::epi_df`] +#' @param epi_data An [`epiprocess::epi_df`][epiprocess::as_epi_df] #' @param outcome A scalar character for the column name we wish to predict. #' @param args_list A list of additional arguments as created by the #' [cdc_baseline_args_list()] constructor function. diff --git a/R/data.R b/R/data.R index 6641abf44..71e5bdcd3 100644 --- a/R/data.R +++ b/R/data.R @@ -59,7 +59,7 @@ #' Subset of Statistics Canada median employment income for postsecondary graduates #' -#' @format An [epiprocess::epi_df] with 10193 rows and 8 variables: +#' @format An [epiprocess::epi_df][epiprocess::as_epi_df] with 10193 rows and 8 variables: #' \describe{ #' \item{geo_value}{The province in Canada associated with each #' row of measurements.} diff --git a/R/epi_recipe.R b/R/epi_recipe.R index f9c4cb4b2..6d01d718f 100644 --- a/R/epi_recipe.R +++ b/R/epi_recipe.R @@ -245,7 +245,7 @@ is_epi_recipe <- function(x) { #' @details #' `add_epi_recipe` has the same behaviour as #' [workflows::add_recipe()] but sets a different -#' default blueprint to automatically handle [epiprocess::epi_df] data. +#' default blueprint to automatically handle [epiprocess::epi_df][epiprocess::as_epi_df] data. #' #' @param x A `workflow` or `epi_workflow` #' diff --git a/R/epi_workflow.R b/R/epi_workflow.R index 6a81dfd40..0bdeece4f 100644 --- a/R/epi_workflow.R +++ b/R/epi_workflow.R @@ -119,7 +119,7 @@ fit.epi_workflow <- function(object, data, ..., control = workflows::control_wor #' - Call [parsnip::predict.model_fit()] for you using the underlying fit #' parsnip model. #' -#' - Ensure that the returned object is an [epiprocess::epi_df] where +#' - Ensure that the returned object is an [epiprocess::epi_df][epiprocess::as_epi_df] where #' possible. Specifically, the output will have `time_value` and #' `geo_value` columns as well as the prediction. #' diff --git a/R/flatline_forecaster.R b/R/flatline_forecaster.R index fa80dfba5..e14e44a96 100644 --- a/R/flatline_forecaster.R +++ b/R/flatline_forecaster.R @@ -1,7 +1,7 @@ #' Predict the future with today's value #' #' This is a simple forecasting model for -#' [epiprocess::epi_df] data. It uses the most recent observation as the +#' [epiprocess::epi_df][epiprocess::as_epi_df] data. It uses the most recent observation as the #' forcast for any future date, and produces intervals based on the quantiles #' of the residuals of such a "flatline" forecast over all available training #' data. @@ -13,7 +13,7 @@ #' This forecaster is very similar to that used by the #' [COVID19ForecastHub](https://covid19forecasthub.org) #' -#' @param epi_data An [epiprocess::epi_df] +#' @param epi_data An [epiprocess::epi_df][epiprocess::as_epi_df] #' @param outcome A scalar character for the column name we wish to predict. #' @param args_list A list of dditional arguments as created by the #' [flatline_args_list()] constructor function. diff --git a/R/get_test_data.R b/R/get_test_data.R index e76715daf..2a2484749 100644 --- a/R/get_test_data.R +++ b/R/get_test_data.R @@ -1,7 +1,7 @@ #' Get test data for prediction based on longest lag period #' #' Based on the longest lag period in the recipe, -#' `get_test_data()` creates an [epi_df] +#' `get_test_data()` creates an [epi_df][epiprocess::as_epi_df] #' with columns `geo_value`, `time_value` #' and other variables in the original dataset, #' which will be used to create features necessary to produce forecasts. diff --git a/man/get_test_data.Rd b/man/get_test_data.Rd index 392d1dce2..b18685d89 100644 --- a/man/get_test_data.Rd +++ b/man/get_test_data.Rd @@ -37,7 +37,7 @@ keys, as well other variables in the original dataset. } \description{ Based on the longest lag period in the recipe, -\code{get_test_data()} creates an \link{epi_df} +\code{get_test_data()} creates an \link[epiprocess:epi_df]{epi_df} with columns \code{geo_value}, \code{time_value} and other variables in the original dataset, which will be used to create features necessary to produce forecasts.