summary(), plot(type = "ate"), and
predict() for CausalShrinkageForest (and
CausalHorseForest) now default to a Bayesian-bootstrap
posterior for the average treatment effect: at each MCMC iteration the
per-observation CATEs are reweighted with Dirichlet(1, …, 1) weights
before being summed, giving a draw from the posterior of the
population ATE (PATE). Credible intervals are correspondingly
wider than before because they now propagate uncertainty in the
covariate distribution, not only in tau(x).
bayesian_bootstrap = TRUE on
summary.CausalShrinkageForest(),
plot.CausalShrinkageForest() (for
type = "ate"), and
predict.CausalShrinkageForest(). Set to FALSE
to recover the previous equal-weight mixed ATE (MATE).predict() now also returns an ate summary
for the new data and retains the posterior CATE sample matrix as
cate_samples.bayesian_bootstrap_ate() returns
PATE and MATE summaries (means, CIs, and full posterior draws) from
either a fitted CausalShrinkageForest or a
CausalShrinkageForestPrediction.This is a breaking change for printed/plotted numerics:
existing scripts will report wider CIs than before. Use
bayesian_bootstrap = FALSE to reproduce previous
output.
ovarian dataset
restructuredThe ovarian dataset is now a single data frame
(previously a list with $clinical and $X
elements). Clinical columns (OS_time,
OS_event, treatment, age,
figo_stage, tumor_grade) and the 2000 gene
expression columns are combined into one data frame with 2006 columns.
This simplifies data access and aligns the format with the
pdac dataset.
Code that previously used ovarian$clinical or
ovarian$X must be updated — see ?ovarian for
the new structure.
plot(fit, type = "ate") and
plot(fit, type = "cate") for
CausalShrinkageForest models incorrectly subtracting the
control forest predictions from the treatment forest predictions.ovarian)Added the ovarian dataset: a processed TCGA-OV cohort (n
= 357) for high-dimensional survival prediction and causal
inference.
See ?ovarian and examples/test-ovarian.R
for a full worked example covering survival prediction (SurvivalBART,
SurvivalDART, HorseTrees) and causal inference (SurvivalBCF,
SurvivalShrinkageBCF, CausalHorseForest).
treatment_coding)All causal model functions — CausalHorseForest(),
CausalShrinkageForest(), SurvivalBCF(), and
SurvivalShrinkageBCF() — now accept a
treatment_coding argument controlling how the treatment
indicator enters the BCF decomposition y = f(x) + b * tau(x) + epsilon.
Four options are available:
"centered" (default): b_i in {-1/2, 1/2}. This is the
original behaviour."binary": b_i in {0, 1}. Standard binary coding."adaptive": b_i = z_i - e_hat(x_i), where e_hat(x_i) is
the estimated propensity score. This follows Hahn, Murray & Carvalho
(2020) and is implemented in the bcf package. Requires a
propensity vector."invariant": Parameter-expanded (invariant) treatment
coding. The coding parameters b_0 and b_1 are assigned N(0, 1/2) priors
and estimated within the Gibbs sampler via conjugate normal updates,
yielding a parameterisation that is invariant to the coding of the
treatment indicator (Hahn et al., 2020, Section 5.2). The treatment
effect is tau(x) = (b_1 - b_0) * tau_tilde(x). Posterior draws of b_0
and b_1 are returned in the fitted object.The predict() method for
CausalShrinkageForest objects automatically carries forward
the treatment coding used at training time. A
propensity_test argument is available for supplying
test-set propensity scores (defaults to 0.5).
All survival-capable functions now support
interval-censored data in addition to right-censored
data. Supply left_time and right_time vectors
(with outcome_type = "interval-censored") instead of
y and status. Three censoring types are
distinguished:
left_time == right_time.left_time < right_time.right_time = Inf.This convention follows
survival::Surv(type = "interval2"). Censored event times
are augmented within the AFT Gibbs sampler. The following functions are
affected:
HorseTrees(), ShrinkageTrees()
(single-forest models)CausalHorseForest(),
CausalShrinkageForest() (causal models)SurvivalBART(), SurvivalDART(),
SurvivalBCF(), SurvivalShrinkageBCF()
(survival wrappers)n_chains)All four primary model-fitting functions —
ShrinkageTrees, HorseTrees,
CausalHorseForest, and CausalShrinkageForest —
now accept an n_chains argument (default 1).
When n_chains > 1:
parallel::mclapply (falls back to sequential execution on
Windows).min(n_chains, parallel::detectCores()).N_post * n_chains total draws.ShrinkageTrees or
CausalShrinkageForest instance, so all existing
print, summary, and predict
methods work without modification.SurvivalBART,
SurvivalDART, SurvivalBCF, and
SurvivalShrinkageBCF inherit n_chains support
through ....print and summary output adapts
automatically: single-chain models show Posterior draws,
multi-chain models show Chains and Draws per chain,
with per-chain acceptance ratios listed separately.ShrinkageTrees and
CausalShrinkageForest with constructors in
constructors.R.print methods for both classes, displaying model
specification, MCMC settings, acceptance ratio, and posterior mean
sigma.summary methods for both classes, returning an
inspectable object with posterior sigma (mean, SD, 95% CI), prediction
summaries, variable importance (posterior inclusion probabilities), and
— for causal models — ATE with credible interval (when
store_posterior_sample = TRUE) and CATE heterogeneity.predict method for ShrinkageTrees,
enabling posterior predictive inference on new data by re-running the
sampler with stored training data and hyperparameters. Returns a
ShrinkageTreesPrediction object with posterior mean and
credible interval bounds. For survival models, the prediction object
additionally stores predictions_sample (full posterior
draws on the original scale) and sigma (posterior draws on
the log-time scale), enabling posterior predictive survival curve
plotting.predict method for
CausalShrinkageForest, returning a
CausalShrinkageForestPrediction object with three
components: prognostic (\(\mu(X)\)), cate (\(\tau(X)\)), and total (\(\mu(X) + \tau(X)\)), each with posterior
mean and credible interval bounds. For survival models with
timescale = "time", predictions are back-transformed to the
original time scale and the CATE becomes a multiplicative time
ratio.print and summary methods for
ShrinkageTreesPrediction and
CausalShrinkageForestPrediction.coda)as.mcmc.list() S3 method for
ShrinkageTrees objects, converting the sigma posterior
(split by chain) into a coda::mcmc.list. This enables all
standard coda diagnostics: Gelman–Rubin R-hat, effective
sample size, Geweke test, Heidelberger–Welch test, autocorrelation
plots, and more.summary() now automatically reports effective
sample size (ESS) and — for multi-chain fits — the
Gelman–Rubin R-hat when the suggested package
coda is installed.coda to Suggests in
DESCRIPTION.plot)S3 plot() methods added for ShrinkageTrees,
CausalShrinkageForest, and
ShrinkageTreesPrediction. Requires the suggested package
ggplot2.
plot(fit, type = "trace") — sigma traceplot; one line
per chain, useful for assessing mixing.plot(fit, type = "density") — overlaid posterior
density of sigma, one curve per chain.plot(fit, type = "vi") — posterior credible intervals
for variable inclusion probabilities (top n_vi
predictors).plot(fit, type = "ate") — posterior density of the ATE
with 95 % credible region (causal models only; requires
store_posterior_sample = TRUE).plot(fit, type = "cate") — point estimates and 95 %
credible intervals for the CATE of each training observation, sorted by
posterior mean (causal models only; requires
store_posterior_sample = TRUE).plot(fit, type = "vi", forest = "both") — side-by-side
VI for the control and treatment forests (causal models
only).plot(fit, type = "survival") — posterior survival
curves \(S(t | x_i) = 1 - \Phi((\log t -
\mu_i) / \sigma)\) derived from the AFT log-normal model
(survival outcomes only).
obs = NULL): computes \(\bar{S}(t) = n^{-1} \sum_i S(t | x_i)\) at
each MCMC iteration with pointwise credible bands.obs = c(1, 5, ...)): one curve per selected training
observation with its own credible band.level controls the credible band width (default
0.95).t_grid allows a custom time grid; auto-generated if
NULL.km = TRUE overlays the Kaplan–Meier estimate as a
dashed black step function (population-averaged plot only; requires
survival package). Ignored with a message when
obs is not NULL.plot(pred, type = "survival") — posterior
predictive survival curves for new (out-of-sample) data
from predict(). Same obs, t_grid,
and level arguments as above. The KM overlay is not
available for prediction objects.HorseTrees,
ShrinkageTrees, SurvivalBART,
SurvivalDART, SurvivalBCF,
SurvivalShrinkageBCF, CausalHorseForest,
CausalShrinkageForest), all S3 methods (print,
summary, predict, plot),
multi-chain MCMC, and a full TCGA PAAD case study.SurvivalBART(), SurvivalDART(),
SurvivalBCF(), and SurvivalShrinkageBCF() now
accept store_posterior_sample as an explicit parameter
(default TRUE), avoiding a “matched by multiple actual
arguments” error when passing it via ....CausalHorseForest and
CausalShrinkageForest failing with
"argument 'y_train' is missing" when called directly
(broken constructor call introduced in 2.0.0 S3 refactor).plot(..., type = "vi") crashing with
"argument must be coercible to non-negative integer" in all
four model functions: covariate matrices were being stored as flat
numeric vectors instead of matrices, making ncol() return
NULL.HorseTrees and
ShrinkageTrees where sigma_hat,
y_mean, and lambda were not initialised in the
binary (probit) branch.SurvivalBCF wrapper for AFT-based Bayesian Causal
Forests.SurvivalShrinkageBCF with Dirichlet structural
sparsity.SurvivalDART for sparse high-dimensional survival
modeling.SurvivalBART for survival modeling using standard
BART.....std::map structure with a more
efficient vector-based lookup, improving overall computational speed by
approximately 30%.standard prior type option for
ShrinkageTrees, corresponding to the conventional BART
implementation without reversible jumps.CONTRIBUTING.md file with guidelines for
contributors🎉 First CRAN release of ShrinkageTrees!
This package provides Bayesian regression tree models with shrinkage priors, supporting:
It includes four core functions:
HorseTrees(): fits a single regression tree with a
standard Horseshoe prior.ShrinkageTrees(): fits a single tree with customizable
shrinkage priors.CausalHorseForest(): fits a causal forest using the
standard Horseshoe prior.CausalShrinkageForest(): fits a flexible causal forest
with user-defined shrinkage priors and tuning options.The ...Trees functions use a single learner to estimate
the outcome model directly. In contrast, the
Causal...Forest variants fit separate models for the
treated and control regression function. This enables estimation of
conditional average treatment effects (CATEs).