Title: Cross-Fitting Engine for Double/Debiased Machine Learning
Version: 0.1.1
Description: Provides a general cross-fitting engine for double / debiased machine learning and other meta-learners. The core functions implement flexible graphs of nuisance models with per-node training fold widths, target-specific evaluation windows, and several fold allocation schemes ("independence", "overlap", "disjoint"). The engine supports both numeric estimators (mode = "estimate") and cross-fitted prediction functions (mode = "predict"), with configurable aggregation over panels and repetitions.
License: GPL-3
URL: https://github.com/EtiennePeyrot/crossfit-R
BugReports: https://github.com/EtiennePeyrot/crossfit-R/issues
Encoding: UTF-8
RoxygenNote: 7.3.1
Depends: R (≥ 4.1.0)
Imports: stats, utils
Suggests: knitr, rmarkdown, testthat (≥ 3.0.0)
Config/testthat/edition: 3
VignetteBuilder: knitr
NeedsCompilation: no
Packaged: 2026-02-16 15:52:07 UTC; skoua
Author: Etienne Peyrot ORCID iD [aut, cre]
Maintainer: Etienne Peyrot <etienne.peyrot@inserm.fr>
Repository: CRAN
Date/Publication: 2026-02-19 20:00:08 UTC

crossfit: Cross-Fitting Engine for Double / Debiased Machine Learning

Description

Provides a general cross-fitting engine for double / debiased machine learning and other meta-learners. The core functions implement flexible graphs of nuisance models with per-node training fold widths, target-specific evaluation windows, and several fold allocation schemes ("independence", "overlap", "disjoint"). The engine supports both numeric estimators (mode = "estimate") and cross-fitted prediction functions (mode = "predict"), with configurable aggregation over panels and repetitions.

Author(s)

Maintainer: Etienne Peyrot etienne.peyrot@inserm.fr (ORCID)

See Also

Useful links:


Internal: map nuisance names to instance keys for each node

Description

Builds, for each instance, a mapping from nuisance names (as used in fit_deps_names / pred_deps_names) to child instance keys. This allows the engine to route predictions correctly from child nuisances into parent fit() and predict() calls.

Usage

build_child_maps(insts)

Arguments

insts

A named list of instances as produced by build_instances.

Value

A list with two components:

fit

Named list mapping instance keys to named character vectors (args \to child instance keys) for fit() dependencies.

pred

Same for predict() dependencies.


Internal: build per-method instance graph and fold geometry

Description

Given a list of validated methods (as returned by validate_batch), constructs the global instance graph used by the cross-fitting engine. Each nuisance (including the synthetic "__TARGET__" node) is expanded into one or more instances, depending on the fold allocation strategy.

Usage

build_instances(methods)

Arguments

methods

A list of fully validated and normalized method specifications, typically the output of validate_batch.

Details

This function:

The resulting "plan" object is consumed by the core engine (crossfit_multi()).

Value

A list with components including (but not limited to):

methods

The (possibly updated) methods list, with folds harmonized.

instances

Named list of instance nodes.

roots

Per-method root instance keys (corresponding to "__TARGET__").

topo

Instance keys in topological order.

method_inst_keys

Per-method instance keys in topological order.

eval_width

Per-method evaluation window width.

inst_offset

Per-instance training window offset (in folds).

K_required

Per-method minimal required number of folds.

K

Global number of folds used in the plan.

child_maps

List of fit/predict child maps as returned by build_child_maps.

method_structs

Per-method set of structural signatures used by that method.

method_fold_allocation

Per-method fold_allocation values.


Internal: check for cycles in the nuisance graph

Description

Validates that the combined dependency graph over all nuisances is acyclic. The graph edges are obtained from the union of fit_deps and pred_deps for each nuisance (including the synthetic "__TARGET__" node).

Usage

check_cycles(nfs)

Arguments

nfs

A named list of normalized nuisance specifications. Each element must contain components fit_deps and pred_deps, as constructed by validate_method().

Details

If a cycle is detected, an error is thrown with a textual representation of the cycle. Otherwise the function returns TRUE invisibly.

Value

Invisibly returns TRUE on success, or throws an error if a cycle is found.


Create a cross-fitting method specification

Description

