ShrinkageTrees: Bayesian Tree Ensembles for Survival Analysis and Causal Inference

Tijn Jacobs

2026-04-21

Introduction

ShrinkageTrees is an R package that brings Bayesian Additive Regression Trees (BART; Chipman, George & McCulloch, 2010) to survival analysis and causal inference, with a particular focus on high-dimensional data.

The package implements BART-based models for right-censored and interval-censored survival outcomes using an accelerated failure time (AFT) formulation. Censored event times are handled through Bayesian data augmentation in the Gibbs sampler, enabling full posterior inference without proportional-hazards assumptions. For causal inference, the package provides Bayesian Causal Forests (BCF; Hahn, Murray & Carvalho, 2020), which decompose the outcome into a prognostic function \(\mu(\mathbf{x})\) and a treatment-effect function \(\tau(\mathbf{x})\), each estimated by a separate tree ensemble. This two-forest structure supports estimation of heterogeneous treatment effects (CATEs) and the average treatment effect (ATE).

A key feature is the availability of multiple regularisation strategies that can be freely combined within a single model:

Package map

Function Task Prior
HorseTrees() Prediction (continuous / binary / survival*) Horseshoe
ShrinkageTrees() Prediction — flexible prior choice* Horseshoe, DART, BART, …
SurvivalBART() Survival prediction* Classical BART
SurvivalDART() Sparse survival prediction* DART (Dirichlet)
SurvivalBCF() Causal survival inference* BCF (classical)
SurvivalShrinkageBCF() Sparse causal survival inference* BCF + DART
CausalHorseForest() Causal inference (all outcomes*) Horseshoe
CausalShrinkageForest() Causal inference — flexible prior* Horseshoe, DART, BART, …

* All survival functions support both right-censored and interval-censored outcomes.

All model-fitting functions return an S3 object with consistent print(), summary(), predict(), and plot() methods.

Key Concepts

Before diving into examples, we clarify a few concepts that appear throughout the package interface.

Outcome types and the timescale parameter

Every model-fitting function accepts an outcome_type argument:

For survival outcomes, the timescale argument controls how the package treats the times:

In most applications you should use timescale = "time" and pass the raw survival times directly.

Shrinkage priors on the step heights

In a BART ensemble each tree contributes a step height (leaf parameter) to the overall prediction. Classical BART assigns these step heights a fixed-variance Gaussian prior, which regularises all leaves equally. In high-dimensional settings, stronger and more adaptive regularisation is desirable. ShrinkageTrees implements two shrinkage priors that are placed directly on the step heights via a scale mixture of normals:

\[ h_\ell \mid \lambda_\ell, \tau, \omega \sim \mathcal{N}(0,\; \omega\, \lambda_\ell^2\, \tau^2). \]

Here \(\tau\) is a global shrinkage parameter shared across all leaves, \(\lambda_\ell\) is a local scale specific to leaf \(\ell\), and \(\omega\) is a fixed scaling constant. The two currently implemented instantiations are:

A forest-wide variant of the horseshoe (prior_type = "horseshoe_fw") shares a single global \(\tau\) across all trees in the forest rather than one per tree. These priors can be selected in ShrinkageTrees() and CausalShrinkageForest() via the prior_type argument, and can be combined with DART’s Dirichlet splitting prior for simultaneous structural and parametric regularisation.

Hyperparameter selection

The two most important hyperparameters are local_hp and global_hp. These control the horseshoe prior on the step heights (leaf parameters):

\[ \mu_{jl} \mid \lambda_{jl}, \tau_j \sim \mathcal{N}(0, \lambda_{jl}^2 \tau_j^2), \qquad \lambda_{jl} \sim \text{C}^+(0, \texttt{local\_hp}), \qquad \tau_j \sim \text{C}^+(0, \texttt{global\_hp}), \]

where \(\text{C}^+\) denotes the half-Cauchy distribution. Smaller values produce stronger shrinkage toward zero; larger values allow more variation.

HorseTrees() and CausalHorseForest() provide a convenience parameter k that sets both scales automatically: local_hp = global_hp = k / sqrt(number_of_trees). The default k = 0.1 works well in many settings and is a good starting point.

ShrinkageTrees() and CausalShrinkageForest() expose local_hp and global_hp directly (no k). A common rule of thumb is:

The survival functions (SurvivalBART, SurvivalDART) use k to calibrate the standard BART leaf prior: local_hp = range(log(y)) / (2 * k * sqrt(number_of_trees)). The default k = 2 follows Chipman et al. (2010).

The store_posterior_sample flag

When store_posterior_sample = TRUE, the fitted object stores the full \(N_\text{post} \times n\) matrix of posterior draws for predictions. This is needed for:

When FALSE, only posterior means and \(\sigma\) draws are stored, saving memory. print() and summary() work in both cases, but predict() will not be available.

Treatment coding (treatment_coding)

All causal model functions — CausalHorseForest(), CausalShrinkageForest(), SurvivalBCF(), and SurvivalShrinkageBCF() — decompose the outcome as \[ Y_i = \mu(\mathbf{x}_i) + b_i \cdot \tau(\mathbf{x}_i) + \varepsilon_i, \] where \(b_i\) is a scalar that depends on the treatment assignment \(Z_i\). The treatment_coding argument controls how \(b_i\) is defined. Four options are available:

"centered" (default). \(b_i = Z_i - 1/2\), so that \(b_i \in \{-1/2,\; 1/2\}\). This is the original BCF parameterisation.

"binary". \(b_i = Z_i\), so that \(b_i \in \{0,\; 1\}\). Standard binary coding; the treatment forest captures the full effect of treatment on the treated.

"adaptive". \(b_i = Z_i - \hat{e}(\mathbf{x}_i)\), where \(\hat{e}(\mathbf{x}_i)\) is the estimated propensity score. This follows Hahn, Murray & Carvalho (2020) and is the coding used in the bcf R package. When using this option, a propensity vector must be supplied.

"invariant". Parameter-expanded (invariant) treatment coding. The coding parameters \(b_0\) and \(b_1\) are assigned \(N(0,\; 1/2)\) priors and estimated within the Gibbs sampler via conjugate normal updates: \[ Y_i = \mu(\mathbf{x}_i) + b_{Z_i} \cdot \tilde{\tau}(\mathbf{x}_i) + \varepsilon_i, \qquad b_0,\; b_1 \sim N(0,\; 1/2). \] The treatment effect is \(\tau(\mathbf{x}_i) = (b_1 - b_0) \cdot \tilde{\tau}(\mathbf{x}_i)\), and the posterior draws of \(b_0\) and \(b_1\) are returned in the fitted object. This parameterisation is invariant to the coding of the treatment indicator (Hahn et al., 2020, Section 5.2).

The examples below illustrate each option on a simple continuous-outcome causal model.

