Skip to content

Commit

Permalink
Merge branch 'main' into Feat/improve-outlier-filter
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock authored Dec 26, 2023
2 parents b65366b + 7bd2ac4 commit 6b5bdeb
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 19 deletions.
8 changes: 4 additions & 4 deletions src/inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ pub(crate) const FACTOR: f64 = 19f64 / 81f64;
pub type Weights = [f32];

pub static DEFAULT_WEIGHTS: [f32; 17] = [
0.27, 0.74, 1.3, 5.52, 5.1, 1.02, 0.78, 0.06, 1.57, 0.14, 0.94, 2.16, 0.06, 0.31, 1.34, 0.21,
2.69,
0.5614, 1.2546, 3.5878, 7.9731, 5.1043, 1.1303, 0.823, 0.0465, 1.629, 0.135, 1.0045, 2.132,
0.0839, 0.3204, 1.3547, 0.219, 2.7849,
];

fn infer<B: Backend>(
Expand Down Expand Up @@ -410,7 +410,7 @@ mod tests {
let metrics = fsrs.evaluate(items.clone(), |_| true).unwrap();

Data::from([metrics.log_loss, metrics.rmse_bins])
.assert_approx_eq(&Data::from([0.213_643_97, 0.053_706_862]), 5);
.assert_approx_eq(&Data::from([0.205_166, 0.024_658]), 5);

let fsrs = FSRS::new(Some(WEIGHTS))?;
let metrics = fsrs.evaluate(items, |_| true).unwrap();
Expand Down Expand Up @@ -503,7 +503,7 @@ mod tests {
fsrs.memory_state_from_sm2(2.5, 10.0, 0.9).unwrap(),
MemoryState {
stability: 9.999995,
difficulty: 6.6293178
difficulty: 7.200902
}
);
assert_eq!(
Expand Down
14 changes: 7 additions & 7 deletions src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ mod tests {
let stability = model.init_stability(rating);
assert_eq!(
stability.to_data(),
Data::from([0.27, 0.74, 1.3, 5.52, 0.27, 0.74])
Data::from([0.5614, 1.2546, 3.5878, 7.9731, 0.5614, 1.2546])
)
}

Expand All @@ -289,7 +289,7 @@ mod tests {
let difficulty = model.init_difficulty(rating);
assert_eq!(
difficulty.to_data(),
Data::from([7.14, 6.12, 5.1, 4.08, 7.14, 6.12])
Data::from([7.3649, 6.2346, 5.1043, 3.974, 7.3649, 6.2346])
)
}

Expand Down Expand Up @@ -317,13 +317,13 @@ mod tests {
next_difficulty.clone().backward();
assert_eq!(
next_difficulty.to_data(),
Data::from([6.56, 5.7799997, 5.0, 4.2200003])
Data::from([6.646, 5.823, 5.0, 4.177])
);
let next_difficulty = model.mean_reversion(next_difficulty);
next_difficulty.clone().backward();
assert_eq!(
next_difficulty.to_data(),
Data::from([6.4723997, 5.7391996, 5.006, 4.2728004])
Data::from([6.574311, 5.7895803, 5.00485, 4.2201195])
)
}

Expand All @@ -343,19 +343,19 @@ mod tests {
s_recall.clone().backward();
assert_eq!(
s_recall.to_data(),
Data::from([23.908455, 12.499619, 54.99991, 169.89117])
Data::from([26.678038, 13.996968, 62.718544, 202.76956])
);
let s_forget = model.stability_after_failure(stability, difficulty, retention);
s_forget.clone().backward();
assert_eq!(
s_forget.to_data(),
Data::from([1.8343093, 2.0118992, 2.245103, 2.5231054])
Data::from([1.8932177, 2.0453987, 2.2637987, 2.5304008])
);
let next_stability = s_recall.mask_where(rating.clone().equal_elem(1), s_forget);
next_stability.clone().backward();
assert_eq!(
next_stability.to_data(),
Data::from([1.8343093, 12.499619, 54.99991, 169.89117])
Data::from([1.8932177, 13.996968, 62.718544, 202.76956])
)
}

Expand Down
19 changes: 12 additions & 7 deletions src/optimal_retention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -673,13 +673,18 @@ mod tests {
None,
)
.0;
assert_eq!(memorization[memorization.len() - 1], 2380.9836436993573)
assert_eq!(memorization[memorization.len() - 1], 3022.055014122344)
}

#[test]
fn simulate_with_existing_cards() {
let mut config = SimulatorConfig::default();
config.learn_span = 10;
let config = SimulatorConfig {
learn_span: 30,
learn_limit: 60,
review_limit: 200,
max_cost_perday: f64::INFINITY,
..Default::default()
};
let cards = vec![
Card {
difficulty: 5.0,
Expand Down Expand Up @@ -708,7 +713,7 @@ mod tests {
fn simulate_with_learn_review_limit() {
let mut config = SimulatorConfig::default();
config.learn_span = 30;
config.learn_limit = 50;
config.learn_limit = 60;
config.review_limit = 200;
config.max_cost_perday = f64::INFINITY;
let results = simulate(
Expand All @@ -721,8 +726,8 @@ mod tests {
assert_eq!(
results.1.to_vec(),
vec![
0, 48, 57, 77, 86, 102, 119, 105, 133, 137, 141, 163, 151, 164, 157, 186, 174, 171,
194, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200
0, 16, 27, 29, 86, 73, 96, 95, 96, 105, 112, 113, 124, 131, 139, 124, 130, 141,
162, 175, 168, 179, 186, 185, 198, 189, 200, 200, 200, 200
]
);
assert_eq!(
Expand All @@ -736,7 +741,7 @@ mod tests {
let config = SimulatorConfig::default();
let fsrs = FSRS::new(None)?;
let optimal_retention = fsrs.optimal_retention(&config, &[], |_v| true).unwrap();
assert_eq!(optimal_retention, 0.8568971936549108);
assert_eq!(optimal_retention, 0.864870726919112);
assert!(fsrs.optimal_retention(&config, &[1.], |_v| true).is_err());
Ok(())
}
Expand Down
3 changes: 2 additions & 1 deletion src/pre_training.rs
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,7 @@ mod tests {
let pretrainset = split_data(items, 1).0;
Data::from(pretrain(pretrainset, average_recall).unwrap())
.assert_approx_eq(&Data::from([1.001_276, 1.811_072, 4.405_64, 8.532_001]), 4)

}

#[test]
Expand All @@ -350,6 +351,6 @@ mod tests {
let mut rating_stability = HashMap::from([(2, 0.35)]);
let rating_count = HashMap::from([(2, 1)]);
let actual = smooth_and_fill(&mut rating_stability, &rating_count).unwrap();
assert_eq!(actual, [0.1277027, 0.35, 0.6148648, 2.6108108,]);
assert_eq!(actual, [0.15661564, 0.35, 1.0009006, 2.2242827,]);
}
}

0 comments on commit 6b5bdeb

Please sign in to comment.