#' @title Calculate the Kaplan-Meier estimator and the Aalen-Johansen estimator
#' @description
#' Core estimation routine that computes a survfit-compatible object
#' from a formula + data interface (`Event()` or `survival::Surv()` on
#' the LHS, and a stratification variable on the RHS if necessary).
#' The back-end  C++ routine supports both weighted and stratified data. Use this
#' when you want **numbers only** (e.g. estimates, SEs, CIs and influence functions)
#' and will plot it yourself.
#'
#' @inheritParams cif-stat-arguments
#'
#' @param formula A model formula specifying the time-to-event outcome on the LHS
#'   (typically `Event(time, status)` or `survival::Surv(time, status)`)
#'   and, optionally, a stratification variable on the RHS.
#'   Unlike [cifplot()], this function does not accept a fitted survfit object.
#' @param report.influence.function Logical. When `TRUE` and `engine = "calculateAJ_Rcpp"`,
#' the influence function is also computed and returned (default `FALSE`).
#' @param report.survfit.std.err Logical. If `TRUE`, report SE on the log-survival
#' scale (survfit's convention). Otherwise SE is on the probability scale.
#' @param engine Character. One of `"auto"`, `"calculateKM"`, or `"calculateAJ_Rcpp"` (default `"calculateAJ_Rcpp"`).
#' @param prob.bound Numeric lower bound used to internally truncate probabilities away from 0 and 1 (default `1e-7`).
#'
#' @details
#'
#' ### Typical use cases
#' - When `outcome.type = "survival"`, this is a thin wrapper around
#' the KM estimator with the chosen variance / CI transformation.
#' - When `outcome.type = "competing-risk"`, this computes the AJ estimator of CIF for `code.event1`.
#'  The returned `$surv` is **1 - CIF**, i.e. in the format that ggsurvfit expects.
#' - Use [cifplot()] if you want to go straight to a figure; use [cifcurve()] if you only want the numbers.
#'
#' ### Risk set display
#' - Set `n.risk.type` to control whether `$n.risk` reflects weighted, unweighted,
#'   or Kish effective sample size (ESS) counts. This only affects the reported
#'   counts (e.g., for plotting or debugging) and leaves estimates and SEs unchanged.
#'
#' ### Standard error and confidence intervals
#'
#' | Argument | Description | Default |
#' |---|---|---|
#' | `error` | SE for KM: `"greenwood"`, `"tsiatis"`, `"if"`. For CIF: `"aalen"`, `"delta"`, `"if"`. | `"greenwood"`, `"delta"` or `"if"` |
#' | `conf.type` | Transformation for CIs: `"plain"`, `"log"`, `"log-log"`, `"arcsin"`, `"logit"`, or `"none"`. | `"arcsin"` |
#' | `conf.int` | Two-sided CI level. | `0.95` |
#'
#'
#' @return
#' A `"survfit"` object. For `outcome.type="survival"`, `$surv` is the survival function.
#' For `outcome.type="competing-risk"`, `$surv` equals `1 - CIF` for `code.event1`.
#' SE and CIs are provided per  `error`, `conf.type` and `conf.int`.
#' This enables an independent use of standard methods for `survfit` such as:
#' -   `summary()`: time-by-time estimates with SEs and CIs
#' -   `plot()`: base R stepwise survival/CIF curves
#' -   `mean()`: restricted mean survival estimates with CIs
#' -   `quantile()`: quantile estimates with CIs
#'
#' Note that `$n.risk`, `$n.event`, and `$n.censor` are rounded up to the nearest integer
#' regardless of whether the data is weighted or not.
#' Some methods (e.g. `residuals.survfit`) may not be supported.
#'
#' @examples
#' data(diabetes.complications)
#' output1 <- cifcurve(Event(t,epsilon) ~ fruitq,
#'                     data = diabetes.complications,
#'                     outcome.type="competing-risk")
#' cifplot(output1,
#'         outcome.type = "competing-risk",
#'         type.y = "risk",
#'         add.risktable = FALSE,
#'         label.y = "CIF of diabetic retinopathy",
#'         label.x = "Years from registration")
#'
#' @importFrom Rcpp sourceCpp
#' @importFrom Rcpp evalCpp
#' @importFrom stats formula
#'
#' @name cifcurve
#' @section Lifecycle:
#' \lifecycle{stable}
#'
#' @seealso [polyreg()] for log-odds product modeling of CIFs; [cifplot()] for display of a CIF; [cifpanel()] for display of multiple CIFs; [ggsurvfit][ggsurvfit], [patchwork][patchwork] and [modelsummary][modelsummary] for display helpers.
#' @export
cifcurve <- function(
    formula,
    data,
    weights = NULL,
    n.risk.type = "weighted",
    subset.condition = NULL,
    na.action = na.omit,
    outcome.type = c("survival","competing-risk"),
    code.event1 = 1,
    code.event2 = 2,
    code.censoring = 0,
    error = NULL,
    conf.type = "arcsine-square root",
    conf.int = 0.95,
    report.influence.function = FALSE,
    report.survfit.std.err = FALSE,
    engine = "calculateAJ_Rcpp",
    prob.bound = 1e-7
) {
  outcome.type <- util_check_outcome_type(outcome.type, formula = formula, data = data)
  n.risk.type <- util_check_n_risk_type(n.risk.type)
  substitute_weights <- substitute(weights)
  has_weights_user <- !(missing(weights) || identical(substitute_weights, quote(NULL)))
  out_read_surv <- eval(substitute(
    util_read_surv(
      formula = formula, data = data, weights = arg_weights,
      code.event1 = code.event1, code.event2 = code.event2, code.censoring = code.censoring,
      subset.condition = subset.condition, na.action = na.action
    ),
    list(arg_weights = substitute_weights)
  ))
  error <- curve_check_error(error, outcome.type, weights = out_read_surv$w, has_weights = has_weights_user)
  call <- match.call()

  strata_fac   <- as.factor(out_read_surv$strata)
  strata_lvls  <- levels(strata_fac)
  strata_var   <- out_read_surv$strata_name %||% NULL
  if (!is.null(strata_var)) {
    strata_fullnames <- paste0(strata_var, "=", strata_lvls)
  } else {
    strata_fullnames <- strata_lvls
  }

  epsilon_norm <- rep.int(0L, length(out_read_surv$epsilon))
  epsilon_norm[out_read_surv$epsilon == code.event1]    <- 1L
  epsilon_norm[out_read_surv$epsilon == code.event2]    <- 2L
  epsilon_norm[out_read_surv$epsilon == code.censoring] <- 0L

  if (identical(outcome.type, "survival") && identical(engine, "calculateKM")) {
    out_km <- calculateKM(out_read_surv$t, out_read_surv$d,
                          out_read_surv$w, as.integer(out_read_surv$strata), error)
    out_km$std.err <- out_km$surv * out_km$std.err
    out_ci <- calculateCI(out_km, conf.int, conf.type, conf.lower = NULL)
    if (isTRUE(report.survfit.std.err))
      out_km$std.err <- out_km$std.err / out_km$surv

    survfit_object <- list(
      time      = out_km$time,
      surv      = out_km$surv,
      n         = out_km$n,
      n.risk    = ceiling(out_km$n.risk),
      n.event   = ceiling(out_km$n.event),
      n.censor  = ceiling(out_km$n.censor),
      std.err   = out_km$std.err,
      std.err.cif = out_km$`std.err.cif`,
      upper     = if (is.null(conf.type) || conf.type %in% c("none","n")) NULL else out_ci$upper,
      lower     = if (is.null(conf.type) || conf.type %in% c("none","n")) NULL else out_ci$lower,
      conf.type = conf.type,
      call      = call,
      type      = "kaplan-meier",
      method    = "Kaplan-Meier"
    )
    if (any(as.integer(out_read_surv$strata) != 1)) {
      names(out_km$strata) <- strata_fullnames
      survfit_object$strata <- out_km$strata
    }
    survfit_object$n.risk.type <- n.risk.type
    survfit_object <- harmonize_engine_output(survfit_object)
    class(survfit_object) <- "survfit"
    return(survfit_object)

  } else if (identical(outcome.type, "competing-risk") && identical(engine, "calculateKM")) {
    out_aj <- calculateAJ(out_read_surv)
    names(out_aj$strata1) <- strata_fullnames

    if (any(as.integer(out_read_surv$strata) != 1)) {
      n <- table(as.integer(out_read_surv$strata))
      rep_list <- mapply(rep, n, out_aj$strata1, SIMPLIFY = FALSE)
      n.risk <- do.call(c, rep_list) -
        out_aj$n.cum.censor - out_aj$n.cum.event1 - out_aj$n.cum.event2
    } else {
      n <- length(out_read_surv$strata)
      n.risk <- n - out_aj$n.cum.censor - out_aj$n.cum.event1 - out_aj$n.cum.event2
    }

    std_err_cif <- calculateAalenDeltaSE(
      out_aj$time1, out_aj$aj1,
      out_aj$n.event1, out_aj$n.event2,
      n.risk,
      out_aj$time0, out_aj$km0, out_aj$strata1, error
    )
    out_aj$std.err <- std_err_cif
    out_aj$surv <- 1 - out_aj$aj1
    out_ci <- calculateCI(out_aj, conf.int, conf.type, conf.lower = NULL)
    if (isTRUE(report.survfit.std.err))
      out_aj$std.err <- out_aj$std.err / out_aj$surv

    survfit_object <- list(
      time        = out_aj$time1,
      surv        = out_aj$surv,
      n           = n,
      n.risk      = ceiling(n.risk),
      n.event     = ceiling(out_aj$n.event1),
      n.censor    = ceiling(out_aj$n.censor),
      std.err     = out_aj$std.err,
      std.err.cif = std_err_cif,
      upper       = if (is.null(conf.type) || conf.type %in% c("none","n")) NULL else out_ci$upper,
      lower       = if (is.null(conf.type) || conf.type %in% c("none","n")) NULL else out_ci$lower,
      conf.type   = conf.type,
      call        = call,
      type        = "aalen-johansen",
      method      = "aalen-johansen"
    )
    if (any(as.integer(out_read_surv$strata) != 1))
      survfit_object$strata <- out_aj$strata1

    survfit_object$n.risk.type <- n.risk.type
    survfit_object <- harmonize_engine_output(survfit_object)
    class(survfit_object) <- "survfit"
    return(survfit_object)
  }

  out_cpp <- call_calculateAJ_Rcpp(
    t = out_read_surv$t,
    epsilon = as.integer(epsilon_norm),
    w = if (has_weights_user) out_read_surv$w else NULL,
    strata = as.integer(out_read_surv$strata),
    error = error,
    conf.type = conf.type,
    conf.int = conf.int,
    report.influence.function = report.influence.function,
    prob.bound = prob.bound,
    n.risk.type = n.risk.type
  )
  if (length(strata_fullnames) && length(out_cpp$strata)) {
    names(out_cpp$strata) <- strata_fullnames
  }
  if (length(strata_lvls)) {
    out_cpp$`strata.levels` <- strata_lvls
  }
  out_cpp$call <- call
  out_cpp <- harmonize_engine_output(out_cpp)
  out_cpp$n.risk <- ceiling(out_cpp$n.risk)
  out_cpp$n.censor <- ceiling(out_cpp$n.censor)
  out_cpp$n.event <- ceiling(out_cpp$n.event)
  out_cpp$n.risk.type <- n.risk.type
  if (isTRUE(report.survfit.std.err)) out_cpp$std.err <- out_cpp$std.err / out_cpp$surv
  class(out_cpp) <- "survfit"
  return(out_cpp)
}