set.seed(50)
n_tc <- 60;  p_tc <- 5
X_tc <- matrix(rnorm(n_tc * p_tc), n_tc, p_tc)
W_tc <- rbinom(n_tc, 1, 0.5)
tau_tc <- 1.5 * (X_tc[, 1] > 0)
y_tc <- X_tc[, 1] + W_tc * tau_tc + rnorm(n_tc, sd = 0.5)
# Centered (default)
fit_tc_cen <- CausalHorseForest(
  y = y_tc,
  X_train_control = X_tc, X_train_treat = X_tc,
  treatment_indicator_train = W_tc,
  treatment_coding = "centered",
  number_of_trees = 5, N_post = 50, N_burn = 25,
  store_posterior_sample = TRUE, verbose = FALSE
)
cat("Centered — ATE:",
    round(mean(fit_tc_cen$train_predictions_treat), 3), "\n")
#> Centered — ATE: 0.238
# Binary
fit_tc_bin <- CausalHorseForest(
  y = y_tc,
  X_train_control = X_tc, X_train_treat = X_tc,
  treatment_indicator_train = W_tc,
  treatment_coding = "binary",
  number_of_trees = 5, N_post = 50, N_burn = 25,
  store_posterior_sample = TRUE, verbose = FALSE
)
cat("Binary — ATE:",
    round(mean(fit_tc_bin$train_predictions_treat), 3), "\n")
#> Binary — ATE: 0.521
# Adaptive (requires propensity scores)
ps_tc <- pnorm(0.3 * X_tc[, 1])   # simple propensity model for illustration

fit_tc_ada <- CausalHorseForest(
  y = y_tc,
  X_train_control = X_tc, X_train_treat = X_tc,
  treatment_indicator_train = W_tc,
  treatment_coding = "adaptive",
  propensity = ps_tc,
  number_of_trees = 5, N_post = 50, N_burn = 25,
  store_posterior_sample = TRUE, verbose = FALSE
)
cat("Adaptive — ATE:",
    round(mean(fit_tc_ada$train_predictions_treat), 3), "\n")
#> Adaptive — ATE: 0.154
# Invariant (parameter-expanded)
fit_tc_inv <- CausalHorseForest(
  y = y_tc,
  X_train_control = X_tc, X_train_treat = X_tc,
  treatment_indicator_train = W_tc,
  treatment_coding = "invariant",
  number_of_trees = 5, N_post = 50, N_burn = 25,
  store_posterior_sample = TRUE, verbose = FALSE
)
cat("Invariant — ATE:",
    round(mean(fit_tc_inv$train_predictions_treat), 3), "\n")
#> Invariant — ATE: 0.232

# Posterior draws of b0 and b1 are stored in the fitted object
cat("b0 posterior mean:", round(mean(fit_tc_inv$b0), 3), "\n")
#> b0 posterior mean: 0.068
cat("b1 posterior mean:", round(mean(fit_tc_inv$b1), 3), "\n")
#> b1 posterior mean: 0.153

The survival functions inherit treatment_coding support. For example, SurvivalBCF(..., treatment_coding = "invariant") works out of the box.

Included Datasets

The package ships with two TCGA datasets for high-dimensional survival analysis and causal inference:

The PDAC Dataset

The pdac dataset contains overall survival times, a binary treatment indicator (radiation therapy vs. control), clinical covariates, and expression values of approximately 3,000 genes selected by median absolute deviation.

library(ShrinkageTrees)
data("pdac")

# Dimensions and column overview
cat("Patients:", nrow(pdac), "\n")
#> Patients: 130
cat("Columns :", ncol(pdac), "\n")
#> Columns : 3032
cat("Clinical columns:", paste(names(pdac)[1:13], collapse = ", "), "\n")
#> Clinical columns: time, status, treatment, age, sex, grade, tumor.cellularity, tumor.purity, absolute.purity, moffitt.cluster, meth.leukocyte.percent, meth.purity.mode, stage
cat("Survival: time (months), censoring rate =",
    round(1 - mean(pdac$status), 2), "\n")
#> Survival: time (months), censoring rate = 0.47
cat("Treatment: radiation =", sum(pdac$treatment),
    "/ control =", sum(1 - pdac$treatment), "\n")
#> Treatment: radiation = 36 / control = 94

We separate the outcome, treatment, and covariate matrix for the analyses below.

time      <- pdac$time
status    <- pdac$status
treatment <- pdac$treatment
X         <- as.matrix(pdac[, !(names(pdac) %in% c("time", "status", "treatment"))])

Prediction Models

This section demonstrates the single-forest models in ShrinkageTrees. These models estimate a single function \(f(\mathbf{x})\) of the covariates, applicable to continuous, binary, and survival outcomes. We begin with a binary-outcome example that will also serve as the propensity score model for the causal analyses later.

HorseTrees — binary outcome (propensity scores)

Before fitting causal models we estimate propensity scores \(\hat{e}(\mathbf{x}) = P(W=1 \mid \mathbf{x})\) using HorseTrees() with a binary outcome. The probit link is used internally: predictions are on the latent Gaussian scale and can be converted to probabilities with pnorm().

The code block below uses reduced MCMC settings for illustration. A real analysis would use N_post = 5000, N_burn = 5000.

ps_fit <- HorseTrees(
  y            = treatment,
  X_train      = X,
  outcome_type = "binary",
  k            = 0.1,
  N_post       = 5000,
  N_burn       = 5000,
  verbose      = FALSE
)

propensity <- pnorm(ps_fit$train_predictions)

For the remainder of this vignette we use a short synthetic run to keep build time low.

set.seed(1)
n <- 80;  p <- 10
X_syn  <- matrix(rnorm(n * p), n, p)
W_syn  <- rbinom(n, 1, pnorm(0.8 * X_syn[, 1]))

ps_fit <- HorseTrees(
  y            = W_syn,
  X_train      = X_syn,
  outcome_type = "binary",
  number_of_trees = 5,
  k            = 0.5,
  N_post       = 50,
  N_burn       = 25,
  verbose      = FALSE
)

propensity_syn <- pnorm(ps_fit$train_predictions)
cat("Propensity scores — range: [",
    round(range(propensity_syn), 3), "]\n")
#> Propensity scores — range: [ 0.469 0.539 ]

HorseTrees — survival outcome

HorseTrees() handles right-censored data via an AFT model. Pass outcome_type = "right-censored" and provide the status vector (1 = event, 0 = censored). When timescale = "time" (the default), the package log-transforms survival times internally and returns predictions on the log scale (see Key Concepts above).

set.seed(2)
log_T <- X_syn[, 1] + rnorm(n)
C     <- rexp(n, 0.5)
y_syn   <- pmin(exp(log_T), C)
d_syn   <- as.integer(exp(log_T) <= C)

ht_surv <- HorseTrees(
  y               = y_syn,
  status          = d_syn,
  X_train         = X_syn,
  outcome_type    = "right-censored",
  timescale       = "time",
  number_of_trees = 5,
  N_post          = 50,
  N_burn          = 25,
  store_posterior_sample = TRUE,
  verbose         = FALSE
)

cat("Posterior mean log-time (first 5 obs):",
    round(ht_surv$train_predictions[1:5], 3), "\n")
#> Posterior mean log-time (first 5 obs): 1.204 1.319 1.354 1.389 1.247
cat("Posterior sigma — mean:",
    round(mean(ht_surv$sigma), 3), "\n")