Helper to create a method specification for crossfit / crossfit_multi. A method bundles together:

Usage

create_method(
  target,
  list_nuisance = NULL,
  folds,
  repeats,
  mode = c("estimate", "predict"),
  eval_fold = if (mode == "estimate") 1L else 0L,
  fold_allocation = c("independence", "overlap", "disjoint"),
  aggregate_panels = NULL,
  aggregate_repeats = NULL
)

Arguments

target

A function representing the target functional. It must accept nuisance predictions as arguments (named after nuisances) and optionally a data argument.

list_nuisance

Optional named list of nuisance specifications created by create_nuisance.

folds

Positive integer giving the number of folds K. May be NULL, in which case crossfit_multi will infer a minimal feasible K from the dependency structure.

repeats

Positive integer giving the number of repetitions.

mode

Cross-fitting mode. Either "estimate" (target returns numeric estimates) or "predict" (target returns a cross-fitted predictor).

eval_fold

Integer giving the width (in folds) of the evaluation window for the target. Must be > 0 for mode = "estimate" and 0 for mode = "predict". If omitted, the default is 1L for "estimate" and 0L for "predict".

fold_allocation

Fold allocation strategy; one of "independence", "overlap", or "disjoint".

aggregate_panels

Aggregation function for panel-level results, typically one of mean_estimate, median_estimate, mean_predictor, median_predictor, or a custom function. May be NULL, in which case a global default can be supplied via crossfit_multi.

aggregate_repeats

Aggregation function for repetition-level results, typically one of mean_estimate, median_estimate, mean_predictor, median_predictor, or a custom function. May be NULL, in which case a global default can be supplied via crossfit_multi.

Details

The returned list is validated by validate_method() to ensure structural soundness, but the validated object is not stored: you are free to modify the returned method before passing it to crossfit or crossfit_multi.

By default, eval_fold is chosen to be 1L when mode = "estimate" and 0L when mode = "predict". If you override eval_fold, it must satisfy these constraints: positive integer for "estimate", zero for "predict".

Value

A method specification list suitable for use in crossfit or crossfit_multi.

Examples

set.seed(1)
n <- 50
x <- rnorm(n)
y <- x + rnorm(n)

# Nuisance: regression for E[Y | X]
nuis_y <- create_nuisance(
  fit = function(data, ...) lm(y ~ x, data = data),
  predict = function(model, data, ...) predict(model, newdata = data)
)

# Target: mean squared error of the nuisance predictor
target_mse <- function(data, nuis_y, ...) {
  mean((data$y - nuis_y)^2)
}

m <- create_method(
  target = target_mse,
  list_nuisance = list(nuis_y = nuis_y),
  folds = 2,
  repeats = 1,
  eval_fold = 1L,
  mode = "estimate",
  fold_allocation = "independence",
  aggregate_panels  = mean_estimate,
  aggregate_repeats = mean_estimate
)

str(m)

Create a nuisance specification

Description

Helper to create a nuisance specification with basic structural checks. A nuisance is defined by a fit function, a predict function, and optional dependency mappings.

Usage

create_nuisance(
  fit,
  predict,
  train_fold = 1L,
  fit_deps = NULL,
  pred_deps = NULL
)

Arguments

fit

A function fit(data, ...) that trains the nuisance model on a subset of the data and returns a fitted model object.

predict

A function predict(model, data, ...) that returns predictions for the nuisance on new data.

train_fold

Positive integer giving the width (in folds) of the training window used for this nuisance. Defaults to 1L.

fit_deps

Optional named character vector mapping fit() argument names to nuisance names, used to specify nuisance inputs to the fit function. If NULL, the dependencies are inferred later from required arguments whose names match nuisance names.

pred_deps

Optional named character vector mapping predict() argument names to nuisance names, used to specify nuisance inputs to the predict function.

Value

A list representing a nuisance specification, suitable for inclusion in the list_nuisance argument of create_method.

Examples

# Simple linear regression nuisance: E[Y | X]
set.seed(1)
n <- 50
x <- rnorm(n)
y <- x + rnorm(n)

nuis <- create_nuisance(
  fit = function(data, ...) lm(y ~ x, data = data),
  predict = function(model, data, ...) predict(model, newdata = data)
)

str(nuis)

Cross-fitting for a single method

