Skip to content

Commit

Permalink
Feat/simulator with existing cards and limit (#141)
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock authored Dec 21, 2023
1 parent 458ed0a commit 8ff6ccf
Showing 1 changed file with 111 additions and 12 deletions.
123 changes: 111 additions & 12 deletions src/optimal_retention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ pub struct SimulatorConfig {
pub first_rating_prob: [f64; 4],
pub review_rating_prob: [f64; 3],
pub loss_aversion: f64,
pub learn_limit: usize,
pub review_limit: usize,
}

impl Default for SimulatorConfig {
Expand All @@ -70,6 +72,8 @@ impl Default for SimulatorConfig {
first_rating_prob: [0.15, 0.2, 0.6, 0.05],
review_rating_prob: [0.3, 0.6, 0.1],
loss_aversion: 2.5,
learn_limit: usize::MAX,
review_limit: usize::MAX,
}
}
}
Expand All @@ -90,7 +94,20 @@ fn stability_after_failure(w: &[f64], s: f64, r: f64, d: f64) -> f64 {
.clamp(0.1, s)
}

fn simulate(config: &SimulatorConfig, w: &[f64], desired_retention: f64, seed: Option<u64>) -> f64 {
struct Card {
pub difficulty: f64,
pub stability: f64,
pub last_date: f64,
pub due: f64,
}

fn simulate(
config: &SimulatorConfig,
w: &[f64],
desired_retention: f64,
seed: Option<u64>,
existing_cards: Option<Vec<Card>>,
) -> (Array1<f64>, Array1<usize>, Array1<usize>) {
let SimulatorConfig {
deck_size,
learn_span,
Expand All @@ -102,6 +119,8 @@ fn simulate(config: &SimulatorConfig, w: &[f64], desired_retention: f64, seed: O
first_rating_prob,
review_rating_prob,
loss_aversion,
learn_limit,
review_limit,
} = config.clone();
let mut card_table = Array2::zeros((Column::COUNT, deck_size));
card_table
Expand All @@ -110,8 +129,18 @@ fn simulate(config: &SimulatorConfig, w: &[f64], desired_retention: f64, seed: O
card_table.slice_mut(s![Column::Difficulty, ..]).fill(1e-10);
card_table.slice_mut(s![Column::Stability, ..]).fill(1e-10);

// let mut review_cnt_per_day = Array1::<f64>::zeros(learn_span);
// let mut learn_cnt_per_day = Array1::<f64>::zeros(learn_span);
// fill card table based on existing_cards
if let Some(existing_cards) = existing_cards {
for (i, card) in existing_cards.into_iter().enumerate() {
card_table[[Column::Difficulty as usize, i]] = card.difficulty;
card_table[[Column::Stability as usize, i]] = card.stability;
card_table[[Column::LastDate as usize, i]] = card.last_date;
card_table[[Column::Due as usize, i]] = card.due;
}
}

let mut review_cnt_per_day = Array1::<usize>::zeros(learn_span);
let mut learn_cnt_per_day = Array1::<usize>::zeros(learn_span);
let mut memorized_cnt_per_day = Array1::zeros(learn_span);

let first_rating_choices = [1, 2, 3, 4];
Expand Down Expand Up @@ -206,12 +235,18 @@ fn simulate(config: &SimulatorConfig, w: &[f64], desired_retention: f64, seed: O
cum_sum[i] = cum_sum[i - 1] + cost[i];
}

// Create 'true_review' mask based on 'need_review' and 'cum_sum'
// Create 'true_review' mask based on 'need_review' and 'cum_sum' and 'review_limit'
let mut review_count = 0;
let true_review =
Zip::from(&need_review)
.and(&cum_sum)
.map_collect(|&need_review_flag, &cum_cost| {
need_review_flag && (cum_cost <= max_cost_perday)
if need_review_flag {
review_count += 1;
}
need_review_flag
&& (cum_cost <= max_cost_perday)
&& (review_count <= review_limit)
});

let need_learn = old_due.mapv(|x| x == learn_span as f64);
Expand All @@ -229,12 +264,16 @@ fn simulate(config: &SimulatorConfig, w: &[f64], desired_retention: f64, seed: O

// dbg!(&cum_sum);

// Create 'true_learn' mask based on 'need_learn' and 'cum_sum'
// Create 'true_learn' mask based on 'need_learn' and 'cum_sum' and 'learn_limit'
let mut learn_count = 0;
let true_learn =
Zip::from(&need_learn)
.and(&cum_sum)
.map_collect(|&need_learn_flag, &cum_cost| {
need_learn_flag && (cum_cost <= max_cost_perday)
if need_learn_flag {
learn_count += 1;
}
need_learn_flag && (cum_cost <= max_cost_perday) && (learn_count <= learn_limit)
});

// Sample 'rating' for 'true_learn' entries
Expand Down Expand Up @@ -346,12 +385,12 @@ fn simulate(config: &SimulatorConfig, w: &[f64], desired_retention: f64, seed: O
.slice_mut(s![Column::Interval, ..])
.assign(&new_interval);
// Update the review_cnt_per_day, learn_cnt_per_day and memorized_cnt_per_day
// review_cnt_per_day[today] = true_review.iter().filter(|&&x| x).count() as f64;
// learn_cnt_per_day[today] = true_learn.iter().filter(|&&x| x).count() as f64;
review_cnt_per_day[today] = true_review.iter().filter(|&&x| x).count();
learn_cnt_per_day[today] = true_learn.iter().filter(|&&x| x).count();
memorized_cnt_per_day[today] = retrievability.sum();
}

memorized_cnt_per_day[memorized_cnt_per_day.len() - 1]
(memorized_cnt_per_day, review_cnt_per_day, learn_cnt_per_day)
}

fn sample<F>(
Expand All @@ -370,12 +409,15 @@ where
Ok((0..n)
.into_par_iter()
.map(|i| {
simulate(
let memorization = simulate(
config,
weights,
desired_retention,
Some((i + 42).try_into().unwrap()),
None,
)
.0;
memorization[memorization.len() - 1]
})
.sum::<f64>()
/ n as f64)
Expand Down Expand Up @@ -628,8 +670,65 @@ mod tests {
&DEFAULT_WEIGHTS.iter().map(|v| *v as f64).collect_vec(),
0.9,
None,
None,
)
.0;
assert_eq!(memorization[memorization.len() - 1], 2380.9836436993573)
}

#[test]
fn simulate_with_existing_cards() {
let mut config = SimulatorConfig::default();
config.learn_span = 10;
let cards = vec![
Card {
difficulty: 5.0,
stability: 5.0,
last_date: -5.0,
due: 0.0,
},
Card {
difficulty: 5.0,
stability: 2.0,
last_date: -2.0,
due: 0.0,
},
];
let memorization = simulate(
&config,
&DEFAULT_WEIGHTS.iter().map(|v| *v as f64).collect_vec(),
0.9,
None,
Some(cards),
);
assert_eq!(memorization, 2380.9836436993573)
dbg!(memorization);
}

#[test]
fn simulate_with_learn_review_limit() {
let mut config = SimulatorConfig::default();
config.learn_span = 30;
config.learn_limit = 50;
config.review_limit = 200;
config.max_cost_perday = f64::INFINITY;
let results = simulate(
&config,
&DEFAULT_WEIGHTS.iter().map(|v| *v as f64).collect_vec(),
0.9,
None,
None,
);
assert_eq!(
results.1.to_vec(),
vec![
0, 48, 57, 77, 86, 102, 119, 105, 133, 137, 141, 163, 151, 164, 157, 186, 174, 171,
194, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200
]
);
assert_eq!(
results.2.to_vec(),
vec![config.learn_limit; config.learn_span]
)
}

#[test]
Expand Down

0 comments on commit 8ff6ccf

Please sign in to comment.