--- title: "Scalar summaries with wrapper" output: rmarkdown::html_vignette: default vignette: > %\VignetteIndexEntry{Scalar summaries with wrapper} %\VignetteEngine{knitr::rmarkdown} %\VignetteEncoding{UTF-8} --- ```{r ws-knit-opts, include = FALSE} knitr::opts_chunk$set( collapse = TRUE, comment = "#>", fig.width = 6, fig.height = 4, fig.align = "center" ) ``` ```{r setup} library(MetaHunt) set.seed(1) ``` ## Why scalar summaries `metahunt()` predicts a function on the grid; conformal routines can return a band at every grid point. Often, though, the inferential target is a single number derived from that function: - The average treatment effect (mean of a CATE function over a reference patient distribution). - The treatment effect at a specific patient profile. - The fraction of the population with a positive treatment effect. - A contrast between two endpoints. For all of these, MetaHunt accepts a `wrapper` argument that **collapses the predicted function to a scalar** before any further calculation. The same wrapper is applied identically to predictions and to calibration residuals, so conformal coverage transfers directly to the scalar summary. ## The wrapper protocol `apply_wrapper(F_mat, wrapper, grid_weights)` defines the contract. - `F_mat` is an `n`-by-`G_grid` numeric matrix; row `j` is one function on the grid. - If `wrapper` is `NULL`, `apply_wrapper()` returns the weighted mean of each row using `grid_weights` (uniform `1/G_grid` by default), divided by `sum(grid_weights)`. - If `wrapper` is a function, `apply_wrapper()` calls `apply(F_mat, 1, wrapper)`, which means **the wrapper receives a single numeric vector of length `G_grid`** — one row of `F_mat` at a time — and must return a single numeric value. The contract therefore is: ``` wrapper :: numeric vector of length G_grid -> numeric scalar ``` Any function satisfying that signature is a valid wrapper. The package then enforces post-hoc that the result is numeric and has exactly one entry per row. ## An ATE example with `grf::causal_forest` We simulate a multi-site clinical trial with `m = 8` sites. Each site has its own individual-level data $(Y, X, T)$ where $Y$ is a continuous outcome, $X$ is a single patient covariate (`age`), and $T$ is binary treatment. The site-level CATE function $\tau^{(i)}(\text{age}) = E[Y(1) - Y(0) \mid \text{age}, \text{site} = i]$ varies across sites in a way that depends on the site's metadata. Each site fits its own `grf::causal_forest` on its individual-level data, and shares only the fitted model — not the patient data — with us. ```{r ws-simulate-trials, eval = requireNamespace("grf", quietly = TRUE)} m <- 8 n_per_site <- 200 G <- 30 W <- data.frame( year = sample(2010:2020, m, replace = TRUE), pct_treated = round(runif(m, 0.3, 0.6), 2) ) site_data_list <- lapply(seq_len(m), function(i) { age <- runif(n_per_site, 30, 80) T <- rbinom(n_per_site, 1, W$pct_treated[i]) site_eff <- (W$year[i] - 2015) / 5 # site-level shift in CATE tau_age <- 0.02 * (age - 50) + site_eff Y0 <- 0.01 * age + rnorm(n_per_site, sd = 0.5) Y1 <- Y0 + tau_age Y <- ifelse(T == 1, Y1, Y0) data.frame(Y = Y, age = age, T = T) }) grid <- data.frame(age = seq(30, 80, length.out = G)) ``` Each site fits its own `causal_forest`. We use `num.trees = 200` to keep the vignette fast; in practice you would use the default 2000 or more. ```{r ws-fit-cf, eval = requireNamespace("grf", quietly = TRUE)} cf_models <- lapply(site_data_list, function(d) grf::causal_forest(X = matrix(d$age, ncol = 1), Y = d$Y, W = d$T, num.trees = 200)) ``` We stack the per-site CATE estimates on the shared `age` grid into the `m`-by-`G` matrix `F_hat`. Here we pass an explicit `predict_fn` to illustrate the general pattern; the dispatch table inside `f_hat_from_models()` already knows how to call `causal_forest`, so for users on standard `grf::causal_forest`, the default `predict_fn` is sufficient and you can omit the `predict_fn` argument. ```{r ws-build-fhat, eval = requireNamespace("grf", quietly = TRUE)} cate_predict <- function(model, grid) { as.numeric(stats::predict(model, newdata = matrix(grid$age, ncol = 1))$predictions) } F_hat <- f_hat_from_models(cf_models, grid, predict_fn = cate_predict) dim(F_hat) ``` We now fit `metahunt()` on `(F_hat, W)` and ask for the predicted ATE at a hypothetical new site. ```{r ws-fit-metahunt, eval = requireNamespace("grf", quietly = TRUE)} fit <- metahunt(F_hat, W, K = 3, dfspa_args = list(denoise = FALSE)) W_new <- data.frame(year = 2018, pct_treated = 0.45) ate_pred <- predict(fit, newdata = W_new, wrapper = mean) ate_pred ``` The scalar `ate_pred` is the predicted average treatment effect for a hypothetical new site with metadata `(year = 2018, pct_treated = 0.45)`, taking the unweighted mean over the 30-point age grid. ## Three custom wrappers Below are three short, self-contained wrappers, each illustrating a different idea. All three are applied to the `F_hat`, `fit`, and `W_new` constructed in the previous section. ### Plain `mean` `mean` is already a function `numeric -> numeric`, so it is a valid wrapper. With a uniform grid this is just the unweighted average of the function over the grid — i.e. the grid-uniform ATE. ```{r wrapper-mean, eval = requireNamespace("grf", quietly = TRUE)} predict(fit, newdata = W_new, wrapper = mean) ``` ### Restricted positive mean Suppose we only credit treatment effects that are positive (for example, in a cost-effectiveness setting). The wrapper averages `max(f(x), 0)` over the grid: ```{r wrapper-restricted, eval = requireNamespace("grf", quietly = TRUE)} restricted_pos_mean <- function(f) sum(pmax(f, 0)) / length(f) predict(fit, newdata = W_new, wrapper = restricted_pos_mean) ``` Because every row of `F_mat` is passed in turn, `f` inside the wrapper is just a numeric vector of length `G_grid`. `length(f)` is therefore the grid size, and dividing by it gives a uniform-weighted average. ### Endpoint contrast The difference `f(x_G) - f(x_1)` is a useful summary when the grid is ordered (e.g. age, dose, or time). For our age grid it is the gap in CATE between an 80-year-old and a 30-year-old patient at the new site: ```{r wrapper-endpoint, eval = requireNamespace("grf", quietly = TRUE)} endpoint_contrast <- function(f) f[length(f)] - f[1] predict(fit, newdata = W_new, wrapper = endpoint_contrast) ``` ## Conformal coverage with a wrapper When you pass `wrapper` into `split_conformal()` (or `cross_conformal()`, or `conformal_from_fit()`), conformity scores are computed *after* the wrapper, on a **single shared quantile**. The interval covers the wrapped scalar with the nominal level — not the underlying function pointwise. With only `m = 8` sites, we hold out a single site (the 8th) and use the other seven for training plus calibration. The calibration set is small, so we use `alpha = 0.1` rather than `0.05`. ```{r ws-split-scalar, eval = requireNamespace("grf", quietly = TRUE)} # Use 7 sites for training+calibration, predict for the held-out 8th tr_cal <- 1:7; new <- 8 res <- split_conformal( F_hat[tr_cal, , drop = FALSE], W[tr_cal, , drop = FALSE], W[new, , drop = FALSE], K = 3, wrapper = mean, alpha = 0.1, cal_frac = 0.5, seed = 1, dfspa_args = list(denoise = FALSE) ) data.frame(prediction = res$prediction, lower = res$lower, upper = res$upper) ``` With only 8 sites in this realistic example, an empirical-coverage check on a single held-out site is not informative — for coverage diagnostics, use a leave-one-out loop or simulate a larger study count. See `?coverage` for the helper function and the `conformal-prediction` vignette for split-conformal at scale. ## Pointwise vs scalar — quick reference | Aspect | Pointwise (`wrapper = NULL`) | Scalar (`wrapper` supplied) | |---------------------|----------------------------------------------------|-----------------------------------------------------| | Output shape | `nrow(W_new)` x `G_grid` matrix | length-`nrow(W_new)` numeric vector | | Conformal quantile | one per grid point (length-`G_grid`) | a single scalar | | Coverage guarantee | per grid point, marginally (not joint over grid) | for the scalar summary, marginally | | Best for | visualising the predicted function with a band | reporting a single number with a valid CI | | Example call | `split_conformal(F, W, W_new, K = 3)` | `split_conformal(F, W, W_new, K = 3, wrapper = mean)` | A pointwise band is a visualisation aid; a scalar interval is the right object for an inferential claim about a specific functional. Pick the wrapper that matches the question you actually want to answer, and let the conformal machinery do the rest. ## See also - `vignette("data-prep")` — building `F_hat` from per-site fitted models (including the `grf::causal_forest` dispatch and the `predict_fn` escape hatch used here). - `vignette("conformal-prediction")` — split- and cross-conformal routines at scale, including empirical-coverage diagnostics that need more than a handful of held-out sites.