#' @importFrom stats sd
#' @importFrom Rglpk Rglpk_solve_LP
#' Estimate second-moment matrix with missing data
#'
#' @description
#' Compute the empirical second-moment (pairwise) matrix from a design matrix
#' that may contain missing values.
#'
#' @usage
#' compute_hatSigma(X_all)
#'
#' @param X_all
#' Numeric matrix of dimension \eqn{n \times p} containing covariates.
#' Entries may be \code{NA}, indicating missing values.
#'
#' @value
#' A symmetric numeric matrix of dimension \eqn{p \times p}.
#' The \eqn{(j,k)} entry equals the sample mean of
#' \eqn{X_{ij} X_{ik}} computed over rows for which both
#' variables \eqn{j} and \eqn{k} are observed.
#'
#' @details
#' The function computes pairwise products using all available observations.
#' An error is raised if there exists a pair of variables that are never
#' observed together.
#'
#' @examples
#' \dontrun{
#' X <- matrix(rnorm(20), ncol = 2)
#' X[sample(length(X), 4)] <- NA
#' compute_hatSigma(X)
#' }
compute_hatSigma <- function(X_all) {
  # X_all: numeric matrix (n x p) with NA for missing entries
  mask <- !is.na(X_all)
  X0 <- X_all
  X0[is.na(X0)] <- 0
  numer <- crossprod(X0)      # p x p sums of products
  counts <- crossprod(mask)   # p x p counts of co-observations
  if (any(counts == 0)) stop("Method will not work - some covariates are never observed alongside each other")
  hatSigma <- numer / counts
  # enforce symmetry numerically
  hatSigma <- (hatSigma + t(hatSigma)) / 2
  hatSigma
}

