Skip to content

Commit

Permalink
improved cpp code
Browse files Browse the repository at this point in the history
  • Loading branch information
qddyy committed Nov 20, 2023
1 parent c9a2d00 commit 980cde0
Show file tree
Hide file tree
Showing 12 changed files with 259 additions and 265 deletions.
4 changes: 2 additions & 2 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ table_pmt <- function(row_loc, col_loc, statistic_func, n_permu) {
.Call(`_LearnNonparam_table_pmt`, row_loc, col_loc, statistic_func, n_permu)
}

twosample_pmt <- function(n_1, n_2, c_xy, statistic_func, n_permu) {
.Call(`_LearnNonparam_twosample_pmt`, n_1, n_2, c_xy, statistic_func, n_permu)
twosample_pmt <- function(data, where_y, statistic_func, n_permu) {
.Call(`_LearnNonparam_twosample_pmt`, data, where_y, statistic_func, n_permu)
}

7 changes: 4 additions & 3 deletions R/TwoSampleTest.R
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,11 @@ TwoSampleTest <- R6Class(
},

.calculate_statistic_permu = function() {
n_1 <- length(private$.data$x)
n_2 <- length(private$.data$y)
private$.statistic_permu <- twosample_pmt(
n_1 = length(private$.data$x),
n_2 = length(private$.data$y),
c_xy = c(private$.data$x, private$.data$y),
data = c(private$.data$x, private$.data$y),
where_y = rep.int(c(FALSE, TRUE), c(n_1, n_2)),
statistic_func = private$.statistic_func,
n_permu = as.integer(private$.n_permu)
)
Expand Down
67 changes: 33 additions & 34 deletions src/RcppExports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,101 +11,100 @@ Rcpp::Rostream<false>& Rcpp::Rcerr = Rcpp::Rcpp_cerr_get();
#endif

// association_pmt
NumericVector association_pmt(NumericVector x, NumericVector y, Function statistic_func, int n_permu);
NumericVector association_pmt(const NumericVector x, NumericVector y, const Function statistic_func, const unsigned n_permu);
RcppExport SEXP _LearnNonparam_association_pmt(SEXP xSEXP, SEXP ySEXP, SEXP statistic_funcSEXP, SEXP n_permuSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< NumericVector >::type x(xSEXP);
Rcpp::traits::input_parameter< const NumericVector >::type x(xSEXP);
Rcpp::traits::input_parameter< NumericVector >::type y(ySEXP);
Rcpp::traits::input_parameter< Function >::type statistic_func(statistic_funcSEXP);
Rcpp::traits::input_parameter< int >::type n_permu(n_permuSEXP);
Rcpp::traits::input_parameter< const Function >::type statistic_func(statistic_funcSEXP);
Rcpp::traits::input_parameter< const unsigned >::type n_permu(n_permuSEXP);
rcpp_result_gen = Rcpp::wrap(association_pmt(x, y, statistic_func, n_permu));
return rcpp_result_gen;
END_RCPP
}
// ksample_pmt
NumericVector ksample_pmt(NumericVector data, IntegerVector group, Function statistic_func, int n_permu);
NumericVector ksample_pmt(const NumericVector data, IntegerVector group, const Function statistic_func, const unsigned n_permu);
RcppExport SEXP _LearnNonparam_ksample_pmt(SEXP dataSEXP, SEXP groupSEXP, SEXP statistic_funcSEXP, SEXP n_permuSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< NumericVector >::type data(dataSEXP);
Rcpp::traits::input_parameter< const NumericVector >::type data(dataSEXP);
Rcpp::traits::input_parameter< IntegerVector >::type group(groupSEXP);
Rcpp::traits::input_parameter< Function >::type statistic_func(statistic_funcSEXP);
Rcpp::traits::input_parameter< int >::type n_permu(n_permuSEXP);
Rcpp::traits::input_parameter< const Function >::type statistic_func(statistic_funcSEXP);
Rcpp::traits::input_parameter< const unsigned >::type n_permu(n_permuSEXP);
rcpp_result_gen = Rcpp::wrap(ksample_pmt(data, group, statistic_func, n_permu));
return rcpp_result_gen;
END_RCPP
}
// multicomp_pmt
NumericMatrix multicomp_pmt(IntegerVector group_i, IntegerVector group_j, NumericVector data, IntegerVector group, Function statistic_func, int n_permu);
NumericVector multicomp_pmt(const IntegerVector group_i, const IntegerVector group_j, const NumericVector data, IntegerVector group, const Function statistic_func, const unsigned n_permu);
RcppExport SEXP _LearnNonparam_multicomp_pmt(SEXP group_iSEXP, SEXP group_jSEXP, SEXP dataSEXP, SEXP groupSEXP, SEXP statistic_funcSEXP, SEXP n_permuSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< IntegerVector >::type group_i(group_iSEXP);
Rcpp::traits::input_parameter< IntegerVector >::type group_j(group_jSEXP);
Rcpp::traits::input_parameter< NumericVector >::type data(dataSEXP);
Rcpp::traits::input_parameter< const IntegerVector >::type group_i(group_iSEXP);
Rcpp::traits::input_parameter< const IntegerVector >::type group_j(group_jSEXP);
Rcpp::traits::input_parameter< const NumericVector >::type data(dataSEXP);
Rcpp::traits::input_parameter< IntegerVector >::type group(groupSEXP);
Rcpp::traits::input_parameter< Function >::type statistic_func(statistic_funcSEXP);
Rcpp::traits::input_parameter< int >::type n_permu(n_permuSEXP);
Rcpp::traits::input_parameter< const Function >::type statistic_func(statistic_funcSEXP);
Rcpp::traits::input_parameter< const unsigned >::type n_permu(n_permuSEXP);
rcpp_result_gen = Rcpp::wrap(multicomp_pmt(group_i, group_j, data, group, statistic_func, n_permu));
return rcpp_result_gen;
END_RCPP
}
// paired_pmt
NumericVector paired_pmt(int n, Function statistic_func, int n_permu);
NumericVector paired_pmt(const unsigned n, const Function statistic_func, const unsigned n_permu);
RcppExport SEXP _LearnNonparam_paired_pmt(SEXP nSEXP, SEXP statistic_funcSEXP, SEXP n_permuSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< int >::type n(nSEXP);
Rcpp::traits::input_parameter< Function >::type statistic_func(statistic_funcSEXP);
Rcpp::traits::input_parameter< int >::type n_permu(n_permuSEXP);
Rcpp::traits::input_parameter< const unsigned >::type n(nSEXP);
Rcpp::traits::input_parameter< const Function >::type statistic_func(statistic_funcSEXP);
Rcpp::traits::input_parameter< const unsigned >::type n_permu(n_permuSEXP);
rcpp_result_gen = Rcpp::wrap(paired_pmt(n, statistic_func, n_permu));
return rcpp_result_gen;
END_RCPP
}
// rcbd_pmt
NumericVector rcbd_pmt(NumericMatrix data, Function statistic_func, int n_permu);
NumericVector rcbd_pmt(NumericMatrix data, const Function statistic_func, const unsigned n_permu);
RcppExport SEXP _LearnNonparam_rcbd_pmt(SEXP dataSEXP, SEXP statistic_funcSEXP, SEXP n_permuSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< NumericMatrix >::type data(dataSEXP);
Rcpp::traits::input_parameter< Function >::type statistic_func(statistic_funcSEXP);
Rcpp::traits::input_parameter< int >::type n_permu(n_permuSEXP);
Rcpp::traits::input_parameter< const Function >::type statistic_func(statistic_funcSEXP);
Rcpp::traits::input_parameter< const unsigned >::type n_permu(n_permuSEXP);
rcpp_result_gen = Rcpp::wrap(rcbd_pmt(data, statistic_func, n_permu));
return rcpp_result_gen;
END_RCPP
}
// table_pmt
NumericVector table_pmt(IntegerVector row_loc, IntegerVector col_loc, Function statistic_func, int n_permu);
NumericVector table_pmt(IntegerVector row_loc, const IntegerVector col_loc, const Function statistic_func, const unsigned n_permu);
RcppExport SEXP _LearnNonparam_table_pmt(SEXP row_locSEXP, SEXP col_locSEXP, SEXP statistic_funcSEXP, SEXP n_permuSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< IntegerVector >::type row_loc(row_locSEXP);
Rcpp::traits::input_parameter< IntegerVector >::type col_loc(col_locSEXP);
Rcpp::traits::input_parameter< Function >::type statistic_func(statistic_funcSEXP);
Rcpp::traits::input_parameter< int >::type n_permu(n_permuSEXP);
Rcpp::traits::input_parameter< const IntegerVector >::type col_loc(col_locSEXP);
Rcpp::traits::input_parameter< const Function >::type statistic_func(statistic_funcSEXP);
Rcpp::traits::input_parameter< const unsigned >::type n_permu(n_permuSEXP);
rcpp_result_gen = Rcpp::wrap(table_pmt(row_loc, col_loc, statistic_func, n_permu));
return rcpp_result_gen;
END_RCPP
}
// twosample_pmt
NumericVector twosample_pmt(int n_1, int n_2, NumericVector c_xy, Function statistic_func, int n_permu);
RcppExport SEXP _LearnNonparam_twosample_pmt(SEXP n_1SEXP, SEXP n_2SEXP, SEXP c_xySEXP, SEXP statistic_funcSEXP, SEXP n_permuSEXP) {
NumericVector twosample_pmt(const NumericVector data, LogicalVector where_y, const Function statistic_func, const unsigned n_permu);
RcppExport SEXP _LearnNonparam_twosample_pmt(SEXP dataSEXP, SEXP where_ySEXP, SEXP statistic_funcSEXP, SEXP n_permuSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< int >::type n_1(n_1SEXP);
Rcpp::traits::input_parameter< int >::type n_2(n_2SEXP);
Rcpp::traits::input_parameter< NumericVector >::type c_xy(c_xySEXP);
Rcpp::traits::input_parameter< Function >::type statistic_func(statistic_funcSEXP);
Rcpp::traits::input_parameter< int >::type n_permu(n_permuSEXP);
rcpp_result_gen = Rcpp::wrap(twosample_pmt(n_1, n_2, c_xy, statistic_func, n_permu));
Rcpp::traits::input_parameter< const NumericVector >::type data(dataSEXP);
Rcpp::traits::input_parameter< LogicalVector >::type where_y(where_ySEXP);
Rcpp::traits::input_parameter< const Function >::type statistic_func(statistic_funcSEXP);
Rcpp::traits::input_parameter< const unsigned >::type n_permu(n_permuSEXP);
rcpp_result_gen = Rcpp::wrap(twosample_pmt(data, where_y, statistic_func, n_permu));
return rcpp_result_gen;
END_RCPP
}
Expand All @@ -117,7 +116,7 @@ static const R_CallMethodDef CallEntries[] = {
{"_LearnNonparam_paired_pmt", (DL_FUNC) &_LearnNonparam_paired_pmt, 3},
{"_LearnNonparam_rcbd_pmt", (DL_FUNC) &_LearnNonparam_rcbd_pmt, 3},
{"_LearnNonparam_table_pmt", (DL_FUNC) &_LearnNonparam_table_pmt, 4},
{"_LearnNonparam_twosample_pmt", (DL_FUNC) &_LearnNonparam_twosample_pmt, 5},
{"_LearnNonparam_twosample_pmt", (DL_FUNC) &_LearnNonparam_twosample_pmt, 4},
{NULL, NULL, 0}
};

Expand Down
43 changes: 19 additions & 24 deletions src/association_pmt.cpp
Original file line number Diff line number Diff line change
@@ -1,50 +1,45 @@
#include "utils.h"
#include <Rcpp.h>
#include <algorithm>
#include <cli/progress.h>

using namespace Rcpp;

inline void association_do(
int i,
NumericVector x,
NumericVector y,
Function statistic_func,
NumericVector statistic_permu,
RObject bar)
unsigned& i,
const NumericVector& x,
const NumericVector& y,
const Function& statistic_func,
NumericVector& statistic_permu,
RObject& bar)
{
statistic_permu[i] = as<double>(statistic_func(x, y));

if (CLI_SHOULD_TICK) {
cli_progress_set(bar, i);
}
i++;
}

// [[Rcpp::export]]
NumericVector association_pmt(
NumericVector x,
const NumericVector x,
NumericVector y,
Function statistic_func,
int n_permu)
const Function statistic_func,
const unsigned n_permu)
{
int total;
if (n_permu == 0) {
total = n_permutation(y);
} else {
total = n_permu;
}

NumericVector statistic_permu(total);
RObject bar = cli_progress_bar(total, NULL);
RObject bar;
cli_progress_init_timer();
NumericVector statistic_permu;

unsigned i = 0;
if (n_permu == 0) {
int i = 0;
std::tie(statistic_permu, bar) = statistic_permu_with_bar(n_permutation(y), true);

do {
association_do(i, x, y, statistic_func, statistic_permu, bar);
i++;
} while (std::next_permutation(y.begin(), y.end()));
} else {
for (int i = 0; i < total; i++) {
std::tie(statistic_permu, bar) = statistic_permu_with_bar(n_permu, false);

while (i < n_permu) {
random_shuffle(y);
association_do(i, x, y, statistic_func, statistic_permu, bar);
}
Expand Down
43 changes: 19 additions & 24 deletions src/ksample_pmt.cpp
Original file line number Diff line number Diff line change
@@ -1,50 +1,45 @@
#include "utils.h"
#include <Rcpp.h>
#include <algorithm>
#include <cli/progress.h>

using namespace Rcpp;

inline void ksample_do(
int i,
NumericVector data,
IntegerVector group,
Function statistic_func,
NumericVector statistic_permu,
RObject bar)
unsigned& i,
const NumericVector& data,
const IntegerVector& group,
const Function& statistic_func,
NumericVector& statistic_permu,
RObject& bar)
{
statistic_permu[i] = as<double>(statistic_func(data, group));

if (CLI_SHOULD_TICK) {
cli_progress_set(bar, i);
}
i++;
}

// [[Rcpp::export]]
NumericVector ksample_pmt(
NumericVector data,
const NumericVector data,
IntegerVector group,
Function statistic_func,
int n_permu)
const Function statistic_func,
const unsigned n_permu)
{
int total;
if (n_permu == 0) {
total = n_permutation(group);
} else {
total = n_permu;
}

NumericVector statistic_permu(total);
RObject bar = cli_progress_bar(total, NULL);
RObject bar;
cli_progress_init_timer();
NumericVector statistic_permu;

unsigned i = 0;
if (n_permu == 0) {
int i = 0;
std::tie(statistic_permu, bar) = statistic_permu_with_bar(n_permutation(group), true);

do {
ksample_do(i, data, group, statistic_func, statistic_permu, bar);
i++;
} while (std::next_permutation(group.begin(), group.end()));
} else {
for (int i = 0; i < total; i++) {
std::tie(statistic_permu, bar) = statistic_permu_with_bar(n_permu, false);

while (i < n_permu) {
random_shuffle(group);
ksample_do(i, data, group, statistic_func, statistic_permu, bar);
}
Expand Down
Loading

0 comments on commit 980cde0

Please sign in to comment.