Skip to content

Commit

Permalink
Use array instead of HashMap (#140)
Browse files Browse the repository at this point in the history
* match arr

* clippy --fix

* use arr
  • Loading branch information
asukaminato0721 authored Dec 19, 2023
1 parent 3d785fc commit 458ed0a
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 65 deletions.
2 changes: 1 addition & 1 deletion src/inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ mod tests {
let metrics = fsrs.evaluate(items.clone(), |_| true).unwrap();

Data::from([metrics.log_loss, metrics.rmse_bins])
.assert_approx_eq(&Data::from([0.21364396810531616, 0.05370686203241348]), 5);
.assert_approx_eq(&Data::from([0.213_643_97, 0.053_706_862]), 5);

let fsrs = FSRS::new(Some(WEIGHTS))?;
let metrics = fsrs.evaluate(items, |_| true).unwrap();
Expand Down
104 changes: 40 additions & 64 deletions src/pre_training.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use crate::error::{FSRSError, Result};
use crate::inference::{DECAY, FACTOR};
use crate::FSRSItem;
use crate::DEFAULT_WEIGHTS;
use itertools::Itertools;
use ndarray::Array1;
use std::collections::HashMap;

Expand Down Expand Up @@ -200,7 +199,13 @@ fn smooth_and_fill(
.iter()
.cloned()
.collect::<HashMap<_, _>>();

let mut rating_stability_arr = [
None,
rating_stability.get(&1).cloned(),
rating_stability.get(&2).cloned(),
rating_stability.get(&3).cloned(),
rating_stability.get(&4).cloned(),
];
match rating_stability.len() {
0 => return Err(FSRSError::NotEnoughData),
1 => {
Expand All @@ -210,87 +215,64 @@ fn smooth_and_fill(
init_s0.sort_by(|a, b| a.partial_cmp(b).unwrap());
}
2 => {
match (
rating_stability.get(&1),
rating_stability.get(&2),
rating_stability.get(&3),
rating_stability.get(&4),
) {
(None, None, Some(&r3), Some(&r4)) => {
match rating_stability_arr {
[_, None, None, Some(r3), Some(r4)] => {
let r2 = r3.powf(1.0 / (1.0 - w2)) * r4.powf(1.0 - 1.0 / (1.0 - w2));
rating_stability.insert(2, r2);
rating_stability.insert(1, (r2.powf(1.0 / w1)) * (r3.powf(1.0 - 1.0 / w1)));
rating_stability_arr[2] = Some(r2);
rating_stability_arr[1] = Some(r2.powf(1.0 / w1) * r3.powf(1.0 - 1.0 / w1));
}
(None, Some(&r2), None, Some(&r4)) => {
[_, None, Some(r2), None, Some(r4)] => {
let r3 = r2.powf(1.0 - w2) * r4.powf(w2);
rating_stability.insert(3, r3);
rating_stability.insert(1, r2.powf(1.0 / w1) * r3.powf(1.0 - 1.0 / w1));
rating_stability_arr[3] = Some(r3);
rating_stability_arr[1] = Some(r2.powf(1.0 / w1) * r3.powf(1.0 - 1.0 / w1));
}
(None, Some(&r2), Some(&r3), None) => {
rating_stability.insert(4, r2.powf(1.0 - 1.0 / w2) * r3.powf(1.0 / w2));
rating_stability.insert(1, r2.powf(1.0 / w1) * r3.powf(1.0 - 1.0 / w1));
[_, None, Some(r2), Some(r3), None] => {
rating_stability_arr[4] = Some(r2.powf(1.0 - 1.0 / w2) * r3.powf(1.0 / w2));
rating_stability_arr[1] = Some(r2.powf(1.0 / w1) * r3.powf(1.0 - 1.0 / w1));
}
(Some(&r1), None, None, Some(&r4)) => {
[_, Some(r1), None, None, Some(r4)] => {
let r2 = r1.powf(w1 / (w1.mul_add(-w2, w1 + w2)))
* r4.powf(1.0 - w1 / (w1.mul_add(-w2, w1 + w2)));
rating_stability.insert(2, r2);
rating_stability.insert(
3,
rating_stability_arr[2] = Some(r2);
rating_stability_arr[3] = Some(
r1.powf(1.0 - w2 / (w1.mul_add(-w2, w1 + w2)))
* r4.powf(w2 / (w1.mul_add(-w2, w1 + w2))),
);
}
(Some(&r1), None, Some(&r3), None) => {
[_, Some(r1), None, Some(r3), None] => {
let r2 = r1.powf(w1) * r3.powf(1.0 - w1);
rating_stability.insert(2, r2);
rating_stability.insert(4, r2.powf(1.0 - 1.0 / w2) * r3.powf(1.0 / w2));
rating_stability_arr[2] = Some(r2);
rating_stability_arr[4] = Some(r2.powf(1.0 - 1.0 / w2) * r3.powf(1.0 / w2));
}
(Some(&r1), Some(&r2), None, None) => {
[_, Some(r1), Some(r2), None, None] => {
let r3 = r1.powf(1.0 - 1.0 / (1.0 - w1)) * r2.powf(1.0 / (1.0 - w1));
rating_stability.insert(3, r3);
rating_stability.insert(4, r2.powf(1.0 - 1.0 / w2) * r3.powf(1.0 / w2));
rating_stability_arr[3] = Some(r3);
rating_stability_arr[4] = Some(r2.powf(1.0 - 1.0 / w2) * r3.powf(1.0 / w2));
}
_ => {}
}
init_s0 = rating_stability
.iter()
.sorted_by(|a, b| a.0.cmp(b.0))
.map(|(_, &v)| v)
.collect();
init_s0 = rating_stability_arr.into_iter().flatten().collect();
}
3 => {
match (
rating_stability.get(&1),
rating_stability.get(&2),
rating_stability.get(&3),
rating_stability.get(&4),
) {
(None, Some(r2), Some(r3), _) => {
rating_stability.insert(1, r2.powf(1.0 / w1) * r3.powf(1.0 - 1.0 / w1));
match rating_stability_arr {
[_, None, Some(r2), Some(r3), _] => {
rating_stability_arr[1] = Some(r2.powf(1.0 / w1) * r3.powf(1.0 - 1.0 / w1));
}
(Some(r1), None, Some(r3), _) => {
rating_stability.insert(2, r1.powf(w1) * r3.powf(1.0 - w1));
[_, Some(r1), None, Some(r3), _] => {
rating_stability_arr[2] = Some(r1.powf(w1) * r3.powf(1.0 - w1));
}
(_, Some(r2), None, Some(r4)) => {
rating_stability.insert(3, r2.powf(1.0 - w2) * r4.powf(w2));
[_, _, Some(r2), None, Some(r4)] => {
rating_stability_arr[3] = Some(r2.powf(1.0 - w2) * r4.powf(w2));
}
(_, Some(r2), Some(r3), None) => {
rating_stability.insert(4, r2.powf(1.0 - 1.0 / w2) * r3.powf(1.0 / w2));
[_, _, Some(r2), Some(r3), None] => {
rating_stability_arr[4] = Some(r2.powf(1.0 - 1.0 / w2) * r3.powf(1.0 / w2));
}
_ => {}
}
init_s0 = rating_stability
.iter()
.sorted_by(|a, b| a.0.cmp(b.0))
.map(|(_, &v)| v)
.collect();
init_s0 = rating_stability_arr.into_iter().flatten().collect();
}
4 => {
init_s0 = rating_stability
.iter()
.sorted_by(|a, b| a.0.cmp(b.0))
.map(|(_, &v)| v)
.collect();
init_s0 = rating_stability_arr.into_iter().flatten().collect();
}
_ => {}
}
Expand Down Expand Up @@ -358,8 +340,7 @@ mod tests {
],
)]);
let actual = search_parameters(pretrainset, 0.9);
Data::from([actual.get(&4).unwrap().clone()])
.assert_approx_eq(&Data::from([1.2301323413848877]), 4);
Data::from([*actual.get(&4).unwrap()]).assert_approx_eq(&Data::from([1.230_132_3]), 4);
}

#[test]
Expand All @@ -369,12 +350,7 @@ mod tests {
let average_recall = calculate_average_recall(&items);
let pretrainset = split_data(items, 1).0;
Data::from(pretrain(pretrainset, average_recall).unwrap()).assert_approx_eq(
&Data::from([
0.9560174345970154,
1.694406509399414,
3.998023509979248,
8.26822280883789,
]),
&Data::from([0.956_017_43, 1.694_406_5, 3.998_023_5, 8.268_223]),
4,
)
}
Expand Down

0 comments on commit 458ed0a

Please sign in to comment.