#' Precompute OSS pattern metadata
#'
#' @description
#' Internal function that computes missingness-pattern metadata required by
#' the OSS estimator from labelled and optional unlabelled design matrices.
#'
#' @usage
#' precompute_oss_meta(X_labelled, X_unlabelled = NULL, use_unlabelled_rule = TRUE)
#'
#' @param X_labelled
#' Numeric matrix of dimension \eqn{n_L \times p} containing labelled covariates.
#' Entries may be \code{NA}.
#'
#' @param X_unlabelled
#' Optional numeric matrix of dimension \eqn{n_U \times p} containing unlabelled
#' covariates. Entries may be \code{NA}.
#'
#' @param use_unlabelled_rule
#' Logical. If \code{TRUE}, a heuristic rule is used to decide whether the
#' covariance matrix is estimated from labelled data only, unlabelled data
#' only, or their combination.
#'
#' @value
#' A list with components:
#' \describe{
#'   \item{hatSigma}{Estimated second-moment matrix of dimension \eqn{p \times p}.}
#'   \item{groups}{List of group-specific metadata indexed by missingness pattern.}
#'   \item{group_ids}{Factor assigning each labelled row to a missingness pattern.}
#'   \item{p}{Number of covariates.}
#' }
#'
#' @details
#' Labelled observations are partitioned into groups according to their
#' missingness patterns. For each pattern, the function computes a linear
#' projection matrix mapping observed covariates to the full covariate space,
#' together with a corresponding Schur complement for missing coordinates when
#' applicable. These quantities are used internally by the OSS estimator.
#'
#' @noRd
precompute_oss_meta <- function(X_labelled, X_unlabelled = NULL, use_unlabelled_rule = TRUE) {
  # X_labelled: matrix of labelled rows (n_L x p) with NA
  # X_unlabelled: optional matrix of unlabelled rows (n_U x p) with NA
  if (!is.matrix(X_labelled)) X_labelled <- as.matrix(X_labelled)
  if (!is.null(X_unlabelled) && !is.matrix(X_unlabelled)) X_unlabelled <- as.matrix(X_unlabelled)

  # Include unlabelled data in estimating covariance
  X_all <- X_labelled
  if (!is.null(X_unlabelled)) {
    if (use_unlabelled_rule) {
      unlab_obs_counts <- colSums(!is.na(X_unlabelled))
      n_L <- nrow(X_labelled)
      if (!all(unlab_obs_counts >= 5 * n_L)) {
        X_all <- rbind(X_labelled, X_unlabelled)
      } else {
        X_all <- X_unlabelled
      }
    } else {
      X_all <- rbind(X_labelled, X_unlabelled)
    }
  }

  hatSigma <- compute_hatSigma(X_all)

  # Find missingness patterns in labelled data
  obs_mat <- !is.na(X_labelled)
  keys <- apply(obs_mat, 1, function(r) paste0(as.integer(r), collapse = ""))
  group_ids <- factor(keys)
  levels_g <- levels(group_ids)
  p <- ncol(X_labelled)

  groups <- vector("list", length(levels_g))
  names(groups) <- levels_g

  for (g in levels_g) {
    idx <- which(group_ids == g)
    rep_row <- idx[1]
    Oi <- which(!is.na(X_labelled[rep_row, ]))
    Mi <- setdiff(seq_len(p), Oi)
    meta <- list(idx = idx, Oi = Oi, Mi = Mi)

    if (length(Oi) > 0) {
      S_OO <- hatSigma[Oi, Oi, drop = FALSE]

      # Invert using Cholesky
      ch <- tryCatch(chol(S_OO), error = function(e) NULL)
      if (!is.null(ch)) {
        inv_S_OO <- chol2inv(ch)   # efficient inverse implied by Cholesky
        Pi <- inv_S_OO %*% hatSigma[Oi, , drop = FALSE]
      } else {
        # Fall back to solve; add tiny ridge only if solve fails
        Pi <- tryCatch(solve(S_OO, hatSigma[Oi, , drop = FALSE]),
                       error = function(e) {
                         eps <- 1e-12 * max(1, mean(diag(S_OO)))
                         inv_S_OO <- tryCatch(solve(S_OO + eps * diag(length(Oi))),
                                              error = function(e2) MASS::ginv(S_OO + eps * diag(length(Oi))))
                         inv_S_OO %*% hatSigma[Oi, , drop = FALSE]
                       })
        inv_S_OO <- NULL  # not stored unless chol used
      }
      meta$Pi <- Pi

      # Precompute Sigma_M = S_MM - S_MO S_OO^{-1} S_OM if Mi not empty
      if (length(Mi) > 0) {
        S_MM <- hatSigma[Mi, Mi, drop = FALSE]
        S_MO <- hatSigma[Mi, Oi, drop = FALSE]
        if (!is.null(ch)) {
          # inv_S_OO available from chol2inv
          inv_part <- inv_S_OO %*% t(S_MO)   # (|Oi| x |Mi|)
        } else {
          inv_S_OO_try <- tryCatch(solve(S_OO), error = function(e) MASS::ginv(S_OO))
          inv_part <- inv_S_OO_try %*% t(S_MO)
        }
        Sigma_M <- S_MM - S_MO %*% inv_part
        Sigma_M <- (Sigma_M + t(Sigma_M)) / 2
        meta$Sigma_M <- Sigma_M
      }
    }

    groups[[g]] <- meta
  }

  list(hatSigma = hatSigma, groups = groups, group_ids = group_ids, p = p)
}
#' Precompute OSS metadata from a model formula
#'
#' @description
#' Internal convenience function that constructs a model matrix from a formula
#' and data frame, then computes OSS missingness-pattern metadata.
#'
#' @usage
#' precompute_oss_meta_from_formula(formula, data)
#'
#' @param formula
#' A model formula specifying the response and covariates.
#'
#' @param data
#' A data frame containing the variables appearing in the formula.
#'
#' @value
#' A list of OSS metadata as returned by \code{precompute_oss_meta}, with the
#' following additional components:
#' \describe{
#'   \item{mm_colnames}{Column names of the constructed model matrix.}
#'   \item{n_labelled}{Number of labelled observations.}
#'   \item{n_unlabelled}{Number of unlabelled observations.}
#' }
#'
#' @details
#' The model frame is constructed using \code{na.action = na.pass} so that
#' missing responses are retained. Rows with observed responses are treated
#' as labelled observations, while rows with missing responses are treated as
#' unlabelled observations.
#'
#' @noRd
precompute_oss_meta_from_formula <- function(formula, data) {
  if (missing(formula) || missing(data)) stop("Please provide formula and data arguments.")
  if (!inherits(formula, "formula")) stop("First argument must be a formula, e.g. y ~ x1 + x2")
  if (!is.data.frame(data)) stop("data must be a data.frame")

  mf <- model.frame(formula, data = data, na.action = na.pass)
  y_all <- model.response(mf)
  terms_obj <- attr(mf, "terms")
  X_all_mm <- model.matrix(terms_obj, mf)

  labelled_idx <- !is.na(y_all)
  if (all(!labelled_idx)) stop("No labelled observations found (response is all NA).")
  X_labelled <- as.matrix(X_all_mm[labelled_idx, , drop = FALSE])
  X_unlabelled <- if (any(!labelled_idx)) as.matrix(X_all_mm[!labelled_idx, , drop = FALSE]) else NULL

  meta <- precompute_oss_meta(X_labelled = X_labelled, X_unlabelled = X_unlabelled)
  meta$mm_colnames <- colnames(X_all_mm)
  meta$n_labelled <- nrow(X_labelled)
  meta$n_unlabelled <- if (!is.null(X_unlabelled)) nrow(X_unlabelled) else 0
  meta
}
#' Compute OSS estimator (core implementation)
#'
#' @description
#' Internal implementation of the OSS estimator for linear regression with
#' missing covariates. This function performs estimation without cross-fitting
#' and is used by higher-level user-facing wrappers.
#'
#' @usage
#' oss_estimator_core(formula, data, all_weights_one = FALSE, precomputed_meta = NULL)
#'
#' @param formula
#' A model formula specifying the response and covariates.
#'
#' @param data
#' A data frame containing the variables appearing in the formula.
#'
#' @param all_weights_one
#' Logical. If \code{TRUE}, all missingness-pattern weights are set equal to one.
#'
#' @param precomputed_meta
#' Optional list of metadata as produced by
#' \code{precompute_oss_meta} or
#' \code{precompute_oss_meta_from_formula}.
#'
#' @value
#' An invisible list with components:
#' \describe{
#'   \item{coef}{Numeric vector of estimated regression coefficients.}
#'   \item{sigma2_hat}{Estimated noise variance, or \code{NA}.}
#'   \item{weights}{Named vector of missingness-pattern weights.}
#'   \item{groups}{Data frame mapping labelled observations to pattern groups.}
#'   \item{beta_cc}{Complete-case coefficient estimates if used, otherwise \code{NULL}.}
#' }
#'
#' @details
#' Labelled observations are grouped according to their missingness patterns.
#' The estimator computes an initial coefficient estimate and an estimate of the
#' noise variance, which are then used to construct pattern-specific weights.
#' A final weighted least-squares problem is solved to obtain the OSS estimate.
#'
#' @noRd
oss_estimator_core <- function(formula, data, all_weights_one = FALSE, precomputed_meta = NULL) {
  # Keep the original interface; allow optional precomputed_meta for speed
  if (missing(formula) || missing(data)) stop("Please provide formula and data arguments.")
  if (!inherits(formula, "formula")) stop("First argument must be a formula, e.g. y ~ x1 + x2")
  if (!is.data.frame(data)) stop("data must be a data.frame")

  mf <- model.frame(formula, data = data, na.action = na.pass)
  y_all <- model.response(mf)
  terms_obj <- attr(mf, "terms")
  X_all_mm <- model.matrix(terms_obj, mf)

  labelled_idx <- !is.na(y_all)
  if (all(!labelled_idx)) stop("No labelled observations found (response is all NA).")

  X_labelled <- as.matrix(X_all_mm[labelled_idx, , drop = FALSE])
  X_unlabelled <- if (any(!labelled_idx)) as.matrix(X_all_mm[!labelled_idx, , drop = FALSE]) else NULL
  y <- y_all[labelled_idx]
  n_L <- nrow(X_labelled)
  p <- ncol(X_labelled)
  mm_colnames <- colnames(X_all_mm)

  # Use precomputed_meta if supplied, else compute fresh
  if (is.null(precomputed_meta)) {
    meta <- precompute_oss_meta(X_labelled = X_labelled, X_unlabelled = X_unlabelled)
  } else {
    meta <- precomputed_meta
    if (!is.null(meta$p) && meta$p != p) stop("precomputed_meta does not match the formula/data (different number of columns).")
  }

  hatSigma <- meta$hatSigma
  groups <- meta$groups
  group_ids <- meta$group_ids
  levels_g <- names(groups)

  # Step 1: initial fit and sigma2 (for weights)
  if (isTRUE(all_weights_one)) {
    weights <- setNames(rep(1, length(levels_g)), levels_g)
    sigma2_hat <- NA
    beta_for_weights <- rep(0, p)
    beta_cc <- NULL
  } else {
    # try complete-case route if the original criterion holds
    obs_complete <- complete.cases(X_labelled)
    n_complete <- sum(obs_complete)
    use_cc <- (n_complete / p >= 5)

    if (use_cc) {
      X_cc <- X_labelled[obs_complete, , drop = FALSE]
      y_cc <- y[obs_complete]
      beta_cc <- tryCatch(
        solve(t(X_cc) %*% X_cc, t(X_cc) %*% y_cc),
        error = function(e) MASS::ginv(t(X_cc) %*% X_cc) %*% t(X_cc) %*% y_cc
      )
      residuals_cc <- y_cc - X_cc %*% beta_cc
      sigma2_hat <- mean(residuals_cc^2)
      beta_for_weights <- beta_cc
    } else {
      # OSS initial fit (grouped)
      A <- matrix(0, p, p); b <- numeric(p)
      for (g in levels_g) {
        m <- groups[[g]]
        idx <- m$idx
        Oi <- m$Oi
        if (length(Oi) == 0) next
        Xg_obs <- X_labelled[idx, Oi, drop = FALSE]
        Pi <- m$Pi
        Vg <- Xg_obs %*% Pi
        A <- A + crossprod(Vg)
        b <- b + t(Vg) %*% y[idx]
      }
      beta_init <- tryCatch(solve(A, b), error = function(e) MASS::ginv(A) %*% b)

      # Compute sigma^2 vector per observation grouped
      sigma2_vec <- numeric(n_L)
      for (g in levels_g) {
        m <- groups[[g]]
        idx <- m$idx
        Oi <- m$Oi; Mi <- m$Mi
        if (length(Oi) == 0) {
          fitted <- rep(0, length(idx))
          term_miss <- if (!is.null(m$Sigma_M)) as.numeric(t(beta_init[Mi]) %*% m$Sigma_M %*% beta_init[Mi]) else 0
          sigma2_vec[idx] <- pmax((y[idx] - fitted)^2 - term_miss, 0)
          next
        }
        Xg_obs <- X_labelled[idx, Oi, drop = FALSE]
        Vg <- Xg_obs %*% m$Pi
        fitted <- as.numeric(Vg %*% beta_init)
        term_miss <- 0
        if (!is.null(m$Sigma_M) && length(m$Mi) > 0) {
          term_miss <- as.numeric(t(beta_init[Mi]) %*% m$Sigma_M %*% beta_init[Mi])
        }
        sigma2_vec[idx] <- pmax((y[idx] - fitted)^2 - term_miss, 0)
      }
      sigma2_hat <- median(sigma2_vec)
      if (!is.finite(sigma2_hat) || sigma2_hat <= 0) sigma2_hat <- 1
      beta_for_weights <- beta_init
      beta_cc <- NULL
    }

    # compute pattern weights
    weights <- setNames(rep(1, length(levels_g)), levels_g)
    for (g in levels_g) {
      m <- groups[[g]]
      Mi <- m$Mi
      if (length(Mi) > 0 && !is.null(m$Sigma_M)) {
        quad_term <- max(as.numeric(t(beta_for_weights[Mi]) %*% m$Sigma_M %*% beta_for_weights[Mi]), 0)
        weights[g] <- sigma2_hat / (sigma2_hat + quad_term)
      } else {
        weights[g] <- 1
      }
    }
  }

  # Step 2: final weighted OSS fit
  A <- matrix(0, p, p); b <- numeric(p)
  for (g in levels_g) {
    m <- groups[[g]]
    idx <- m$idx; Oi <- m$Oi
    if (length(Oi) == 0) next
    Xg_obs <- X_labelled[idx, Oi, drop = FALSE]
    Vg <- Xg_obs %*% m$Pi
    w <- weights[g]
    A <- A + w * crossprod(Vg)
    b <- b + w * t(Vg) %*% y[idx]
  }
  beta_hat <- tryCatch(solve(A, b), error = function(e) MASS::ginv(A) %*% b)

  if (!is.null(mm_colnames)) names(beta_hat) <- mm_colnames
  coef_df <- data.frame(Estimate = as.numeric(beta_hat))
  rownames(coef_df) <- names(beta_hat)

  invisible(list(
    coef = beta_hat,
    sigma2_hat = if (exists("sigma2_hat")) sigma2_hat else NA,
    weights = weights,
    groups = data.frame(row = which(labelled_idx), group = as.character(group_ids)),
    beta_cc = if (exists("beta_cc")) beta_cc else NULL
  ))
}