Description

Convenience wrapper around crossfit_multi for the common case of a single method. It enforces that method is a single method specification and forwards the aggregation functions stored inside method.

Usage

crossfit(
  data,
  method,
  fold_split = function(data, K) sample(rep_len(1:K, nrow(data))),
  seed = NULL,
  max_fail = Inf,
  verbose = FALSE
)

Arguments

data

Data frame or matrix with the observations.

method

A single method specification (list) created by create_method. It must contain a target function and aggregate_panels / aggregate_repeats must be functions.

fold_split

A function producing a K-fold split of the data (see crossfit_multi).

seed

Integer base random seed.

max_fail

Non-negative integer or Inf controlling how many repetitions the method may fail before being disabled.

verbose

Logical; if TRUE, prints a compact status line per repetition.

Value

The same structure as crossfit_multi, but with a single method named "method". The final estimate is in $estimates$method.

Examples

set.seed(1)
n <- 100
x <- rnorm(n)
y <- x + rnorm(n)

data <- data.frame(x = x, y = y)

# Nuisance: E[Y | X]
nuis_y <- create_nuisance(
  fit = function(data, ...) lm(y ~ x, data = data),
  predict = function(model, data, ...) predict(model, newdata = data)
)

# Target: mean squared error of the nuisance predictor
target_mse <- function(data, nuis_y, ...) {
  mean((data$y - nuis_y)^2)
}

method <- create_method(
  target = target_mse,
  list_nuisance = list(nuis_y = nuis_y),
  folds = 2,
  repeats = 2,
  eval_fold = 1L,
  mode = "estimate",
  fold_allocation = "independence",
  aggregate_panels  = mean_estimate,
  aggregate_repeats = mean_estimate
)

cf <- crossfit(data, method)
cf$estimates

Cross-fitting for multiple methods

Description

Runs cross-fitting for one or more methods defined via create_method and create_nuisance. This is the main engine that:

Usage

crossfit_multi(
  data,
  methods,
  fold_split = function(data, K) sample(rep_len(1:K, nrow(data))),
  seed = NULL,
  aggregate_panels = identity,
  aggregate_repeats = identity,
  max_fail = Inf,
  verbose = FALSE
)

Arguments

data

Data frame or matrix of size n \times p containing the observations.

methods

A (named) list of method specifications, typically created with create_method.

fold_split

A function of the form function(data, K) returning a vector of length nrow(data) with integer fold labels in 1:K. It must assign at least one observation to each fold.

seed

Integer base random seed used for the K-fold splits; each repetition uses seed + rep_id - 1.

aggregate_panels

Function used as the default aggregator over panels (folds) for each method. It is applied to the list of per-panel values. Methods can override this via their own aggregate_panels.

aggregate_repeats

Function used as the default aggregator over repetitions for each method. It is applied to the list of per-repetition aggregated values. Methods can override this via their own aggregate_repeats.

max_fail

Non-negative integer or Inf controlling how many repetitions a method is allowed to fail before being disabled. Structural model failures and panel-level errors both count toward this limit.

verbose

Logical; if TRUE, prints a compact status line per repetition.

Details

Each method can operate in either mode = "estimate" (target returns numeric values) or mode = "predict" (target returns a prediction function). Cross-fitting ensures that nuisance models are always trained on folds disjoint from the folds on which their predictions are used in the target.

Value

A list with components:

estimates

Named list of final estimates per method (after aggregating over panels and repetitions).

per_method

For each method, a list with values (per-repetition aggregated results) and errors (error traces).

repeats_done

Number of repetitions successfully completed for each method.

K

Number of folds used in the plan.

K_required

Per-method minimal required K based on their dependency structure.

methods

The validated and normalized method specifications.

plan

The cross-fitting plan produced by build_instances().

Examples

set.seed(1)
n <- 100
x <- rnorm(n)
y <- x + rnorm(n)

data <- data.frame(x = x, y = y)

# Shared nuisance: E[Y | X]
nuis_y <- create_nuisance(
  fit = function(data, ...) lm(y ~ x, data = data),
  predict = function(model, data, ...) predict(model, newdata = data)
)

# Method 1: MSE of nuisance predictor
target_mse <- function(data, nuis_y, ...) {
  mean((data$y - nuis_y)^2)
}

