
// poLCAParallel
// Copyright (C) 2022 Sherman Lo

// This program is free software; you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation; either version 2 of the License, or
// (at your option) any later version.

// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.

// You should have received a copy of the GNU General Public License along
// with this program; if not, write to the Free Software Foundation, Inc.,
// 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.

#include <RcppArmadillo.h>

#include <array>
#include <cstddef>
#include <iterator>
#include <map>
#include <span>
#include <vector>

#include "goodness_fit.h"
#include "util.h"

/**
 * Function to be exported to R, goodness of fit statistics
 *
 * Get goodness of fit statistics given fitted probabilities
 *
 * @param responses Design matrix <b>transposed</b> of responses, matrix
 * containing outcomes/responses for each category as integers 1, 2, 3, ....
 * Missing values may be encoded as 0. The matrix has dimensions
 * <ul>
 *   <li>dim 0: for each category</li>
 *   <li>dim 1: for each data point</li>
 * </ul>
 * @param prior Vector of prior probabilities, for each cluster
 * @param outcome_prob Vector of response probabilities for each outcome,
 * conditioned on the category and cluster. Can be the return value of
 * <code>poLCAParallel.vectorize.R</code>. Flatten list in the following order
 * <ul>
 *   <li>dim 0: for each outcome</li>
 *   <li>dim 1: for each category</li>
 *   <li>dim 2: for each cluster</li>
 * </ul>
 * @param n_data Number of data points
 * @param n_outcomes_int Vector, number of possible responses for each category
 * @param n_cluster Number of clusters, or classes, to fit
 * @return List containing:
 * <ul>
 *   <li>
 *     <code>[[1]]</code>: unique_freq_table, a data frame of unique responses
 *     with their observed frequency and expected frequency
 *   </li>
 *   <li><code>[[2]]</code>: ln_l_ratio</li>
 *   <li><code>[[3]]</code>: chi_squared</li>
 * </ul>
 */
// [[Rcpp::export]]
Rcpp::List GoodnessFitRcpp(Rcpp::IntegerMatrix responses,
                           Rcpp::NumericVector prior,
                           Rcpp::NumericVector outcome_prob, std::size_t n_data,
                           Rcpp::IntegerVector n_outcomes_int,
                           std ::size_t n_cluster) {
  std::vector<std::size_t> n_outcomes_size_t(n_outcomes_int.cbegin(),
                                             n_outcomes_int.cend());
  polca_parallel::NOutcomes n_outcomes(n_outcomes_size_t.data(),
                                       n_outcomes_size_t.size());
  std::size_t n_category = n_outcomes.size();

  polca_parallel::GoodnessOfFit goodness_of_fit;
  goodness_of_fit.Calc(responses, prior, outcome_prob, n_data, n_outcomes,
                       n_cluster);

  std::map<std::vector<int>, polca_parallel::Frequency>& frequency_map =
      goodness_of_fit.GetFrequencyMap();

  // get log likelihood ratio and chi squared statistics
  auto [ln_l_ratio, chi_squared] = goodness_of_fit.GetStatistics();

  // transfer results from frequency_map to a NumericMatrix
  // frequency_table
  // last two columns for observed and expected frequency
  std::size_t n_unique = frequency_map.size();
  Rcpp::NumericMatrix frequency_table(n_unique, n_category + 2);
  auto freq_table_ptr = frequency_table.begin();

  std::size_t data_index = 0;
  for (auto iter = frequency_map.cbegin(); iter != frequency_map.cend();
       ++iter) {
    const std::vector<int>& response_i = iter->first;
    polca_parallel::Frequency frequency = iter->second;

    // copy over response
    for (std::size_t j = 0; j < n_category; ++j) {
      *std::next(freq_table_ptr, j * n_unique + data_index) = response_i[j];
    }
    // copy over observed and expected frequency
    *std::next(freq_table_ptr, n_category * n_unique + data_index) =
        static_cast<double>(frequency.observed);
    *std::next(freq_table_ptr, (n_category + 1) * n_unique + data_index) =
        frequency.expected;
    ++data_index;
  }

  Rcpp::List to_return;
  to_return.push_back(frequency_table);
  to_return.push_back(ln_l_ratio);
  to_return.push_back(chi_squared);

  return to_return;
}