#> Posterior sigma — mean: 0.981

HorseTrees — interval-censored outcome

When event times are not observed exactly but known to lie within an interval, the package supports interval censoring. Instead of providing y and status, pass left_time and right_time with outcome_type = "interval-censored".

The three censoring types are encoded as follows:

This convention matches survival::Surv(type = "interval2").

set.seed(20)

# Generate true event times
true_T <- rexp(n, rate = exp(-0.5 * X_syn[, 1]))

# Create interval-censored observations
left_syn  <- true_T * runif(n, 0.5, 1.0)
right_syn <- true_T * runif(n, 1.0, 1.5)

# Mark some as exact observations and some as right-censored
exact_idx <- sample(n, 25)
left_syn[exact_idx]  <- true_T[exact_idx]
right_syn[exact_idx] <- true_T[exact_idx]

rc_idx <- sample(setdiff(seq_len(n), exact_idx), 15)
right_syn[rc_idx] <- Inf

cat("Exact events:", sum(left_syn == right_syn), "\n")
#> Exact events: 25
cat("Interval-censored:", sum(left_syn < right_syn & is.finite(right_syn)), "\n")
#> Interval-censored: 40
cat("Right-censored:", sum(!is.finite(right_syn)), "\n")
#> Right-censored: 15

ht_ic <- HorseTrees(
  left_time       = left_syn,
  right_time      = right_syn,
  X_train         = X_syn,
  outcome_type    = "interval-censored",
  timescale       = "time",
  number_of_trees = 5,
  N_post          = 50,
  N_burn          = 25,
  store_posterior_sample = TRUE,
  verbose         = FALSE
)

cat("Posterior mean log-time (first 5 obs):",
    round(ht_ic$train_predictions[1:5], 3), "\n")
#> Posterior mean log-time (first 5 obs): 0.666 0.698 0.596 0.785 0.726
cat("Posterior sigma — mean:",
    round(mean(ht_ic$sigma), 3), "\n")
#> Posterior sigma — mean: 1.016

All survival functions (SurvivalBART, SurvivalDART, SurvivalBCF, SurvivalShrinkageBCF) and the general-purpose functions (ShrinkageTrees, CausalShrinkageForest, CausalHorseForest) accept left_time and right_time in the same way. For example, using SurvivalBART:

set.seed(21)
fit_sbart_ic <- SurvivalBART(
  left_time       = left_syn,
  right_time      = right_syn,
  X_train         = X_syn,
  number_of_trees = 5,
  N_post          = 50,
  N_burn          = 25,
  verbose         = FALSE
)

cat("SurvivalBART (IC) class:", class(fit_sbart_ic), "\n")
#> SurvivalBART (IC) class: ShrinkageTrees

ShrinkageTrees — flexible prior choice

While HorseTrees() fixes the prior to the horseshoe, the more general ShrinkageTrees() function exposes the prior_type argument, allowing the user to select among all implemented regularisation strategies. Available options are "horseshoe", "horseshoe_fw" (forest-wide), "half-cauchy", "standard" (classical BART), and "dirichlet" (DART). Below we compare the per-tree horseshoe and the forest-wide horseshoe on a continuous outcome.

set.seed(3)
y_cont <- X_syn[, 1] + 0.5 * X_syn[, 2] + rnorm(n)

# Horseshoe prior (default for HorseTrees)
fit_hs <- ShrinkageTrees(
  y               = y_cont,
  X_train         = X_syn,
  outcome_type    = "continuous",
  prior_type      = "horseshoe",
  local_hp        = 0.1 / sqrt(5),
  global_hp       = 0.1 / sqrt(5),
  number_of_trees = 5,
  N_post          = 50,
  N_burn          = 25,
  verbose         = FALSE
)

# Forest-wide horseshoe (horseshoe_fw)
fit_fw <- ShrinkageTrees(
  y               = y_cont,
  X_train         = X_syn,
  outcome_type    = "continuous",
  prior_type      = "horseshoe_fw",
  local_hp        = 0.1 / sqrt(5),
  global_hp       = 0.1 / sqrt(5),
  number_of_trees = 5,
  N_post          = 50,
  N_burn          = 25,
  verbose         = FALSE
)

cat("Horseshoe   — train RMSE:",
    round(sqrt(mean((fit_hs$train_predictions - y_cont)^2)), 3), "\n")
#> Horseshoe   — train RMSE: 1.217
cat("Horseshoe FW— train RMSE:",
    round(sqrt(mean((fit_fw$train_predictions - y_cont)^2)), 3), "\n")
#> Horseshoe FW— train RMSE: 1.23

SurvivalBART and SurvivalDART

SurvivalBART() and SurvivalDART() fit classical BART and DART models for right-censored survival data under the AFT formulation. They calibrate prior hyperparameters automatically from the data range, providing a simple interface when horseshoe shrinkage is not needed.

set.seed(4)

# SurvivalBART: classical BART prior, AFT likelihood
fit_sbart <- SurvivalBART(
  time            = y_syn,
  status          = d_syn,
  X_train         = X_syn,
  number_of_trees = 5,
  k               = 2.0,
  N_post          = 50,
  N_burn          = 25,
  verbose         = FALSE
)

# SurvivalDART: Dirichlet (DART) splitting prior
fit_sdart <- SurvivalDART(
  time            = y_syn,
  status          = d_syn,
  X_train         = X_syn,
  number_of_trees = 5,
  k               = 2.0,
  N_post          = 50,
  N_burn          = 25,
  verbose         = FALSE
)

cat("SurvivalBART  class:", class(fit_sbart), "\n")
#> SurvivalBART  class: ShrinkageTrees
cat("SurvivalDART  class:", class(fit_sdart), "\n")
#> SurvivalDART  class: ShrinkageTrees

High-Dimensional Survival Analysis

A key motivation for horseshoe shrinkage and the Dirichlet (DART) sparsity prior is their behaviour in the \(p \gg n\) regime: many covariates are available but only a small subset drives the outcome. Classical BART may struggle here because the standard Gaussian leaf prior is non-sparse and does not concentrate on a small number of predictors.

We illustrate both priors on a sparse AFT simulation: \(n = 60\) observations, \(p = 200\) predictors, and only three active predictors.

set.seed(20)
n_hd <- 60;  p_hd <- 200
X_hd <- matrix(rnorm(n_hd * p_hd), n_hd, p_hd)

# True log-survival depends only on predictors 1, 2, and 3
log_T_hd <- 1.5 * X_hd[, 1] - 1.0 * X_hd[, 2] + 0.5 * X_hd[, 3] + rnorm(n_hd)
C_hd     <- rexp(n_hd, rate = 0.5)
y_hd     <- pmin(exp(log_T_hd), C_hd)
d_hd     <- as.integer(exp(log_T_hd) <= C_hd)

cat("n =", n_hd, "| p =", p_hd,
    "| active predictors = 3",
    "| censoring rate =", round(1 - mean(d_hd), 2), "\n")
#> n = 60 | p = 200 | active predictors = 3 | censoring rate = 0.4