#' Precompute OSS metadata for cross-fitting
#'
#' @description
#' Internal helper used for two-fold cross-fitting. Estimates the second-moment
#' matrix using one fold (and optional unlabelled data), then constructs
#' missingness-pattern metadata for the other fold.
#'
#' @usage
#' precompute_oss_meta_crossfit(fold1, fold2, X_unlabelled = NULL)
#'
#' @param fold1
#' Numeric matrix whose rows define the target fold for which missingness
#' patterns and group-wise metadata are constructed.
#'
#' @param fold2
#' Numeric matrix whose rows are used to estimate the second-moment matrix.
#'
#' @param X_unlabelled
#' Optional numeric matrix of additional unlabelled rows to include when
#' estimating the second-moment matrix.
#'
#' @value
#' A list with the same structure as returned by \code{precompute_oss_meta},
#' containing:
#' \describe{
#'   \item{hatSigma}{Estimated second-moment matrix.}
#'   \item{groups}{List of group-wise missingness metadata.}
#'   \item{group_ids}{Factor assigning rows of \code{fold1} to groups.}
#'   \item{p}{Number of columns.}
#' }
#'
#' @details
#' This function is used internally by the cross-fitted OSS estimator to ensure
#' that nuisance quantities are estimated on data independent of the target
#' fold.
#'
#' @noRd
precompute_oss_meta_crossfit <- function(fold1, fold2, X_unlabelled = NULL) {
    # fold1: matrix (n1 x p) with NA for missing entries - rows for which we want groups/meta
    # fold2: matrix (n2 x p) with NA for missing entries - rows used to estimate hatSigma
    # X_unlabelled: optional matrix (nU x p) with NA to also include in hatSigma estimation
    # Returns: list(hatSigma = hatSigma, groups = groups, group_ids = group_ids, p = p)

    if (!is.matrix(fold1)) fold1 <- as.matrix(fold1)
    if (!is.matrix(fold2)) fold2 <- as.matrix(fold2)
    if (!is.null(X_unlabelled) && !is.matrix(X_unlabelled)) X_unlabelled <- as.matrix(X_unlabelled)
    if (ncol(fold1) != ncol(fold2)) stop("fold1 and fold2 must have the same number of columns")
    if (!is.null(X_unlabelled) && ncol(X_unlabelled) != ncol(fold1)) stop("X_unlabelled must have same number of columns as folds")

    # Build data used for hatSigma: fold2 combined with unlabelled (if provided)
    if (is.null(X_unlabelled)) {
      X_for_sigma <- fold2
    } else {
      X_for_sigma <- rbind(fold2, X_unlabelled)
    }

    # compute hatSigma using compute_hatSigma
    hatSigma <- compute_hatSigma(X_for_sigma)

    # Now compute groups / pattern meta for fold1 using this hatSigma
    obs_mat <- !is.na(fold1)
    keys <- apply(obs_mat, 1, function(r) paste0(as.integer(r), collapse = ""))
    group_ids <- factor(keys)
    levels_g <- levels(group_ids)
    p <- ncol(fold1)

    groups <- vector("list", length(levels_g))
    names(groups) <- levels_g

    for (g in levels_g) {
      idx <- which(group_ids == g)
      rep_row <- idx[1]
      Oi <- which(!is.na(fold1[rep_row, , drop = TRUE]))
      Mi <- setdiff(seq_len(p), Oi)
      meta <- list(idx = idx, Oi = Oi, Mi = Mi)

      if (length(Oi) > 0) {
        S_OO <- hatSigma[Oi, Oi, drop = FALSE]

        # Try Cholesky first
        ch <- tryCatch(chol(S_OO), error = function(e) NULL)
        if (!is.null(ch)) {
          inv_S_OO <- chol2inv(ch)
          Pi <- inv_S_OO %*% hatSigma[Oi, , drop = FALSE]
        } else {
          # Fall back to solve
          Pi <- tryCatch(
            solve(S_OO, hatSigma[Oi, , drop = FALSE]),
            error = function(e) {
              eps <- 1e-12 * max(1, mean(diag(S_OO)))
              inv_try <- tryCatch(solve(S_OO + eps * diag(length(Oi))),
                                  error = function(e2) MASS::ginv(S_OO + eps * diag(length(Oi))))
              inv_try %*% hatSigma[Oi, , drop = FALSE]
            }
          )
          inv_S_OO <- NULL
        }
        meta$Pi <- Pi

        # Precompute Sigma_M (Schur complement) if there are missing coords
        if (length(Mi) > 0) {
          S_MM <- hatSigma[Mi, Mi, drop = FALSE]
          S_MO <- hatSigma[Mi, Oi, drop = FALSE]
          if (!is.null(ch)) {
            inv_part <- inv_S_OO %*% t(S_MO)   # (|Oi| x |Mi|)
          } else {
            inv_S_OO_try <- tryCatch(solve(S_OO), error = function(e) MASS::ginv(S_OO))
            inv_part <- inv_S_OO_try %*% t(S_MO)
          }
          Sigma_M <- S_MM - S_MO %*% inv_part
          Sigma_M <- (Sigma_M + t(Sigma_M)) / 2
          meta$Sigma_M <- Sigma_M
        }
      }

      groups[[g]] <- meta
    }

    list(hatSigma = hatSigma, groups = groups, group_ids = group_ids, p = p)
  }


  #' Two-fold cross-fitted OSS estimator
  #'
  #' @description
  #' Internal two-fold cross-fitting wrapper around the core OSS estimator.
  #' Nuisance quantities are estimated on one fold and applied to the other,
  #' with final coefficients obtained by averaging across folds.
  #'
  #' @usage
  #' OSS_estimator_Crossfit(formula, data, all_weights_one = FALSE,
  #'                        precomputed_meta = NULL)
  #'
  #' @param formula
  #' A model formula specifying the linear regression.
  #'
  #' @param data
  #' A data.frame containing the variables in the model. Some response values
  #' may be missing and are treated as unlabelled observations.
  #'
  #' @param all_weights_one
  #' Logical; if TRUE, all missingness-pattern weights are set to one.
  #'
  #' @param precomputed_meta
  #' Ignored. Present only for compatibility with the non–cross-fitted
  #' estimator interface.
  #'
  #' @value
  #' An invisible list containing:
  #' \describe{
  #'   \item{coef}{Cross-fitted coefficient estimates.}
  #'   \item{beta_A, beta_B}{Fold-specific coefficient estimates.}
  #'   \item{sigma2_hat_A, sigma2_hat_B}{Fold-specific noise variance estimates.}
  #'   \item{weights_A, weights_B}{Fold-specific pattern weights.}
  #'   \item{foldA_rownums, foldB_rownums}{Row indices defining the two folds.}
  #' }
  #'
  #' @details
  #' This function is used internally to reduce bias arising from estimating
  #' nuisance quantities on the same data used for coefficient estimation.
  #' It is not intended to be called directly by users.
  #'
  #' @noRd

