diff --git a/__tests__/algorithm.test.ts b/__tests__/algorithm.test.ts index 9025276..c8a3020 100644 --- a/__tests__/algorithm.test.ts +++ b/__tests__/algorithm.test.ts @@ -460,3 +460,53 @@ describe('change Params', () => { }).toThrow('Requested retention rate should be in the range (0,1]') }) }) + +describe('next_state', () => { + it('next_state not NaN', () => { + const f = fsrs() + const next_state = f.next_state( + { stability: 0, difficulty: 0 }, + 1, + 1 /** Again */ + ) + + expect(Number.isNaN(next_state.stability)).toBe(false) + expect(next_state).toEqual(f.next_state(null, 1, 1 /** Again */)) + expect(next_state).toEqual( + f.next_state({ difficulty: 0, stability: 0 }, 1, 1 /** Again */) + ) + }) + + it('invalid memory state', () => { + const f = fsrs() + + const init = f.next_state(null, 0, 3 /** Good */) + // d<1 + expect(() => { + f.next_state( + { stability: init.stability, difficulty: 0 }, + 1, + 1 /** Again */ + ) + }).toThrow('invalid memory state') + + // s<0.01 + expect(() => { + f.next_state( + { stability: 0, difficulty: init.stability }, + 1, + 1 /** Again */ + ) + }).toThrow('invalid memory state') + + // g<0 + expect(() => { + f.next_state(init, 1, -1 /** invalid grade */) + }).toThrow('invalid memory state') + + // g>4 + expect(() => { + f.next_state(init, 1, 5 /** invalid grade */) + }).toThrow('invalid memory state') + }) +}) diff --git a/src/fsrs/algorithm.ts b/src/fsrs/algorithm.ts index 6f8ab13..75940f5 100644 --- a/src/fsrs/algorithm.ts +++ b/src/fsrs/algorithm.ts @@ -285,14 +285,20 @@ export class FSRSAlgorithm { * @returns The next state of memory with updated difficulty and stability. */ next_state(memory_state: FSRSState | null, t: number, g: number): FSRSState { - if (!memory_state) { + const { difficulty: d, stability: s } = memory_state ?? { + difficulty: 0, + stability: 0, + } + if (d === 0 && s === 0) { return { difficulty: this.init_difficulty(clamp(g, 1, 4)), stability: this.init_stability(clamp(g, 1, 4)), } } - const { difficulty: d, stability: s } = memory_state - if (g < 1 || g > 4) { + if (d < 1 || s < 0.01 || g < 0 || g > 4) { + throw new Error('invalid memory state') + } + if (g === 0) { return { difficulty: d, stability: s,