ShrinkageTrees (horseshoe) places global–local shrinkage on the step heights of every leaf, automatically regularising all 200 predictors toward zero while preserving the signal in the three active ones.

set.seed(21)
fit_hd_hs <- ShrinkageTrees(
  y               = y_hd,
  status          = d_hd,
  X_train         = X_hd,
  outcome_type    = "right-censored",
  prior_type      = "horseshoe",
  local_hp        = 0.1 / sqrt(10),
  global_hp       = 0.1 / sqrt(10),
  number_of_trees = 10,
  N_post          = 50,
  N_burn          = 25,
  verbose         = FALSE
)

SurvivalDART uses a Dirichlet prior on split probabilities to induce structural sparsity: after burn-in, most splitting probability is concentrated on truly predictive variables. Setting rho_dirichlet = 3 encodes the prior belief that approximately three predictors are active.

set.seed(22)
fit_hd_dart <- SurvivalDART(
  time            = y_hd,
  status          = d_hd,
  X_train         = X_hd,
  number_of_trees = 10,
  rho_dirichlet   = 3,
  N_post          = 50,
  N_burn          = 25,
  verbose         = FALSE
)

Both models run without error in the \(p > n\) regime. We compare their posterior mean predictions in log-time against the latent true values used to generate the data.

rmse_hs   <- sqrt(mean((fit_hd_hs$train_predictions  - log_T_hd)^2))
rmse_dart <- sqrt(mean((fit_hd_dart$train_predictions - log_T_hd)^2))

cat(sprintf("%-18s  train RMSE (log-time): %.3f\n", "Horseshoe", rmse_hs))
#> Horseshoe           train RMSE (log-time): 2.530
cat(sprintf("%-18s  train RMSE (log-time): %.3f\n", "DART",      rmse_dart))
#> DART                train RMSE (log-time): 2.402

The DART model also produces variable importance plots that display the posterior distribution of each predictor’s splitting probability. With only 50 posterior draws the top-10 plot below should already concentrate most probability mass near the three truly active predictors.

plot(fit_hd_dart, type = "vi", n_vi = 10)

Causal Forest Models

For causal inference, ShrinkageTrees provides Bayesian Causal Forest (BCF) models that decompose the outcome into a prognostic component and a treatment effect component: \[ Y_i = \mu(\mathbf{x}_i) + W_i \cdot \tau(\mathbf{x}_i) + \varepsilon_i, \] where \(\mu(\cdot)\) is the prognostic (control) function modelled by one tree ensemble, and \(\tau(\cdot)\) is the heterogeneous treatment effect modelled by a second ensemble. This two-forest structure allows each component to have its own regularisation, number of trees, and prior — for instance, a standard BART prior for the prognostic forest and horseshoe shrinkage for the treatment effect forest.

The package provides four causal model functions with increasing generality: SurvivalBCF() (classical BCF for survival), SurvivalShrinkageBCF() (BCF + DART for survival), CausalHorseForest() (horseshoe BCF for all outcome types), and CausalShrinkageForest() (fully configurable BCF). We illustrate each below on synthetic data with a known treatment effect.

set.seed(5)
tau_true <- 1.5 * (X_syn[, 1] > 0)    # heterogeneous treatment effect
y_causal <- X_syn[, 1] + W_syn * tau_true + rnorm(n, sd = 0.5)

SurvivalBCF — classical BCF for survival

SurvivalBCF() fits a BCF model for right-censored survival outcomes using classical BART priors.

# Full analysis (eval=FALSE — use larger MCMC settings in practice)
fit_sbcf <- SurvivalBCF(
  time       = time,
  status     = status,
  X_train    = X,
  treatment  = treatment,
  propensity = propensity,   # from HorseTrees above
  N_post     = 5000,
  N_burn     = 5000,
  verbose    = FALSE
)
set.seed(6)
fit_sbcf <- SurvivalBCF(
  time            = y_syn,
  status          = d_syn,
  X_train         = X_syn,
  treatment       = W_syn,
  number_of_trees_control = 5,
  number_of_trees_treat   = 5,
  N_post          = 50,
  N_burn          = 25,
  verbose         = FALSE
)
cat("SurvivalBCF class:", class(fit_sbcf), "\n")
#> SurvivalBCF class: CausalShrinkageForest
cat("ATE (posterior mean):",
    round(mean(fit_sbcf$train_predictions_treat), 3), "\n")
#> ATE (posterior mean): 1.475

SurvivalShrinkageBCF — sparse causal survival forest

SurvivalShrinkageBCF() extends BCF with a Dirichlet splitting prior on both forests, inducing sparsity in high-dimensional settings.

set.seed(7)
fit_ssbcf <- SurvivalShrinkageBCF(
  time            = y_syn,
  status          = d_syn,
  X_train         = X_syn,
  treatment       = W_syn,
  number_of_trees_control = 5,
  number_of_trees_treat   = 5,
  N_post          = 50,
  N_burn          = 25,
  verbose         = FALSE
)
cat("SurvivalShrinkageBCF class:", class(fit_ssbcf), "\n")
#> SurvivalShrinkageBCF class: CausalShrinkageForest

CausalHorseForest — horseshoe causal forest

CausalHorseForest() is the primary novel contribution of this package. It applies horseshoe shrinkage to the leaf parameters of both the prognostic and treatment-effect forests. This enables effective regularisation when many covariates are available but few are truly predictive of heterogeneous treatment effects.

set.seed(8)
fit_chf <- CausalHorseForest(
  y                         = y_causal,
  X_train_control           = X_syn,
  X_train_treat             = X_syn,
  treatment_indicator_train = W_syn,
  outcome_type              = "continuous",
  number_of_trees           = 5,
  N_post                    = 50,
  N_burn                    = 25,
  store_posterior_sample    = TRUE,
  verbose                   = FALSE
)

cat("CausalHorseForest class:", class(fit_chf), "\n")
#> CausalHorseForest class: CausalShrinkageForest

# Posterior mean CATE
cate_mean <- fit_chf$train_predictions_treat
cat("CATE — posterior mean (first 5):",
    round(cate_mean[1:5], 3), "\n")
#> CATE — posterior mean (first 5): 1.48 1.559 1.504 1.281 1.421

# Posterior ATE
ate_samples <- rowMeans(fit_chf$train_predictions_sample_treat)
cat("ATE posterior mean:",
    round(mean(ate_samples), 3),
    "  95% CI: [",
    round(quantile(ate_samples, 0.025), 3), ",",
    round(quantile(ate_samples, 0.975), 3), "]\n")
#> ATE posterior mean: 1.422   95% CI: [ 0.567 , 2.154 ]

The fitted object stores the posterior mean CATE for each training observation in train_predictions_treat. When store_posterior_sample = TRUE, the full posterior sample matrix is available in train_predictions_sample_treat, from which the posterior ATE distribution and credible intervals can be computed as shown above.

You can also supply separate test matrices to obtain out-of-sample CATE predictions.

set.seed(9)
X_test <- matrix(rnorm(20 * p), 20, p)
W_test <- rbinom(20, 1, 0.5)