calculateAJ <- function(data) {
  out_km0 <- calculateKM(data$t, data$d0, data$w, as.integer(data$strata), "none")
  km0 <- util_get_surv(data$t, out_km0$surv, out_km0$time, as.integer(data$strata), out_km0$strata, out_km0$strata.levels)
  ip.weight <- (data$d0 == 0) * ifelse(km0 > 0, 1 / km0, 0)
  d1_ipw <- as.matrix(data$w * data$d1 * ip.weight)

  aj1 <- time1 <- integer(0)
  n.cum.event1 <- n.cum.event2 <- n.cum.censor <- numeric(0)
  n.event1 <- n.event2 <- n.censor <- numeric(0)
  strata1 <- integer(0)
  strata_vec <- as.integer(data$strata)

  for (level in sort(unique(strata_vec))) {
    idx <- which(strata_vec == level)

    sub_t  <- data$t[idx]
    sub_d0 <- data$d0[idx]
    sub_d1 <- data$d1[idx]
    sub_d2 <- data$d2[idx]
    sub_d1_ipw <- d1_ipw[idx, , drop = FALSE]

    o <- order(sub_t)
    sub_t  <- sub_t[o]
    sub_d0 <- sub_d0[o]
    sub_d1 <- sub_d1[o]
    sub_d2 <- sub_d2[o]
    sub_d1_ipw <- sub_d1_ipw[o, , drop = FALSE]

    not_atrisk <- outer(sub_t, sub_t, FUN = ">=")
    sub_aj1 <- as.vector(not_atrisk %*% sub_d1_ipw) / length(sub_t)
    sub_n.censor <- as.vector(not_atrisk %*% as.matrix(sub_d0))
    sub_n.event1 <- as.vector(not_atrisk %*% as.matrix(sub_d1))
    sub_n.event2 <- as.vector(not_atrisk %*% as.matrix(sub_d2))

    keep <- !duplicated(rev(sub_t))
    keep <- rev(keep)

    u_t  <- sub_t[keep]
    u_aj1 <- sub_aj1[keep]
    u_nc  <- sub_n.censor[keep]
    u_ne1 <- sub_n.event1[keep]
    u_ne2 <- sub_n.event2[keep]

    oo <- order(u_t)
    u_t  <- u_t[oo]
    u_aj1 <- u_aj1[oo]
    u_nc  <- u_nc[oo]
    u_ne1 <- u_ne1[oo]
    u_ne2 <- u_ne2[oo]

    inc_nc  <- c(u_nc[1],  diff(u_nc))
    inc_ne1 <- c(u_ne1[1], diff(u_ne1))
    inc_ne2 <- c(u_ne2[1], diff(u_ne2))

    time1 <- c(time1, u_t)
    aj1   <- c(aj1, u_aj1)

    n.cum.censor <- c(n.cum.censor, u_nc)
    n.cum.event1 <- c(n.cum.event1, u_ne1)
    n.cum.event2 <- c(n.cum.event2, u_ne2)

    n.censor <- c(n.censor, inc_nc)
    n.event1 <- c(n.event1, inc_ne1)
    n.event2 <- c(n.event2, inc_ne2)

    strata1 <- c(strata1, length(u_t))
  }

  list(
    time1        = time1,
    aj1          = aj1,
    n.event1     = n.event1,
    n.event2     = n.event2,
    n.censor     = n.censor,
    n.cum.event1 = n.cum.event1,
    n.cum.event2 = n.cum.event2,
    n.cum.censor = n.cum.censor,
    strata1      = strata1,
    time0        = out_km0$time,
    km0          = out_km0$surv
  )
}