# Method 2: mean fitted value
target_mean <- function(data, nuis_y, ...) {
  mean(nuis_y)
}

m1 <- create_method(
  target = target_mse,
  list_nuisance = list(nuis_y = nuis_y),
  folds = 2,
  repeats = 2,
  eval_fold = 1L,
  mode = "estimate",
  fold_allocation = "independence"
)

m2 <- create_method(
  target = target_mean,
  list_nuisance = list(nuis_y = nuis_y),
  folds = 2,
  repeats = 2,
  eval_fold = 1L,
  mode = "estimate",
  fold_allocation = "overlap"
)

cf_multi <- crossfit_multi(
  data    = data,
  methods = list(mse = m1, mean = m2),
  aggregate_panels  = mean_estimate,
  aggregate_repeats = mean_estimate
)

cf_multi$estimates

Internal: ensure a fitted model exists for an instance and panel

Description

For a given instance and token, this function either retrieves a cached fitted model (based on its model signature) or fits a new one using the appropriate training folds, then optionally caches it.

Usage

ensure_model(
  inst_key,
  token,
  data,
  methods,
  plan,
  fit_child_map,
  pred_child_map,
  K,
  fold_idx,
  model_cache
)

Arguments

inst_key

Instance key in plan$instances.

token

An evaluation or training token (see make_eval_token and make_train_token).

data

Training data (matrix or data frame).

methods

Methods list used to build plan.

plan

Cross-fitting plan as returned by build_instances().

fit_child_map

Child map for fit() dependencies (from build_child_maps).

pred_child_map

Child map for predict() dependencies.

K

Number of folds.

fold_idx

List mapping folds 1:K to row indices.

model_cache

Environment used to store fitted models.

Details

Child nuisances are recursively predicted on the training window and passed as additional arguments into the instance's fit() function.

Structural failures are recorded in plan$fail_env so that methods relying on the same structural model can be skipped for the current repetition.

Value

A fitted model object for the given instance and token.


Internal: compute a code signature for a function

Description

Builds a crude but stable code signature for a function based on its argument names and body. Argument names are sorted (to avoid dependence on declaration order) and the body is deparse()d into a single string. The result is used as a key for function deduplication.

Usage

fun_code_sig(f)

Arguments

f

A function object.

Value

A single character string representing the function's code signature.


Internal: assign an integer id to a function signature

Description

Maps a function code signature (as produced by fun_code_sig) to a small integer id in a registry. If the signature is new, it is appended to the registry; otherwise, the existing id is returned. This is used to keep structural signatures compact.

Usage

fun_registry_id(reg, sig)

Arguments

reg

A registry environment created by fun_registry_new.

sig

A character string code signature.

Value

An integer id corresponding to sig within reg.


Internal: create a function registry

Description

Constructs a small environment used to deduplicate identical fit() / predict() functions across methods. Functions are represented by string signatures (via fun_code_sig) stored in $sigs.

Usage

fun_registry_new()

Value

An environment with a character vector component sigs used by fun_registry_id.


Internal helper utilities

Description

Small utilities used internally by the cross-fitting engine.

Usage

idx_mod(n, mod)

pass_named(fun, args)

is.int(n)

required_no_default(fun)

Arguments

n

Integer index or count used in modulo arithmetic.

mod

Modulus (number of folds).

fun

A function to be called.

args

A named list of arguments.

Details

These helpers are not intended for direct use by end users.

Value

Various small values used for internal control flow and argument handling.


Internal: wrap a model and its child predictors

Description

Wraps a fitted model and its prediction function into a layered predictor that also calls dependency predictors for nuisance inputs. This is used internally in mode = "predict" to build cross-fitted predictors that compose multiple nuisance learners.

Usage

layer(pred, model, pred_deps_predict = NULL)

Arguments

pred

Prediction function for the nuisance node (typically the predict component of a nuisance specification).

model

Fitted model object returned by the corresponding fit().

pred_deps_predict

Optional list of child predictors, each of which will be called on newdata and passed into pred under the appropriate argument name.

Value

A function f(newdata, ...) calling pred(model, data = newdata, ...) with extra arguments coming from pred_deps_predict. Intended for internal use.


Internal logging utilities

Description