fit_chf_test <- CausalHorseForest(
  y                         = y_causal,
  X_train_control           = X_syn,
  X_train_treat             = X_syn,
  treatment_indicator_train = W_syn,
  X_test_control            = X_test,
  X_test_treat              = X_test,
  treatment_indicator_test  = W_test,
  outcome_type              = "continuous",
  number_of_trees           = 5,
  N_post                    = 50,
  N_burn                    = 25,
  store_posterior_sample    = TRUE,
  verbose                   = FALSE
)

cat("Test CATE (first 5):",
    round(fit_chf_test$test_predictions_treat[1:5], 3), "\n")
#> Test CATE (first 5): 1.599 1.503 1.615 1.497 1.5

CausalShrinkageForest — flexible causal priors

CausalShrinkageForest() is the most general causal model interface. It allows independent prior choices for the prognostic and treatment effect forests via prior_type_control and prior_type_treat. For example, one could use a standard BART prior for the prognostic forest (where variable selection is less critical) and horseshoe shrinkage for the treatment forest (where most covariates are expected to be irrelevant for the treatment effect).

set.seed(10)
lh <- 0.1 / sqrt(5)

fit_csf <- CausalShrinkageForest(
  y                         = y_causal,
  X_train_control           = X_syn,
  X_train_treat             = X_syn,
  treatment_indicator_train = W_syn,
  outcome_type              = "continuous",
  prior_type_control        = "horseshoe",
  prior_type_treat          = "horseshoe",
  local_hp_control          = lh,
  global_hp_control         = lh,
  local_hp_treat            = lh,
  global_hp_treat           = lh,
  number_of_trees_control   = 5,
  number_of_trees_treat     = 5,
  N_post                    = 50,
  N_burn                    = 25,
  store_posterior_sample    = TRUE,
  verbose                   = FALSE
)

cat("CausalShrinkageForest class:", class(fit_csf), "\n")
#> CausalShrinkageForest class: CausalShrinkageForest
cat("Acceptance ratio (control):",
    round(fit_csf$acceptance_ratio_control, 3), "\n")
#> Acceptance ratio (control): 0.348
cat("Acceptance ratio (treat)  :",
    round(fit_csf$acceptance_ratio_treat, 3), "\n")
#> Acceptance ratio (treat)  : 0.248

The horseshoe_fw prior adds a forest-wide shrinkage parameter that is tracked in the fitted object.

set.seed(11)
fit_fw2 <- CausalShrinkageForest(
  y                         = y_causal,
  X_train_control           = X_syn,
  X_train_treat             = X_syn,
  treatment_indicator_train = W_syn,
  outcome_type              = "continuous",
  prior_type_control        = "horseshoe_fw",
  prior_type_treat          = "horseshoe_fw",
  local_hp_control          = lh,
  global_hp_control         = lh,
  local_hp_treat            = lh,
  global_hp_treat           = lh,
  number_of_trees_control   = 5,
  number_of_trees_treat     = 5,
  N_post                    = 50,
  N_burn                    = 25,
  verbose                   = FALSE
)

cat("Forest-wide shrinkage (control, first 5 draws):\n")
#> Forest-wide shrinkage (control, first 5 draws):
print(round(fit_fw2$forestwide_shrinkage_control[1:5], 4))
#> [1] 1 1 1 1 1

S3 Methods

All fitted objects — whether from prediction models or causal models — support a consistent set of S3 methods: print(), summary(), predict(), and plot(). This section illustrates each method using the models fitted above.

print()

Calling print() (or just typing the object name) displays a concise model summary.

print(fit_chf)
#> 
#> CausalShrinkageForest model
#> ---------------------------
#> Outcome type:         Continuous
#> Training size (n):    80
#> Posterior draws:      50 (burn-in 25)
#> Posterior mean sigma: 0.821
#> 
#>                       Control             Treatment           
#>                       ------------------- -------------------
#> Prior:                horseshoe           horseshoe           
#> Number of trees:      5                   5                   
#> Number of features:   10                  10                  
#> Acceptance ratio:     0.428               0.404

For causal models the output additionally shows the number of trees in each forest and prior details for both components.

print(fit_csf)
#> 
#> CausalShrinkageForest model
#> ---------------------------
#> Outcome type:         Continuous
#> Training size (n):    80
#> Posterior draws:      50 (burn-in 25)
#> Posterior mean sigma: 0.814
#> 
#>                       Control             Treatment           
#>                       ------------------- -------------------
#> Prior:                horseshoe           horseshoe           
#> Number of trees:      5                   5                   
#> Number of features:   10                  10                  
#> Acceptance ratio:     0.348               0.248

summary()

summary() returns a structured list and displays a richer description including posterior statistics for \(\sigma\), acceptance ratios, and treatment effect estimates.

smry <- summary(fit_hs)
print(smry)
#> 
#> ShrinkageTrees model summary
#> ============================
#> Call: ShrinkageTrees(y = y_cont, X_train = X_syn, outcome_type = "continuous", 
#>     number_of_trees = 5, prior_type = "horseshoe", local_hp = 0.1/sqrt(5), 
#>     global_hp = 0.1/sqrt(5), N_post = 50, N_burn = 25, verbose = FALSE)
#> 
#> Outcome: Continuous | Prior: horseshoe | Trees: 5
#> Data:    n = 80, p = 10 | Draws: 50 (burn-in 25)
#> 
#> Posterior sigma:
#>   Mean: 0.987  SD: 0.089  95% CI: [0.851, 1.205]
#> 
#> Predictions (posterior mean):
#>   Train: mean = 0.053, sd = 0.053, range = [-0.094, 0.159]
#>   Test:  mean = 0.065, sd = NA, range = [0.065, 0.065]
#> 
#> Variable importance (posterior inclusion probability):
#>   X3: 0.164   X9: 0.143   X4: 0.142   X2: 0.132   X8: 0.131   X5: 0.091   X7: 0.084   X6: 0.064   X10: 0.037   X1: 0.011 
#> 
#> MCMC acceptance ratio: 0.484
#> 
#> Convergence diagnostics (coda):
#>   Effective sample size: sigma = 50

For causal models the summary includes the posterior ATE with a 95% credible interval (when store_posterior_sample = TRUE).

