--- title: "ShrinkageTrees: Bayesian Tree Ensembles for Survival Analysis and Causal Inference" author: "Tijn Jacobs" date: "`r Sys.Date()`" output: rmarkdown::html_vignette: toc: true toc_depth: 3 vignette: > %\VignetteEncoding{UTF-8} %\VignetteIndexEntry{ShrinkageTrees: Introduction and Usage} %\VignetteEngine{knitr::rmarkdown} --- ```{r setup, include=FALSE} knitr::opts_chunk$set( collapse = TRUE, comment = "#>", message = FALSE, warning = FALSE, fig.width = 6, fig.height = 3.5, out.width = "100%" ) library(ShrinkageTrees) set.seed(42) ``` ## Introduction **ShrinkageTrees** is an R package that brings Bayesian Additive Regression Trees (BART; Chipman, George & McCulloch, 2010) to **survival analysis** and **causal inference**, with a particular focus on high-dimensional data. The package implements BART-based models for **right-censored** and **interval-censored** survival outcomes using an accelerated failure time (AFT) formulation. Censored event times are handled through Bayesian data augmentation in the Gibbs sampler, enabling full posterior inference without proportional-hazards assumptions. For causal inference, the package provides Bayesian Causal Forests (BCF; Hahn, Murray & Carvalho, 2020), which decompose the outcome into a prognostic function $\mu(\mathbf{x})$ and a treatment-effect function $\tau(\mathbf{x})$, each estimated by a separate tree ensemble. This two-forest structure supports estimation of heterogeneous treatment effects (CATEs) and the average treatment effect (ATE). A key feature is the availability of multiple **regularisation strategies** that can be freely combined within a single model: - **Classical BART priors** on the tree structure and leaf parameters. - **Dirichlet splitting priors** (DART; Linero, 2018) for structural variable selection. - **Horseshoe shrinkage on the leaf step heights** (Jacobs, van Wieringen & van der Pas, 2025) — a global–local prior that aggressively shrinks uninformative leaves toward zero while preserving strong signals. This is the main methodological novelty implemented in the package. - **Half-Cauchy shrinkage** — a lighter-weight alternative that provides local shrinkage without a global scale parameter. ### Package map | Function | Task | Prior | |---|---|---| | `HorseTrees()` | Prediction (continuous / binary / survival*) | Horseshoe | | `ShrinkageTrees()` | Prediction — flexible prior choice* | Horseshoe, DART, BART, … | | `SurvivalBART()` | Survival prediction* | Classical BART | | `SurvivalDART()` | Sparse survival prediction* | DART (Dirichlet) | | `SurvivalBCF()` | Causal survival inference* | BCF (classical) | | `SurvivalShrinkageBCF()` | Sparse causal survival inference* | BCF + DART | | `CausalHorseForest()` | Causal inference (all outcomes*) | Horseshoe | | `CausalShrinkageForest()` | Causal inference — flexible prior* | Horseshoe, DART, BART, … | \* All survival functions support both right-censored and interval-censored outcomes. All model-fitting functions return an S3 object with consistent `print()`, `summary()`, `predict()`, and `plot()` methods. ## Key Concepts Before diving into examples, we clarify a few concepts that appear throughout the package interface. ### Outcome types and the `timescale` parameter Every model-fitting function accepts an `outcome_type` argument: - `"continuous"` — standard regression (default for most functions). - `"binary"` — probit BART for binary outcomes (0/1). - `"right-censored"` — accelerated failure time model for survival data. The outcome `y` contains (possibly censored) follow-up times, and the `status` vector indicates events (1) vs. censored observations (0). - `"interval-censored"` — AFT model for interval-censored survival data. Instead of `y` and `status`, provide `left_time` and `right_time` vectors specifying the lower and upper bounds of the observation window for each individual. Three cases are distinguished: - **Exact events**: `left_time == right_time` (event observed exactly). - **Interval-censored**: `left_time < right_time` with finite `right_time` (event occurred somewhere in the interval). - **Right-censored**: `right_time = Inf` (event not yet observed). This convention follows `survival::Surv(type = "interval2")`. For survival outcomes, the `timescale` argument controls how the package treats the times: - `timescale = "time"` (default): the supplied values are on the **original time scale** (positive numbers). The package internally applies a log-transform, i.e. models $\log(T) = f(\mathbf{x}) + \varepsilon$. Predictions from `summary()` and `predict()` are back-transformed to the time scale automatically. - `timescale = "log"`: the supplied values are **already log-transformed**. No further transformation is applied. Predictions stay on the log scale. In most applications you should use `timescale = "time"` and pass the raw survival times directly. ### Shrinkage priors on the step heights In a BART ensemble each tree contributes a step height (leaf parameter) to the overall prediction. Classical BART assigns these step heights a fixed-variance Gaussian prior, which regularises all leaves equally. In high-dimensional settings, stronger and more adaptive regularisation is desirable. **ShrinkageTrees** implements two shrinkage priors that are placed directly on the step heights via a scale mixture of normals: $$ h_\ell \mid \lambda_\ell, \tau, \omega \sim \mathcal{N}(0,\; \omega\, \lambda_\ell^2\, \tau^2). $$ Here $\tau$ is a **global** shrinkage parameter shared across all leaves, $\lambda_\ell$ is a **local** scale specific to leaf $\ell$, and $\omega$ is a fixed scaling constant. The two currently implemented instantiations are: - **Horseshoe** (`prior_type = "horseshoe"`). Both $\lambda_\ell$ and $\tau$ receive independent half-Cauchy priors. The heavy tails of the half-Cauchy allow individual leaves to escape shrinkage when the data support a strong effect, while the global parameter $\tau$ pulls the bulk of the estimates toward zero. This is the default prior in `HorseTrees()` and `CausalHorseForest()`. - **Half-Cauchy** (`prior_type = "half-cauchy"`). Only the local scales $\lambda_\ell$ receive a half-Cauchy prior; there is no global shrinkage parameter. This provides per-leaf adaptivity without the additional pooling across the ensemble. A forest-wide variant of the horseshoe (`prior_type = "horseshoe_fw"`) shares a single global $\tau$ across all trees in the forest rather than one per tree. These priors can be selected in `ShrinkageTrees()` and `CausalShrinkageForest()` via the `prior_type` argument, and can be combined with DART's Dirichlet splitting prior for simultaneous structural and parametric regularisation. ### Hyperparameter selection The two most important hyperparameters are `local_hp` and `global_hp`. These control the horseshoe prior on the step heights (leaf parameters): $$ \mu_{jl} \mid \lambda_{jl}, \tau_j \sim \mathcal{N}(0, \lambda_{jl}^2 \tau_j^2), \qquad \lambda_{jl} \sim \text{C}^+(0, \texttt{local\_hp}), \qquad \tau_j \sim \text{C}^+(0, \texttt{global\_hp}), $$ where $\text{C}^+$ denotes the half-Cauchy distribution. Smaller values produce stronger shrinkage toward zero; larger values allow more variation. **`HorseTrees()` and `CausalHorseForest()`** provide a convenience parameter `k` that sets both scales automatically: `local_hp = global_hp = k / sqrt(number_of_trees)`. The default `k = 0.1` works well in many settings and is a good starting point. **`ShrinkageTrees()` and `CausalShrinkageForest()`** expose `local_hp` and `global_hp` directly (no `k`). A common rule of thumb is: - `local_hp = k / sqrt(number_of_trees)` with `k` in [0.05, 0.5]. - `global_hp = local_hp` (symmetric) or a larger value if you want less overall shrinkage. The **survival functions** (`SurvivalBART`, `SurvivalDART`) use `k` to calibrate the standard BART leaf prior: `local_hp = range(log(y)) / (2 * k * sqrt(number_of_trees))`. The default `k = 2` follows Chipman et al. (2010). ### The `store_posterior_sample` flag When `store_posterior_sample = TRUE`, the fitted object stores the full $N_\text{post} \times n$ matrix of posterior draws for predictions. This is needed for: - `predict()` on new data (it re-runs the sampler internally, so posterior samples are always produced); - `plot(fit, type = "ate")` and `plot(fit, type = "cate")`, which require the full posterior distribution; - `plot(fit, type = "survival")` — full posterior credible bands over both $\mu_i$ and $\sigma$ (without samples, only sigma-uncertainty bands are available); - any custom posterior analysis (e.g. computing posterior credible intervals for individual predictions). When `FALSE`, only posterior *means* and $\sigma$ draws are stored, saving memory. `print()` and `summary()` work in both cases, but `predict()` will not be available. ### Treatment coding (`treatment_coding`) All causal model functions — `CausalHorseForest()`, `CausalShrinkageForest()`, `SurvivalBCF()`, and `SurvivalShrinkageBCF()` — decompose the outcome as $$ Y_i = \mu(\mathbf{x}_i) + b_i \cdot \tau(\mathbf{x}_i) + \varepsilon_i, $$ where $b_i$ is a scalar that depends on the treatment assignment $Z_i$. The `treatment_coding` argument controls how $b_i$ is defined. Four options are available: **`"centered"`** (default). $b_i = Z_i - 1/2$, so that $b_i \in \{-1/2,\; 1/2\}$. This is the original BCF parameterisation. **`"binary"`**. $b_i = Z_i$, so that $b_i \in \{0,\; 1\}$. Standard binary coding; the treatment forest captures the full effect of treatment on the treated. **`"adaptive"`**. $b_i = Z_i - \hat{e}(\mathbf{x}_i)$, where $\hat{e}(\mathbf{x}_i)$ is the estimated propensity score. This follows Hahn, Murray & Carvalho (2020) and is the coding used in the `bcf` R package. When using this option, a `propensity` vector must be supplied. **`"invariant"`**. Parameter-expanded (invariant) treatment coding. The coding parameters $b_0$ and $b_1$ are assigned $N(0,\; 1/2)$ priors and estimated within the Gibbs sampler via conjugate normal updates: $$ Y_i = \mu(\mathbf{x}_i) + b_{Z_i} \cdot \tilde{\tau}(\mathbf{x}_i) + \varepsilon_i, \qquad b_0,\; b_1 \sim N(0,\; 1/2). $$ The treatment effect is $\tau(\mathbf{x}_i) = (b_1 - b_0) \cdot \tilde{\tau}(\mathbf{x}_i)$, and the posterior draws of $b_0$ and $b_1$ are returned in the fitted object. This parameterisation is invariant to the coding of the treatment indicator (Hahn et al., 2020, Section 5.2). The examples below illustrate each option on a simple continuous-outcome causal model. ```{r tc-data} set.seed(50) n_tc <- 60; p_tc <- 5 X_tc <- matrix(rnorm(n_tc * p_tc), n_tc, p_tc) W_tc <- rbinom(n_tc, 1, 0.5) tau_tc <- 1.5 * (X_tc[, 1] > 0) y_tc <- X_tc[, 1] + W_tc * tau_tc + rnorm(n_tc, sd = 0.5) ``` ```{r tc-centered} # Centered (default) fit_tc_cen <- CausalHorseForest( y = y_tc, X_train_control = X_tc, X_train_treat = X_tc, treatment_indicator_train = W_tc, treatment_coding = "centered", number_of_trees = 5, N_post = 50, N_burn = 25, store_posterior_sample = TRUE, verbose = FALSE ) cat("Centered — ATE:", round(mean(fit_tc_cen$train_predictions_treat), 3), "\n") ``` ```{r tc-binary} # Binary fit_tc_bin <- CausalHorseForest( y = y_tc, X_train_control = X_tc, X_train_treat = X_tc, treatment_indicator_train = W_tc, treatment_coding = "binary", number_of_trees = 5, N_post = 50, N_burn = 25, store_posterior_sample = TRUE, verbose = FALSE ) cat("Binary — ATE:", round(mean(fit_tc_bin$train_predictions_treat), 3), "\n") ``` ```{r tc-adaptive} # Adaptive (requires propensity scores) ps_tc <- pnorm(0.3 * X_tc[, 1]) # simple propensity model for illustration fit_tc_ada <- CausalHorseForest( y = y_tc, X_train_control = X_tc, X_train_treat = X_tc, treatment_indicator_train = W_tc, treatment_coding = "adaptive", propensity = ps_tc, number_of_trees = 5, N_post = 50, N_burn = 25, store_posterior_sample = TRUE, verbose = FALSE ) cat("Adaptive — ATE:", round(mean(fit_tc_ada$train_predictions_treat), 3), "\n") ``` ```{r tc-invariant} # Invariant (parameter-expanded) fit_tc_inv <- CausalHorseForest( y = y_tc, X_train_control = X_tc, X_train_treat = X_tc, treatment_indicator_train = W_tc, treatment_coding = "invariant", number_of_trees = 5, N_post = 50, N_burn = 25, store_posterior_sample = TRUE, verbose = FALSE ) cat("Invariant — ATE:", round(mean(fit_tc_inv$train_predictions_treat), 3), "\n") # Posterior draws of b0 and b1 are stored in the fitted object cat("b0 posterior mean:", round(mean(fit_tc_inv$b0), 3), "\n") cat("b1 posterior mean:", round(mean(fit_tc_inv$b1), 3), "\n") ``` The survival functions inherit `treatment_coding` support. For example, `SurvivalBCF(..., treatment_coding = "invariant")` works out of the box. ## Included Datasets The package ships with two TCGA datasets for high-dimensional survival analysis and causal inference: - **`pdac`** — TCGA pancreatic ductal adenocarcinoma (PAAD) cohort (n = 178). A data frame with overall survival times, a binary treatment indicator (radiation therapy vs. control), clinical covariates, and expression values of ~3,000 genes selected by median absolute deviation. - **`ovarian`** — TCGA ovarian cancer (OV) cohort (n = 357). A list with `X` (357 x 2,000 gene expression matrix, log2-normalised TPM) and `clinical` (data frame with OS time/event, age, FIGO stage, tumor grade, and treatment: carboplatin vs cisplatin). See `?ovarian` for details. ### The PDAC Dataset The `pdac` dataset contains overall survival times, a binary treatment indicator (radiation therapy vs. control), clinical covariates, and expression values of approximately 3,000 genes selected by median absolute deviation. ```{r load-data} library(ShrinkageTrees) data("pdac") # Dimensions and column overview cat("Patients:", nrow(pdac), "\n") cat("Columns :", ncol(pdac), "\n") cat("Clinical columns:", paste(names(pdac)[1:13], collapse = ", "), "\n") cat("Survival: time (months), censoring rate =", round(1 - mean(pdac$status), 2), "\n") cat("Treatment: radiation =", sum(pdac$treatment), "/ control =", sum(1 - pdac$treatment), "\n") ``` We separate the outcome, treatment, and covariate matrix for the analyses below. ```{r prepare-data} time <- pdac$time status <- pdac$status treatment <- pdac$treatment X <- as.matrix(pdac[, !(names(pdac) %in% c("time", "status", "treatment"))]) ``` ## Prediction Models This section demonstrates the single-forest models in **ShrinkageTrees**. These models estimate a single function $f(\mathbf{x})$ of the covariates, applicable to continuous, binary, and survival outcomes. We begin with a binary-outcome example that will also serve as the propensity score model for the causal analyses later. ### HorseTrees — binary outcome (propensity scores) Before fitting causal models we estimate propensity scores $\hat{e}(\mathbf{x}) = P(W=1 \mid \mathbf{x})$ using `HorseTrees()` with a binary outcome. The probit link is used internally: predictions are on the latent Gaussian scale and can be converted to probabilities with `pnorm()`. The code block below uses reduced MCMC settings for illustration. A real analysis would use `N_post = 5000, N_burn = 5000`. ```{r horsetrees-binary, eval=FALSE} ps_fit <- HorseTrees( y = treatment, X_train = X, outcome_type = "binary", k = 0.1, N_post = 5000, N_burn = 5000, verbose = FALSE ) propensity <- pnorm(ps_fit$train_predictions) ``` For the remainder of this vignette we use a short synthetic run to keep build time low. ```{r horsetrees-binary-small} set.seed(1) n <- 80; p <- 10 X_syn <- matrix(rnorm(n * p), n, p) W_syn <- rbinom(n, 1, pnorm(0.8 * X_syn[, 1])) ps_fit <- HorseTrees( y = W_syn, X_train = X_syn, outcome_type = "binary", number_of_trees = 5, k = 0.5, N_post = 50, N_burn = 25, verbose = FALSE ) propensity_syn <- pnorm(ps_fit$train_predictions) cat("Propensity scores — range: [", round(range(propensity_syn), 3), "]\n") ``` ### HorseTrees — survival outcome `HorseTrees()` handles right-censored data via an AFT model. Pass `outcome_type = "right-censored"` and provide the `status` vector (1 = event, 0 = censored). When `timescale = "time"` (the default), the package log-transforms survival times internally and returns predictions on the log scale (see *Key Concepts* above). ```{r horsetrees-survival} set.seed(2) log_T <- X_syn[, 1] + rnorm(n) C <- rexp(n, 0.5) y_syn <- pmin(exp(log_T), C) d_syn <- as.integer(exp(log_T) <= C) ht_surv <- HorseTrees( y = y_syn, status = d_syn, X_train = X_syn, outcome_type = "right-censored", timescale = "time", number_of_trees = 5, N_post = 50, N_burn = 25, store_posterior_sample = TRUE, verbose = FALSE ) cat("Posterior mean log-time (first 5 obs):", round(ht_surv$train_predictions[1:5], 3), "\n") cat("Posterior sigma — mean:", round(mean(ht_surv$sigma), 3), "\n") ``` ### HorseTrees — interval-censored outcome When event times are not observed exactly but known to lie within an interval, the package supports **interval censoring**. Instead of providing `y` and `status`, pass `left_time` and `right_time` with `outcome_type = "interval-censored"`. The three censoring types are encoded as follows: - **Exact event**: `left_time[i] == right_time[i]` - **Interval-censored**: `left_time[i] < right_time[i]` (both finite) - **Right-censored**: `right_time[i] = Inf` This convention matches `survival::Surv(type = "interval2")`. ```{r horsetrees-ic} set.seed(20) # Generate true event times true_T <- rexp(n, rate = exp(-0.5 * X_syn[, 1])) # Create interval-censored observations left_syn <- true_T * runif(n, 0.5, 1.0) right_syn <- true_T * runif(n, 1.0, 1.5) # Mark some as exact observations and some as right-censored exact_idx <- sample(n, 25) left_syn[exact_idx] <- true_T[exact_idx] right_syn[exact_idx] <- true_T[exact_idx] rc_idx <- sample(setdiff(seq_len(n), exact_idx), 15) right_syn[rc_idx] <- Inf cat("Exact events:", sum(left_syn == right_syn), "\n") cat("Interval-censored:", sum(left_syn < right_syn & is.finite(right_syn)), "\n") cat("Right-censored:", sum(!is.finite(right_syn)), "\n") ht_ic <- HorseTrees( left_time = left_syn, right_time = right_syn, X_train = X_syn, outcome_type = "interval-censored", timescale = "time", number_of_trees = 5, N_post = 50, N_burn = 25, store_posterior_sample = TRUE, verbose = FALSE ) cat("Posterior mean log-time (first 5 obs):", round(ht_ic$train_predictions[1:5], 3), "\n") cat("Posterior sigma — mean:", round(mean(ht_ic$sigma), 3), "\n") ``` All survival functions (`SurvivalBART`, `SurvivalDART`, `SurvivalBCF`, `SurvivalShrinkageBCF`) and the general-purpose functions (`ShrinkageTrees`, `CausalShrinkageForest`, `CausalHorseForest`) accept `left_time` and `right_time` in the same way. For example, using `SurvivalBART`: ```{r sbart-ic} set.seed(21) fit_sbart_ic <- SurvivalBART( left_time = left_syn, right_time = right_syn, X_train = X_syn, number_of_trees = 5, N_post = 50, N_burn = 25, verbose = FALSE ) cat("SurvivalBART (IC) class:", class(fit_sbart_ic), "\n") ``` ### ShrinkageTrees — flexible prior choice While `HorseTrees()` fixes the prior to the horseshoe, the more general `ShrinkageTrees()` function exposes the `prior_type` argument, allowing the user to select among all implemented regularisation strategies. Available options are `"horseshoe"`, `"horseshoe_fw"` (forest-wide), `"half-cauchy"`, `"standard"` (classical BART), and `"dirichlet"` (DART). Below we compare the per-tree horseshoe and the forest-wide horseshoe on a continuous outcome. ```{r shrinkage-continuous} set.seed(3) y_cont <- X_syn[, 1] + 0.5 * X_syn[, 2] + rnorm(n) # Horseshoe prior (default for HorseTrees) fit_hs <- ShrinkageTrees( y = y_cont, X_train = X_syn, outcome_type = "continuous", prior_type = "horseshoe", local_hp = 0.1 / sqrt(5), global_hp = 0.1 / sqrt(5), number_of_trees = 5, N_post = 50, N_burn = 25, verbose = FALSE ) # Forest-wide horseshoe (horseshoe_fw) fit_fw <- ShrinkageTrees( y = y_cont, X_train = X_syn, outcome_type = "continuous", prior_type = "horseshoe_fw", local_hp = 0.1 / sqrt(5), global_hp = 0.1 / sqrt(5), number_of_trees = 5, N_post = 50, N_burn = 25, verbose = FALSE ) cat("Horseshoe — train RMSE:", round(sqrt(mean((fit_hs$train_predictions - y_cont)^2)), 3), "\n") cat("Horseshoe FW— train RMSE:", round(sqrt(mean((fit_fw$train_predictions - y_cont)^2)), 3), "\n") ``` ### SurvivalBART and SurvivalDART `SurvivalBART()` and `SurvivalDART()` fit classical BART and DART models for right-censored survival data under the AFT formulation. They calibrate prior hyperparameters automatically from the data range, providing a simple interface when horseshoe shrinkage is not needed. ```{r survival-bart-dart} set.seed(4) # SurvivalBART: classical BART prior, AFT likelihood fit_sbart <- SurvivalBART( time = y_syn, status = d_syn, X_train = X_syn, number_of_trees = 5, k = 2.0, N_post = 50, N_burn = 25, verbose = FALSE ) # SurvivalDART: Dirichlet (DART) splitting prior fit_sdart <- SurvivalDART( time = y_syn, status = d_syn, X_train = X_syn, number_of_trees = 5, k = 2.0, N_post = 50, N_burn = 25, verbose = FALSE ) cat("SurvivalBART class:", class(fit_sbart), "\n") cat("SurvivalDART class:", class(fit_sdart), "\n") ``` ## High-Dimensional Survival Analysis A key motivation for horseshoe shrinkage and the Dirichlet (DART) sparsity prior is their behaviour in the **$p \gg n$ regime**: many covariates are available but only a small subset drives the outcome. Classical BART may struggle here because the standard Gaussian leaf prior is non-sparse and does not concentrate on a small number of predictors. We illustrate both priors on a sparse AFT simulation: $n = 60$ observations, $p = 200$ predictors, and only three active predictors. ```{r hd-data} set.seed(20) n_hd <- 60; p_hd <- 200 X_hd <- matrix(rnorm(n_hd * p_hd), n_hd, p_hd) # True log-survival depends only on predictors 1, 2, and 3 log_T_hd <- 1.5 * X_hd[, 1] - 1.0 * X_hd[, 2] + 0.5 * X_hd[, 3] + rnorm(n_hd) C_hd <- rexp(n_hd, rate = 0.5) y_hd <- pmin(exp(log_T_hd), C_hd) d_hd <- as.integer(exp(log_T_hd) <= C_hd) cat("n =", n_hd, "| p =", p_hd, "| active predictors = 3", "| censoring rate =", round(1 - mean(d_hd), 2), "\n") ``` **ShrinkageTrees (horseshoe)** places global–local shrinkage on the step heights of every leaf, automatically regularising all 200 predictors toward zero while preserving the signal in the three active ones. ```{r hd-horseshoe} set.seed(21) fit_hd_hs <- ShrinkageTrees( y = y_hd, status = d_hd, X_train = X_hd, outcome_type = "right-censored", prior_type = "horseshoe", local_hp = 0.1 / sqrt(10), global_hp = 0.1 / sqrt(10), number_of_trees = 10, N_post = 50, N_burn = 25, verbose = FALSE ) ``` **SurvivalDART** uses a Dirichlet prior on split probabilities to induce structural sparsity: after burn-in, most splitting probability is concentrated on truly predictive variables. Setting `rho_dirichlet = 3` encodes the prior belief that approximately three predictors are active. ```{r hd-dart} set.seed(22) fit_hd_dart <- SurvivalDART( time = y_hd, status = d_hd, X_train = X_hd, number_of_trees = 10, rho_dirichlet = 3, N_post = 50, N_burn = 25, verbose = FALSE ) ``` Both models run without error in the $p > n$ regime. We compare their posterior mean predictions in log-time against the latent true values used to generate the data. ```{r hd-compare} rmse_hs <- sqrt(mean((fit_hd_hs$train_predictions - log_T_hd)^2)) rmse_dart <- sqrt(mean((fit_hd_dart$train_predictions - log_T_hd)^2)) cat(sprintf("%-18s train RMSE (log-time): %.3f\n", "Horseshoe", rmse_hs)) cat(sprintf("%-18s train RMSE (log-time): %.3f\n", "DART", rmse_dart)) ``` The DART model also produces **variable importance** plots that display the posterior distribution of each predictor's splitting probability. With only 50 posterior draws the top-10 plot below should already concentrate most probability mass near the three truly active predictors. ```{r hd-vi, eval=requireNamespace("ggplot2", quietly=TRUE)} plot(fit_hd_dart, type = "vi", n_vi = 10) ``` ## Causal Forest Models For causal inference, **ShrinkageTrees** provides Bayesian Causal Forest (BCF) models that decompose the outcome into a prognostic component and a treatment effect component: $$ Y_i = \mu(\mathbf{x}_i) + W_i \cdot \tau(\mathbf{x}_i) + \varepsilon_i, $$ where $\mu(\cdot)$ is the prognostic (control) function modelled by one tree ensemble, and $\tau(\cdot)$ is the heterogeneous treatment effect modelled by a second ensemble. This two-forest structure allows each component to have its own regularisation, number of trees, and prior — for instance, a standard BART prior for the prognostic forest and horseshoe shrinkage for the treatment effect forest. The package provides four causal model functions with increasing generality: `SurvivalBCF()` (classical BCF for survival), `SurvivalShrinkageBCF()` (BCF + DART for survival), `CausalHorseForest()` (horseshoe BCF for all outcome types), and `CausalShrinkageForest()` (fully configurable BCF). We illustrate each below on synthetic data with a known treatment effect. ```{r causal-data} set.seed(5) tau_true <- 1.5 * (X_syn[, 1] > 0) # heterogeneous treatment effect y_causal <- X_syn[, 1] + W_syn * tau_true + rnorm(n, sd = 0.5) ``` ### SurvivalBCF — classical BCF for survival `SurvivalBCF()` fits a BCF model for right-censored survival outcomes using classical BART priors. ```{r survbcf, eval=FALSE} # Full analysis (eval=FALSE — use larger MCMC settings in practice) fit_sbcf <- SurvivalBCF( time = time, status = status, X_train = X, treatment = treatment, propensity = propensity, # from HorseTrees above N_post = 5000, N_burn = 5000, verbose = FALSE ) ``` ```{r survbcf-small} set.seed(6) fit_sbcf <- SurvivalBCF( time = y_syn, status = d_syn, X_train = X_syn, treatment = W_syn, number_of_trees_control = 5, number_of_trees_treat = 5, N_post = 50, N_burn = 25, verbose = FALSE ) cat("SurvivalBCF class:", class(fit_sbcf), "\n") cat("ATE (posterior mean):", round(mean(fit_sbcf$train_predictions_treat), 3), "\n") ``` ### SurvivalShrinkageBCF — sparse causal survival forest `SurvivalShrinkageBCF()` extends BCF with a Dirichlet splitting prior on both forests, inducing sparsity in high-dimensional settings. ```{r survsbcf-small} set.seed(7) fit_ssbcf <- SurvivalShrinkageBCF( time = y_syn, status = d_syn, X_train = X_syn, treatment = W_syn, number_of_trees_control = 5, number_of_trees_treat = 5, N_post = 50, N_burn = 25, verbose = FALSE ) cat("SurvivalShrinkageBCF class:", class(fit_ssbcf), "\n") ``` ### CausalHorseForest — horseshoe causal forest `CausalHorseForest()` is the primary novel contribution of this package. It applies horseshoe shrinkage to the leaf parameters of both the prognostic and treatment-effect forests. This enables effective regularisation when many covariates are available but few are truly predictive of heterogeneous treatment effects. ```{r causal-horse} set.seed(8) fit_chf <- CausalHorseForest( y = y_causal, X_train_control = X_syn, X_train_treat = X_syn, treatment_indicator_train = W_syn, outcome_type = "continuous", number_of_trees = 5, N_post = 50, N_burn = 25, store_posterior_sample = TRUE, verbose = FALSE ) cat("CausalHorseForest class:", class(fit_chf), "\n") # Posterior mean CATE cate_mean <- fit_chf$train_predictions_treat cat("CATE — posterior mean (first 5):", round(cate_mean[1:5], 3), "\n") # Posterior ATE ate_samples <- rowMeans(fit_chf$train_predictions_sample_treat) cat("ATE posterior mean:", round(mean(ate_samples), 3), " 95% CI: [", round(quantile(ate_samples, 0.025), 3), ",", round(quantile(ate_samples, 0.975), 3), "]\n") ``` The fitted object stores the posterior mean CATE for each training observation in `train_predictions_treat`. When `store_posterior_sample = TRUE`, the full posterior sample matrix is available in `train_predictions_sample_treat`, from which the posterior ATE distribution and credible intervals can be computed as shown above. You can also supply separate test matrices to obtain out-of-sample CATE predictions. ```{r causal-horse-test} set.seed(9) X_test <- matrix(rnorm(20 * p), 20, p) W_test <- rbinom(20, 1, 0.5) fit_chf_test <- CausalHorseForest( y = y_causal, X_train_control = X_syn, X_train_treat = X_syn, treatment_indicator_train = W_syn, X_test_control = X_test, X_test_treat = X_test, treatment_indicator_test = W_test, outcome_type = "continuous", number_of_trees = 5, N_post = 50, N_burn = 25, store_posterior_sample = TRUE, verbose = FALSE ) cat("Test CATE (first 5):", round(fit_chf_test$test_predictions_treat[1:5], 3), "\n") ``` ### CausalShrinkageForest — flexible causal priors `CausalShrinkageForest()` is the most general causal model interface. It allows independent prior choices for the prognostic and treatment effect forests via `prior_type_control` and `prior_type_treat`. For example, one could use a standard BART prior for the prognostic forest (where variable selection is less critical) and horseshoe shrinkage for the treatment forest (where most covariates are expected to be irrelevant for the treatment effect). ```{r causal-shrinkage} set.seed(10) lh <- 0.1 / sqrt(5) fit_csf <- CausalShrinkageForest( y = y_causal, X_train_control = X_syn, X_train_treat = X_syn, treatment_indicator_train = W_syn, outcome_type = "continuous", prior_type_control = "horseshoe", prior_type_treat = "horseshoe", local_hp_control = lh, global_hp_control = lh, local_hp_treat = lh, global_hp_treat = lh, number_of_trees_control = 5, number_of_trees_treat = 5, N_post = 50, N_burn = 25, store_posterior_sample = TRUE, verbose = FALSE ) cat("CausalShrinkageForest class:", class(fit_csf), "\n") cat("Acceptance ratio (control):", round(fit_csf$acceptance_ratio_control, 3), "\n") cat("Acceptance ratio (treat) :", round(fit_csf$acceptance_ratio_treat, 3), "\n") ``` The `horseshoe_fw` prior adds a forest-wide shrinkage parameter that is tracked in the fitted object. ```{r causal-shrinkage-fw} set.seed(11) fit_fw2 <- CausalShrinkageForest( y = y_causal, X_train_control = X_syn, X_train_treat = X_syn, treatment_indicator_train = W_syn, outcome_type = "continuous", prior_type_control = "horseshoe_fw", prior_type_treat = "horseshoe_fw", local_hp_control = lh, global_hp_control = lh, local_hp_treat = lh, global_hp_treat = lh, number_of_trees_control = 5, number_of_trees_treat = 5, N_post = 50, N_burn = 25, verbose = FALSE ) cat("Forest-wide shrinkage (control, first 5 draws):\n") print(round(fit_fw2$forestwide_shrinkage_control[1:5], 4)) ``` ## S3 Methods All fitted objects — whether from prediction models or causal models — support a consistent set of S3 methods: `print()`, `summary()`, `predict()`, and `plot()`. This section illustrates each method using the models fitted above. ### print() Calling `print()` (or just typing the object name) displays a concise model summary. ```{r print} print(fit_chf) ``` For causal models the output additionally shows the number of trees in each forest and prior details for both components. ```{r print-csf} print(fit_csf) ``` ### summary() `summary()` returns a structured list and displays a richer description including posterior statistics for $\sigma$, acceptance ratios, and treatment effect estimates. ```{r summary-shrinkage} smry <- summary(fit_hs) print(smry) ``` For causal models the summary includes the posterior ATE with a 95% credible interval (when `store_posterior_sample = TRUE`). ```{r summary-causal} smry_c <- summary(fit_chf) print(smry_c) # Access the ATE directly cat("ATE mean :", round(smry_c$treatment_effect$ate, 3), "\n") cat("ATE 95% CI: [", round(smry_c$treatment_effect$ate_lower, 3), ",", round(smry_c$treatment_effect$ate_upper, 3), "]\n") ``` ##### Population vs. mixed ATE By default the ATE credible interval is obtained by a **Bayesian bootstrap**: at each MCMC iteration $s$ the observation-level CATEs $\tau^{(s)}(x_i)$ are reweighted with Dirichlet(1, ..., 1) weights, $$ \widehat{\mathrm{PATE}}^{(s)} \;=\; \sum_{i=1}^n w_i^{(s)}\, \tau^{(s)}(x_i), \qquad (w_1^{(s)}, \dots, w_n^{(s)}) \sim \mathrm{Dir}(1, \dots, 1). $$ The collection $\{\widehat{\mathrm{PATE}}^{(s)}\}$ approximates the posterior of the **population ATE** and therefore propagates uncertainty in both $\tau(\cdot)$ and the covariate distribution $F_X$. Setting `bayesian_bootstrap = FALSE` reverts to equal $1/n$ weights, giving the **mixed ATE** (MATE) that conditions on the observed covariates and has a narrower credible interval. ```{r summary-pate-mate} smry_pate <- summary(fit_chf, bayesian_bootstrap = TRUE) # default smry_mate <- summary(fit_chf, bayesian_bootstrap = FALSE) ``` The standalone helper `bayesian_bootstrap_ate()` returns both posteriors and their draws in a single list, and also works on a `CausalShrinkageForestPrediction` returned by `predict()` so that the PATE integrates over a prespecified target population. ```{r bb-ate} bb <- bayesian_bootstrap_ate(fit_chf) cat("PATE:", round(bb$pate_mean, 3), " 95% CI: [", round(bb$pate_ci$lower, 3), ",", round(bb$pate_ci$upper, 3), "]\n") cat("MATE:", round(bb$mate_mean, 3), " 95% CI: [", round(bb$mate_ci$lower, 3), ",", round(bb$mate_ci$upper, 3), "]\n") ``` ### predict() `predict()` computes the posterior predictive distribution on new data. It returns a `ShrinkageTreesPrediction` object with posterior mean and credible-interval vectors. ```{r predict} X_new <- matrix(rnorm(10 * p), 10, p) pred <- predict(fit_hs, newdata = X_new) print(pred) ``` ```{r predict-ci} # Point estimates and 95% credible intervals head(data.frame( mean = round(pred$mean, 3), lower = round(pred$lower, 3), upper = round(pred$upper, 3) )) ``` #### Causal predictions For causal models (`CausalShrinkageForest` and `CausalHorseForest`), `predict()` returns three sets of posterior summaries: - **prognostic**: the control-forest prediction $\mu(\mathbf{x})$ — the expected outcome under control. - **cate**: the Conditional Average Treatment Effect $\tau(\mathbf{x})$ — the additional effect of treatment for each individual. - **total**: the combined prediction $\mu(\mathbf{x}) + \tau(\mathbf{x})$ — the expected outcome under treatment. For survival models with `timescale = "time"`, the prognostic and total components are back-transformed to the original time scale (via $\exp(\cdot)$), and the CATE becomes a **multiplicative time ratio**: $\exp(\tau) > 1$ means treatment prolongs survival. The `predict()` method requires two covariate matrices — one for each forest — matching the columns used at fit time. ```{r predict-causal} X_new_ctrl <- matrix(rnorm(10 * p), 10, p) X_new_treat <- matrix(rnorm(10 * p), 10, p) pred_c <- predict(fit_chf, newdata_control = X_new_ctrl, newdata_treat = X_new_treat) print(pred_c) ``` ```{r predict-causal-detail} # Extract individual components head(data.frame( prognostic = round(pred_c$prognostic$mean, 3), cate = round(pred_c$cate$mean, 3), total = round(pred_c$total$mean, 3) )) ``` ### plot() The `plot()` method produces diagnostic and inferential graphics using the **ggplot2** package (a suggested dependency). #### Sigma traceplot ```{r plot-trace, eval=requireNamespace("ggplot2", quietly=TRUE)} plot(fit_hs, type = "trace") ``` #### Posterior ATE distribution (causal models) The ATE density uses the Bayesian-bootstrap PATE posterior by default; pass `bayesian_bootstrap = FALSE` to plot the (narrower) mixed ATE density instead. See the summary section above for the definitions. ```{r plot-ate, eval=requireNamespace("ggplot2", quietly=TRUE)} plot(fit_chf, type = "ate") # PATE (default) plot(fit_chf, type = "ate", bayesian_bootstrap = FALSE) # MATE ``` #### CATE caterpillar plot ```{r plot-cate, eval=requireNamespace("ggplot2", quietly=TRUE)} plot(fit_chf, type = "cate") ``` #### Variable importance (Dirichlet prior) Variable importance plots are available when `prior_type = "dirichlet"`. For causal models with `prior_type_control = "dirichlet"` or `prior_type_treat = "dirichlet"`, use `forest = "control"`, `"treat"`, or `"both"`. ```{r vi-fit} set.seed(12) fit_dart <- ShrinkageTrees( y = y_cont, X_train = X_syn, outcome_type = "continuous", prior_type = "dirichlet", local_hp = 0.1 / sqrt(5), number_of_trees = 5, N_post = 50, N_burn = 25, verbose = FALSE ) ``` ```{r plot-vi, eval=requireNamespace("ggplot2", quietly=TRUE)} plot(fit_dart, type = "vi", n_vi = 10) ``` #### Survival curves For survival models (`outcome_type = "right-censored"` or `"interval-censored"`), the `plot()` method can draw posterior survival curves derived from the fitted AFT log-normal model: $$ S(t \mid \mathbf{x}_i) = 1 - \Phi\!\left(\frac{\log t - \mu_i}{\sigma}\right), $$ where $\mu_i = f(\mathbf{x}_i)$ is the BART ensemble prediction and $\sigma$ is the residual standard deviation on the log-time scale. The `type = "survival"` option supports two modes controlled by the `obs` argument: - **Population-averaged curve** (`obs = NULL`, the default): computes $\bar{S}(t) = n^{-1}\sum_i S(t \mid \mathbf{x}_i)$ at each MCMC iteration, giving credible bands that reflect full posterior uncertainty. - **Individual curves** (`obs = c(1, 5, ...)`): one curve per selected training observation with its own credible band. Additional options: | Argument | Description | |---|---| | `level` | Width of the pointwise credible band (default `0.95`). | | `t_grid` | Custom time grid (original scale). Auto-generated if `NULL`. | | `km` | If `TRUE`, overlay the Kaplan–Meier estimate (population-average only). | We use the survival fit from the earlier section: ```{r surv-curve-pop, eval=requireNamespace("ggplot2", quietly=TRUE)} # Population-averaged survival curve with 95% credible band plot(ht_surv, type = "survival") ``` ```{r surv-curve-km, eval=requireNamespace("ggplot2", quietly=TRUE) && requireNamespace("survival", quietly=TRUE)} # Same curve with the Kaplan-Meier estimate overlaid for comparison plot(ht_surv, type = "survival", km = TRUE) ``` ```{r surv-curve-ind, eval=requireNamespace("ggplot2", quietly=TRUE)} # Individual survival curves for observations 1, 20, 40, 60, and 80 plot(ht_surv, type = "survival", obs = c(1, 20, 40, 60, 80)) ``` ```{r surv-curve-single, eval=requireNamespace("ggplot2", quietly=TRUE)} # Single individual with a narrower 90% credible band plot(ht_surv, type = "survival", obs = 1, level = 0.90) ``` When `store_posterior_sample = FALSE`, the credible bands only reflect uncertainty in $\sigma$ (using plug-in posterior mean $\hat{\mu}_i$). The survival functions (`SurvivalBART`, `SurvivalDART`, etc.) store posterior samples by default, so full posterior bands are available out of the box. #### Posterior predictive survival curves The survival curves above are based on the **training** data — they show $S(t \mid \mathbf{x}_i)$ for the observations used to fit the model. For **new** (out-of-sample) data, call `predict()` first and then `plot()` on the prediction object. This produces *posterior predictive* survival curves that propagate full parameter uncertainty through to the new covariate values: ```{r surv-pred-setup} # New observations for prediction set.seed(99) X_new <- matrix(rnorm(20 * p), ncol = p) pred_surv <- predict(ht_surv, newdata = X_new) ``` ```{r surv-pred-pop, eval=requireNamespace("ggplot2", quietly=TRUE)} # Population-averaged posterior predictive survival curve plot(pred_surv, type = "survival") ``` ```{r surv-pred-ind, eval=requireNamespace("ggplot2", quietly=TRUE)} # Individual posterior predictive curves for selected new observations plot(pred_surv, type = "survival", obs = c(1, 5, 10)) ``` The same `level` and `t_grid` arguments are available as for the training-data survival curves. The Kaplan–Meier overlay (`km = TRUE`) is not available for prediction objects, since observed event times are only known for the training set. ## Multi-Chain MCMC Running multiple independent chains improves mixing diagnostics and reduces sensitivity to starting values. Pass `n_chains > 1` to any model-fitting function; chains are run in parallel via `parallel::mclapply` on Unix-like systems. ```{r multi-chain} set.seed(13) fit_2chain <- ShrinkageTrees( y = y_cont, X_train = X_syn, outcome_type = "continuous", prior_type = "horseshoe", local_hp = 0.1 / sqrt(5), global_hp = 0.1 / sqrt(5), number_of_trees = 5, N_post = 50, N_burn = 25, n_chains = 2, verbose = FALSE ) cat("n_chains stored :", fit_2chain$mcmc$n_chains, "\n") cat("Total sigma draws:", length(fit_2chain$sigma), " (2 chains x 50 draws)\n") cat("Per-chain acceptance ratios:\n") print(round(fit_2chain$chains$acceptance_ratios, 3)) ``` The same interface works for causal models. ```{r multi-chain-causal} set.seed(14) fit_causal_2chain <- CausalShrinkageForest( y = y_causal, X_train_control = X_syn, X_train_treat = X_syn, treatment_indicator_train = W_syn, outcome_type = "continuous", prior_type_control = "horseshoe", prior_type_treat = "horseshoe", local_hp_control = lh, global_hp_control = lh, local_hp_treat = lh, global_hp_treat = lh, number_of_trees_control = 5, number_of_trees_treat = 5, N_post = 50, N_burn = 25, n_chains = 2, verbose = FALSE ) cat("Pooled sigma draws:", length(fit_causal_2chain$sigma), "\n") cat("Per-chain acceptance ratios (control):\n") print(round(fit_causal_2chain$chains$acceptance_ratios_control, 3)) ``` With multiple chains the traceplot shows one line per chain, and the overlaid density plot compares the marginal posterior of $\sigma$ across chains — both require `n_chains > 1`. ```{r plot-trace-2chain, eval=requireNamespace("ggplot2", quietly=TRUE)} plot(fit_2chain, type = "trace") ``` ```{r plot-density-2chain, eval=requireNamespace("ggplot2", quietly=TRUE)} plot(fit_2chain, type = "density") ``` ## Convergence Diagnostics MCMC methods require careful assessment of convergence before trusting the posterior summaries. Here are practical guidelines for **ShrinkageTrees** models. ### Sigma traceplot The traceplot of the error standard deviation $\sigma$ (via `plot(fit, type = "trace")`) is the primary diagnostic. A well-mixing chain should show: - **No trend**: the trace should fluctuate around a stable level after burn-in. A persistent upward or downward drift indicates the sampler has not yet converged — increase `N_burn`. - **Good mixing**: the chain should move freely across its stationary range. If the trace is "sticky" (stays in the same region for long stretches), this indicates poor mixing. - **Chain agreement** (when `n_chains > 1`): separate chains should overlap substantially. The density overlay (`plot(fit, type = "density")`) makes this easy to check visually. ### Acceptance ratio The `summary()` output reports the average Metropolis–Hastings acceptance ratio for the tree structure proposals (grow/prune moves). As a rough guide: - **0.15–0.50** is typical and healthy for tree-based MCMC. - **Very low** (< 0.05) means most proposals are rejected. The sampler is barely exploring tree space. Consider increasing `N_burn` and `N_post`, or relaxing the tree structure prior (lower `power`, higher `base`). - **Very high** (> 0.70) means almost all proposals are accepted, which typically indicates the trees are staying very small (e.g. stumps). This is less concerning but may limit the model's expressiveness. ### Formal diagnostics with coda When the suggested package **coda** is installed, `summary()` automatically reports **effective sample size** (ESS) and — for multi-chain fits — the **Gelman–Rubin $\hat{R}$**. ```{r coda-summary, eval=requireNamespace("coda", quietly=TRUE)} # summary() includes convergence diagnostics when coda is available summary(fit_2chain) ``` For more detailed diagnostics, convert the fitted object to a `coda::mcmc.list` with `as.mcmc.list()`: ```{r coda-diagnostics, eval=requireNamespace("coda", quietly=TRUE)} library(coda) mcmc_obj <- as.mcmc.list(fit_2chain) # Gelman-Rubin R-hat (values near 1 indicate convergence) coda::gelman.diag(mcmc_obj) # Effective sample size coda::effectiveSize(mcmc_obj) # Geweke diagnostic (per chain) coda::geweke.diag(mcmc_obj[[1]]) ``` The returned `mcmc.list` object is compatible with all **coda** functions, including `coda::autocorr.plot()`, `coda::gelman.plot()`, `coda::heidel.diag()`, and `coda::raftery.diag()`. ### Recommended MCMC settings The examples in this vignette use very small `N_post` and `N_burn` to keep build time low. For a real analysis: - **Minimum**: `N_post = 2000, N_burn = 2000`. - **Recommended**: `N_post = 5000, N_burn = 5000`. - **High-dimensional or survival**: `N_post = 5000, N_burn = 10000` (the AFT data augmentation step can slow mixing, so a longer burn-in helps). - **Multiple chains**: `n_chains = 2` or `4` to verify convergence and produce pooled posterior samples. ## Case Study: TCGA PAAD (Full Analysis) The full analysis of the `pdac` dataset replicates the case study from Jacobs, van Wieringen & van der Pas (2025). Due to the high-dimensional covariate space (~3,000 genes) and the large MCMC settings needed for reliable inference, the code below is provided for reference but is not evaluated during vignette building. Pre-computed results can be reproduced by running the `pdac_analysis` demo: `demo("pdac_analysis", package = "ShrinkageTrees")`. ### Step 1: Propensity score estimation ```{r pdac-ps, eval=FALSE} data("pdac") time <- pdac$time status <- pdac$status treatment <- pdac$treatment X <- as.matrix(pdac[, !(names(pdac) %in% c("time","status","treatment"))]) set.seed(2025) ps_fit <- HorseTrees( y = treatment, X_train = X, outcome_type = "binary", k = 0.1, N_post = 5000, N_burn = 5000, verbose = FALSE ) propensity <- pnorm(ps_fit$train_predictions) ``` ```{r pdac-ps-overlap, eval=FALSE} # Overlap plot p0 <- propensity[treatment == 0] p1 <- propensity[treatment == 1] hist(p0, breaks = 15, col = rgb(1, 0.5, 0, 0.5), xlim = range(propensity), xlab = "Propensity score", main = "Propensity score overlap") hist(p1, breaks = 15, col = rgb(0, 0.5, 0, 0.5), add = TRUE) legend("topright", legend = c("Control", "Treated"), fill = c(rgb(1,0.5,0,0.5), rgb(0,0.5,0,0.5))) ``` ### Step 2: Causal survival forest ```{r pdac-causal, eval=FALSE} # Augment control matrix with propensity scores (BCF-style) X_control <- cbind(propensity, X) # Log-transform and centre survival times log_time <- log(time) - mean(log(time)) set.seed(2025) fit_pdac <- CausalHorseForest( y = log_time, status = status, X_train_control = X_control, X_train_treat = X, treatment_indicator_train = treatment, outcome_type = "right-censored", timescale = "log", number_of_trees = 200, N_post = 5000, N_burn = 5000, store_posterior_sample = TRUE, verbose = FALSE ) ``` ### Step 3: ATE and CATE estimation ```{r pdac-ate, eval=FALSE} # Print model summary print(fit_pdac) smry_pdac <- summary(fit_pdac) print(smry_pdac) # ATE cat("ATE posterior mean:", round(smry_pdac$treatment_effect$ate, 3), "\n") cat("95% CI: [", round(smry_pdac$treatment_effect$ate_lower, 3), ",", round(smry_pdac$treatment_effect$ate_upper, 3), "]\n") ``` ### Step 4: Diagnostics ```{r pdac-diag, eval=FALSE} # Sigma convergence plot(fit_pdac, type = "trace") # Posterior ATE density plot(fit_pdac, type = "ate") # CATE caterpillar plot plot(fit_pdac, type = "cate") ``` ## References Chipman, H. A., George, E. I., & McCulloch, R. E. (2010). Bayesian Additive Regression Trees. *Annals of Applied Statistics*, 4(1), 266–298. Hahn, P. R., Murray, J. S., & Carvalho, C. M. (2020). Bayesian Regression Tree Models for Causal Inference: Regularization, Confounding, and Heterogeneous Treatment Effects. *Bayesian Analysis*, 15(3), 965–1056. Jacobs, T., van Wieringen, W. N., & van der Pas, S. L. (2025). Horseshoe Forests for High-Dimensional Causal Survival Analysis. *arXiv preprint* arXiv:2507.22004. Linero, A. R. (2018). Bayesian Regression Trees for High-Dimensional Prediction and Variable Selection. *Journal of the American Statistical Association*, 113(522), 626–636. Sparapani, R., Spanbauer, C., & McCulloch, R. (2021). Nonparametric Machine Learning and Efficient Computation with Bayesian Additive Regression Trees: The BART R Package. *Journal of Statistical Software*, 97(1), 1–66.