These helpers print training and prediction schedules for debugging the cross-fitting plan. Logging is controlled by the global show_log flag: when show_log is TRUE, both log_train() and log_pred() write messages to the console.

Usage

log_train(inst, token, folds)

log_pred(inst, token, folds, insts)

Arguments

inst

An instance descriptor from the internal plan object.

token

A token created by make_eval_token() or make_train_token(), indicating the method index and panel.

folds

Integer vector of fold indices involved in the operation.

insts

List of all instances (for log_pred() only), used to resolve the goal of the prediction.

Value

These functions are called for their side-effect (printing); they return NULL invisibly.


Internal: create an evaluation token

Description

Constructs a token describing an evaluation context for a given method and panel index. If inst_key is NULL, the token refers to the synthetic "__TARGET__" node of method mi.

Usage

make_eval_token(inst_key = NULL, mi, p)

Arguments

inst_key

Optional instance key; if NULL, defaults to paste0("M", mi, "|__TARGET__").

mi

Integer method index.

p

Integer panel index (fold position in the cyclic schedule).

Value

A list token with fields type = "eval", inst_key, mi, p.


Internal: create a training token

Description

Constructs a token describing a training context for a given instance key, method index and panel index.

Usage

make_train_token(inst_key, mi, p)

Arguments

inst_key

Instance key for the nuisance to be trained.

mi

Integer method index.

p

Integer panel index (fold position in the cyclic schedule).

Value

A list token with fields type = "train", inst_key, mi, p.


Aggregators for scalar estimates

Description

These helpers implement simple aggregation schemes for panel-level and repetition-level estimates in crossfit and crossfit_multi.

Usage

mean_estimate(xs)

median_estimate(xs)

Arguments

xs

A list of numeric values or numeric vectors. Elements are unlisted and concatenated prior to aggregation, so xs may contain scalars or length-k vectors.

Details

In mode = "estimate", each repetition typically produces a list of numeric values (one per evaluation panel). The functions mean_estimate() and median_estimate() aggregate such lists into a single numeric value.

Value