smry_c <- summary(fit_chf)
print(smry_c)
#> 
#> CausalShrinkageForest model summary
#> =====================================
#> Call: CausalHorseForest(y = y_causal, X_train_control = X_syn, X_train_treat = X_syn, 
#>     treatment_indicator_train = W_syn, outcome_type = "continuous", 
#>     number_of_trees = 5, N_post = 50, N_burn = 25, store_posterior_sample = TRUE, 
#>     verbose = FALSE)
#> 
#> Outcome: Continuous
#> Prior:   control = horseshoe, treatment = horseshoe
#> Trees:   control = 5, treatment = 5
#> Data:    n = 80, p_control = 10, p_treat = 10 | Draws: 50 (burn-in 25)
#> 
#> Treatment effect:
#>   PATE:    1.4224  95% CI (Bayesian bootstrap): [0.621, 2.1773]
#>   CATE SD: 0.1203
#> 
#> Prognostic function (mu):
#>   Mean: 0.597  SD: 0.023  Range: [0.542, 0.662]
#> 
#> Posterior sigma:
#>   Mean: 0.821  SD: 0.078  95% CI: [0.695, 0.99]
#> 
#> Variable importance - control forest (posterior inclusion probability):
#>   X6: 0.195   X3: 0.162   X2: 0.137   X8: 0.119   X10: 0.098   X9: 0.076   X7: 0.073   X5: 0.07   X4: 0.037   X1: 0.033 
#> 
#> Variable importance - treatment forest (posterior inclusion probability):
#>   X2: 0.177   X6: 0.173   X3: 0.123   X8: 0.11   X9: 0.098   X10: 0.087   X1: 0.086   X5: 0.068   X4: 0.046   X7: 0.032 
#> 
#> MCMC acceptance ratios: control = 0.428, treatment = 0.404

# Access the ATE directly
cat("ATE mean  :", round(smry_c$treatment_effect$ate, 3), "\n")
#> ATE mean  : 1.422
cat("ATE 95% CI: [",
    round(smry_c$treatment_effect$ate_lower, 3), ",",
    round(smry_c$treatment_effect$ate_upper, 3), "]\n")
#> ATE 95% CI: [ 0.621 , 2.177 ]
Population vs. mixed ATE

By default the ATE credible interval is obtained by a Bayesian bootstrap: at each MCMC iteration \(s\) the observation-level CATEs \(\tau^{(s)}(x_i)\) are reweighted with Dirichlet(1, …, 1) weights,

\[ \widehat{\mathrm{PATE}}^{(s)} \;=\; \sum_{i=1}^n w_i^{(s)}\, \tau^{(s)}(x_i), \qquad (w_1^{(s)}, \dots, w_n^{(s)}) \sim \mathrm{Dir}(1, \dots, 1). \]

The collection \(\{\widehat{\mathrm{PATE}}^{(s)}\}\) approximates the posterior of the population ATE and therefore propagates uncertainty in both \(\tau(\cdot)\) and the covariate distribution \(F_X\). Setting bayesian_bootstrap = FALSE reverts to equal \(1/n\) weights, giving the mixed ATE (MATE) that conditions on the observed covariates and has a narrower credible interval.

smry_pate <- summary(fit_chf, bayesian_bootstrap = TRUE)   # default
smry_mate <- summary(fit_chf, bayesian_bootstrap = FALSE)

The standalone helper bayesian_bootstrap_ate() returns both posteriors and their draws in a single list, and also works on a CausalShrinkageForestPrediction returned by predict() so that the PATE integrates over a prespecified target population.

bb <- bayesian_bootstrap_ate(fit_chf)
cat("PATE:", round(bb$pate_mean, 3),
    " 95% CI: [", round(bb$pate_ci$lower, 3), ",",
                  round(bb$pate_ci$upper, 3), "]\n")
#> PATE: 1.428  95% CI: [ 0.532 , 2.141 ]
cat("MATE:", round(bb$mate_mean, 3),
    " 95% CI: [", round(bb$mate_ci$lower, 3), ",",
                  round(bb$mate_ci$upper, 3), "]\n")
#> MATE: 1.422  95% CI: [ 0.567 , 2.154 ]

predict()

predict() computes the posterior predictive distribution on new data. It returns a ShrinkageTreesPrediction object with posterior mean and credible-interval vectors.

X_new  <- matrix(rnorm(10 * p), 10, p)

pred <- predict(fit_hs, newdata = X_new)
print(pred)
#> 
#> ShrinkageTrees predictions
#> --------------------------
#> Observations:         10
#> Credible interval:    95%
#> Scale:                fitted value
#> 
#>              mean     lower     upper
#>          --------  --------  --------
#>   [  1]    -0.008    -0.398     0.248
#>   [  2]     0.009    -0.492     0.390
#>   [  3]     0.089    -0.312     0.546
#>   [  4]     0.045    -0.237     0.326
#>   [  5]     0.002    -0.385     0.433
#>   [  6]    -0.022    -0.464     0.248
#>   ... (4 more)
# Point estimates and 95% credible intervals
head(data.frame(
  mean  = round(pred$mean,  3),
  lower = round(pred$lower, 3),
  upper = round(pred$upper, 3)
))
#>     mean  lower upper
#> 1 -0.008 -0.398 0.248
#> 2  0.009 -0.492 0.390
#> 3  0.089 -0.312 0.546
#> 4  0.045 -0.237 0.326
#> 5  0.002 -0.385 0.433
#> 6 -0.022 -0.464 0.248

Causal predictions

For causal models (CausalShrinkageForest and CausalHorseForest), predict() returns three sets of posterior summaries:

  • prognostic: the control-forest prediction \(\mu(\mathbf{x})\) — the expected outcome under control.
  • cate: the Conditional Average Treatment Effect \(\tau(\mathbf{x})\) — the additional effect of treatment for each individual.
  • total: the combined prediction \(\mu(\mathbf{x}) + \tau(\mathbf{x})\) — the expected outcome under treatment.

For survival models with timescale = "time", the prognostic and total components are back-transformed to the original time scale (via \(\exp(\cdot)\)), and the CATE becomes a multiplicative time ratio: \(\exp(\tau) > 1\) means treatment prolongs survival.

The predict() method requires two covariate matrices — one for each forest — matching the columns used at fit time.

X_new_ctrl  <- matrix(rnorm(10 * p), 10, p)
X_new_treat <- matrix(rnorm(10 * p), 10, p)

pred_c <- predict(fit_chf, newdata_control = X_new_ctrl,
                  newdata_treat = X_new_treat)
print(pred_c)
#> 
#> CausalShrinkageForest predictions
#> ----------------------------------
#> Observations:         10
#> Credible interval:    95%
#> Outcome type:         continuous
#> 
#> PATE: 1.372  95% CI (Bayesian bootstrap): [0.34, 2.013]
#> 
#> Prognostic (mu):
#>              mean     lower     upper
#>          --------  --------  --------
#>   [  1]     0.559     0.474     0.751
#>   [  2]     0.555     0.474     0.739
#>   [  3]     0.557     0.433     0.674
#>   [  4]     0.553     0.469     0.654
#>   [  5]     0.548     0.429     0.659
#>   [  6]     0.554     0.468     0.659
#>   ... (4 more)
#> 
#> CATE (tau):
#>              mean     lower     upper
#>          --------  --------  --------
#>   [  1]     1.380     0.205     2.122
#>   [  2]     1.366     0.179     2.014
#>   [  3]     1.332     0.137     2.135
#>   [  4]     1.351     0.121     2.077
#>   [  5]     1.350     0.205     2.049
#>   [  6]     1.374     0.079     2.064
#>   ... (4 more)
#> 
#> Total (mu + tau):
#>              mean     lower     upper
#>          --------  --------  --------
#>   [  1]     1.939     0.738     2.713
#>   [  2]     1.921     0.643     2.579
#>   [  3]     1.889     0.626     2.728
#>   [  4]     1.904     0.603     2.694
#>   [  5]     1.898     0.735     2.580
#>   [  6]     1.929     0.584     2.572
#>   ... (4 more)
# Extract individual components
head(data.frame(
  prognostic = round(pred_c$prognostic$mean, 3),
  cate       = round(pred_c$cate$mean, 3),
  total      = round(pred_c$total$mean, 3)
))
#>   prognostic  cate total
#> 1      0.559 1.380 1.939
#> 2      0.555 1.366 1.921
#> 3      0.557 1.332 1.889
#> 4      0.553 1.351 1.904
#> 5      0.548 1.350 1.898
#> 6      0.554 1.374 1.929