OSS_estimator_Crossfit <- function(formula, data, all_weights_one = FALSE, precomputed_meta = NULL) {
  if (missing(formula) || missing(data)) stop("Please provide formula and data arguments.")
  if (!inherits(formula, "formula")) stop("First argument must be a formula, e.g. y ~ x1 + x2")
  if (!is.data.frame(data)) stop("data must be a data.frame")
  if (!is.null(precomputed_meta)) {
    warning("precomputed_meta is ignored for cross-fitting (nuisance objects are re-estimated per fold).")
  }
  # Build model frame / model matrix exactly as in oss_estimator so column ordering matches
  mf <- model.frame(formula, data = data, na.action = na.pass)
  y_all <- model.response(mf)
  terms_obj <- attr(mf, "terms")
  X_all_mm <- model.matrix(terms_obj, mf)
  mm_colnames <- colnames(X_all_mm)

  labelled_idx <- !is.na(y_all)
  if (all(!labelled_idx)) stop("No labelled observations found (response is all NA).")
  nL <- sum(labelled_idx)
  if (nL < 2) stop("Need at least 2 labelled observations for two-fold cross-fitting.")

  # Extract labelled and (optional) unlabelled design matrices (in the same order as rows of data)
  X_labelled_full <- as.matrix(X_all_mm[labelled_idx, , drop = FALSE])   # rows in original order among labelled
  X_unlabelled <- if (any(!labelled_idx)) as.matrix(X_all_mm[!labelled_idx, , drop = FALSE]) else NULL

  # Split labelled rows into two folds (randomly)
  perm <- sample(nL)
  nA <- floor(nL / 2)
  if (nA < 1) nA <- 1
  foldA_pos <- sort(perm[1:nA])
  foldB_pos <- sort(perm[(nA + 1):nL])
  if (length(foldB_pos) < 1) {
    # if nL == 1 this would happen but we already guarded; still check
    stop("Could not split labelled data into two non-empty folds.")
  }

  # Map fold positions back to original data row numbers (so we can construct data_A / data_B)
  labelled_rownums <- which(labelled_idx)   # indices in original data of labelled rows
  foldA_rownums <- labelled_rownums[foldA_pos]
  foldB_rownums <- labelled_rownums[foldB_pos]

  # Build precomputed metas for each direction using the helper already defined
  meta_A_from_B <- tryCatch(
    precompute_oss_meta_crossfit(fold1 = X_labelled_full[foldA_pos, , drop = FALSE],
                                 fold2 = X_labelled_full[foldB_pos, , drop = FALSE],
                                 X_unlabelled = X_unlabelled),
    error = function(e) {
      stop("Failed to compute hatSigma/meta for fold A (using fold B). ",
           "This usually means some column pairs never co-occur in fold2+unlabelled and compute_hatSigma failed. ",
           "Try using more unlabelled rows or a different split. (orig error: ", e$message, ")")
    }
  )
  meta_B_from_A <- tryCatch(
    precompute_oss_meta_crossfit(fold1 = X_labelled_full[foldB_pos, , drop = FALSE],
                                 fold2 = X_labelled_full[foldA_pos, , drop = FALSE],
                                 X_unlabelled = X_unlabelled),
    error = function(e) {
      stop("Failed to compute hatSigma/meta for fold B (using fold A). ",
           "This usually means some column pairs never co-occur in fold2+unlabelled and compute_hatSigma failed. ",
           "Try using more unlabelled rows or a different split. (orig error: ", e$message, ")")
    }
  )

  # Sanity check: meta dimensions must match p
  p <- ncol(X_all_mm)
  if (!is.null(meta_A_from_B$p) && meta_A_from_B$p != p) stop("Meta (A_from_B) does not match model dimension p.")
  if (!is.null(meta_B_from_A$p) && meta_B_from_A$p != p) stop("Meta (B_from_A) does not match model dimension p.")

  # Construct data_A: keep original data but set response NA for labelled rows not in foldA
  data_A <- data
  response_colname <- all.vars(formula[[2]])[1]
  # model.frame puts response in first column
  # Note: preserve original data row indexing
  rows_to_na_A <- setdiff(labelled_rownums, foldA_rownums)
  if (length(rows_to_na_A) > 0) data_A[[response_colname]][rows_to_na_A] <- NA

  # Similarly construct data_B
  data_B <- data
  rows_to_na_B <- setdiff(labelled_rownums, foldB_rownums)
  if (length(rows_to_na_B) > 0) data_B[[response_colname]][rows_to_na_B] <- NA

  # Now call the existing OSS estimator on each target fold, passing the corresponding precomputed meta.
  res_A <- tryCatch(
    oss_estimator_core(formula = formula, data = data_A, all_weights_one = all_weights_one, precomputed_meta = meta_A_from_B),
    error = function(e) stop("oss_estimator failed on fold A: ", e$message)
  )
  res_B <- tryCatch(
    oss_estimator_core(formula = formula, data = data_B, all_weights_one = all_weights_one, precomputed_meta = meta_B_from_A),
    error = function(e) stop("oss_estimator failed on fold B: ", e$message)
  )

  beta_A <- as.numeric(res_A$coef)
  beta_B <- as.numeric(res_B$coef)
  if (length(beta_A) != p || length(beta_B) != p) stop("Returned beta length does not match model dimension p.")
  names(beta_A) <- names(beta_B) <- mm_colnames

  # Weighted average by fold size (more stable if folds unequal); this is a standard cross-fit aggregator.
  beta_cf <- (length(foldA_pos) * beta_A + length(foldB_pos) * beta_B) / (length(foldA_pos) + length(foldB_pos))
  names(beta_cf) <- mm_colnames

  # Print coefficients in same friendly format as oss_estimator
  coef_df <- data.frame(Estimate = as.numeric(beta_cf))
  rownames(coef_df) <- names(beta_cf)

  invisible(list(
    coef = beta_cf,
    beta_A = beta_A,
    beta_B = beta_B,
    sigma2_hat_A = res_A$sigma2_hat,
    sigma2_hat_B = res_B$sigma2_hat,
    weights_A = res_A$weights,
    weights_B = res_B$weights,
    foldA_rownums = foldA_rownums,
    foldB_rownums = foldB_rownums
  ))
}


