From 7bd2ac4e919c29b5ad085be9a9f81f84cd345c76 Mon Sep 17 00:00:00 2001 From: Jarrett Ye Date: Tue, 26 Dec 2023 20:35:42 +0800 Subject: [PATCH] update default weights (#145) --- src/inference.rs | 8 ++++---- src/model.rs | 14 +++++++------- src/optimal_retention.rs | 19 ++++++++++++------- src/pre_training.rs | 4 ++-- 4 files changed, 25 insertions(+), 20 deletions(-) diff --git a/src/inference.rs b/src/inference.rs index 8c652ac1..f6d00a89 100644 --- a/src/inference.rs +++ b/src/inference.rs @@ -19,8 +19,8 @@ pub(crate) const FACTOR: f64 = 19f64 / 81f64; pub type Weights = [f32]; pub static DEFAULT_WEIGHTS: [f32; 17] = [ - 0.27, 0.74, 1.3, 5.52, 5.1, 1.02, 0.78, 0.06, 1.57, 0.14, 0.94, 2.16, 0.06, 0.31, 1.34, 0.21, - 2.69, + 0.5614, 1.2546, 3.5878, 7.9731, 5.1043, 1.1303, 0.823, 0.0465, 1.629, 0.135, 1.0045, 2.132, + 0.0839, 0.3204, 1.3547, 0.219, 2.7849, ]; fn infer( @@ -410,7 +410,7 @@ mod tests { let metrics = fsrs.evaluate(items.clone(), |_| true).unwrap(); Data::from([metrics.log_loss, metrics.rmse_bins]) - .assert_approx_eq(&Data::from([0.213_643_97, 0.053_706_862]), 5); + .assert_approx_eq(&Data::from([0.205_166, 0.024_658]), 5); let fsrs = FSRS::new(Some(WEIGHTS))?; let metrics = fsrs.evaluate(items, |_| true).unwrap(); @@ -503,7 +503,7 @@ mod tests { fsrs.memory_state_from_sm2(2.5, 10.0, 0.9).unwrap(), MemoryState { stability: 9.999995, - difficulty: 6.6293178 + difficulty: 7.200902 } ); assert_eq!( diff --git a/src/model.rs b/src/model.rs index 6a5e5d7a..1f3ef963 100644 --- a/src/model.rs +++ b/src/model.rs @@ -278,7 +278,7 @@ mod tests { let stability = model.init_stability(rating); assert_eq!( stability.to_data(), - Data::from([0.27, 0.74, 1.3, 5.52, 0.27, 0.74]) + Data::from([0.5614, 1.2546, 3.5878, 7.9731, 0.5614, 1.2546]) ) } @@ -289,7 +289,7 @@ mod tests { let difficulty = model.init_difficulty(rating); assert_eq!( difficulty.to_data(), - Data::from([7.14, 6.12, 5.1, 4.08, 7.14, 6.12]) + Data::from([7.3649, 6.2346, 5.1043, 3.974, 7.3649, 6.2346]) ) } @@ -317,13 +317,13 @@ mod tests { next_difficulty.clone().backward(); assert_eq!( next_difficulty.to_data(), - Data::from([6.56, 5.7799997, 5.0, 4.2200003]) + Data::from([6.646, 5.823, 5.0, 4.177]) ); let next_difficulty = model.mean_reversion(next_difficulty); next_difficulty.clone().backward(); assert_eq!( next_difficulty.to_data(), - Data::from([6.4723997, 5.7391996, 5.006, 4.2728004]) + Data::from([6.574311, 5.7895803, 5.00485, 4.2201195]) ) } @@ -343,19 +343,19 @@ mod tests { s_recall.clone().backward(); assert_eq!( s_recall.to_data(), - Data::from([23.908455, 12.499619, 54.99991, 169.89117]) + Data::from([26.678038, 13.996968, 62.718544, 202.76956]) ); let s_forget = model.stability_after_failure(stability, difficulty, retention); s_forget.clone().backward(); assert_eq!( s_forget.to_data(), - Data::from([1.8343093, 2.0118992, 2.245103, 2.5231054]) + Data::from([1.8932177, 2.0453987, 2.2637987, 2.5304008]) ); let next_stability = s_recall.mask_where(rating.clone().equal_elem(1), s_forget); next_stability.clone().backward(); assert_eq!( next_stability.to_data(), - Data::from([1.8343093, 12.499619, 54.99991, 169.89117]) + Data::from([1.8932177, 13.996968, 62.718544, 202.76956]) ) } diff --git a/src/optimal_retention.rs b/src/optimal_retention.rs index 83d34a85..ce49eff7 100644 --- a/src/optimal_retention.rs +++ b/src/optimal_retention.rs @@ -673,13 +673,18 @@ mod tests { None, ) .0; - assert_eq!(memorization[memorization.len() - 1], 2380.9836436993573) + assert_eq!(memorization[memorization.len() - 1], 3022.055014122344) } #[test] fn simulate_with_existing_cards() { - let mut config = SimulatorConfig::default(); - config.learn_span = 10; + let config = SimulatorConfig { + learn_span: 30, + learn_limit: 60, + review_limit: 200, + max_cost_perday: f64::INFINITY, + ..Default::default() + }; let cards = vec![ Card { difficulty: 5.0, @@ -708,7 +713,7 @@ mod tests { fn simulate_with_learn_review_limit() { let mut config = SimulatorConfig::default(); config.learn_span = 30; - config.learn_limit = 50; + config.learn_limit = 60; config.review_limit = 200; config.max_cost_perday = f64::INFINITY; let results = simulate( @@ -721,8 +726,8 @@ mod tests { assert_eq!( results.1.to_vec(), vec![ - 0, 48, 57, 77, 86, 102, 119, 105, 133, 137, 141, 163, 151, 164, 157, 186, 174, 171, - 194, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200, 200 + 0, 16, 27, 29, 86, 73, 96, 95, 96, 105, 112, 113, 124, 131, 139, 124, 130, 141, + 162, 175, 168, 179, 186, 185, 198, 189, 200, 200, 200, 200 ] ); assert_eq!( @@ -736,7 +741,7 @@ mod tests { let config = SimulatorConfig::default(); let fsrs = FSRS::new(None)?; let optimal_retention = fsrs.optimal_retention(&config, &[], |_v| true).unwrap(); - assert_eq!(optimal_retention, 0.8568971936549108); + assert_eq!(optimal_retention, 0.864870726919112); assert!(fsrs.optimal_retention(&config, &[1.], |_v| true).is_err()); Ok(()) } diff --git a/src/pre_training.rs b/src/pre_training.rs index 7c9da41b..c6f07278 100644 --- a/src/pre_training.rs +++ b/src/pre_training.rs @@ -337,7 +337,7 @@ mod tests { let average_recall = calculate_average_recall(&items); let pretrainset = split_data(items, 1).0; Data::from(pretrain(pretrainset, average_recall).unwrap()) - .assert_approx_eq(&Data::from([1.001_276, 1.811_072, 4.405_640, 8.532_001]), 4) + .assert_approx_eq(&Data::from([1.001_276, 1.810_509, 4.401_906, 8.529_174]), 4) } #[test] @@ -350,6 +350,6 @@ mod tests { let mut rating_stability = HashMap::from([(2, 0.35)]); let rating_count = HashMap::from([(2, 1)]); let actual = smooth_and_fill(&mut rating_stability, &rating_count).unwrap(); - assert_eq!(actual, [0.1277027, 0.35, 0.6148648, 2.6108108,]); + assert_eq!(actual, [0.15661564, 0.35, 1.0009006, 2.2242827,]); } }