ShrinkageTrees is an R package that brings Bayesian Additive Regression Trees (BART; Chipman, George & McCulloch, 2010) to survival analysis and causal inference, with a particular focus on high-dimensional data.
The package implements BART-based models for right-censored and interval-censored survival outcomes using an accelerated failure time (AFT) formulation. Censored event times are handled through Bayesian data augmentation in the Gibbs sampler, enabling full posterior inference without proportional-hazards assumptions. For causal inference, the package provides Bayesian Causal Forests (BCF; Hahn, Murray & Carvalho, 2020), which decompose the outcome into a prognostic function \(\mu(\mathbf{x})\) and a treatment-effect function \(\tau(\mathbf{x})\), each estimated by a separate tree ensemble. This two-forest structure supports estimation of heterogeneous treatment effects (CATEs) and the average treatment effect (ATE).
A key feature is the availability of multiple regularisation strategies that can be freely combined within a single model:
| Function | Task | Prior |
|---|---|---|
HorseTrees() |
Prediction (continuous / binary / survival*) | Horseshoe |
ShrinkageTrees() |
Prediction — flexible prior choice* | Horseshoe, DART, BART, … |
SurvivalBART() |
Survival prediction* | Classical BART |
SurvivalDART() |
Sparse survival prediction* | DART (Dirichlet) |
SurvivalBCF() |
Causal survival inference* | BCF (classical) |
SurvivalShrinkageBCF() |
Sparse causal survival inference* | BCF + DART |
CausalHorseForest() |
Causal inference (all outcomes*) | Horseshoe |
CausalShrinkageForest() |
Causal inference — flexible prior* | Horseshoe, DART, BART, … |
* All survival functions support both right-censored and interval-censored outcomes.
All model-fitting functions return an S3 object with consistent
print(), summary(), predict(),
and plot() methods.
Before diving into examples, we clarify a few concepts that appear throughout the package interface.
timescale parameterEvery model-fitting function accepts an outcome_type
argument:
"continuous" — standard regression (default for most
functions)."binary" — probit BART for binary outcomes (0/1)."right-censored" — accelerated failure time model for
survival data. The outcome y contains (possibly censored)
follow-up times, and the status vector indicates events (1)
vs. censored observations (0)."interval-censored" — AFT model for interval-censored
survival data. Instead of y and status,
provide left_time and right_time vectors
specifying the lower and upper bounds of the observation window for each
individual. Three cases are distinguished:
left_time == right_time
(event observed exactly).left_time < right_time with finite
right_time (event occurred somewhere in the interval).right_time = Inf
(event not yet observed).survival::Surv(type = "interval2").For survival outcomes, the timescale argument controls
how the package treats the times:
timescale = "time" (default): the supplied values are
on the original time scale (positive numbers). The
package internally applies a log-transform, i.e. models \(\log(T) = f(\mathbf{x}) + \varepsilon\).
Predictions from summary() and predict() are
back-transformed to the time scale automatically.timescale = "log": the supplied values are
already log-transformed. No further transformation is
applied. Predictions stay on the log scale.In most applications you should use timescale = "time"
and pass the raw survival times directly.
In a BART ensemble each tree contributes a step height (leaf parameter) to the overall prediction. Classical BART assigns these step heights a fixed-variance Gaussian prior, which regularises all leaves equally. In high-dimensional settings, stronger and more adaptive regularisation is desirable. ShrinkageTrees implements two shrinkage priors that are placed directly on the step heights via a scale mixture of normals:
\[ h_\ell \mid \lambda_\ell, \tau, \omega \sim \mathcal{N}(0,\; \omega\, \lambda_\ell^2\, \tau^2). \]
Here \(\tau\) is a global shrinkage parameter shared across all leaves, \(\lambda_\ell\) is a local scale specific to leaf \(\ell\), and \(\omega\) is a fixed scaling constant. The two currently implemented instantiations are:
prior_type = "horseshoe").
Both \(\lambda_\ell\) and \(\tau\) receive independent half-Cauchy
priors. The heavy tails of the half-Cauchy allow individual leaves to
escape shrinkage when the data support a strong effect, while the global
parameter \(\tau\) pulls the bulk of
the estimates toward zero. This is the default prior in
HorseTrees() and CausalHorseForest().prior_type = "half-cauchy"). Only the local scales \(\lambda_\ell\) receive a half-Cauchy prior;
there is no global shrinkage parameter. This provides per-leaf
adaptivity without the additional pooling across the ensemble.A forest-wide variant of the horseshoe
(prior_type = "horseshoe_fw") shares a single global \(\tau\) across all trees in the forest
rather than one per tree. These priors can be selected in
ShrinkageTrees() and CausalShrinkageForest()
via the prior_type argument, and can be combined with
DART’s Dirichlet splitting prior for simultaneous structural and
parametric regularisation.
The two most important hyperparameters are local_hp and
global_hp. These control the horseshoe prior on the step
heights (leaf parameters):
\[ \mu_{jl} \mid \lambda_{jl}, \tau_j \sim \mathcal{N}(0, \lambda_{jl}^2 \tau_j^2), \qquad \lambda_{jl} \sim \text{C}^+(0, \texttt{local\_hp}), \qquad \tau_j \sim \text{C}^+(0, \texttt{global\_hp}), \]
where \(\text{C}^+\) denotes the half-Cauchy distribution. Smaller values produce stronger shrinkage toward zero; larger values allow more variation.
HorseTrees() and
CausalHorseForest() provide a convenience
parameter k that sets both scales automatically:
local_hp = global_hp = k / sqrt(number_of_trees). The
default k = 0.1 works well in many settings and is a good
starting point.
ShrinkageTrees() and
CausalShrinkageForest() expose
local_hp and global_hp directly (no
k). A common rule of thumb is:
local_hp = k / sqrt(number_of_trees) with
k in [0.05, 0.5].global_hp = local_hp (symmetric) or a larger value if
you want less overall shrinkage.The survival functions (SurvivalBART,
SurvivalDART) use k to calibrate the standard
BART leaf prior:
local_hp = range(log(y)) / (2 * k * sqrt(number_of_trees)).
The default k = 2 follows Chipman et al. (2010).
store_posterior_sample flagWhen store_posterior_sample = TRUE, the fitted object
stores the full \(N_\text{post} \times
n\) matrix of posterior draws for predictions. This is needed
for:
predict() on new data (it re-runs the sampler
internally, so posterior samples are always produced);plot(fit, type = "ate") and
plot(fit, type = "cate"), which require the full posterior
distribution;plot(fit, type = "survival") — full posterior credible
bands over both \(\mu_i\) and \(\sigma\) (without samples, only
sigma-uncertainty bands are available);When FALSE, only posterior means and \(\sigma\) draws are stored, saving memory.
print() and summary() work in both cases, but
predict() will not be available.
treatment_coding)All causal model functions — CausalHorseForest(),
CausalShrinkageForest(), SurvivalBCF(), and
SurvivalShrinkageBCF() — decompose the outcome as \[
Y_i = \mu(\mathbf{x}_i) + b_i \cdot \tau(\mathbf{x}_i) + \varepsilon_i,
\] where \(b_i\) is a scalar
that depends on the treatment assignment \(Z_i\). The treatment_coding
argument controls how \(b_i\) is
defined. Four options are available:
"centered" (default). \(b_i = Z_i - 1/2\), so that \(b_i \in \{-1/2,\; 1/2\}\). This is the
original BCF parameterisation.
"binary". \(b_i = Z_i\), so that \(b_i \in \{0,\; 1\}\). Standard binary
coding; the treatment forest captures the full effect of treatment on
the treated.
"adaptive". \(b_i = Z_i - \hat{e}(\mathbf{x}_i)\), where
\(\hat{e}(\mathbf{x}_i)\) is the
estimated propensity score. This follows Hahn, Murray & Carvalho
(2020) and is the coding used in the bcf R package. When
using this option, a propensity vector must be
supplied.
"invariant". Parameter-expanded
(invariant) treatment coding. The coding parameters \(b_0\) and \(b_1\) are assigned \(N(0,\; 1/2)\) priors and estimated within
the Gibbs sampler via conjugate normal updates: \[
Y_i = \mu(\mathbf{x}_i) + b_{Z_i} \cdot \tilde{\tau}(\mathbf{x}_i) +
\varepsilon_i,
\qquad b_0,\; b_1 \sim N(0,\; 1/2).
\] The treatment effect is \(\tau(\mathbf{x}_i) = (b_1 - b_0) \cdot
\tilde{\tau}(\mathbf{x}_i)\), and the posterior draws of \(b_0\) and \(b_1\) are returned in the fitted object.
This parameterisation is invariant to the coding of the treatment
indicator (Hahn et al., 2020, Section 5.2).
The examples below illustrate each option on a simple continuous-outcome causal model.
set.seed(50)
n_tc <- 60; p_tc <- 5
X_tc <- matrix(rnorm(n_tc * p_tc), n_tc, p_tc)
W_tc <- rbinom(n_tc, 1, 0.5)
tau_tc <- 1.5 * (X_tc[, 1] > 0)
y_tc <- X_tc[, 1] + W_tc * tau_tc + rnorm(n_tc, sd = 0.5)# Centered (default)
fit_tc_cen <- CausalHorseForest(
y = y_tc,
X_train_control = X_tc, X_train_treat = X_tc,
treatment_indicator_train = W_tc,
treatment_coding = "centered",
number_of_trees = 5, N_post = 50, N_burn = 25,
store_posterior_sample = TRUE, verbose = FALSE
)
cat("Centered — ATE:",
round(mean(fit_tc_cen$train_predictions_treat), 3), "\n")
#> Centered — ATE: 0.238# Binary
fit_tc_bin <- CausalHorseForest(
y = y_tc,
X_train_control = X_tc, X_train_treat = X_tc,
treatment_indicator_train = W_tc,
treatment_coding = "binary",
number_of_trees = 5, N_post = 50, N_burn = 25,
store_posterior_sample = TRUE, verbose = FALSE
)
cat("Binary — ATE:",
round(mean(fit_tc_bin$train_predictions_treat), 3), "\n")
#> Binary — ATE: 0.521# Adaptive (requires propensity scores)
ps_tc <- pnorm(0.3 * X_tc[, 1]) # simple propensity model for illustration
fit_tc_ada <- CausalHorseForest(
y = y_tc,
X_train_control = X_tc, X_train_treat = X_tc,
treatment_indicator_train = W_tc,
treatment_coding = "adaptive",
propensity = ps_tc,
number_of_trees = 5, N_post = 50, N_burn = 25,
store_posterior_sample = TRUE, verbose = FALSE
)
cat("Adaptive — ATE:",
round(mean(fit_tc_ada$train_predictions_treat), 3), "\n")
#> Adaptive — ATE: 0.154# Invariant (parameter-expanded)
fit_tc_inv <- CausalHorseForest(
y = y_tc,
X_train_control = X_tc, X_train_treat = X_tc,
treatment_indicator_train = W_tc,
treatment_coding = "invariant",
number_of_trees = 5, N_post = 50, N_burn = 25,
store_posterior_sample = TRUE, verbose = FALSE
)
cat("Invariant — ATE:",
round(mean(fit_tc_inv$train_predictions_treat), 3), "\n")
#> Invariant — ATE: 0.232
# Posterior draws of b0 and b1 are stored in the fitted object
cat("b0 posterior mean:", round(mean(fit_tc_inv$b0), 3), "\n")
#> b0 posterior mean: 0.068
cat("b1 posterior mean:", round(mean(fit_tc_inv$b1), 3), "\n")
#> b1 posterior mean: 0.153The survival functions inherit treatment_coding support.
For example,
SurvivalBCF(..., treatment_coding = "invariant") works out
of the box.
The package ships with two TCGA datasets for high-dimensional survival analysis and causal inference:
pdac — TCGA pancreatic ductal
adenocarcinoma (PAAD) cohort (n = 178). A data frame with overall
survival times, a binary treatment indicator (radiation therapy
vs. control), clinical covariates, and expression values of ~3,000 genes
selected by median absolute deviation.ovarian — TCGA ovarian cancer (OV)
cohort (n = 357). A list with X (357 x 2,000 gene
expression matrix, log2-normalised TPM) and clinical (data
frame with OS time/event, age, FIGO stage, tumor grade, and treatment:
carboplatin vs cisplatin). See ?ovarian for details.The pdac dataset contains overall survival times, a
binary treatment indicator (radiation therapy vs. control), clinical
covariates, and expression values of approximately 3,000 genes selected
by median absolute deviation.
library(ShrinkageTrees)
data("pdac")
# Dimensions and column overview
cat("Patients:", nrow(pdac), "\n")
#> Patients: 130
cat("Columns :", ncol(pdac), "\n")
#> Columns : 3032
cat("Clinical columns:", paste(names(pdac)[1:13], collapse = ", "), "\n")
#> Clinical columns: time, status, treatment, age, sex, grade, tumor.cellularity, tumor.purity, absolute.purity, moffitt.cluster, meth.leukocyte.percent, meth.purity.mode, stage
cat("Survival: time (months), censoring rate =",
round(1 - mean(pdac$status), 2), "\n")
#> Survival: time (months), censoring rate = 0.47
cat("Treatment: radiation =", sum(pdac$treatment),
"/ control =", sum(1 - pdac$treatment), "\n")
#> Treatment: radiation = 36 / control = 94We separate the outcome, treatment, and covariate matrix for the analyses below.
This section demonstrates the single-forest models in ShrinkageTrees. These models estimate a single function \(f(\mathbf{x})\) of the covariates, applicable to continuous, binary, and survival outcomes. We begin with a binary-outcome example that will also serve as the propensity score model for the causal analyses later.
Before fitting causal models we estimate propensity scores \(\hat{e}(\mathbf{x}) = P(W=1 \mid
\mathbf{x})\) using HorseTrees() with a binary
outcome. The probit link is used internally: predictions are on the
latent Gaussian scale and can be converted to probabilities with
pnorm().
The code block below uses reduced MCMC settings for illustration. A
real analysis would use N_post = 5000, N_burn = 5000.
ps_fit <- HorseTrees(
y = treatment,
X_train = X,
outcome_type = "binary",
k = 0.1,
N_post = 5000,
N_burn = 5000,
verbose = FALSE
)
propensity <- pnorm(ps_fit$train_predictions)For the remainder of this vignette we use a short synthetic run to keep build time low.
set.seed(1)
n <- 80; p <- 10
X_syn <- matrix(rnorm(n * p), n, p)
W_syn <- rbinom(n, 1, pnorm(0.8 * X_syn[, 1]))
ps_fit <- HorseTrees(
y = W_syn,
X_train = X_syn,
outcome_type = "binary",
number_of_trees = 5,
k = 0.5,
N_post = 50,
N_burn = 25,
verbose = FALSE
)
propensity_syn <- pnorm(ps_fit$train_predictions)
cat("Propensity scores — range: [",
round(range(propensity_syn), 3), "]\n")
#> Propensity scores — range: [ 0.469 0.539 ]HorseTrees() handles right-censored data via an AFT
model. Pass outcome_type = "right-censored" and provide the
status vector (1 = event, 0 = censored). When
timescale = "time" (the default), the package
log-transforms survival times internally and returns predictions on the
log scale (see Key Concepts above).
set.seed(2)
log_T <- X_syn[, 1] + rnorm(n)
C <- rexp(n, 0.5)
y_syn <- pmin(exp(log_T), C)
d_syn <- as.integer(exp(log_T) <= C)
ht_surv <- HorseTrees(
y = y_syn,
status = d_syn,
X_train = X_syn,
outcome_type = "right-censored",
timescale = "time",
number_of_trees = 5,
N_post = 50,
N_burn = 25,
store_posterior_sample = TRUE,
verbose = FALSE
)
cat("Posterior mean log-time (first 5 obs):",
round(ht_surv$train_predictions[1:5], 3), "\n")
#> Posterior mean log-time (first 5 obs): 1.204 1.319 1.354 1.389 1.247
cat("Posterior sigma — mean:",
round(mean(ht_surv$sigma), 3), "\n")
#> Posterior sigma — mean: 0.981When event times are not observed exactly but known to lie within an
interval, the package supports interval censoring.
Instead of providing y and status, pass
left_time and right_time with
outcome_type = "interval-censored".
The three censoring types are encoded as follows:
left_time[i] == right_time[i]left_time[i] < right_time[i] (both finite)right_time[i] = InfThis convention matches
survival::Surv(type = "interval2").
set.seed(20)
# Generate true event times
true_T <- rexp(n, rate = exp(-0.5 * X_syn[, 1]))
# Create interval-censored observations
left_syn <- true_T * runif(n, 0.5, 1.0)
right_syn <- true_T * runif(n, 1.0, 1.5)
# Mark some as exact observations and some as right-censored
exact_idx <- sample(n, 25)
left_syn[exact_idx] <- true_T[exact_idx]
right_syn[exact_idx] <- true_T[exact_idx]
rc_idx <- sample(setdiff(seq_len(n), exact_idx), 15)
right_syn[rc_idx] <- Inf
cat("Exact events:", sum(left_syn == right_syn), "\n")
#> Exact events: 25
cat("Interval-censored:", sum(left_syn < right_syn & is.finite(right_syn)), "\n")
#> Interval-censored: 40
cat("Right-censored:", sum(!is.finite(right_syn)), "\n")
#> Right-censored: 15
ht_ic <- HorseTrees(
left_time = left_syn,
right_time = right_syn,
X_train = X_syn,
outcome_type = "interval-censored",
timescale = "time",
number_of_trees = 5,
N_post = 50,
N_burn = 25,
store_posterior_sample = TRUE,
verbose = FALSE
)
cat("Posterior mean log-time (first 5 obs):",
round(ht_ic$train_predictions[1:5], 3), "\n")
#> Posterior mean log-time (first 5 obs): 0.666 0.698 0.596 0.785 0.726
cat("Posterior sigma — mean:",
round(mean(ht_ic$sigma), 3), "\n")
#> Posterior sigma — mean: 1.016All survival functions (SurvivalBART,
SurvivalDART, SurvivalBCF,
SurvivalShrinkageBCF) and the general-purpose functions
(ShrinkageTrees, CausalShrinkageForest,
CausalHorseForest) accept left_time and
right_time in the same way. For example, using
SurvivalBART:
While HorseTrees() fixes the prior to the horseshoe, the
more general ShrinkageTrees() function exposes the
prior_type argument, allowing the user to select among all
implemented regularisation strategies. Available options are
"horseshoe", "horseshoe_fw" (forest-wide),
"half-cauchy", "standard" (classical BART),
and "dirichlet" (DART). Below we compare the per-tree
horseshoe and the forest-wide horseshoe on a continuous outcome.
set.seed(3)
y_cont <- X_syn[, 1] + 0.5 * X_syn[, 2] + rnorm(n)
# Horseshoe prior (default for HorseTrees)
fit_hs <- ShrinkageTrees(
y = y_cont,
X_train = X_syn,
outcome_type = "continuous",
prior_type = "horseshoe",
local_hp = 0.1 / sqrt(5),
global_hp = 0.1 / sqrt(5),
number_of_trees = 5,
N_post = 50,
N_burn = 25,
verbose = FALSE
)
# Forest-wide horseshoe (horseshoe_fw)
fit_fw <- ShrinkageTrees(
y = y_cont,
X_train = X_syn,
outcome_type = "continuous",
prior_type = "horseshoe_fw",
local_hp = 0.1 / sqrt(5),
global_hp = 0.1 / sqrt(5),
number_of_trees = 5,
N_post = 50,
N_burn = 25,
verbose = FALSE
)
cat("Horseshoe — train RMSE:",
round(sqrt(mean((fit_hs$train_predictions - y_cont)^2)), 3), "\n")
#> Horseshoe — train RMSE: 1.217
cat("Horseshoe FW— train RMSE:",
round(sqrt(mean((fit_fw$train_predictions - y_cont)^2)), 3), "\n")
#> Horseshoe FW— train RMSE: 1.23SurvivalBART() and SurvivalDART() fit
classical BART and DART models for right-censored survival data under
the AFT formulation. They calibrate prior hyperparameters automatically
from the data range, providing a simple interface when horseshoe
shrinkage is not needed.
set.seed(4)
# SurvivalBART: classical BART prior, AFT likelihood
fit_sbart <- SurvivalBART(
time = y_syn,
status = d_syn,
X_train = X_syn,
number_of_trees = 5,
k = 2.0,
N_post = 50,
N_burn = 25,
verbose = FALSE
)
# SurvivalDART: Dirichlet (DART) splitting prior
fit_sdart <- SurvivalDART(
time = y_syn,
status = d_syn,
X_train = X_syn,
number_of_trees = 5,
k = 2.0,
N_post = 50,
N_burn = 25,
verbose = FALSE
)
cat("SurvivalBART class:", class(fit_sbart), "\n")
#> SurvivalBART class: ShrinkageTrees
cat("SurvivalDART class:", class(fit_sdart), "\n")
#> SurvivalDART class: ShrinkageTreesA key motivation for horseshoe shrinkage and the Dirichlet (DART) sparsity prior is their behaviour in the \(p \gg n\) regime: many covariates are available but only a small subset drives the outcome. Classical BART may struggle here because the standard Gaussian leaf prior is non-sparse and does not concentrate on a small number of predictors.
We illustrate both priors on a sparse AFT simulation: \(n = 60\) observations, \(p = 200\) predictors, and only three active predictors.
set.seed(20)
n_hd <- 60; p_hd <- 200
X_hd <- matrix(rnorm(n_hd * p_hd), n_hd, p_hd)
# True log-survival depends only on predictors 1, 2, and 3
log_T_hd <- 1.5 * X_hd[, 1] - 1.0 * X_hd[, 2] + 0.5 * X_hd[, 3] + rnorm(n_hd)
C_hd <- rexp(n_hd, rate = 0.5)
y_hd <- pmin(exp(log_T_hd), C_hd)
d_hd <- as.integer(exp(log_T_hd) <= C_hd)
cat("n =", n_hd, "| p =", p_hd,
"| active predictors = 3",
"| censoring rate =", round(1 - mean(d_hd), 2), "\n")
#> n = 60 | p = 200 | active predictors = 3 | censoring rate = 0.4ShrinkageTrees (horseshoe) places global–local shrinkage on the step heights of every leaf, automatically regularising all 200 predictors toward zero while preserving the signal in the three active ones.
set.seed(21)
fit_hd_hs <- ShrinkageTrees(
y = y_hd,
status = d_hd,
X_train = X_hd,
outcome_type = "right-censored",
prior_type = "horseshoe",
local_hp = 0.1 / sqrt(10),
global_hp = 0.1 / sqrt(10),
number_of_trees = 10,
N_post = 50,
N_burn = 25,
verbose = FALSE
)SurvivalDART uses a Dirichlet prior on split
probabilities to induce structural sparsity: after burn-in, most
splitting probability is concentrated on truly predictive variables.
Setting rho_dirichlet = 3 encodes the prior belief that
approximately three predictors are active.
set.seed(22)
fit_hd_dart <- SurvivalDART(
time = y_hd,
status = d_hd,
X_train = X_hd,
number_of_trees = 10,
rho_dirichlet = 3,
N_post = 50,
N_burn = 25,
verbose = FALSE
)Both models run without error in the \(p > n\) regime. We compare their posterior mean predictions in log-time against the latent true values used to generate the data.
rmse_hs <- sqrt(mean((fit_hd_hs$train_predictions - log_T_hd)^2))
rmse_dart <- sqrt(mean((fit_hd_dart$train_predictions - log_T_hd)^2))
cat(sprintf("%-18s train RMSE (log-time): %.3f\n", "Horseshoe", rmse_hs))
#> Horseshoe train RMSE (log-time): 2.530
cat(sprintf("%-18s train RMSE (log-time): %.3f\n", "DART", rmse_dart))
#> DART train RMSE (log-time): 2.402The DART model also produces variable importance plots that display the posterior distribution of each predictor’s splitting probability. With only 50 posterior draws the top-10 plot below should already concentrate most probability mass near the three truly active predictors.
For causal inference, ShrinkageTrees provides Bayesian Causal Forest (BCF) models that decompose the outcome into a prognostic component and a treatment effect component: \[ Y_i = \mu(\mathbf{x}_i) + W_i \cdot \tau(\mathbf{x}_i) + \varepsilon_i, \] where \(\mu(\cdot)\) is the prognostic (control) function modelled by one tree ensemble, and \(\tau(\cdot)\) is the heterogeneous treatment effect modelled by a second ensemble. This two-forest structure allows each component to have its own regularisation, number of trees, and prior — for instance, a standard BART prior for the prognostic forest and horseshoe shrinkage for the treatment effect forest.
The package provides four causal model functions with increasing
generality: SurvivalBCF() (classical BCF for survival),
SurvivalShrinkageBCF() (BCF + DART for survival),
CausalHorseForest() (horseshoe BCF for all outcome types),
and CausalShrinkageForest() (fully configurable BCF). We
illustrate each below on synthetic data with a known treatment
effect.
set.seed(5)
tau_true <- 1.5 * (X_syn[, 1] > 0) # heterogeneous treatment effect
y_causal <- X_syn[, 1] + W_syn * tau_true + rnorm(n, sd = 0.5)SurvivalBCF() fits a BCF model for right-censored
survival outcomes using classical BART priors.
# Full analysis (eval=FALSE — use larger MCMC settings in practice)
fit_sbcf <- SurvivalBCF(
time = time,
status = status,
X_train = X,
treatment = treatment,
propensity = propensity, # from HorseTrees above
N_post = 5000,
N_burn = 5000,
verbose = FALSE
)set.seed(6)
fit_sbcf <- SurvivalBCF(
time = y_syn,
status = d_syn,
X_train = X_syn,
treatment = W_syn,
number_of_trees_control = 5,
number_of_trees_treat = 5,
N_post = 50,
N_burn = 25,
verbose = FALSE
)
cat("SurvivalBCF class:", class(fit_sbcf), "\n")
#> SurvivalBCF class: CausalShrinkageForest
cat("ATE (posterior mean):",
round(mean(fit_sbcf$train_predictions_treat), 3), "\n")
#> ATE (posterior mean): 1.475SurvivalShrinkageBCF() extends BCF with a Dirichlet
splitting prior on both forests, inducing sparsity in high-dimensional
settings.
set.seed(7)
fit_ssbcf <- SurvivalShrinkageBCF(
time = y_syn,
status = d_syn,
X_train = X_syn,
treatment = W_syn,
number_of_trees_control = 5,
number_of_trees_treat = 5,
N_post = 50,
N_burn = 25,
verbose = FALSE
)
cat("SurvivalShrinkageBCF class:", class(fit_ssbcf), "\n")
#> SurvivalShrinkageBCF class: CausalShrinkageForestCausalHorseForest() is the primary novel contribution of
this package. It applies horseshoe shrinkage to the leaf parameters of
both the prognostic and treatment-effect forests. This enables effective
regularisation when many covariates are available but few are truly
predictive of heterogeneous treatment effects.
set.seed(8)
fit_chf <- CausalHorseForest(
y = y_causal,
X_train_control = X_syn,
X_train_treat = X_syn,
treatment_indicator_train = W_syn,
outcome_type = "continuous",
number_of_trees = 5,
N_post = 50,
N_burn = 25,
store_posterior_sample = TRUE,
verbose = FALSE
)
cat("CausalHorseForest class:", class(fit_chf), "\n")
#> CausalHorseForest class: CausalShrinkageForest
# Posterior mean CATE
cate_mean <- fit_chf$train_predictions_treat
cat("CATE — posterior mean (first 5):",
round(cate_mean[1:5], 3), "\n")
#> CATE — posterior mean (first 5): 1.48 1.559 1.504 1.281 1.421
# Posterior ATE
ate_samples <- rowMeans(fit_chf$train_predictions_sample_treat)
cat("ATE posterior mean:",
round(mean(ate_samples), 3),
" 95% CI: [",
round(quantile(ate_samples, 0.025), 3), ",",
round(quantile(ate_samples, 0.975), 3), "]\n")
#> ATE posterior mean: 1.422 95% CI: [ 0.567 , 2.154 ]The fitted object stores the posterior mean CATE for each training
observation in train_predictions_treat. When
store_posterior_sample = TRUE, the full posterior sample
matrix is available in train_predictions_sample_treat, from
which the posterior ATE distribution and credible intervals can be
computed as shown above.
You can also supply separate test matrices to obtain out-of-sample CATE predictions.
set.seed(9)
X_test <- matrix(rnorm(20 * p), 20, p)
W_test <- rbinom(20, 1, 0.5)
fit_chf_test <- CausalHorseForest(
y = y_causal,
X_train_control = X_syn,
X_train_treat = X_syn,
treatment_indicator_train = W_syn,
X_test_control = X_test,
X_test_treat = X_test,
treatment_indicator_test = W_test,
outcome_type = "continuous",
number_of_trees = 5,
N_post = 50,
N_burn = 25,
store_posterior_sample = TRUE,
verbose = FALSE
)
cat("Test CATE (first 5):",
round(fit_chf_test$test_predictions_treat[1:5], 3), "\n")
#> Test CATE (first 5): 1.599 1.503 1.615 1.497 1.5CausalShrinkageForest() is the most general causal model
interface. It allows independent prior choices for the prognostic and
treatment effect forests via prior_type_control and
prior_type_treat. For example, one could use a standard
BART prior for the prognostic forest (where variable selection is less
critical) and horseshoe shrinkage for the treatment forest (where most
covariates are expected to be irrelevant for the treatment effect).
set.seed(10)
lh <- 0.1 / sqrt(5)
fit_csf <- CausalShrinkageForest(
y = y_causal,
X_train_control = X_syn,
X_train_treat = X_syn,
treatment_indicator_train = W_syn,
outcome_type = "continuous",
prior_type_control = "horseshoe",
prior_type_treat = "horseshoe",
local_hp_control = lh,
global_hp_control = lh,
local_hp_treat = lh,
global_hp_treat = lh,
number_of_trees_control = 5,
number_of_trees_treat = 5,
N_post = 50,
N_burn = 25,
store_posterior_sample = TRUE,
verbose = FALSE
)
cat("CausalShrinkageForest class:", class(fit_csf), "\n")
#> CausalShrinkageForest class: CausalShrinkageForest
cat("Acceptance ratio (control):",
round(fit_csf$acceptance_ratio_control, 3), "\n")
#> Acceptance ratio (control): 0.348
cat("Acceptance ratio (treat) :",
round(fit_csf$acceptance_ratio_treat, 3), "\n")
#> Acceptance ratio (treat) : 0.248The horseshoe_fw prior adds a forest-wide shrinkage
parameter that is tracked in the fitted object.
set.seed(11)
fit_fw2 <- CausalShrinkageForest(
y = y_causal,
X_train_control = X_syn,
X_train_treat = X_syn,
treatment_indicator_train = W_syn,
outcome_type = "continuous",
prior_type_control = "horseshoe_fw",
prior_type_treat = "horseshoe_fw",
local_hp_control = lh,
global_hp_control = lh,
local_hp_treat = lh,
global_hp_treat = lh,
number_of_trees_control = 5,
number_of_trees_treat = 5,
N_post = 50,
N_burn = 25,
verbose = FALSE
)
cat("Forest-wide shrinkage (control, first 5 draws):\n")
#> Forest-wide shrinkage (control, first 5 draws):
print(round(fit_fw2$forestwide_shrinkage_control[1:5], 4))
#> [1] 1 1 1 1 1All fitted objects — whether from prediction models or causal models
— support a consistent set of S3 methods: print(),
summary(), predict(), and plot().
This section illustrates each method using the models fitted above.
Calling print() (or just typing the object name)
displays a concise model summary.
print(fit_chf)
#>
#> CausalShrinkageForest model
#> ---------------------------
#> Outcome type: Continuous
#> Training size (n): 80
#> Posterior draws: 50 (burn-in 25)
#> Posterior mean sigma: 0.821
#>
#> Control Treatment
#> ------------------- -------------------
#> Prior: horseshoe horseshoe
#> Number of trees: 5 5
#> Number of features: 10 10
#> Acceptance ratio: 0.428 0.404For causal models the output additionally shows the number of trees in each forest and prior details for both components.
print(fit_csf)
#>
#> CausalShrinkageForest model
#> ---------------------------
#> Outcome type: Continuous
#> Training size (n): 80
#> Posterior draws: 50 (burn-in 25)
#> Posterior mean sigma: 0.814
#>
#> Control Treatment
#> ------------------- -------------------
#> Prior: horseshoe horseshoe
#> Number of trees: 5 5
#> Number of features: 10 10
#> Acceptance ratio: 0.348 0.248summary() returns a structured list and displays a
richer description including posterior statistics for \(\sigma\), acceptance ratios, and treatment
effect estimates.
smry <- summary(fit_hs)
print(smry)
#>
#> ShrinkageTrees model summary
#> ============================
#> Call: ShrinkageTrees(y = y_cont, X_train = X_syn, outcome_type = "continuous",
#> number_of_trees = 5, prior_type = "horseshoe", local_hp = 0.1/sqrt(5),
#> global_hp = 0.1/sqrt(5), N_post = 50, N_burn = 25, verbose = FALSE)
#>
#> Outcome: Continuous | Prior: horseshoe | Trees: 5
#> Data: n = 80, p = 10 | Draws: 50 (burn-in 25)
#>
#> Posterior sigma:
#> Mean: 0.987 SD: 0.089 95% CI: [0.851, 1.205]
#>
#> Predictions (posterior mean):
#> Train: mean = 0.053, sd = 0.053, range = [-0.094, 0.159]
#> Test: mean = 0.065, sd = NA, range = [0.065, 0.065]
#>
#> Variable importance (posterior inclusion probability):
#> X3: 0.164 X9: 0.143 X4: 0.142 X2: 0.132 X8: 0.131 X5: 0.091 X7: 0.084 X6: 0.064 X10: 0.037 X1: 0.011
#>
#> MCMC acceptance ratio: 0.484
#>
#> Convergence diagnostics (coda):
#> Effective sample size: sigma = 50For causal models the summary includes the posterior ATE with a 95%
credible interval (when store_posterior_sample = TRUE).
smry_c <- summary(fit_chf)
print(smry_c)
#>
#> CausalShrinkageForest model summary
#> =====================================
#> Call: CausalHorseForest(y = y_causal, X_train_control = X_syn, X_train_treat = X_syn,
#> treatment_indicator_train = W_syn, outcome_type = "continuous",
#> number_of_trees = 5, N_post = 50, N_burn = 25, store_posterior_sample = TRUE,
#> verbose = FALSE)
#>
#> Outcome: Continuous
#> Prior: control = horseshoe, treatment = horseshoe
#> Trees: control = 5, treatment = 5
#> Data: n = 80, p_control = 10, p_treat = 10 | Draws: 50 (burn-in 25)
#>
#> Treatment effect:
#> PATE: 1.4224 95% CI (Bayesian bootstrap): [0.621, 2.1773]
#> CATE SD: 0.1203
#>
#> Prognostic function (mu):
#> Mean: 0.597 SD: 0.023 Range: [0.542, 0.662]
#>
#> Posterior sigma:
#> Mean: 0.821 SD: 0.078 95% CI: [0.695, 0.99]
#>
#> Variable importance - control forest (posterior inclusion probability):
#> X6: 0.195 X3: 0.162 X2: 0.137 X8: 0.119 X10: 0.098 X9: 0.076 X7: 0.073 X5: 0.07 X4: 0.037 X1: 0.033
#>
#> Variable importance - treatment forest (posterior inclusion probability):
#> X2: 0.177 X6: 0.173 X3: 0.123 X8: 0.11 X9: 0.098 X10: 0.087 X1: 0.086 X5: 0.068 X4: 0.046 X7: 0.032
#>
#> MCMC acceptance ratios: control = 0.428, treatment = 0.404
# Access the ATE directly
cat("ATE mean :", round(smry_c$treatment_effect$ate, 3), "\n")
#> ATE mean : 1.422
cat("ATE 95% CI: [",
round(smry_c$treatment_effect$ate_lower, 3), ",",
round(smry_c$treatment_effect$ate_upper, 3), "]\n")
#> ATE 95% CI: [ 0.621 , 2.177 ]By default the ATE credible interval is obtained by a Bayesian bootstrap: at each MCMC iteration \(s\) the observation-level CATEs \(\tau^{(s)}(x_i)\) are reweighted with Dirichlet(1, …, 1) weights,
\[ \widehat{\mathrm{PATE}}^{(s)} \;=\; \sum_{i=1}^n w_i^{(s)}\, \tau^{(s)}(x_i), \qquad (w_1^{(s)}, \dots, w_n^{(s)}) \sim \mathrm{Dir}(1, \dots, 1). \]
The collection \(\{\widehat{\mathrm{PATE}}^{(s)}\}\)
approximates the posterior of the population ATE and
therefore propagates uncertainty in both \(\tau(\cdot)\) and the covariate
distribution \(F_X\). Setting
bayesian_bootstrap = FALSE reverts to equal \(1/n\) weights, giving the mixed
ATE (MATE) that conditions on the observed covariates and has a
narrower credible interval.
smry_pate <- summary(fit_chf, bayesian_bootstrap = TRUE) # default
smry_mate <- summary(fit_chf, bayesian_bootstrap = FALSE)The standalone helper bayesian_bootstrap_ate() returns
both posteriors and their draws in a single list, and also works on a
CausalShrinkageForestPrediction returned by
predict() so that the PATE integrates over a prespecified
target population.
bb <- bayesian_bootstrap_ate(fit_chf)
cat("PATE:", round(bb$pate_mean, 3),
" 95% CI: [", round(bb$pate_ci$lower, 3), ",",
round(bb$pate_ci$upper, 3), "]\n")
#> PATE: 1.428 95% CI: [ 0.532 , 2.141 ]
cat("MATE:", round(bb$mate_mean, 3),
" 95% CI: [", round(bb$mate_ci$lower, 3), ",",
round(bb$mate_ci$upper, 3), "]\n")
#> MATE: 1.422 95% CI: [ 0.567 , 2.154 ]predict() computes the posterior predictive distribution
on new data. It returns a ShrinkageTreesPrediction object
with posterior mean and credible-interval vectors.
X_new <- matrix(rnorm(10 * p), 10, p)
pred <- predict(fit_hs, newdata = X_new)
print(pred)
#>
#> ShrinkageTrees predictions
#> --------------------------
#> Observations: 10
#> Credible interval: 95%
#> Scale: fitted value
#>
#> mean lower upper
#> -------- -------- --------
#> [ 1] -0.008 -0.398 0.248
#> [ 2] 0.009 -0.492 0.390
#> [ 3] 0.089 -0.312 0.546
#> [ 4] 0.045 -0.237 0.326
#> [ 5] 0.002 -0.385 0.433
#> [ 6] -0.022 -0.464 0.248
#> ... (4 more)# Point estimates and 95% credible intervals
head(data.frame(
mean = round(pred$mean, 3),
lower = round(pred$lower, 3),
upper = round(pred$upper, 3)
))
#> mean lower upper
#> 1 -0.008 -0.398 0.248
#> 2 0.009 -0.492 0.390
#> 3 0.089 -0.312 0.546
#> 4 0.045 -0.237 0.326
#> 5 0.002 -0.385 0.433
#> 6 -0.022 -0.464 0.248For causal models (CausalShrinkageForest and
CausalHorseForest), predict() returns three
sets of posterior summaries:
For survival models with timescale = "time", the
prognostic and total components are back-transformed to the original
time scale (via \(\exp(\cdot)\)), and
the CATE becomes a multiplicative time ratio: \(\exp(\tau) > 1\) means treatment
prolongs survival.
The predict() method requires two covariate matrices —
one for each forest — matching the columns used at fit time.
X_new_ctrl <- matrix(rnorm(10 * p), 10, p)
X_new_treat <- matrix(rnorm(10 * p), 10, p)
pred_c <- predict(fit_chf, newdata_control = X_new_ctrl,
newdata_treat = X_new_treat)
print(pred_c)
#>
#> CausalShrinkageForest predictions
#> ----------------------------------
#> Observations: 10
#> Credible interval: 95%
#> Outcome type: continuous
#>
#> PATE: 1.372 95% CI (Bayesian bootstrap): [0.34, 2.013]
#>
#> Prognostic (mu):
#> mean lower upper
#> -------- -------- --------
#> [ 1] 0.559 0.474 0.751
#> [ 2] 0.555 0.474 0.739
#> [ 3] 0.557 0.433 0.674
#> [ 4] 0.553 0.469 0.654
#> [ 5] 0.548 0.429 0.659
#> [ 6] 0.554 0.468 0.659
#> ... (4 more)
#>
#> CATE (tau):
#> mean lower upper
#> -------- -------- --------
#> [ 1] 1.380 0.205 2.122
#> [ 2] 1.366 0.179 2.014
#> [ 3] 1.332 0.137 2.135
#> [ 4] 1.351 0.121 2.077
#> [ 5] 1.350 0.205 2.049
#> [ 6] 1.374 0.079 2.064
#> ... (4 more)
#>
#> Total (mu + tau):
#> mean lower upper
#> -------- -------- --------
#> [ 1] 1.939 0.738 2.713
#> [ 2] 1.921 0.643 2.579
#> [ 3] 1.889 0.626 2.728
#> [ 4] 1.904 0.603 2.694
#> [ 5] 1.898 0.735 2.580
#> [ 6] 1.929 0.584 2.572
#> ... (4 more)# Extract individual components
head(data.frame(
prognostic = round(pred_c$prognostic$mean, 3),
cate = round(pred_c$cate$mean, 3),
total = round(pred_c$total$mean, 3)
))
#> prognostic cate total
#> 1 0.559 1.380 1.939
#> 2 0.555 1.366 1.921
#> 3 0.557 1.332 1.889
#> 4 0.553 1.351 1.904
#> 5 0.548 1.350 1.898
#> 6 0.554 1.374 1.929The plot() method produces diagnostic and inferential
graphics using the ggplot2 package (a suggested
dependency).
The ATE density uses the Bayesian-bootstrap PATE posterior by
default; pass bayesian_bootstrap = FALSE to plot the
(narrower) mixed ATE density instead. See the summary section above for
the definitions.
Variable importance plots are available when
prior_type = "dirichlet". For causal models with
prior_type_control = "dirichlet" or
prior_type_treat = "dirichlet", use
forest = "control", "treat", or
"both".
set.seed(12)
fit_dart <- ShrinkageTrees(
y = y_cont,
X_train = X_syn,
outcome_type = "continuous",
prior_type = "dirichlet",
local_hp = 0.1 / sqrt(5),
number_of_trees = 5,
N_post = 50,
N_burn = 25,
verbose = FALSE
)For survival models (outcome_type = "right-censored" or
"interval-censored"), the plot() method can
draw posterior survival curves derived from the fitted AFT log-normal
model: \[
S(t \mid \mathbf{x}_i) = 1 - \Phi\!\left(\frac{\log t -
\mu_i}{\sigma}\right),
\] where \(\mu_i =
f(\mathbf{x}_i)\) is the BART ensemble prediction and \(\sigma\) is the residual standard deviation
on the log-time scale.
The type = "survival" option supports two modes
controlled by the obs argument:
obs = NULL,
the default): computes \(\bar{S}(t) =
n^{-1}\sum_i S(t \mid \mathbf{x}_i)\) at each MCMC iteration,
giving credible bands that reflect full posterior uncertainty.obs = c(1, 5, ...)): one curve per selected training
observation with its own credible band.Additional options:
| Argument | Description |
|---|---|
level |
Width of the pointwise credible band (default
0.95). |
t_grid |
Custom time grid (original scale). Auto-generated if
NULL. |
km |
If TRUE, overlay the Kaplan–Meier estimate
(population-average only). |
We use the survival fit from the earlier section:
# Same curve with the Kaplan-Meier estimate overlaid for comparison
plot(ht_surv, type = "survival", km = TRUE)# Individual survival curves for observations 1, 20, 40, 60, and 80
plot(ht_surv, type = "survival", obs = c(1, 20, 40, 60, 80))# Single individual with a narrower 90% credible band
plot(ht_surv, type = "survival", obs = 1, level = 0.90)When store_posterior_sample = FALSE, the credible bands
only reflect uncertainty in \(\sigma\)
(using plug-in posterior mean \(\hat{\mu}_i\)). The survival functions
(SurvivalBART, SurvivalDART, etc.) store
posterior samples by default, so full posterior bands are available out
of the box.
The survival curves above are based on the training
data — they show \(S(t \mid
\mathbf{x}_i)\) for the observations used to fit the model. For
new (out-of-sample) data, call predict()
first and then plot() on the prediction object. This
produces posterior predictive survival curves that propagate
full parameter uncertainty through to the new covariate values:
# New observations for prediction
set.seed(99)
X_new <- matrix(rnorm(20 * p), ncol = p)
pred_surv <- predict(ht_surv, newdata = X_new)# Individual posterior predictive curves for selected new observations
plot(pred_surv, type = "survival", obs = c(1, 5, 10))The same level and t_grid arguments are
available as for the training-data survival curves. The Kaplan–Meier
overlay (km = TRUE) is not available for prediction
objects, since observed event times are only known for the training
set.
Running multiple independent chains improves mixing diagnostics and
reduces sensitivity to starting values. Pass
n_chains > 1 to any model-fitting function; chains are
run in parallel via parallel::mclapply on Unix-like
systems.
set.seed(13)
fit_2chain <- ShrinkageTrees(
y = y_cont,
X_train = X_syn,
outcome_type = "continuous",
prior_type = "horseshoe",
local_hp = 0.1 / sqrt(5),
global_hp = 0.1 / sqrt(5),
number_of_trees = 5,
N_post = 50,
N_burn = 25,
n_chains = 2,
verbose = FALSE
)
cat("n_chains stored :", fit_2chain$mcmc$n_chains, "\n")
#> n_chains stored : 2
cat("Total sigma draws:", length(fit_2chain$sigma),
" (2 chains x 50 draws)\n")
#> Total sigma draws: 100 (2 chains x 50 draws)
cat("Per-chain acceptance ratios:\n")
#> Per-chain acceptance ratios:
print(round(fit_2chain$chains$acceptance_ratios, 3))
#> [1] 0.484 0.460The same interface works for causal models.
set.seed(14)
fit_causal_2chain <- CausalShrinkageForest(
y = y_causal,
X_train_control = X_syn,
X_train_treat = X_syn,
treatment_indicator_train = W_syn,
outcome_type = "continuous",
prior_type_control = "horseshoe",
prior_type_treat = "horseshoe",
local_hp_control = lh,
global_hp_control = lh,
local_hp_treat = lh,
global_hp_treat = lh,
number_of_trees_control = 5,
number_of_trees_treat = 5,
N_post = 50,
N_burn = 25,
n_chains = 2,
verbose = FALSE
)
cat("Pooled sigma draws:", length(fit_causal_2chain$sigma), "\n")
#> Pooled sigma draws: 100
cat("Per-chain acceptance ratios (control):\n")
#> Per-chain acceptance ratios (control):
print(round(fit_causal_2chain$chains$acceptance_ratios_control, 3))
#> [1] 0.352 0.340With multiple chains the traceplot shows one line per chain, and the
overlaid density plot compares the marginal posterior of \(\sigma\) across chains — both require
n_chains > 1.
MCMC methods require careful assessment of convergence before trusting the posterior summaries. Here are practical guidelines for ShrinkageTrees models.
The traceplot of the error standard deviation \(\sigma\) (via
plot(fit, type = "trace")) is the primary diagnostic. A
well-mixing chain should show:
N_burn.n_chains > 1): separate chains should overlap
substantially. The density overlay
(plot(fit, type = "density")) makes this easy to check
visually.The summary() output reports the average
Metropolis–Hastings acceptance ratio for the tree structure proposals
(grow/prune moves). As a rough guide:
N_burn and N_post, or relaxing the
tree structure prior (lower power, higher
base).When the suggested package coda is installed,
summary() automatically reports effective sample
size (ESS) and — for multi-chain fits — the
Gelman–Rubin \(\hat{R}\).
# summary() includes convergence diagnostics when coda is available
summary(fit_2chain)
#>
#> ShrinkageTrees model summary
#> ============================
#> Call: ShrinkageTrees(y = y_cont, X_train = X_syn, outcome_type = "continuous",
#> number_of_trees = 5, prior_type = "horseshoe", local_hp = 0.1/sqrt(5),
#> global_hp = 0.1/sqrt(5), N_post = 50, N_burn = 25, n_chains = 2,
#> verbose = FALSE)
#>
#> Outcome: Continuous | Prior: horseshoe | Trees: 5
#> Data: n = 80, p = 10 | Draws: 50 x 2 chains (burn-in 25)
#>
#> Posterior sigma:
#> Mean: 0.971 SD: 0.08 95% CI: [0.82, 1.131]
#>
#> Predictions (posterior mean):
#> Train: mean = 0.075, sd = 0.082, range = [-0.151, 0.201]
#> Test: mean = 0.055, sd = NA, range = [0.055, 0.055]
#>
#> Variable importance (posterior inclusion probability):
#> X3: 0.174 X8: 0.143 X5: 0.137 X2: 0.116 X7: 0.086 X1: 0.084 X10: 0.076 X4: 0.074 X6: 0.065 X9: 0.046
#>
#> MCMC acceptance ratio (per chain): 0.484, 0.46
#>
#> Convergence diagnostics (coda):
#> Gelman-Rubin R-hat: = 1.019
#> Effective sample size: sigma = 79For more detailed diagnostics, convert the fitted object to a
coda::mcmc.list with as.mcmc.list():
library(coda)
mcmc_obj <- as.mcmc.list(fit_2chain)
# Gelman-Rubin R-hat (values near 1 indicate convergence)
coda::gelman.diag(mcmc_obj)
#> Potential scale reduction factors:
#>
#> Point est. Upper C.I.
#> sigma 1.02 1.11
# Effective sample size
coda::effectiveSize(mcmc_obj)
#> sigma
#> 78.73458
# Geweke diagnostic (per chain)
coda::geweke.diag(mcmc_obj[[1]])
#>
#> Fraction in 1st window = 0.1
#> Fraction in 2nd window = 0.5
#>
#> sigma
#> 0.5847The returned mcmc.list object is compatible with all
coda functions, including
coda::autocorr.plot(), coda::gelman.plot(),
coda::heidel.diag(), and
coda::raftery.diag().
The examples in this vignette use very small N_post and
N_burn to keep build time low. For a real analysis:
N_post = 2000, N_burn = 2000.N_post = 5000, N_burn = 5000.N_post = 5000, N_burn = 10000 (the AFT data augmentation
step can slow mixing, so a longer burn-in helps).n_chains = 2 or
4 to verify convergence and produce pooled posterior
samples.The full analysis of the pdac dataset replicates the
case study from Jacobs, van Wieringen & van der Pas (2025). Due to
the high-dimensional covariate space (~3,000 genes) and the large MCMC
settings needed for reliable inference, the code below is provided for
reference but is not evaluated during vignette building. Pre-computed
results can be reproduced by running the pdac_analysis
demo:
demo("pdac_analysis", package = "ShrinkageTrees").
data("pdac")
time <- pdac$time
status <- pdac$status
treatment <- pdac$treatment
X <- as.matrix(pdac[, !(names(pdac) %in% c("time","status","treatment"))])
set.seed(2025)
ps_fit <- HorseTrees(
y = treatment,
X_train = X,
outcome_type = "binary",
k = 0.1,
N_post = 5000,
N_burn = 5000,
verbose = FALSE
)
propensity <- pnorm(ps_fit$train_predictions)# Overlap plot
p0 <- propensity[treatment == 0]
p1 <- propensity[treatment == 1]
hist(p0, breaks = 15, col = rgb(1, 0.5, 0, 0.5), xlim = range(propensity),
xlab = "Propensity score", main = "Propensity score overlap")
hist(p1, breaks = 15, col = rgb(0, 0.5, 0, 0.5), add = TRUE)
legend("topright", legend = c("Control", "Treated"),
fill = c(rgb(1,0.5,0,0.5), rgb(0,0.5,0,0.5)))# Augment control matrix with propensity scores (BCF-style)
X_control <- cbind(propensity, X)
# Log-transform and centre survival times
log_time <- log(time) - mean(log(time))
set.seed(2025)
fit_pdac <- CausalHorseForest(
y = log_time,
status = status,
X_train_control = X_control,
X_train_treat = X,
treatment_indicator_train = treatment,
outcome_type = "right-censored",
timescale = "log",
number_of_trees = 200,
N_post = 5000,
N_burn = 5000,
store_posterior_sample = TRUE,
verbose = FALSE
)Chipman, H. A., George, E. I., & McCulloch, R. E. (2010). Bayesian Additive Regression Trees. Annals of Applied Statistics, 4(1), 266–298.
Hahn, P. R., Murray, J. S., & Carvalho, C. M. (2020). Bayesian Regression Tree Models for Causal Inference: Regularization, Confounding, and Heterogeneous Treatment Effects. Bayesian Analysis, 15(3), 965–1056.
Jacobs, T., van Wieringen, W. N., & van der Pas, S. L. (2025). Horseshoe Forests for High-Dimensional Causal Survival Analysis. arXiv preprint arXiv:2507.22004.
Linero, A. R. (2018). Bayesian Regression Trees for High-Dimensional Prediction and Variable Selection. Journal of the American Statistical Association, 113(522), 626–636.
Sparapani, R., Spanbauer, C., & McCulloch, R. (2021). Nonparametric Machine Learning and Efficient Computation with Bayesian Additive Regression Trees: The BART R Package. Journal of Statistical Software, 97(1), 1–66.