#' Linear regression with missing data
#'
#' @description
#' Fits a linear regression model in the presence of missing covariates and/or
#' missing responses using the OSS (Ordinary Semi-Supervised) estimator. This corresponds
#' to Section 2 of \insertCite{RisebrowSSLR;textual}{LRMiss}.
#' The method exploits partially observed covariates and optionally unlabelled
#' observations to improve estimation efficiency. If sufficient complete cases are present,
#' the weights are estimated from them, otherwise the weights are estimated using an
#' initial consistent estimate. If sufficient unlabelled data is present the covariance
#' matrix is estimated exclusively from them, otherwise the covariance is estimated
#' elementwise.
#'
#' @usage
#' oss_estimator(formula, data,
#'               all_weights_one = FALSE,
#'               crossfitting = FALSE)
#'
#' @param formula
#' A model formula specifying the linear regression, e.g. \code{y ~ x1 + x2}.
#'
#' @param data
#' A data.frame containing the variables in the model. Rows with missing
#' responses are treated as unlabelled observations.
#'
#' @param all_weights_one
#' Logical; if TRUE, all missingness-pattern weights are set to one, yielding an
#' unweighted OSS estimator.
#'
#' @param crossfitting
#' Logical; if TRUE, a two-fold cross-fitted version of the OSS estimator is
#' used.
#'
#' @return
#' An invisible list with components:
#' \describe{
#'   \item{coef}{Numeric vector of estimated regression coefficients.}
#'   \item{sigma2_hat}{Estimated noise variance, or \code{NA} if not computed.}
#'   \item{weights}{Named vector of weights associated with each missingness pattern.}
#'   \item{groups}{Data.frame mapping labelled observations to missingness patterns.}
#'   \item{beta_cc}{Complete-case coefficient estimates if used, otherwise \code{NULL}.}
#' }
#'
#'
#'
#' @examples
#' dat <- data.frame(
#'   y = c(1.0, NA, 2.3, 0.5),
#'   x1 = rnorm(4),
#'   x2 = rnorm(4)
#' )
#'
#' ## Without cross-fitting
#' res <- oss_estimator(y ~ x1 + x2, dat)
#'
#' ## With cross-fitting
#' res_cf <- oss_estimator(y ~ x1 + x2, dat, crossfitting = TRUE)
#' @references
#' \insertRef{RisebrowSSLR}{LRMiss}
#' @export

oss_estimator <- function(formula, data, all_weights_one = FALSE, crossfitting = FALSE) {
  if (!is.logical(crossfitting) || length(crossfitting) != 1) stop("crossfitting must be TRUE or FALSE")
  if (!crossfitting) {
    # call the preserved original implementation
    return(oss_estimator_core(formula = formula, data = data, all_weights_one = all_weights_one))
  } else {
    return(OSS_estimator_Crossfit(formula = formula, data = data, all_weights_one = all_weights_one))
  }
}