A single numeric value (the mean or median of all entries in xs.

Examples

xs <- list(c(1, 2, 3), 4, c(5, 6))
mean_estimate(xs)
xs <- list(c(1, 100), 10, 20)
median_estimate(xs)

Aggregators for cross-fitted predictors

Description

These helpers aggregate several cross-fitted predictors into a single ensemble predictor. They are designed for methods run with mode = "predict" in crossfit and crossfit_multi.

Usage

mean_predictor(fs)

median_predictor(fs)

Arguments

fs

A list of prediction functions. Each function must accept at least a newdata argument and return a numeric vector of predictions of the same length as nrow(newdata).

Value

A function of the form function(newdata, ...), which returns a numeric vector of predictions. If fs is empty, the returned function always returns numeric(0).

Examples

# Two simple prediction functions of x
f1 <- function(newdata, ...) newdata$x
f2 <- function(newdata, ...) 2 * newdata$x

ens_mean   <- mean_predictor(list(f1, f2))

newdata <- data.frame(x = 1:5)
ens_mean(newdata)
# Two simple prediction functions of x
f1 <- function(newdata, ...) newdata$x
f2 <- function(newdata, ...) 2 * newdata$x

ens_median <- median_predictor(list(f1, f2))

newdata <- data.frame(x = 1:5)
ens_median(newdata)

Internal: get a model from the cache

Description

Simple environment-based cache lookup used to store fitted models keyed by their model signature.

Usage

model_cache_get(cache_env, key)

Arguments

cache_env

Environment acting as a hash table for models.

key

Character key, typically produced by model_signature.

Value

The cached model object or NULL if absent.


Internal: set a model in the cache

Description

Stores a fitted model in an environment-based cache under the given key.

Usage

model_cache_set(cache_env, key, value)

Arguments

cache_env

Environment acting as a hash table for models.

key

Character key, typically produced by model_signature.

value

Fitted model object to store.

Value

Invisibly returns NULL.


Internal: compute a model cache key for an instance

Description

Computes a unique model signature for a given instance under a given evaluation token. The signature combines:

Usage

model_signature(inst_key, eval_token, methods, plan, K)

Arguments

inst_key

Instance key (character) in plan$instances.

eval_token

An evaluation token created by make_eval_token().

methods

The methods list used to build plan.

plan

The cross-fitting plan as returned by build_instances.

K

Number of folds.

Details

This is used as a key in the model cache so that models are reused exactly when both structure and training data coincide.

Value

A character string uniquely identifying the model for this instance and training-fold configuration.


Internal: indices used for target evaluation

Description

Given an evaluation token and the cross-fitting plan, returns the row indices and fold labels used to evaluate the target model for that panel. The evaluation window has width plan$eval_width[mi] and is wrapped cyclically on K folds.

Usage

obs_for_eval(eval_token, K, fold_idx, plan)

Arguments

eval_token

An evaluation token created by make_eval_token.

K

Total number of folds.

fold_idx

A list mapping fold labels 1:K to integer row indices in data.

plan

The cross-fitting plan as returned by build_instances().

Value

A list with components idx (row indices) and folds (fold labels).


Internal: indices used for nuisance training

Description

Given a training token and the cross-fitting plan, returns the row indices and fold labels used to train a given nuisance instance. The training window is determined by the instance-specific offset and train_fold value.

Usage

obs_for_train(train_token, K, fold_idx, plan, methods)

Arguments

train_token

A training token created by make_train_token.

K

Total number of folds.

fold_idx

A list mapping fold labels 1:K to integer row indices.

plan

The cross-fitting plan as returned by build_instances().

methods

The methods list used to build plan.

Value

A list with components idx (row indices) and folds (fold labels).


Internal: predict nuisance values for an instance and token

Description

Recursively produces predictions for a given instance under a given token. This ensures the corresponding model is fitted (via ensure_model), gathers predictions from child nuisances as additional inputs, and then calls the instance's predict() function.

Usage

predict_instance_for_token(
  inst_key,
  token,
  data,
  methods,
  plan,
  fit_child_map,
  pred_child_map,
  K,
  fold_idx,
  model_cache,
  mode
)

Arguments

inst_key

Instance key in plan$instances.

token

An evaluation or training token.

data

Training data (matrix or data frame).

methods

Methods list.

plan

Cross-fitting plan.

fit_child_map

Child map for fit() dependencies.

pred_child_map

Child map for predict() dependencies.

K

Number of folds.

fold_idx

List mapping folds to row indices.

model_cache

Environment used to store fitted models.

mode

Either "estimate" or "predict" (method mode).

Details

In mode = "predict", this returns a layered prediction function on newdata, suitable for building cross-fitted predictors for the target. In mode = "estimate", it returns the prediction values on the appropriate folds (eval or train).

Value

In mode = "predict", a prediction function of newdata. In mode = "estimate", a vector or matrix of predictions on the relevant subset of data.


Internal: validate and normalize a batch of methods

Description

Standardizes a list of method specifications:

Usage

validate_batch(methods, default_agg_panels, default_agg_repeats)

Arguments

methods

A list of method specifications.

default_agg_panels

Default aggregation function for panels (used when a method does not provide aggregate_panels). Must be a function.

default_agg_repeats

Default aggregation function for repetitions (used when a method does not provide aggregate_repeats). Must be a function.

Details

This is the entry point used by crossfit_multi to validate user-specified methods. It is not intended for direct use by end users.

Value

A named list of fully validated and normalized methods.


Internal: validate and normalize a single method specification

Description

Validates a method specification and normalizes its nuisance graph. This includes:

Usage

validate_method(met, mname = NULL)

Arguments

met

A method specification list.

mname

Optional method name (used for error messages).

Details

This function is normally called from validate_batch() and is not intended to be used directly by end users.

Value

A normalized method specification with a fully validated nuisance graph and an added "__TARGET__" node.


Internal: basic structural checks for a single nuisance

Description

Performs basic validation and normalization for a single nuisance specification. This ensures that the nuisance:

Usage

validate_nuisance(nf, nm = NULL, mname = NULL)

Arguments

nf

A list representing a nuisance specification.

nm

Optional nuisance name for error messages.

mname

Optional method name for error messages.

Details

This function does not infer dependencies or check cycles; that is done at the method level by validate_method().

Value

A normalized nuisance list with guaranteed components fit, predict, train_fold, fit_deps, and pred_deps.