curve_check_error <- function(x, outcome.type, weights = NULL, has_weights = NULL) {
  ot <- util_check_outcome_type(x = outcome.type, auto_message = FALSE)

  choices <- switch(
    ot,
    "survival"       = c("greenwood", "tsiatis", "if"),
    "competing-risk" = c("aalen", "delta", "if"),
    stop(sprintf("Invalid outcome.type: %s", outcome.type), call. = FALSE)
  )

  if (is.null(has_weights)) {
    has_weights <- !is.null(weights) && length(weights) > 0
  }

  fallback <- if (has_weights) {
    "if"
  } else if (ot == "survival") {
    "greenwood"
  } else {
    "delta"
  }

  if (is.null(x)) return(fallback)

  normalize_error <- function(z) {
    z <- tolower(trimws(as.character(z)))

    if (z %in% c("g", "greenwood")) return("greenwood")
    if (z %in% c("t", "tsiatis"))   return("tsiatis")
    if (z %in% c("if", "influence function", "influence_function",
                 "influence curve", "ic")) return("if")
    if (z %in% c("a", "aalen"))     return("aalen")
    if (z %in% c("d", "delta"))     return("delta")
    z
  }
  x_norm <- normalize_error(x)
  if (has_weights && x_norm != "if") {
    warning(
      sprintf("%s with weights: error='%s' is not supported; falling back to 'if'.",
              ot, as.character(x)),
      call. = FALSE
    )
    return("if")
  }

  if (x_norm %in% choices) return(x_norm)

  warning(
    sprintf("%s: unsupported error='%s'; falling back to '%s'.",
            ot, as.character(x), fallback),
    call. = FALSE
  )
  fallback
}


