Skip to content

Commit

Permalink
Add sanity check for NaN values
Browse files Browse the repository at this point in the history
  • Loading branch information
dae committed Nov 30, 2023
1 parent 96ae7fc commit 7a45972
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 38 deletions.
3 changes: 2 additions & 1 deletion benches/benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ pub(crate) fn calc_mem(inf: &FSRS, past_reviews: usize) -> MemoryState {
delta_t: 21,
};
let reviews = repeat(review.clone()).take(past_reviews + 1).collect_vec();
inf.memory_state(FSRSItem { reviews }, None)
inf.memory_state(FSRSItem { reviews }, None).unwrap()
}

pub(crate) fn next_states(inf: &FSRS) -> NextStates {
Expand All @@ -31,6 +31,7 @@ pub(crate) fn next_states(inf: &FSRS) -> NextStates {
0.9,
21,
)
.unwrap()
}

pub fn criterion_benchmark(c: &mut Criterion) {
Expand Down
1 change: 1 addition & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ pub enum FSRSError {
Interrupted,
InvalidWeights,
OptimalNotFound,
InvalidInput,
}

pub type Result<T, E = FSRSError> = std::result::Result<T, E>;
96 changes: 60 additions & 36 deletions src/inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,11 @@ impl<B: Backend> FSRS<B> {
/// [FSRS::memory_state_from_sm2] for the first review (which should not be included
/// in FSRSItem). If not provided, the card starts as new.
/// Weights must have been provided when calling FSRS::new().
pub fn memory_state(&self, item: FSRSItem, starting_state: Option<MemoryState>) -> MemoryState {
pub fn memory_state(
&self,
item: FSRSItem,
starting_state: Option<MemoryState>,
) -> Result<MemoryState> {
let (time_history, rating_history) =
item.reviews.iter().map(|r| (r.delta_t, r.rating)).unzip();
let size = item.reviews.len();
Expand All @@ -81,9 +85,15 @@ impl<B: Backend> FSRS<B> {
Tensor::from_data(Data::new(rating_history, Shape { dims: [size] }).convert())
.unsqueeze()
.transpose();
self.model()
let state: MemoryState = self
.model()
.forward(time_history, rating_history, starting_state.map(Into::into))
.into()
.into();
if !state.stability.is_normal() || !state.difficulty.is_normal() {
Err(FSRSError::InvalidInput)
} else {
Ok(state)
}
}

/// If a card has incomplete learning history, memory state can be approximated from
Expand All @@ -94,7 +104,7 @@ impl<B: Backend> FSRS<B> {
ease_factor: f32,
interval: f32,
sm2_retention: f32,
) -> MemoryState {
) -> Result<MemoryState> {
let stability = interval.max(0.1) / (9.0 * (1.0 / sm2_retention - 1.0));
let w = &self.model().w;
let w8: f32 = w.get(8).into_scalar().elem();
Expand All @@ -103,9 +113,13 @@ impl<B: Backend> FSRS<B> {
let difficulty = 11.0
- (ease_factor - 1.0)
/ (w8.exp() * stability.powf(-w9) * ((1.0 - sm2_retention) * w10).exp_m1());
MemoryState {
stability,
difficulty: difficulty.clamp(1.0, 10.0),
if !stability.is_normal() || !difficulty.is_normal() {
Err(FSRSError::InvalidInput)
} else {
Ok(MemoryState {
stability,
difficulty: difficulty.clamp(1.0, 10.0),
})
}
}

Expand Down Expand Up @@ -134,35 +148,41 @@ impl<B: Backend> FSRS<B> {
current_memory_state: Option<MemoryState>,
desired_retention: f32,
days_elapsed: u32,
) -> NextStates {
) -> Result<NextStates> {
let delta_t = Tensor::from_data(Data::new(vec![days_elapsed.elem()], Shape { dims: [1] }));
let current_memory_state_tensors = current_memory_state.map(MemoryStateTensors::from);
let model = self.model();
let mut next_memory_states = (1..=4).map(|rating| {
if let (Some(current_memory_state), 0) = (current_memory_state, days_elapsed) {
// When there's an existing memory state and no days have elapsed, we leave it unchanged.
current_memory_state
} else {
MemoryState::from(model.step(
delta_t.clone(),
Tensor::from_data(Data::new(vec![rating.elem()], Shape { dims: [1] })),
current_memory_state_tensors.clone(),
))
}
Ok(
if let (Some(current_memory_state), 0) = (current_memory_state, days_elapsed) {
// When there's an existing memory state and no days have elapsed, we leave it unchanged.
current_memory_state
} else {
let state = MemoryState::from(model.step(
delta_t.clone(),
Tensor::from_data(Data::new(vec![rating.elem()], Shape { dims: [1] })),
current_memory_state_tensors.clone(),
));
if !state.stability.is_normal() || !state.difficulty.is_normal() {
return Err(FSRSError::InvalidInput);
}
state
},
)
});

let mut get_next_state = || {
let memory = next_memory_states.next().unwrap();
let memory = next_memory_states.next().unwrap()?;
let interval = next_interval(memory.stability, desired_retention);
ItemState { memory, interval }
Ok(ItemState { memory, interval })
};

NextStates {
again: get_next_state(),
hard: get_next_state(),
good: get_next_state(),
easy: get_next_state(),
}
Ok(NextStates {
again: get_next_state()?,
hard: get_next_state()?,
good: get_next_state()?,
easy: get_next_state()?,
})
}

/// Determine how well the model and weights predict performance.
Expand Down Expand Up @@ -343,7 +363,7 @@ mod tests {
};
let fsrs = FSRS::new(Some(WEIGHTS))?;
assert_eq!(
fsrs.memory_state(item, None),
fsrs.memory_state(item, None).unwrap(),
MemoryState {
stability: 51.344814,
difficulty: 7.005062
Expand All @@ -359,6 +379,7 @@ mod tests {
0.9,
21
)
.unwrap()
.good
.memory,
MemoryState {
Expand Down Expand Up @@ -420,9 +441,9 @@ mod tests {
],
};
let fsrs = FSRS::new(Some(WEIGHTS))?;
let state = fsrs.memory_state(item, None);
let state = fsrs.memory_state(item, None).unwrap();
assert_eq!(
fsrs.next_states(Some(state), 0.9, 21),
fsrs.next_states(Some(state), 0.9, 21).unwrap(),
NextStates {
again: ItemState {
memory: MemoryState {
Expand Down Expand Up @@ -462,12 +483,12 @@ mod tests {
fn states_are_unchaged_when_no_days_elapsed() -> Result<()> {
let fsrs = FSRS::new(Some(&[]))?;
// the first time a card is seen, a memory state must be set
let mut state_a = fsrs.next_states(None, 1.0, 0).again.memory;
let mut state_a = fsrs.next_states(None, 1.0, 0)?.again.memory;
// but if no days have elapsed and it's reviewed again, the state should be unchanged
let state_b = fsrs.next_states(Some(state_a), 1.0, 0).again.memory;
let state_b = fsrs.next_states(Some(state_a), 1.0, 0)?.again.memory;
assert_eq!(state_a, state_b);
// if a day elapses, it's counted
state_a = fsrs.next_states(Some(state_a), 1.0, 1).again.memory;
state_a = fsrs.next_states(Some(state_a), 1.0, 1)?.again.memory;
assert_ne!(state_a, state_b);

Ok(())
Expand All @@ -477,14 +498,14 @@ mod tests {
fn memory_from_sm2() -> Result<()> {
let fsrs = FSRS::new(Some(&[]))?;
assert_eq!(
fsrs.memory_state_from_sm2(2.5, 10.0, 0.9),
fsrs.memory_state_from_sm2(2.5, 10.0, 0.9).unwrap(),
MemoryState {
stability: 9.999995,
difficulty: 6.2652965
}
);
assert_eq!(
fsrs.memory_state_from_sm2(1.3, 20.0, 0.9),
fsrs.memory_state_from_sm2(1.3, 20.0, 0.9).unwrap(),
MemoryState {
stability: 19.99999,
difficulty: 9.956561
Expand All @@ -494,10 +515,13 @@ mod tests {
let ease_factor = 2.0;
let fsrs_factor = fsrs
.next_states(
Some(fsrs.memory_state_from_sm2(ease_factor, interval as f32, 0.9)),
Some(
fsrs.memory_state_from_sm2(ease_factor, interval as f32, 0.9)
.unwrap(),
),
0.9,
interval,
)
)?
.good
.memory
.stability
Expand Down
8 changes: 7 additions & 1 deletion src/training.rs
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ impl<B: Backend> FSRS<B> {
finish_progress();

let weight_sets = weight_sets?;
let average_weights = weight_sets
let average_weights: Vec<f32> = weight_sets
.iter()
.fold(vec![0.0; weight_sets[0].len()], |sum, weights| {
sum.par_iter().zip(weights).map(|(a, b)| a + b).collect()
Expand All @@ -308,6 +308,12 @@ impl<B: Backend> FSRS<B> {
.map(|&sum| sum / n_splits as f32)
.collect();

for weight in &average_weights {
if !weight.is_normal() {
return Err(FSRSError::InvalidInput);
}
}

Ok(average_weights)
}
}
Expand Down

0 comments on commit 7a45972

Please sign in to comment.