diff --git a/R/MultipleComparison.R b/R/MultipleComparison.R index 1b78f784..48b1dd96 100644 --- a/R/MultipleComparison.R +++ b/R/MultipleComparison.R @@ -98,7 +98,7 @@ MultipleComparison <- R6Class( group_i = private$.ij$i - 1, group_j = private$.ij$j - 1, data = unname(private$.data), - group = as.integer(names(private$.data)) - 1, + group = as.integer(names(private$.data)), statistic_func = private$.statistic_func, n_permu = as.integer(private$.n_permu) ) diff --git a/R/RcppExports.R b/R/RcppExports.R index eddef808..b4de7492 100644 --- a/R/RcppExports.R +++ b/R/RcppExports.R @@ -25,7 +25,7 @@ table_pmt <- function(row_loc, col_loc, statistic_func, n_permu) { .Call(`_LearnNonparam_table_pmt`, row_loc, col_loc, statistic_func, n_permu) } -twosample_pmt <- function(n_1, n_2, c_xy, statistic_func, n_permu) { - .Call(`_LearnNonparam_twosample_pmt`, n_1, n_2, c_xy, statistic_func, n_permu) +twosample_pmt <- function(data, where_y, statistic_func, n_permu) { + .Call(`_LearnNonparam_twosample_pmt`, data, where_y, statistic_func, n_permu) } diff --git a/R/TwoSampleTest.R b/R/TwoSampleTest.R index 5915f8e0..09c7a5e3 100644 --- a/R/TwoSampleTest.R +++ b/R/TwoSampleTest.R @@ -37,10 +37,11 @@ TwoSampleTest <- R6Class( }, .calculate_statistic_permu = function() { + n_1 <- length(private$.data$x) + n_2 <- length(private$.data$y) private$.statistic_permu <- twosample_pmt( - n_1 = length(private$.data$x), - n_2 = length(private$.data$y), - c_xy = c(private$.data$x, private$.data$y), + data = c(private$.data$x, private$.data$y), + where_y = rep.int(c(FALSE, TRUE), c(n_1, n_2)), statistic_func = private$.statistic_func, n_permu = as.integer(private$.n_permu) ) diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index a992ae39..f0d3b5d8 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -11,101 +11,100 @@ Rcpp::Rostream& Rcpp::Rcerr = Rcpp::Rcpp_cerr_get(); #endif // association_pmt -NumericVector association_pmt(NumericVector x, NumericVector y, Function statistic_func, int n_permu); +NumericVector association_pmt(const NumericVector x, NumericVector y, const Function statistic_func, const unsigned n_permu); RcppExport SEXP _LearnNonparam_association_pmt(SEXP xSEXP, SEXP ySEXP, SEXP statistic_funcSEXP, SEXP n_permuSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; - Rcpp::traits::input_parameter< NumericVector >::type x(xSEXP); + Rcpp::traits::input_parameter< const NumericVector >::type x(xSEXP); Rcpp::traits::input_parameter< NumericVector >::type y(ySEXP); - Rcpp::traits::input_parameter< Function >::type statistic_func(statistic_funcSEXP); - Rcpp::traits::input_parameter< int >::type n_permu(n_permuSEXP); + Rcpp::traits::input_parameter< const Function >::type statistic_func(statistic_funcSEXP); + Rcpp::traits::input_parameter< const unsigned >::type n_permu(n_permuSEXP); rcpp_result_gen = Rcpp::wrap(association_pmt(x, y, statistic_func, n_permu)); return rcpp_result_gen; END_RCPP } // ksample_pmt -NumericVector ksample_pmt(NumericVector data, IntegerVector group, Function statistic_func, int n_permu); +NumericVector ksample_pmt(const NumericVector data, IntegerVector group, const Function statistic_func, const unsigned n_permu); RcppExport SEXP _LearnNonparam_ksample_pmt(SEXP dataSEXP, SEXP groupSEXP, SEXP statistic_funcSEXP, SEXP n_permuSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; - Rcpp::traits::input_parameter< NumericVector >::type data(dataSEXP); + Rcpp::traits::input_parameter< const NumericVector >::type data(dataSEXP); Rcpp::traits::input_parameter< IntegerVector >::type group(groupSEXP); - Rcpp::traits::input_parameter< Function >::type statistic_func(statistic_funcSEXP); - Rcpp::traits::input_parameter< int >::type n_permu(n_permuSEXP); + Rcpp::traits::input_parameter< const Function >::type statistic_func(statistic_funcSEXP); + Rcpp::traits::input_parameter< const unsigned >::type n_permu(n_permuSEXP); rcpp_result_gen = Rcpp::wrap(ksample_pmt(data, group, statistic_func, n_permu)); return rcpp_result_gen; END_RCPP } // multicomp_pmt -NumericMatrix multicomp_pmt(IntegerVector group_i, IntegerVector group_j, NumericVector data, IntegerVector group, Function statistic_func, int n_permu); +NumericVector multicomp_pmt(const IntegerVector group_i, const IntegerVector group_j, const NumericVector data, IntegerVector group, const Function statistic_func, const unsigned n_permu); RcppExport SEXP _LearnNonparam_multicomp_pmt(SEXP group_iSEXP, SEXP group_jSEXP, SEXP dataSEXP, SEXP groupSEXP, SEXP statistic_funcSEXP, SEXP n_permuSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; - Rcpp::traits::input_parameter< IntegerVector >::type group_i(group_iSEXP); - Rcpp::traits::input_parameter< IntegerVector >::type group_j(group_jSEXP); - Rcpp::traits::input_parameter< NumericVector >::type data(dataSEXP); + Rcpp::traits::input_parameter< const IntegerVector >::type group_i(group_iSEXP); + Rcpp::traits::input_parameter< const IntegerVector >::type group_j(group_jSEXP); + Rcpp::traits::input_parameter< const NumericVector >::type data(dataSEXP); Rcpp::traits::input_parameter< IntegerVector >::type group(groupSEXP); - Rcpp::traits::input_parameter< Function >::type statistic_func(statistic_funcSEXP); - Rcpp::traits::input_parameter< int >::type n_permu(n_permuSEXP); + Rcpp::traits::input_parameter< const Function >::type statistic_func(statistic_funcSEXP); + Rcpp::traits::input_parameter< const unsigned >::type n_permu(n_permuSEXP); rcpp_result_gen = Rcpp::wrap(multicomp_pmt(group_i, group_j, data, group, statistic_func, n_permu)); return rcpp_result_gen; END_RCPP } // paired_pmt -NumericVector paired_pmt(int n, Function statistic_func, int n_permu); +NumericVector paired_pmt(const unsigned n, const Function statistic_func, const unsigned n_permu); RcppExport SEXP _LearnNonparam_paired_pmt(SEXP nSEXP, SEXP statistic_funcSEXP, SEXP n_permuSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; - Rcpp::traits::input_parameter< int >::type n(nSEXP); - Rcpp::traits::input_parameter< Function >::type statistic_func(statistic_funcSEXP); - Rcpp::traits::input_parameter< int >::type n_permu(n_permuSEXP); + Rcpp::traits::input_parameter< const unsigned >::type n(nSEXP); + Rcpp::traits::input_parameter< const Function >::type statistic_func(statistic_funcSEXP); + Rcpp::traits::input_parameter< const unsigned >::type n_permu(n_permuSEXP); rcpp_result_gen = Rcpp::wrap(paired_pmt(n, statistic_func, n_permu)); return rcpp_result_gen; END_RCPP } // rcbd_pmt -NumericVector rcbd_pmt(NumericMatrix data, Function statistic_func, int n_permu); +NumericVector rcbd_pmt(NumericMatrix data, const Function statistic_func, const unsigned n_permu); RcppExport SEXP _LearnNonparam_rcbd_pmt(SEXP dataSEXP, SEXP statistic_funcSEXP, SEXP n_permuSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; Rcpp::traits::input_parameter< NumericMatrix >::type data(dataSEXP); - Rcpp::traits::input_parameter< Function >::type statistic_func(statistic_funcSEXP); - Rcpp::traits::input_parameter< int >::type n_permu(n_permuSEXP); + Rcpp::traits::input_parameter< const Function >::type statistic_func(statistic_funcSEXP); + Rcpp::traits::input_parameter< const unsigned >::type n_permu(n_permuSEXP); rcpp_result_gen = Rcpp::wrap(rcbd_pmt(data, statistic_func, n_permu)); return rcpp_result_gen; END_RCPP } // table_pmt -NumericVector table_pmt(IntegerVector row_loc, IntegerVector col_loc, Function statistic_func, int n_permu); +NumericVector table_pmt(IntegerVector row_loc, const IntegerVector col_loc, const Function statistic_func, const unsigned n_permu); RcppExport SEXP _LearnNonparam_table_pmt(SEXP row_locSEXP, SEXP col_locSEXP, SEXP statistic_funcSEXP, SEXP n_permuSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; Rcpp::traits::input_parameter< IntegerVector >::type row_loc(row_locSEXP); - Rcpp::traits::input_parameter< IntegerVector >::type col_loc(col_locSEXP); - Rcpp::traits::input_parameter< Function >::type statistic_func(statistic_funcSEXP); - Rcpp::traits::input_parameter< int >::type n_permu(n_permuSEXP); + Rcpp::traits::input_parameter< const IntegerVector >::type col_loc(col_locSEXP); + Rcpp::traits::input_parameter< const Function >::type statistic_func(statistic_funcSEXP); + Rcpp::traits::input_parameter< const unsigned >::type n_permu(n_permuSEXP); rcpp_result_gen = Rcpp::wrap(table_pmt(row_loc, col_loc, statistic_func, n_permu)); return rcpp_result_gen; END_RCPP } // twosample_pmt -NumericVector twosample_pmt(int n_1, int n_2, NumericVector c_xy, Function statistic_func, int n_permu); -RcppExport SEXP _LearnNonparam_twosample_pmt(SEXP n_1SEXP, SEXP n_2SEXP, SEXP c_xySEXP, SEXP statistic_funcSEXP, SEXP n_permuSEXP) { +NumericVector twosample_pmt(const NumericVector data, LogicalVector where_y, const Function statistic_func, const unsigned n_permu); +RcppExport SEXP _LearnNonparam_twosample_pmt(SEXP dataSEXP, SEXP where_ySEXP, SEXP statistic_funcSEXP, SEXP n_permuSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; - Rcpp::traits::input_parameter< int >::type n_1(n_1SEXP); - Rcpp::traits::input_parameter< int >::type n_2(n_2SEXP); - Rcpp::traits::input_parameter< NumericVector >::type c_xy(c_xySEXP); - Rcpp::traits::input_parameter< Function >::type statistic_func(statistic_funcSEXP); - Rcpp::traits::input_parameter< int >::type n_permu(n_permuSEXP); - rcpp_result_gen = Rcpp::wrap(twosample_pmt(n_1, n_2, c_xy, statistic_func, n_permu)); + Rcpp::traits::input_parameter< const NumericVector >::type data(dataSEXP); + Rcpp::traits::input_parameter< LogicalVector >::type where_y(where_ySEXP); + Rcpp::traits::input_parameter< const Function >::type statistic_func(statistic_funcSEXP); + Rcpp::traits::input_parameter< const unsigned >::type n_permu(n_permuSEXP); + rcpp_result_gen = Rcpp::wrap(twosample_pmt(data, where_y, statistic_func, n_permu)); return rcpp_result_gen; END_RCPP } @@ -117,7 +116,7 @@ static const R_CallMethodDef CallEntries[] = { {"_LearnNonparam_paired_pmt", (DL_FUNC) &_LearnNonparam_paired_pmt, 3}, {"_LearnNonparam_rcbd_pmt", (DL_FUNC) &_LearnNonparam_rcbd_pmt, 3}, {"_LearnNonparam_table_pmt", (DL_FUNC) &_LearnNonparam_table_pmt, 4}, - {"_LearnNonparam_twosample_pmt", (DL_FUNC) &_LearnNonparam_twosample_pmt, 5}, + {"_LearnNonparam_twosample_pmt", (DL_FUNC) &_LearnNonparam_twosample_pmt, 4}, {NULL, NULL, 0} }; diff --git a/src/association_pmt.cpp b/src/association_pmt.cpp index 2c590a84..e413bb15 100644 --- a/src/association_pmt.cpp +++ b/src/association_pmt.cpp @@ -1,56 +1,46 @@ #include "utils.h" -#include -#include -#include using namespace Rcpp; -inline void association_do( - int i, - NumericVector x, - NumericVector y, - Function statistic_func, - NumericVector statistic_permu, - RObject bar) +void association_do( + unsigned& i, + const NumericVector& x, + const NumericVector& y, + const Function& statistic_func, + NumericVector& statistic_permu, + RObject& bar) { statistic_permu[i] = as(statistic_func(x, y)); - if (CLI_SHOULD_TICK) { - cli_progress_set(bar, i); - } + update_bar_and_i(bar, i); } // [[Rcpp::export]] NumericVector association_pmt( - NumericVector x, + const NumericVector x, NumericVector y, - Function statistic_func, - int n_permu) + const Function statistic_func, + const unsigned n_permu) { - int total; - if (n_permu == 0) { - total = n_permutation(y); - } else { - total = n_permu; - } - - NumericVector statistic_permu(total); - RObject bar = cli_progress_bar(total, NULL); + std::pair paired; + unsigned i = 0; if (n_permu == 0) { - int i = 0; + paired = statistic_permu_with_bar(n_permutation(y), true); + do { - association_do(i, x, y, statistic_func, statistic_permu, bar); - i++; + association_do(i, x, y, statistic_func, paired.first, paired.second); } while (std::next_permutation(y.begin(), y.end())); } else { - for (int i = 0; i < total; i++) { + paired = statistic_permu_with_bar(n_permu, false); + + while (i < n_permu) { random_shuffle(y); - association_do(i, x, y, statistic_func, statistic_permu, bar); + association_do(i, x, y, statistic_func, paired.first, paired.second); } } - cli_progress_done(bar); + cli_progress_done(paired.second);; - return statistic_permu; + return paired.first; } diff --git a/src/ksample_pmt.cpp b/src/ksample_pmt.cpp index 0fe3487b..634e0968 100644 --- a/src/ksample_pmt.cpp +++ b/src/ksample_pmt.cpp @@ -1,56 +1,46 @@ #include "utils.h" -#include -#include -#include using namespace Rcpp; -inline void ksample_do( - int i, - NumericVector data, - IntegerVector group, - Function statistic_func, - NumericVector statistic_permu, - RObject bar) +void ksample_do( + unsigned& i, + const NumericVector& data, + const IntegerVector& group, + const Function& statistic_func, + NumericVector& statistic_permu, + RObject& bar) { statistic_permu[i] = as(statistic_func(data, group)); - if (CLI_SHOULD_TICK) { - cli_progress_set(bar, i); - } + update_bar_and_i(bar, i); } // [[Rcpp::export]] NumericVector ksample_pmt( - NumericVector data, + const NumericVector data, IntegerVector group, - Function statistic_func, - int n_permu) + const Function statistic_func, + const unsigned n_permu) { - int total; - if (n_permu == 0) { - total = n_permutation(group); - } else { - total = n_permu; - } - - NumericVector statistic_permu(total); - RObject bar = cli_progress_bar(total, NULL); + std::pair paired; + unsigned i = 0; if (n_permu == 0) { - int i = 0; + paired = statistic_permu_with_bar(n_permutation(group), true); + do { - ksample_do(i, data, group, statistic_func, statistic_permu, bar); - i++; + ksample_do(i, data, group, statistic_func, paired.first, paired.second); } while (std::next_permutation(group.begin(), group.end())); } else { - for (int i = 0; i < total; i++) { + paired = statistic_permu_with_bar(n_permu, false); + + while (i < n_permu) { random_shuffle(group); - ksample_do(i, data, group, statistic_func, statistic_permu, bar); + ksample_do(i, data, group, statistic_func, paired.first, paired.second); } } - cli_progress_done(bar); + cli_progress_done(paired.second);; - return statistic_permu; + return paired.first; } diff --git a/src/multicomp_pmt.cpp b/src/multicomp_pmt.cpp index cb540f02..dd300ea7 100644 --- a/src/multicomp_pmt.cpp +++ b/src/multicomp_pmt.cpp @@ -1,71 +1,62 @@ #include "utils.h" -#include -#include -#include using namespace Rcpp; -inline void multicomp_do( - int i, int n, int n_pair, - IntegerVector group_i, - IntegerVector group_j, - NumericVector data, - IntegerVector group, - Function statistic_func, - NumericMatrix statistic_permu, - List split, RObject bar) +void multicomp_do( + unsigned& i, + const unsigned& n, + const unsigned& n_pair, + const IntegerVector& group_i, + const IntegerVector& group_j, + const NumericVector& data, + const IntegerVector& group, + const Function& statistic_func, + NumericVector& statistic_permu, + RObject& bar, List& split) { - for (int j = 0; j < n; j++) { - split[j] = data[group == j]; + for (unsigned j = 1; j <= n; j++) { + split[j - 1] = data[group == j]; } - for (int k = 0; k < n_pair; k++) { + for (unsigned k = 0; k < n_pair; k++) { statistic_permu(k, i) = as(statistic_func(split[group_i[k]], split[group_j[k]], data, group)); } - if (CLI_SHOULD_TICK) { - cli_progress_set(bar, i); - } + update_bar_and_i(bar, i); } // [[Rcpp::export]] -NumericMatrix multicomp_pmt( - IntegerVector group_i, - IntegerVector group_j, - NumericVector data, +NumericVector multicomp_pmt( + const IntegerVector group_i, + const IntegerVector group_j, + const NumericVector data, IntegerVector group, - Function statistic_func, - int n_permu) + const Function statistic_func, + const unsigned n_permu) { - int total; - if (n_permu == 0) { - total = n_permutation(group); - } else { - total = n_permu; - } - - int n_pair = group_i.size(); - - NumericMatrix statistic_permu(n_pair, total); - RObject bar = cli_progress_bar(total, NULL); - - int n = *std::max_element(group.begin(), group.end()) + 1; + unsigned n_pair = group_i.size(); + unsigned n = group[group.size() - 1]; List split(n); + std::pair paired; + + unsigned i = 0; if (n_permu == 0) { - int i = 0; + paired = statistic_permu_with_bar(n_permutation(group), true, n_pair); + do { - multicomp_do(i, n, n_pair, group_i, group_j, data, group, statistic_func, statistic_permu, split, bar); - i++; + multicomp_do(i, n, n_pair, group_i, group_j, data, group, statistic_func, paired.first, paired.second, split); } while (std::next_permutation(group.begin(), group.end())); } else { - for (int i = 0; i < total; i++) { + paired = statistic_permu_with_bar(n_permu, false, n_pair); + + while (i < n_permu) { random_shuffle(group); - multicomp_do(i, n, n_pair, group_i, group_j, data, group, statistic_func, statistic_permu, split, bar); + multicomp_do(i, n, n_pair, group_i, group_j, data, group, statistic_func, paired.first, paired.second, split); } } - cli_progress_done(bar); + cli_progress_done(paired.second);; - return statistic_permu; + return paired.first; } diff --git a/src/paired_pmt.cpp b/src/paired_pmt.cpp index 3a12cfbe..318dae0a 100644 --- a/src/paired_pmt.cpp +++ b/src/paired_pmt.cpp @@ -1,59 +1,52 @@ #include "utils.h" -#include -#include using namespace Rcpp; -inline void paired_do( - int i, - Function statistic_func, - NumericVector statistic_permu, - LogicalVector swapped, RObject bar) +void paired_do( + unsigned& i, + const unsigned& k, + const unsigned& n, + const Function& statistic_func, + NumericVector& statistic_permu, + RObject& bar, LogicalVector& swapped) { + for (unsigned j = 0; j < n; j++) { + swapped[j] = ((i & (1 << j)) != 0); + } + statistic_permu[i] = as(statistic_func(swapped)); - if (CLI_SHOULD_TICK) { - cli_progress_set(bar, i); - } + update_bar_and_i(bar, i); } // [[Rcpp::export]] NumericVector paired_pmt( - int n, - Function statistic_func, - int n_permu) + const unsigned n, + const Function statistic_func, + const unsigned n_permu) { - int total; - if (n_permu == 0) { - total = (1 << n); - } else { - total = n_permu; - } + LogicalVector swapped(n); - NumericVector statistic_permu(total); - RObject bar = cli_progress_bar(total, NULL); + unsigned total = (1 << n); - LogicalVector swapped(n); + std::pair paired; + unsigned i = 0; if (n_permu == 0) { - for (int i = 0; i < total; i++) { - for (int j = 0; j < n; j++) { - swapped[j] = ((i & (1 << j)) != 0); - } - paired_do(i, statistic_func, statistic_permu, swapped, bar); + paired = statistic_permu_with_bar(total, true); + + while (i < total) { + paired_do(i, i, n, statistic_func, paired.first, paired.second, swapped); } } else { - int r_int; - for (int i = 0; i < total; i++) { - r_int = rand_int(total); - for (int j = 0; j < n; j++) { - swapped[j] = ((r_int & (1 << j)) != 0); - } - paired_do(i, statistic_func, statistic_permu, swapped, bar); + paired = statistic_permu_with_bar(n_permu, false); + + while (i < n_permu) { + paired_do(i, rand_int(total), n, statistic_func, paired.first, paired.second, swapped); } } - cli_progress_done(bar); + cli_progress_done(paired.second);; - return statistic_permu; + return paired.first; } \ No newline at end of file diff --git a/src/rcbd_pmt.cpp b/src/rcbd_pmt.cpp index 0baf248e..30bd208d 100644 --- a/src/rcbd_pmt.cpp +++ b/src/rcbd_pmt.cpp @@ -1,51 +1,41 @@ #include "utils.h" -#include -#include -#include using namespace Rcpp; -inline void rcbd_do( - int i, - NumericMatrix data, - Function statistic_func, - NumericVector statistic_permu, - RObject bar) +void rcbd_do( + unsigned& i, + const NumericMatrix& data, + const Function& statistic_func, + NumericVector& statistic_permu, + RObject& bar) { statistic_permu[i] = as(statistic_func(data)); - if (CLI_SHOULD_TICK) { - cli_progress_set(bar, i); - } + update_bar_and_i(bar, i); } // [[Rcpp::export]] NumericVector rcbd_pmt( NumericMatrix data, - Function statistic_func, - int n_permu) + const Function statistic_func, + const unsigned n_permu) { - int n_col = data.ncol(); + unsigned n_col = data.ncol(); - int total = 1; + std::pair paired; + + unsigned i = 0; + unsigned j = 0; if (n_permu == 0) { - for (int k = 0; k < n_col; k++) { + unsigned total = 1; + for (unsigned k = 0; k < n_col; k++) { total *= n_permutation(data.column(k)); } - } else { - total = n_permu; - } - - NumericVector statistic_permu(total); - RObject bar = cli_progress_bar(total, NULL); + paired = statistic_permu_with_bar(total, true); - if (n_permu == 0) { - int i = 0; - int j = 0; while (j < n_col) { if (j == 0) { - rcbd_do(i, data, statistic_func, statistic_permu, bar); - i++; + rcbd_do(i, data, statistic_func, paired.first, paired.second); } if (std::next_permutation(data.column(j).begin(), data.column(j).end())) { @@ -55,15 +45,17 @@ NumericVector rcbd_pmt( } } } else { - for (int i = 0; i < total; i++) { - for (int j = 0; j < n_col; j++) { + paired = statistic_permu_with_bar(n_permu, false); + + while (i < n_permu) { + for (j = 0; j < n_col; j++) { random_shuffle(data.column(j)); } - rcbd_do(i, data, statistic_func, statistic_permu, bar); + rcbd_do(i, data, statistic_func, paired.first, paired.second); } } - cli_progress_done(bar); + cli_progress_done(paired.second);; - return statistic_permu; + return paired.first; } diff --git a/src/table_pmt.cpp b/src/table_pmt.cpp index b7ad8e03..5f9012dd 100644 --- a/src/table_pmt.cpp +++ b/src/table_pmt.cpp @@ -1,65 +1,56 @@ #include "utils.h" -#include -#include -#include using namespace Rcpp; -inline void table_do( - int i, int n, - IntegerVector row_loc, - IntegerVector col_loc, - Function statistic_func, - NumericVector statistic_permu, - IntegerMatrix data, RObject bar) +void table_do( + unsigned& i, + const unsigned& n, + const IntegerVector& row_loc, + const IntegerVector& col_loc, + const Function& statistic_func, + NumericVector& statistic_permu, + RObject& bar, IntegerMatrix& data) { - std::fill(data.begin(), data.end(), 0); + data.fill(0); - for (int j = 0; j < n; j++) { + for (unsigned j = 0; j < n; j++) { data(row_loc[j], col_loc[j])++; } statistic_permu[i] = as(statistic_func(data)); - if (CLI_SHOULD_TICK) { - cli_progress_set(bar, i); - } + update_bar_and_i(bar, i); } // [[Rcpp::export]] NumericVector table_pmt( IntegerVector row_loc, - IntegerVector col_loc, - Function statistic_func, - int n_permu) + const IntegerVector col_loc, + const Function statistic_func, + const unsigned n_permu) { - int total; - if (n_permu == 0) { - total = n_permutation(row_loc); - } else { - total = n_permu; - } - - NumericVector statistic_permu(total); - RObject bar = cli_progress_bar(total, NULL); - - int n = row_loc.size(); + unsigned n = row_loc.size(); IntegerMatrix data(row_loc[n - 1] + 1, col_loc[n - 1] + 1); + std::pair paired; + + unsigned i = 0; if (n_permu == 0) { - int i = 0; + paired = statistic_permu_with_bar(n_permutation(row_loc), true); + do { - table_do(i, n, row_loc, col_loc, statistic_func, statistic_permu, data, bar); - i++; + table_do(i, n, row_loc, col_loc, statistic_func, paired.first, paired.second, data); } while (std::next_permutation(row_loc.begin(), row_loc.end())); } else { - for (int i = 0; i < total; i++) { + paired = statistic_permu_with_bar(n_permu, false); + + while (i < n_permu) { random_shuffle(row_loc); - table_do(i, n, row_loc, col_loc, statistic_func, statistic_permu, data, bar); + table_do(i, n, row_loc, col_loc, statistic_func, paired.first, paired.second, data); } } - cli_progress_done(bar); + cli_progress_done(paired.second);; - return statistic_permu; + return paired.first; } diff --git a/src/twosample_pmt.cpp b/src/twosample_pmt.cpp index 473051e3..1f444c5c 100644 --- a/src/twosample_pmt.cpp +++ b/src/twosample_pmt.cpp @@ -1,72 +1,46 @@ #include "utils.h" -#include -#include -#include using namespace Rcpp; -int n_combination(int n, int k) +void twosample_do( + unsigned& i, + const NumericVector& data, + const LogicalVector& where_y, + const Function& statistic_func, + NumericVector& statistic_permu, + RObject& bar) { - double C = 1; + statistic_permu[i] = as(statistic_func(data[!where_y], data[where_y])); - for (int i = 1; i <= k; i++) { - C *= (i + n - k); - C /= i; - } - - return (int)C; -} - -inline void twosample_do( - int i, - NumericVector c_xy, - Function statistic_func, - NumericVector statistic_permu, - LogicalVector where_x, RObject bar) -{ - statistic_permu[i] = as(statistic_func(c_xy[where_x], c_xy[!where_x])); - - if (CLI_SHOULD_TICK) { - cli_progress_set(bar, i); - } + update_bar_and_i(bar, i); } // [[Rcpp::export]] NumericVector twosample_pmt( - int n_1, int n_2, - NumericVector c_xy, - Function statistic_func, - int n_permu) + const NumericVector data, + LogicalVector where_y, + const Function statistic_func, + const unsigned n_permu) { - int total; - if (n_permu == 0) { - total = n_combination(n_1 + n_2, std::min(n_1, n_2)); - } else { - total = n_permu; - } - - NumericVector statistic_permu(total); - RObject bar = cli_progress_bar(total, NULL); - - LogicalVector where_x(n_1 + n_2, FALSE); - for (int k = 0; k < n_1; k++) { - where_x[k] = TRUE; - } + std::pair paired; + unsigned i = 0; if (n_permu == 0) { - int i = 0; + paired = statistic_permu_with_bar(n_permutation(where_y), true); + do { - twosample_do(i, c_xy, statistic_func, statistic_permu, where_x, bar); - i++; - } while (std::prev_permutation(where_x.begin(), where_x.end())); + twosample_do(i, data, where_y, statistic_func, paired.first, paired.second); + } while (std::next_permutation(where_y.begin(), where_y.end())); } else { - for (int i = 0; i < total; i++) { - random_shuffle(where_x); - twosample_do(i, c_xy, statistic_func, statistic_permu, where_x, bar); + paired = statistic_permu_with_bar(n_permu, false); + + while (i < n_permu) { + random_shuffle(where_y); + twosample_do(i, data, where_y, statistic_func, paired.first, paired.second); } } - cli_progress_done(bar); + cli_progress_done(paired.second);; - return statistic_permu; + return paired.first; } diff --git a/src/utils.cpp b/src/utils.cpp new file mode 100644 index 00000000..3729f4a0 --- /dev/null +++ b/src/utils.cpp @@ -0,0 +1,22 @@ +#include "utils.h" + +std::pair statistic_permu_with_bar( + const unsigned n, const bool exact, + const unsigned statistic_size) +{ + RObject bar = cli_progress_bar(n, NULL); + cli_progress_set_type(bar, "iterator"); + + if (exact) { + cli_progress_set_name(bar, "Building exact permutation distribution"); + } else { + cli_progress_set_name(bar, "Sampling from exact permutation distribution"); + } + + NumericVector statistic_permu(no_init(n * statistic_size)); + if (statistic_size > 1) { + statistic_permu.attr("dim") = IntegerVector::create(statistic_size, n); + } + + return std::pair(statistic_permu, bar); +} \ No newline at end of file diff --git a/src/utils.h b/src/utils.h index 5fa493ad..f8bb545d 100644 --- a/src/utils.h +++ b/src/utils.h @@ -1,33 +1,56 @@ -#pragma once +#ifndef UTILS_H +#define UTILS_H #include #include +#include +#include -inline int rand_int(const int n) +using namespace Rcpp; + +// progress bar + +std::pair statistic_permu_with_bar( + const unsigned n, const bool exact, + const unsigned statistic_size = 1); + +inline void update_bar_and_i(RObject& bar, unsigned& i) +{ + if (CLI_SHOULD_TICK) { + cli_progress_set(bar, i); + } + i++; +} + +// random shuffle (tied to the same RNG which R uses) + +inline unsigned rand_int(const unsigned& n) { return floor(unif_rand() * n); } template -void random_shuffle(T v) +void random_shuffle(T&& v) { - int j; - int n = v.size(); - for (int i = 0; i < n - 1; i++) { + unsigned j; + unsigned n = v.size(); + for (unsigned i = 0; i < n - 1; i++) { j = i + rand_int(n - i); std::swap(v[i], v[j]); } } +// count + template -int n_permutation(T v) +unsigned n_permutation(T&& v) { double A = 1; - int n_i = 0; - int n = v.size(); + unsigned n_i = 0; + unsigned n = v.size(); double current = v[0]; - for (int i = 0; i < n; i++) { + for (unsigned i = 0; i < n; i++) { A *= (i + 1); if (v[i] == current) { n_i++; @@ -38,5 +61,7 @@ int n_permutation(T v) current = v[i]; } - return (int)A; + return (unsigned)A; } + +#endif