Skip to content

Commit

Permalink
Factor sm2 retention into memory state calculation (#101)
Browse files Browse the repository at this point in the history
  • Loading branch information
dae authored Oct 12, 2023
1 parent 1886a00 commit 5089470
Showing 1 changed file with 12 additions and 6 deletions.
18 changes: 12 additions & 6 deletions src/inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,20 @@ impl<B: Backend> FSRS<B> {
/// If a card has incomplete learning history, memory state can be approximated from
/// current sm2 values.
/// Weights must have been provided when calling FSRS::new().
pub fn memory_state_from_sm2(&self, ease_factor: f32, interval: f32) -> MemoryState {
let stability = interval.max(0.1);
pub fn memory_state_from_sm2(
&self,
ease_factor: f32,
interval: f32,
sm2_retention: f32,
) -> MemoryState {
let stability = interval.max(0.1) * 0.9_f32.ln() / sm2_retention.ln();
let w = &self.model().w;
let w8: f32 = w.get(8).into_scalar().elem();
let w9: f32 = w.get(9).into_scalar().elem();
let w10: f32 = w.get(10).into_scalar().elem();
let difficulty = 11.0
- (ease_factor - 1.0) / (w8.exp() * stability.powf(-w9) * ((0.1 * w10).exp() - 1.0));
- (ease_factor - 1.0)
/ (w8.exp() * stability.powf(-w9) * (((1.0 - sm2_retention) * w10).exp() - 1.0));
MemoryState {
stability,
difficulty: difficulty.clamp(1.0, 10.0),
Expand Down Expand Up @@ -471,14 +477,14 @@ mod tests {
fn memory_from_sm2() -> Result<()> {
let fsrs = FSRS::new(Some(&[]))?;
assert_eq!(
fsrs.memory_state_from_sm2(2.5, 10.0),
fsrs.memory_state_from_sm2(2.5, 10.0, 0.9),
MemoryState {
stability: 10.0,
difficulty: 6.265295
}
);
assert_eq!(
fsrs.memory_state_from_sm2(1.3, 20.0),
fsrs.memory_state_from_sm2(1.3, 20.0, 0.9),
MemoryState {
stability: 20.0,
difficulty: 9.956561
Expand All @@ -488,7 +494,7 @@ mod tests {
let ease_factor = 2.0;
let fsrs_factor = fsrs
.next_states(
Some(fsrs.memory_state_from_sm2(ease_factor, interval as f32)),
Some(fsrs.memory_state_from_sm2(ease_factor, interval as f32, 0.9)),
0.9,
interval,
)
Expand Down

0 comments on commit 5089470

Please sign in to comment.