call_calculateAJ_Rcpp <- function(t, epsilon, w = NULL, strata = NULL,
                                   error = "greenwood",
                                   conf.type = "arcsin",
                                   conf.int = 0.95,
                                   report.influence.function = FALSE,
                                   prob.bound = 1e-5,
                                   n.risk.type = "weighted") {
  calculateAJ_Rcpp(
    t = t,
    epsilon = epsilon,
    w = w %||% numeric(),
    strata = strata %||% integer(),
    error = error,
    conf_type = conf.type,
    conf_int = conf.int,
    return_if = isTRUE(report.influence.function),
    prob_bound = prob.bound,
    n_risk_type = n.risk.type
  )
}

#' @keywords internal
harmonize_engine_output <- function(out) {
  n <- length(out$time %||% numeric())
  out$lower               <- out$lower               %||% out$low
  out$upper               <- out$upper               %||% out$high
  out$`std.err.cif`        <- out$`std.err.cif`        %||% rep(NA_real_, n)
  out$`influence.function` <- out$`influence.function` %||% list()
  out$strata               <- out$strata               %||% integer()
  out$`strata.levels`      <- out$`strata.levels`      %||% integer()
  out$type                 <- out$type                 %||% "kaplan-meier"
  out$method               <- out$method               %||% "Kaplan-Meier"
  out$`conf.type`          <- out$`conf.type`          %||% "arcsine-square root"
  out
}
