| Title: | Cross-Fitting Engine for Double/Debiased Machine Learning |
| Version: | 0.1.1 |
| Description: | Provides a general cross-fitting engine for double / debiased machine learning and other meta-learners. The core functions implement flexible graphs of nuisance models with per-node training fold widths, target-specific evaluation windows, and several fold allocation schemes ("independence", "overlap", "disjoint"). The engine supports both numeric estimators (mode = "estimate") and cross-fitted prediction functions (mode = "predict"), with configurable aggregation over panels and repetitions. |
| License: | GPL-3 |
| URL: | https://github.com/EtiennePeyrot/crossfit-R |
| BugReports: | https://github.com/EtiennePeyrot/crossfit-R/issues |
| Encoding: | UTF-8 |
| RoxygenNote: | 7.3.1 |
| Depends: | R (≥ 4.1.0) |
| Imports: | stats, utils |
| Suggests: | knitr, rmarkdown, testthat (≥ 3.0.0) |
| Config/testthat/edition: | 3 |
| VignetteBuilder: | knitr |
| NeedsCompilation: | no |
| Packaged: | 2026-02-16 15:52:07 UTC; skoua |
| Author: | Etienne Peyrot |
| Maintainer: | Etienne Peyrot <etienne.peyrot@inserm.fr> |
| Repository: | CRAN |
| Date/Publication: | 2026-02-19 20:00:08 UTC |
crossfit: Cross-Fitting Engine for Double / Debiased Machine Learning
Description
Provides a general cross-fitting engine for double / debiased machine learning and other meta-learners. The core functions implement flexible graphs of nuisance models with per-node training fold widths, target-specific evaluation windows, and several fold allocation schemes ("independence", "overlap", "disjoint"). The engine supports both numeric estimators (mode = "estimate") and cross-fitted prediction functions (mode = "predict"), with configurable aggregation over panels and repetitions.
Author(s)
Maintainer: Etienne Peyrot etienne.peyrot@inserm.fr (ORCID)
See Also
Useful links:
Report bugs at https://github.com/EtiennePeyrot/crossfit-R/issues
Internal: map nuisance names to instance keys for each node
Description
Builds, for each instance, a mapping from nuisance names (as used in
fit_deps_names / pred_deps_names) to child instance
keys. This allows the engine to route predictions correctly from
child nuisances into parent fit() and predict()
calls.
Usage
build_child_maps(insts)
Arguments
insts |
A named list of instances as produced by
|
Value
A list with two components:
fitNamed list mapping instance keys to named character vectors (args
\tochild instance keys) forfit()dependencies.predSame for
predict()dependencies.
Internal: build per-method instance graph and fold geometry
Description
Given a list of validated methods (as returned by
validate_batch), constructs the global instance graph
used by the cross-fitting engine. Each nuisance (including the
synthetic "__TARGET__" node) is expanded into one or more
instances, depending on the fold allocation strategy.
Usage
build_instances(methods)
Arguments
methods |
A list of fully validated and normalized method
specifications, typically the output of
|
Details
This function:
builds a DAG of instances per method, starting from
"__TARGET__",computes per-method evaluation width (
eval_width),assigns training window offsets (
inst_offset) according tofold_allocation,determines the minimal required number of folds (
K_required) per method and harmonizes a globalK,constructs child maps (via
build_child_maps),computes structural signatures and marks instances that are worth caching.
The resulting "plan" object is consumed by the core engine
(crossfit_multi()).
Value
A list with components including (but not limited to):
methodsThe (possibly updated) methods list, with
foldsharmonized.instancesNamed list of instance nodes.
rootsPer-method root instance keys (corresponding to
"__TARGET__").topoInstance keys in topological order.
method_inst_keysPer-method instance keys in topological order.
eval_widthPer-method evaluation window width.
inst_offsetPer-instance training window offset (in folds).
K_requiredPer-method minimal required number of folds.
KGlobal number of folds used in the plan.
child_mapsList of fit/predict child maps as returned by
build_child_maps.method_structsPer-method set of structural signatures used by that method.
method_fold_allocationPer-method
fold_allocationvalues.
Internal: check for cycles in the nuisance graph
Description
Validates that the combined dependency graph over all nuisances is
acyclic. The graph edges are obtained from the union of
fit_deps and pred_deps for each nuisance (including the
synthetic "__TARGET__" node).
Usage
check_cycles(nfs)
Arguments
nfs |
A named list of normalized nuisance specifications. Each
element must contain components |
Details
If a cycle is detected, an error is thrown with a textual
representation of the cycle. Otherwise the function returns
TRUE invisibly.
Value
Invisibly returns TRUE on success, or throws an error
if a cycle is found.
Create a cross-fitting method specification
Description
Helper to create a method specification for
crossfit / crossfit_multi. A method
bundles together:
a target functional
target(),a named list of nuisance specifications,
cross-fitting geometry (
folds,repeats,eval_fold,mode,fold_allocation),and panel / repetition aggregation functions.
Usage
create_method(
target,
list_nuisance = NULL,
folds,
repeats,
mode = c("estimate", "predict"),
eval_fold = if (mode == "estimate") 1L else 0L,
fold_allocation = c("independence", "overlap", "disjoint"),
aggregate_panels = NULL,
aggregate_repeats = NULL
)
Arguments
target |
A function representing the target functional. It must
accept nuisance predictions as arguments (named after nuisances) and
optionally a |
list_nuisance |
Optional named list of nuisance specifications
created by |
folds |
Positive integer giving the number of folds |
repeats |
Positive integer giving the number of repetitions. |
mode |
Cross-fitting mode. Either |
eval_fold |
Integer giving the width (in folds) of the
evaluation window for the target. Must be |
fold_allocation |
Fold allocation strategy; one of
|
aggregate_panels |
Aggregation function for panel-level
results, typically one of |
aggregate_repeats |
Aggregation function for repetition-level
results, typically one of |
Details
The returned list is validated by validate_method() to ensure
structural soundness, but the validated object is not stored: you are
free to modify the returned method before passing it to
crossfit or crossfit_multi.
By default, eval_fold is chosen to be 1L when
mode = "estimate" and 0L when mode = "predict".
If you override eval_fold, it must satisfy these constraints:
positive integer for "estimate", zero for "predict".
Value
A method specification list suitable for use in
crossfit or crossfit_multi.
Examples
set.seed(1)
n <- 50
x <- rnorm(n)
y <- x + rnorm(n)
# Nuisance: regression for E[Y | X]
nuis_y <- create_nuisance(
fit = function(data, ...) lm(y ~ x, data = data),
predict = function(model, data, ...) predict(model, newdata = data)
)
# Target: mean squared error of the nuisance predictor
target_mse <- function(data, nuis_y, ...) {
mean((data$y - nuis_y)^2)
}
m <- create_method(
target = target_mse,
list_nuisance = list(nuis_y = nuis_y),
folds = 2,
repeats = 1,
eval_fold = 1L,
mode = "estimate",
fold_allocation = "independence",
aggregate_panels = mean_estimate,
aggregate_repeats = mean_estimate
)
str(m)
Create a nuisance specification
Description
Helper to create a nuisance specification with basic structural
checks. A nuisance is defined by a fit function, a
predict function, and optional dependency mappings.
Usage
create_nuisance(
fit,
predict,
train_fold = 1L,
fit_deps = NULL,
pred_deps = NULL
)
Arguments
fit |
A function |
predict |
A function |
train_fold |
Positive integer giving the width (in folds) of the
training window used for this nuisance. Defaults to |
fit_deps |
Optional named character vector mapping
|
pred_deps |
Optional named character vector mapping
|
Value
A list representing a nuisance specification, suitable for
inclusion in the list_nuisance argument of
create_method.
Examples
# Simple linear regression nuisance: E[Y | X]
set.seed(1)
n <- 50
x <- rnorm(n)
y <- x + rnorm(n)
nuis <- create_nuisance(
fit = function(data, ...) lm(y ~ x, data = data),
predict = function(model, data, ...) predict(model, newdata = data)
)
str(nuis)
Cross-fitting for a single method
Description
Convenience wrapper around crossfit_multi for the
common case of a single method. It enforces that method is a
single method specification and forwards the aggregation functions
stored inside method.
Usage
crossfit(
data,
method,
fold_split = function(data, K) sample(rep_len(1:K, nrow(data))),
seed = NULL,
max_fail = Inf,
verbose = FALSE
)
Arguments
data |
Data frame or matrix with the observations. |
method |
A single method specification (list) created by
|
fold_split |
A function producing a K-fold split of the data
(see |
seed |
Integer base random seed. |
max_fail |
Non-negative integer or |
verbose |
Logical; if |
Value
The same structure as crossfit_multi, but with
a single method named "method". The final estimate is in
$estimates$method.
Examples
set.seed(1)
n <- 100
x <- rnorm(n)
y <- x + rnorm(n)
data <- data.frame(x = x, y = y)
# Nuisance: E[Y | X]
nuis_y <- create_nuisance(
fit = function(data, ...) lm(y ~ x, data = data),
predict = function(model, data, ...) predict(model, newdata = data)
)
# Target: mean squared error of the nuisance predictor
target_mse <- function(data, nuis_y, ...) {
mean((data$y - nuis_y)^2)
}
method <- create_method(
target = target_mse,
list_nuisance = list(nuis_y = nuis_y),
folds = 2,
repeats = 2,
eval_fold = 1L,
mode = "estimate",
fold_allocation = "independence",
aggregate_panels = mean_estimate,
aggregate_repeats = mean_estimate
)
cf <- crossfit(data, method)
cf$estimates
Cross-fitting for multiple methods
Description
Runs cross-fitting for one or more methods defined via
create_method and create_nuisance. This
is the main engine that:
validates and normalizes method specifications,
builds the global instance graph and fold geometry,
repeatedly draws K-fold splits and evaluates all active methods,
aggregates results across panels and repetitions.
Usage
crossfit_multi(
data,
methods,
fold_split = function(data, K) sample(rep_len(1:K, nrow(data))),
seed = NULL,
aggregate_panels = identity,
aggregate_repeats = identity,
max_fail = Inf,
verbose = FALSE
)
Arguments
data |
Data frame or matrix of size |
methods |
A (named) list of method specifications, typically
created with |
fold_split |
A function of the form |
seed |
Integer base random seed used for the K-fold splits; each
repetition uses |
aggregate_panels |
Function used as the default aggregator
over panels (folds) for each method. It is applied to the list of
per-panel values. Methods can override this via their own
|
aggregate_repeats |
Function used as the default
aggregator over repetitions for each method. It is applied to the
list of per-repetition aggregated values. Methods can override this
via their own |
max_fail |
Non-negative integer or |
verbose |
Logical; if |
Details
Each method can operate in either mode = "estimate" (target
returns numeric values) or mode = "predict" (target returns a
prediction function). Cross-fitting ensures that nuisance models are
always trained on folds disjoint from the folds on which their
predictions are used in the target.
Value
A list with components:
estimatesNamed list of final estimates per method (after aggregating over panels and repetitions).
per_methodFor each method, a list with
values(per-repetition aggregated results) anderrors(error traces).repeats_doneNumber of repetitions successfully completed for each method.
KNumber of folds used in the plan.
K_requiredPer-method minimal required K based on their dependency structure.
methodsThe validated and normalized method specifications.
planThe cross-fitting plan produced by
build_instances().
Examples
set.seed(1)
n <- 100
x <- rnorm(n)
y <- x + rnorm(n)
data <- data.frame(x = x, y = y)
# Shared nuisance: E[Y | X]
nuis_y <- create_nuisance(
fit = function(data, ...) lm(y ~ x, data = data),
predict = function(model, data, ...) predict(model, newdata = data)
)
# Method 1: MSE of nuisance predictor
target_mse <- function(data, nuis_y, ...) {
mean((data$y - nuis_y)^2)
}
# Method 2: mean fitted value
target_mean <- function(data, nuis_y, ...) {
mean(nuis_y)
}
m1 <- create_method(
target = target_mse,
list_nuisance = list(nuis_y = nuis_y),
folds = 2,
repeats = 2,
eval_fold = 1L,
mode = "estimate",
fold_allocation = "independence"
)
m2 <- create_method(
target = target_mean,
list_nuisance = list(nuis_y = nuis_y),
folds = 2,
repeats = 2,
eval_fold = 1L,
mode = "estimate",
fold_allocation = "overlap"
)
cf_multi <- crossfit_multi(
data = data,
methods = list(mse = m1, mean = m2),
aggregate_panels = mean_estimate,
aggregate_repeats = mean_estimate
)
cf_multi$estimates
Internal: ensure a fitted model exists for an instance and panel
Description
For a given instance and token, this function either retrieves a cached fitted model (based on its model signature) or fits a new one using the appropriate training folds, then optionally caches it.
Usage
ensure_model(
inst_key,
token,
data,
methods,
plan,
fit_child_map,
pred_child_map,
K,
fold_idx,
model_cache
)
Arguments
inst_key |
Instance key in |
token |
An evaluation or training token (see
|
data |
Training data (matrix or data frame). |
methods |
Methods list used to build |
plan |
Cross-fitting plan as returned by
|
fit_child_map |
Child map for |
pred_child_map |
Child map for |
K |
Number of folds. |
fold_idx |
List mapping folds |
model_cache |
Environment used to store fitted models. |
Details
Child nuisances are recursively predicted on the training window and
passed as additional arguments into the instance's fit()
function.
Structural failures are recorded in plan$fail_env so that
methods relying on the same structural model can be skipped for the
current repetition.
Value
A fitted model object for the given instance and token.
Internal: compute a code signature for a function
Description
Builds a crude but stable code signature for a function based on its
argument names and body. Argument names are sorted (to avoid
dependence on declaration order) and the body is deparse()d
into a single string. The result is used as a key for function
deduplication.
Usage
fun_code_sig(f)
Arguments
f |
A function object. |
Value
A single character string representing the function's code signature.
Internal: assign an integer id to a function signature
Description
Maps a function code signature (as produced by
fun_code_sig) to a small integer id in a registry. If
the signature is new, it is appended to the registry; otherwise, the
existing id is returned. This is used to keep structural signatures
compact.
Usage
fun_registry_id(reg, sig)
Arguments
reg |
A registry environment created by
|
sig |
A character string code signature. |
Value
An integer id corresponding to sig within
reg.
Internal: create a function registry
Description
Constructs a small environment used to deduplicate identical
fit() / predict() functions across methods. Functions
are represented by string signatures (via fun_code_sig)
stored in $sigs.
Usage
fun_registry_new()
Value
An environment with a character vector component sigs
used by fun_registry_id.
Internal helper utilities
Description
Small utilities used internally by the cross-fitting engine.
Usage
idx_mod(n, mod)
pass_named(fun, args)
is.int(n)
required_no_default(fun)
Arguments
n |
Integer index or count used in modulo arithmetic. |
mod |
Modulus (number of folds). |
fun |
A function to be called. |
args |
A named list of arguments. |
Details
-
idx_mod()wraps panel indices modulo the number of folds, using 1-based indexing. -
pass_named()calls a function with only the arguments it declares, dropping unknown ones defensively. -
is.int()checks whether a value is a non-negative scalar integer (within floating-point tolerance). -
required_no_default()returns the names of required formals (arguments without defaults) of a function.
These helpers are not intended for direct use by end users.
Value
Various small values used for internal control flow and argument handling.
Internal: wrap a model and its child predictors
Description
Wraps a fitted model and its prediction function into a layered
predictor that also calls dependency predictors for nuisance inputs.
This is used internally in mode = "predict" to build
cross-fitted predictors that compose multiple nuisance learners.
Usage
layer(pred, model, pred_deps_predict = NULL)
Arguments
pred |
Prediction function for the nuisance node (typically the
|
model |
Fitted model object returned by the corresponding
|
pred_deps_predict |
Optional list of child predictors, each of
which will be called on |
Value
A function f(newdata, ...) calling
pred(model, data = newdata, ...) with extra arguments coming
from pred_deps_predict. Intended for internal use.
Internal logging utilities
Description
These helpers print training and prediction schedules for debugging
the cross-fitting plan. Logging is controlled by the global
show_log flag: when show_log is TRUE, both
log_train() and log_pred() write messages to the console.
Usage
log_train(inst, token, folds)
log_pred(inst, token, folds, insts)
Arguments
inst |
An instance descriptor from the internal plan object. |
token |
A token created by |
folds |
Integer vector of fold indices involved in the operation. |
insts |
List of all instances (for |
Value
These functions are called for their side-effect (printing);
they return NULL invisibly.
Internal: create an evaluation token
Description
Constructs a token describing an evaluation context for a given method
and panel index. If inst_key is NULL, the token refers
to the synthetic "__TARGET__" node of method mi.
Usage
make_eval_token(inst_key = NULL, mi, p)
Arguments
inst_key |
Optional instance key; if |
mi |
Integer method index. |
p |
Integer panel index (fold position in the cyclic schedule). |
Value
A list token with fields type = "eval",
inst_key, mi, p.
Internal: create a training token
Description
Constructs a token describing a training context for a given instance key, method index and panel index.
Usage
make_train_token(inst_key, mi, p)
Arguments
inst_key |
Instance key for the nuisance to be trained. |
mi |
Integer method index. |
p |
Integer panel index (fold position in the cyclic schedule). |
Value
A list token with fields type = "train",
inst_key, mi, p.
Aggregators for scalar estimates
Description
These helpers implement simple aggregation schemes for panel-level
and repetition-level estimates in crossfit and
crossfit_multi.
Usage
mean_estimate(xs)
median_estimate(xs)
Arguments
xs |
A list of numeric values or numeric vectors. Elements are
unlisted and concatenated prior to aggregation, so |
Details
In mode = "estimate", each repetition typically produces a list
of numeric values (one per evaluation panel). The functions
mean_estimate() and median_estimate() aggregate such
lists into a single numeric value.
Value
A single numeric value (the mean or median of all entries in
xs.
Examples
xs <- list(c(1, 2, 3), 4, c(5, 6))
mean_estimate(xs)
xs <- list(c(1, 100), 10, 20)
median_estimate(xs)
Aggregators for cross-fitted predictors
Description
These helpers aggregate several cross-fitted predictors into a single
ensemble predictor. They are designed for methods run with
mode = "predict" in crossfit and
crossfit_multi.
Usage
mean_predictor(fs)
median_predictor(fs)
Arguments
fs |
A list of prediction functions. Each function must accept
at least a |
Value
A function of the form function(newdata, ...), which
returns a numeric vector of predictions. If fs is empty, the
returned function always returns numeric(0).
Examples
# Two simple prediction functions of x
f1 <- function(newdata, ...) newdata$x
f2 <- function(newdata, ...) 2 * newdata$x
ens_mean <- mean_predictor(list(f1, f2))
newdata <- data.frame(x = 1:5)
ens_mean(newdata)
# Two simple prediction functions of x
f1 <- function(newdata, ...) newdata$x
f2 <- function(newdata, ...) 2 * newdata$x
ens_median <- median_predictor(list(f1, f2))
newdata <- data.frame(x = 1:5)
ens_median(newdata)
Internal: get a model from the cache
Description
Simple environment-based cache lookup used to store fitted models keyed by their model signature.
Usage
model_cache_get(cache_env, key)
Arguments
cache_env |
Environment acting as a hash table for models. |
key |
Character key, typically produced by
|
Value
The cached model object or NULL if absent.
Internal: set a model in the cache
Description
Stores a fitted model in an environment-based cache under the given key.
Usage
model_cache_set(cache_env, key, value)
Arguments
cache_env |
Environment acting as a hash table for models. |
key |
Character key, typically produced by
|
value |
Fitted model object to store. |
Value
Invisibly returns NULL.
Internal: compute a model cache key for an instance
Description
Computes a unique model signature for a given instance under a given evaluation token. The signature combines:
the instance's structural signature (
inst$struct_sig),the set of training folds used for this eval panel.
Usage
model_signature(inst_key, eval_token, methods, plan, K)
Arguments
inst_key |
Instance key (character) in |
eval_token |
An evaluation token created by
|
methods |
The methods list used to build |
plan |
The cross-fitting plan as returned by
|
K |
Number of folds. |
Details
This is used as a key in the model cache so that models are reused exactly when both structure and training data coincide.
Value
A character string uniquely identifying the model for this instance and training-fold configuration.
Internal: indices used for target evaluation
Description
Given an evaluation token and the cross-fitting plan, returns the row
indices and fold labels used to evaluate the target model for that
panel. The evaluation window has width plan$eval_width[mi] and
is wrapped cyclically on K folds.
Usage
obs_for_eval(eval_token, K, fold_idx, plan)
Arguments
eval_token |
An evaluation token created by
|
K |
Total number of folds. |
fold_idx |
A list mapping fold labels |
plan |
The cross-fitting plan as returned by
|
Value
A list with components idx (row indices) and
folds (fold labels).
Internal: indices used for nuisance training
Description
Given a training token and the cross-fitting plan, returns the row
indices and fold labels used to train a given nuisance instance. The
training window is determined by the instance-specific offset and
train_fold value.
Usage
obs_for_train(train_token, K, fold_idx, plan, methods)
Arguments
train_token |
A training token created by
|
K |
Total number of folds. |
fold_idx |
A list mapping fold labels |
plan |
The cross-fitting plan as returned by
|
methods |
The methods list used to build |
Value
A list with components idx (row indices) and
folds (fold labels).
Internal: predict nuisance values for an instance and token
Description
Recursively produces predictions for a given instance under a given
token. This ensures the corresponding model is fitted (via
ensure_model), gathers predictions from child nuisances
as additional inputs, and then calls the instance's predict()
function.
Usage
predict_instance_for_token(
inst_key,
token,
data,
methods,
plan,
fit_child_map,
pred_child_map,
K,
fold_idx,
model_cache,
mode
)
Arguments
inst_key |
Instance key in |
token |
An evaluation or training token. |
data |
Training data (matrix or data frame). |
methods |
Methods list. |
plan |
Cross-fitting plan. |
fit_child_map |
Child map for |
pred_child_map |
Child map for |
K |
Number of folds. |
fold_idx |
List mapping folds to row indices. |
model_cache |
Environment used to store fitted models. |
mode |
Either |
Details
In mode = "predict", this returns a layered prediction
function on newdata, suitable for building cross-fitted predictors
for the target. In mode = "estimate", it returns the
prediction values on the appropriate folds (eval or train).
Value
In mode = "predict", a prediction function of
newdata. In mode = "estimate", a vector or matrix of
predictions on the relevant subset of data.
Internal: validate and normalize a batch of methods
Description
Standardizes a list of method specifications:
ensures every method has a name,
checks and sets
modeandfold_allocation,fills in per-method aggregation functions from global defaults,
calls
validate_method()on each method.
Usage
validate_batch(methods, default_agg_panels, default_agg_repeats)
Arguments
methods |
A list of method specifications. |
default_agg_panels |
Default aggregation function for panels
(used when a method does not provide |
default_agg_repeats |
Default aggregation function for
repetitions (used when a method does not provide
|
Details
This is the entry point used by crossfit_multi to
validate user-specified methods. It is not intended for direct use by
end users.
Value
A named list of fully validated and normalized methods.
Internal: validate and normalize a single method specification
Description
Validates a method specification and normalizes its nuisance graph. This includes:
checking
mode,eval_fold,folds, andrepeats,validating the nuisance container,
checking that the target is a function and that its required arguments are nuisances,
attaching a synthetic
"__TARGET__"nuisance node,inferring
fit_depsandpred_depsfrom required arguments,enforcing coverage of required arguments,
and running a cycle check on the full graph.
Usage
validate_method(met, mname = NULL)
Arguments
met |
A method specification list. |
mname |
Optional method name (used for error messages). |
Details
This function is normally called from validate_batch() and is
not intended to be used directly by end users.
Value
A normalized method specification with a fully validated
nuisance graph and an added "__TARGET__" node.
Internal: basic structural checks for a single nuisance
Description
Performs basic validation and normalization for a single nuisance specification. This ensures that the nuisance:
is a list with
fitandpredictfunctions,has a valid positive integer
train_fold,has
fit_depsandpred_depseitherNULLor named character vectors (arg\tonuisance).
Usage
validate_nuisance(nf, nm = NULL, mname = NULL)
Arguments
nf |
A list representing a nuisance specification. |
nm |
Optional nuisance name for error messages. |
mname |
Optional method name for error messages. |
Details
This function does not infer dependencies or check cycles; that is
done at the method level by validate_method().
Value
A normalized nuisance list with guaranteed components
fit, predict, train_fold,
fit_deps, and pred_deps.