plot()

The plot() method produces diagnostic and inferential graphics using the ggplot2 package (a suggested dependency).

Sigma traceplot

plot(fit_hs, type = "trace")

Posterior ATE distribution (causal models)

The ATE density uses the Bayesian-bootstrap PATE posterior by default; pass bayesian_bootstrap = FALSE to plot the (narrower) mixed ATE density instead. See the summary section above for the definitions.

plot(fit_chf, type = "ate")                         # PATE (default)

plot(fit_chf, type = "ate", bayesian_bootstrap = FALSE)  # MATE

CATE caterpillar plot

plot(fit_chf, type = "cate")

Variable importance (Dirichlet prior)

Variable importance plots are available when prior_type = "dirichlet". For causal models with prior_type_control = "dirichlet" or prior_type_treat = "dirichlet", use forest = "control", "treat", or "both".

set.seed(12)
fit_dart <- ShrinkageTrees(
  y               = y_cont,
  X_train         = X_syn,
  outcome_type    = "continuous",
  prior_type      = "dirichlet",
  local_hp        = 0.1 / sqrt(5),
  number_of_trees = 5,
  N_post          = 50,
  N_burn          = 25,
  verbose         = FALSE
)
plot(fit_dart, type = "vi", n_vi = 10)

Survival curves

For survival models (outcome_type = "right-censored" or "interval-censored"), the plot() method can draw posterior survival curves derived from the fitted AFT log-normal model: \[ S(t \mid \mathbf{x}_i) = 1 - \Phi\!\left(\frac{\log t - \mu_i}{\sigma}\right), \] where \(\mu_i = f(\mathbf{x}_i)\) is the BART ensemble prediction and \(\sigma\) is the residual standard deviation on the log-time scale.

The type = "survival" option supports two modes controlled by the obs argument:

  • Population-averaged curve (obs = NULL, the default): computes \(\bar{S}(t) = n^{-1}\sum_i S(t \mid \mathbf{x}_i)\) at each MCMC iteration, giving credible bands that reflect full posterior uncertainty.
  • Individual curves (obs = c(1, 5, ...)): one curve per selected training observation with its own credible band.

Additional options:

Argument Description
level Width of the pointwise credible band (default 0.95).
t_grid Custom time grid (original scale). Auto-generated if NULL.
km If TRUE, overlay the Kaplan–Meier estimate (population-average only).

We use the survival fit from the earlier section:

# Population-averaged survival curve with 95% credible band
plot(ht_surv, type = "survival")

# Same curve with the Kaplan-Meier estimate overlaid for comparison
plot(ht_surv, type = "survival", km = TRUE)

# Individual survival curves for observations 1, 20, 40, 60, and 80
plot(ht_surv, type = "survival", obs = c(1, 20, 40, 60, 80))

# Single individual with a narrower 90% credible band
plot(ht_surv, type = "survival", obs = 1, level = 0.90)

When store_posterior_sample = FALSE, the credible bands only reflect uncertainty in \(\sigma\) (using plug-in posterior mean \(\hat{\mu}_i\)). The survival functions (SurvivalBART, SurvivalDART, etc.) store posterior samples by default, so full posterior bands are available out of the box.

Posterior predictive survival curves

The survival curves above are based on the training data — they show \(S(t \mid \mathbf{x}_i)\) for the observations used to fit the model. For new (out-of-sample) data, call predict() first and then plot() on the prediction object. This produces posterior predictive survival curves that propagate full parameter uncertainty through to the new covariate values:

# New observations for prediction
set.seed(99)
X_new <- matrix(rnorm(20 * p), ncol = p)
pred_surv <- predict(ht_surv, newdata = X_new)
# Population-averaged posterior predictive survival curve
plot(pred_surv, type = "survival")

# Individual posterior predictive curves for selected new observations
plot(pred_surv, type = "survival", obs = c(1, 5, 10))

The same level and t_grid arguments are available as for the training-data survival curves. The Kaplan–Meier overlay (km = TRUE) is not available for prediction objects, since observed event times are only known for the training set.

Multi-Chain MCMC

Running multiple independent chains improves mixing diagnostics and reduces sensitivity to starting values. Pass n_chains > 1 to any model-fitting function; chains are run in parallel via parallel::mclapply on Unix-like systems.

set.seed(13)
fit_2chain <- ShrinkageTrees(
  y               = y_cont,
  X_train         = X_syn,
  outcome_type    = "continuous",
  prior_type      = "horseshoe",
  local_hp        = 0.1 / sqrt(5),
  global_hp       = 0.1 / sqrt(5),
  number_of_trees = 5,
  N_post          = 50,
  N_burn          = 25,
  n_chains        = 2,
  verbose         = FALSE
)

cat("n_chains stored  :", fit_2chain$mcmc$n_chains, "\n")
#> n_chains stored  : 2
cat("Total sigma draws:", length(fit_2chain$sigma),
    " (2 chains x 50 draws)\n")
#> Total sigma draws: 100  (2 chains x 50 draws)
cat("Per-chain acceptance ratios:\n")
#> Per-chain acceptance ratios:
print(round(fit_2chain$chains$acceptance_ratios, 3))
#> [1] 0.484 0.460

The same interface works for causal models.

set.seed(14)
fit_causal_2chain <- CausalShrinkageForest(
  y                         = y_causal,
  X_train_control           = X_syn,
  X_train_treat             = X_syn,
  treatment_indicator_train = W_syn,
  outcome_type              = "continuous",
  prior_type_control        = "horseshoe",
  prior_type_treat          = "horseshoe",
  local_hp_control          = lh,
  global_hp_control         = lh,
  local_hp_treat            = lh,
  global_hp_treat           = lh,
  number_of_trees_control   = 5,
  number_of_trees_treat     = 5,
  N_post                    = 50,
  N_burn                    = 25,
  n_chains                  = 2,
  verbose                   = FALSE
)

cat("Pooled sigma draws:", length(fit_causal_2chain$sigma), "\n")
#> Pooled sigma draws: 100
cat("Per-chain acceptance ratios (control):\n")
#> Per-chain acceptance ratios (control):
print(round(fit_causal_2chain$chains$acceptance_ratios_control, 3))
#> [1] 0.352 0.340

With multiple chains the traceplot shows one line per chain, and the overlaid density plot compares the marginal posterior of \(\sigma\) across chains — both require n_chains > 1.

plot(fit_2chain, type = "trace")

