Skip to content

Commit

Permalink
improved MultipleComparison
Browse files Browse the repository at this point in the history
  • Loading branch information
qddyy committed Jan 7, 2024
1 parent 61b66f4 commit a4e58a3
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 39 deletions.
25 changes: 15 additions & 10 deletions R/MultiCompT.R
Original file line number Diff line number Diff line change
Expand Up @@ -38,24 +38,29 @@ MultiCompT <- R6Class(
.bonferroni = NULL,

.define = function() {
lengths <- vapply(
X = split(private$.data, names(private$.data)),
FUN = length, FUN.VALUE = integer(1), USE.NAMES = FALSE
)

if (private$.scoring == "none") {
N <- length(private$.data)
k <- as.integer(names(private$.data)[N])
private$.statistic_func <- function(x, y, data, group) {
mse <- sum(vapply(
X = split(data, group),
FUN = function(x) (length(x) - 1) * var(x),
private$.statistic_func <- function(i, j, data, group) {
means <- vapply(
X = split(data, group), FUN = mean,
FUN.VALUE = numeric(1), USE.NAMES = FALSE
)) / (N - k)
(mean(x) - mean(y)) / sqrt(
mse * (1 / length(x) + 1 / length(y))
)
mse <- sum((data - means[group])^2) / (N - k)
(means[i] - means[j]) / sqrt(
mse * (1 / lengths[i] + 1 / lengths[j])
)
}
} else {
var <- var(private$.data)
private$.statistic_func <- function(x, y, ...) {
(mean(x) - mean(y)) / sqrt(
var * (1 / length(x) + 1 / length(y))
private$.statistic_func <- function(i, j, data, group) {
(mean(data[group == i]) - mean(data[group == j])) / sqrt(
var * (1 / lengths[i] + 1 / lengths[j])
)
}
}
Expand Down
9 changes: 3 additions & 6 deletions R/MultipleComparison.R
Original file line number Diff line number Diff line change
Expand Up @@ -36,20 +36,17 @@ MultipleComparison <- R6Class(
.calculate_statistic = function() {
data <- unname(private$.data)
group <- as.integer(names(private$.data))
where <- split(seq_along(group), group)
private$.statistic <- as.numeric(.mapply(
FUN = function(i, j) {
private$.statistic_func(
data[where[[i]]], data[where[[j]]], data, group
)
private$.statistic_func(i, j, data, group)
}, dots = private$.group_ij, MoreArgs = NULL
))
},

.calculate_statistic_permu = function() {
private$.statistic_permu <- multicomp_pmt(
group_i = private$.group_ij$i - 1,
group_j = private$.group_ij$j - 1,
group_i = private$.group_ij$i,
group_j = private$.group_ij$j,
data = unname(private$.data),
group = as.integer(names(private$.data)),
statistic_func = private$.statistic_func,
Expand Down
25 changes: 15 additions & 10 deletions R/TukeyHSD.R
Original file line number Diff line number Diff line change
Expand Up @@ -34,24 +34,29 @@ TukeyHSD <- R6Class(
.name = "Tukey's HSD",

.define = function() {
lengths <- vapply(
X = split(private$.data, names(private$.data)),
FUN = length, FUN.VALUE = integer(1), USE.NAMES = FALSE
)

if (private$.scoring == "none") {
N <- length(private$.data)
k <- as.integer(names(private$.data)[N])
private$.statistic_func <- function(x, y, data, group) {
mse <- sum(vapply(
X = split(data, group),
FUN = function(x) (length(x) - 1) * var(x),
private$.statistic_func <- function(i, j, data, group) {
means <- vapply(
X = split(data, group), FUN = mean,
FUN.VALUE = numeric(1), USE.NAMES = FALSE
)) / (N - k)
(mean(x) - mean(y)) / sqrt(
mse / 2 * (1 / length(x) + 1 / length(y))
)
mse <- sum((data - means[group])^2) / (N - k)
(means[i] - means[j]) / sqrt(
mse / 2 * (1 / lengths[i] + 1 / lengths[j])
)
}
} else {
var <- var(private$.data)
private$.statistic_func <- function(x, y, ...) {
(mean(x) - mean(y)) / sqrt(
var / 2 * (1 / length(x) + 1 / length(y))
private$.statistic_func <- function(i, j, data, group) {
(mean(data[group == i]) - mean(data[group == j])) / sqrt(
var / 2 * (1 / lengths[i] + 1 / lengths[j])
)
}
}
Expand Down
17 changes: 4 additions & 13 deletions src/multicomp_pmt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,17 @@ NumericVector multicomp_pmt(
R_len_t n_group = group[group.size() - 1];
R_len_t n_pair = n_group * (n_group - 1) / 2;

List split(n_group);
auto do_split = [&]() {
for (R_len_t i = 1; i <= n_group; i++) {
split[i - 1] = data[group == i];
}
};

R_len_t j;
R_len_t k;

auto multicomp_statistic = [&]() -> double {
return as<double>(statistic_func(split[group_i[j]], split[group_j[j]], data, group));
return as<double>(statistic_func(group_i[k], group_j[k], data, group));
};

if (n_permu == 0) {
PermuBar bar(n_permutation(group), true, n_pair);

do {
do_split();
for (j = 0; j < n_pair; j++) {
for (k = 0; k < n_pair; k++) {
bar.update(multicomp_statistic());
};
} while (std::next_permutation(group.begin(), group.end()));
Expand All @@ -43,8 +35,7 @@ NumericVector multicomp_pmt(

do {
random_shuffle(group);
do_split();
for (j = 0; j < n_pair - 1; j++) {
for (k = 0; k < n_pair - 1; k++) {
bar.update(multicomp_statistic());
};
} while (bar.update(multicomp_statistic()));
Expand Down

0 comments on commit a4e58a3

Please sign in to comment.