plot(fit_2chain, type = "density")

Convergence Diagnostics

MCMC methods require careful assessment of convergence before trusting the posterior summaries. Here are practical guidelines for ShrinkageTrees models.

Sigma traceplot

The traceplot of the error standard deviation \(\sigma\) (via plot(fit, type = "trace")) is the primary diagnostic. A well-mixing chain should show:

Acceptance ratio

The summary() output reports the average Metropolis–Hastings acceptance ratio for the tree structure proposals (grow/prune moves). As a rough guide:

Formal diagnostics with coda

When the suggested package coda is installed, summary() automatically reports effective sample size (ESS) and — for multi-chain fits — the Gelman–Rubin \(\hat{R}\).

# summary() includes convergence diagnostics when coda is available
summary(fit_2chain)
#> 
#> ShrinkageTrees model summary
#> ============================
#> Call: ShrinkageTrees(y = y_cont, X_train = X_syn, outcome_type = "continuous", 
#>     number_of_trees = 5, prior_type = "horseshoe", local_hp = 0.1/sqrt(5), 
#>     global_hp = 0.1/sqrt(5), N_post = 50, N_burn = 25, n_chains = 2, 
#>     verbose = FALSE)
#> 
#> Outcome: Continuous | Prior: horseshoe | Trees: 5
#> Data:    n = 80, p = 10 | Draws: 50 x 2 chains (burn-in 25)
#> 
#> Posterior sigma:
#>   Mean: 0.971  SD: 0.08  95% CI: [0.82, 1.131]
#> 
#> Predictions (posterior mean):
#>   Train: mean = 0.075, sd = 0.082, range = [-0.151, 0.201]
#>   Test:  mean = 0.055, sd = NA, range = [0.055, 0.055]
#> 
#> Variable importance (posterior inclusion probability):
#>   X3: 0.174   X8: 0.143   X5: 0.137   X2: 0.116   X7: 0.086   X1: 0.084   X10: 0.076   X4: 0.074   X6: 0.065   X9: 0.046 
#> 
#> MCMC acceptance ratio (per chain): 0.484, 0.46
#> 
#> Convergence diagnostics (coda):
#>   Gelman-Rubin R-hat:  = 1.019
#>   Effective sample size: sigma = 79

For more detailed diagnostics, convert the fitted object to a coda::mcmc.list with as.mcmc.list():

library(coda)
mcmc_obj <- as.mcmc.list(fit_2chain)

# Gelman-Rubin R-hat (values near 1 indicate convergence)
coda::gelman.diag(mcmc_obj)
#> Potential scale reduction factors:
#> 
#>       Point est. Upper C.I.
#> sigma       1.02       1.11

# Effective sample size
coda::effectiveSize(mcmc_obj)
#>    sigma 
#> 78.73458

# Geweke diagnostic (per chain)
coda::geweke.diag(mcmc_obj[[1]])
#> 
#> Fraction in 1st window = 0.1
#> Fraction in 2nd window = 0.5 
#> 
#>  sigma 
#> 0.5847

The returned mcmc.list object is compatible with all coda functions, including coda::autocorr.plot(), coda::gelman.plot(), coda::heidel.diag(), and coda::raftery.diag().

Case Study: TCGA PAAD (Full Analysis)

The full analysis of the pdac dataset replicates the case study from Jacobs, van Wieringen & van der Pas (2025). Due to the high-dimensional covariate space (~3,000 genes) and the large MCMC settings needed for reliable inference, the code below is provided for reference but is not evaluated during vignette building. Pre-computed results can be reproduced by running the pdac_analysis demo: demo("pdac_analysis", package = "ShrinkageTrees").

Step 1: Propensity score estimation

data("pdac")

time      <- pdac$time
status    <- pdac$status
treatment <- pdac$treatment
X         <- as.matrix(pdac[, !(names(pdac) %in% c("time","status","treatment"))])

set.seed(2025)
ps_fit <- HorseTrees(
  y            = treatment,
  X_train      = X,
  outcome_type = "binary",
  k            = 0.1,
  N_post       = 5000,
  N_burn       = 5000,
  verbose      = FALSE
)

propensity <- pnorm(ps_fit$train_predictions)
# Overlap plot
p0 <- propensity[treatment == 0]
p1 <- propensity[treatment == 1]

hist(p0, breaks = 15, col = rgb(1, 0.5, 0, 0.5), xlim = range(propensity),
     xlab = "Propensity score", main = "Propensity score overlap")
hist(p1, breaks = 15, col = rgb(0, 0.5, 0, 0.5), add = TRUE)
legend("topright", legend = c("Control", "Treated"),
       fill = c(rgb(1,0.5,0,0.5), rgb(0,0.5,0,0.5)))

Step 2: Causal survival forest

# Augment control matrix with propensity scores (BCF-style)
X_control <- cbind(propensity, X)

# Log-transform and centre survival times
log_time <- log(time) - mean(log(time))

set.seed(2025)
fit_pdac <- CausalHorseForest(
  y                         = log_time,
  status                    = status,
  X_train_control           = X_control,
  X_train_treat             = X,
  treatment_indicator_train = treatment,
  outcome_type              = "right-censored",
  timescale                 = "log",
  number_of_trees           = 200,
  N_post                    = 5000,
  N_burn                    = 5000,
  store_posterior_sample    = TRUE,
  verbose                   = FALSE
)

Step 3: ATE and CATE estimation

# Print model summary
print(fit_pdac)
smry_pdac <- summary(fit_pdac)
print(smry_pdac)

# ATE
cat("ATE posterior mean:",
    round(smry_pdac$treatment_effect$ate, 3), "\n")
cat("95% CI: [",
    round(smry_pdac$treatment_effect$ate_lower, 3), ",",
    round(smry_pdac$treatment_effect$ate_upper, 3), "]\n")

Step 4: Diagnostics

# Sigma convergence
plot(fit_pdac, type = "trace")

# Posterior ATE density
plot(fit_pdac, type = "ate")

# CATE caterpillar plot
plot(fit_pdac, type = "cate")

References

Chipman, H. A., George, E. I., & McCulloch, R. E. (2010). Bayesian Additive Regression Trees. Annals of Applied Statistics, 4(1), 266–298.

Hahn, P. R., Murray, J. S., & Carvalho, C. M. (2020). Bayesian Regression Tree Models for Causal Inference: Regularization, Confounding, and Heterogeneous Treatment Effects. Bayesian Analysis, 15(3), 965–1056.

Jacobs, T., van Wieringen, W. N., & van der Pas, S. L. (2025). Horseshoe Forests for High-Dimensional Causal Survival Analysis. arXiv preprint arXiv:2507.22004.

Linero, A. R. (2018). Bayesian Regression Trees for High-Dimensional Prediction and Variable Selection. Journal of the American Statistical Association, 113(522), 626–636.

Sparapani, R., Spanbauer, C., & McCulloch, R. (2021). Nonparametric Machine Learning and Efficient Computation with Bayesian Additive Regression Trees: The BART R Package. Journal of Statistical Software, 97(1), 1–66.