Skip to content

Commit

Permalink
Merge branch 'main' into Feat/option-enable_short_term-in-training-
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock authored Jan 1, 2025
2 parents 06cc550 + a7aaa40 commit 29b010a
Show file tree
Hide file tree
Showing 7 changed files with 250 additions and 114 deletions.
8 changes: 3 additions & 5 deletions .github/workflows/check.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,7 @@

set -eux -o pipefail

cargo fmt --check || (
echo
echo "Please run 'cargo fmt' to format the code."
exit 1
)
cargo fmt --check

cargo clippy -- -Dwarnings

Expand All @@ -15,5 +11,7 @@ pushd tests/data/
wget https://github.com/open-spaced-repetition/fsrs-optimizer-burn/files/12394182/collection.anki21.zip
unzip *.zip

RUSTDOCFLAGS="-D warnings" cargo doc --release

cargo install cargo-llvm-cov --locked
SKIP_TRAINING=1 cargo llvm-cov --release
11 changes: 8 additions & 3 deletions src/batch_shuffle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,17 +105,22 @@ mod tests {
backend::{ndarray::NdArrayDevice, NdArray},
tensor::Shape,
};
use itertools::Itertools;

use super::*;
use crate::{
convertor_tests::anki21_sample_file_converted_to_fsrs, dataset::prepare_training_data,
convertor_tests::anki21_sample_file_converted_to_fsrs,
dataset::{constant_weighted_fsrs_items, prepare_training_data},
};

#[test]
fn test_simple_dataloader() {
let train_set = anki21_sample_file_converted_to_fsrs();
let train_set = anki21_sample_file_converted_to_fsrs()
.into_iter()
.sorted_by_cached_key(|item| item.reviews.len())
.collect();
let (_pre_train_set, train_set) = prepare_training_data(train_set);
let dataset = FSRSDataset::from(train_set);
let dataset = FSRSDataset::from(constant_weighted_fsrs_items(train_set));
let batch_size = 512;
let seed = 114514;
let device = NdArrayDevice::Cpu;
Expand Down
27 changes: 16 additions & 11 deletions src/convertor_tests.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::convertor_tests::RevlogReviewKind::*;
use crate::dataset::FSRSBatcher;
use crate::dataset::{constant_weighted_fsrs_items, FSRSBatcher};
use crate::dataset::{FSRSItem, FSRSReview};
use crate::optimal_retention::{RevlogEntry, RevlogReviewKind};
use crate::test_helpers::NdArrayAutodiff;
Expand Down Expand Up @@ -94,7 +94,7 @@ fn convert_to_fsrs_items(
mut entries: Vec<RevlogEntry>,
next_day_starts_at: i64,
timezone: Tz,
) -> Option<Vec<FSRSItem>> {
) -> Option<Vec<(i64, FSRSItem)>> {
// entries = filter_out_cram(entries);
// entries = filter_out_manual(entries);
entries = remove_revlog_before_last_first_learn(entries);
Expand All @@ -110,7 +110,7 @@ fn convert_to_fsrs_items(
.iter()
.enumerate()
.skip(1)
.map(|(idx, _)| {
.map(|(idx, entry)| {
let reviews = entries
.iter()
.take(idx + 1)
Expand All @@ -119,9 +119,9 @@ fn convert_to_fsrs_items(
delta_t: r.last_interval.max(0) as u32,
})
.collect();
FSRSItem { reviews }
(entry.id, FSRSItem { reviews })
})
.filter(|item| item.current().delta_t > 0)
.filter(|(_, item)| item.current().delta_t > 0)
.collect(),
)
}
Expand All @@ -137,8 +137,8 @@ pub(crate) fn anki_to_fsrs(revlogs: Vec<RevlogEntry>) -> Vec<FSRSItem> {
})
.flatten()
.collect_vec();
revlogs.sort_by_cached_key(|r| r.reviews.len());
revlogs
revlogs.sort_by_cached_key(|(id, _)| *id);
revlogs.into_iter().map(|(_, item)| item).collect()
}

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
Expand Down Expand Up @@ -256,10 +256,11 @@ fn conversion_works() {
);

// convert a subset and check it matches expectations
let mut fsrs_items = single_card_revlog
let fsrs_items = single_card_revlog
.into_iter()
.filter_map(|entries| convert_to_fsrs_items(entries, 4, Tz::Asia__Shanghai))
.flatten()
.map(|(_, item)| item)
.collect_vec();
assert_eq!(
fsrs_items,
Expand Down Expand Up @@ -387,9 +388,11 @@ fn conversion_works() {
]
);

let mut weighted_fsrs_items = constant_weighted_fsrs_items(fsrs_items);

let device = NdArrayDevice::Cpu;
let batcher = FSRSBatcher::<NdArrayAutodiff>::new(device);
let res = batcher.batch(vec![fsrs_items.pop().unwrap()]);
let res = batcher.batch(vec![weighted_fsrs_items.pop().unwrap()]);
assert_eq!(res.delta_ts.into_scalar(), 64.0);
assert_eq!(
res.r_historys.squeeze(1).to_data(),
Expand Down Expand Up @@ -443,7 +446,8 @@ fn delta_t_is_correct() -> Result<()> {
],
NEXT_DAY_AT,
Tz::Asia__Shanghai
),
)
.map(|items| items.into_iter().map(|(_, item)| item).collect_vec()),
Some(vec![FSRSItem {
reviews: vec![
FSRSReview {
Expand All @@ -468,7 +472,8 @@ fn delta_t_is_correct() -> Result<()> {
],
NEXT_DAY_AT,
Tz::Asia__Shanghai
),
)
.map(|items| items.into_iter().map(|(_, item)| item).collect_vec()),
Some(vec![
FSRSItem {
reviews: vec![
Expand Down
103 changes: 76 additions & 27 deletions src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,19 @@ pub struct FSRSItem {
pub reviews: Vec<FSRSReview>,
}

#[derive(Debug, Clone)]
pub(crate) struct WeightedFSRSItem {
pub weight: f32,
pub item: FSRSItem,
}

#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq)]
pub struct FSRSReview {
/// 1-4
pub rating: u32,
/// The number of days that passed
/// # Warning
/// [`delta_t`] for item first(initial) review must be 0
/// `delta_t` for item first(initial) review must be 0
pub delta_t: u32,
}

Expand Down Expand Up @@ -88,22 +94,26 @@ pub(crate) struct FSRSBatch<B: Backend> {
pub r_historys: Tensor<B, 2, Float>,
pub delta_ts: Tensor<B, 1, Float>,
pub labels: Tensor<B, 1, Int>,
pub weights: Tensor<B, 1, Float>,
}

impl<B: Backend> Batcher<FSRSItem, FSRSBatch<B>> for FSRSBatcher<B> {
fn batch(&self, items: Vec<FSRSItem>) -> FSRSBatch<B> {
let pad_size = items
impl<B: Backend> Batcher<WeightedFSRSItem, FSRSBatch<B>> for FSRSBatcher<B> {
fn batch(&self, weighted_items: Vec<WeightedFSRSItem>) -> FSRSBatch<B> {
let pad_size = weighted_items
.iter()
.map(|x| x.reviews.len())
.map(|x| x.item.reviews.len())
.max()
.expect("FSRSItem is empty")
- 1;

let (time_histories, rating_histories) = items
let (time_histories, rating_histories) = weighted_items
.iter()
.map(|item| {
let (mut delta_t, mut rating): (Vec<_>, Vec<_>) =
item.history().map(|r| (r.delta_t, r.rating)).unzip();
.map(|weighted_item| {
let (mut delta_t, mut rating): (Vec<_>, Vec<_>) = weighted_item
.item
.history()
.map(|r| (r.delta_t, r.rating))
.unzip();
delta_t.resize(pad_size, 0);
rating.resize(pad_size, 0);
let delta_t = Tensor::from_data(
Expand All @@ -130,19 +140,23 @@ impl<B: Backend> Batcher<FSRSItem, FSRSBatch<B>> for FSRSBatcher<B> {
})
.unzip();

let (delta_ts, labels) = items
let (delta_ts, labels, weights) = weighted_items
.iter()
.map(|item| {
let current = item.current();
let delta_t = Tensor::from_data(Data::from([current.delta_t.elem()]), &self.device);
.map(|weighted_item| {
let current = weighted_item.item.current();
let delta_t: Tensor<B, 1> =
Tensor::from_data(Data::from([current.delta_t.elem()]), &self.device);
let label = match current.rating {
1 => 0.0,
_ => 1.0,
};
let label = Tensor::from_data(Data::from([label.elem()]), &self.device);
(delta_t, label)
let label: Tensor<B, 1, Int> =
Tensor::from_data(Data::from([label.elem()]), &self.device);
let weight: Tensor<B, 1> =
Tensor::from_data(Data::from([weighted_item.weight.elem()]), &self.device);
(delta_t, label, weight)
})
.unzip();
.multiunzip();

let t_historys = Tensor::cat(time_histories, 0)
.transpose()
Expand All @@ -152,6 +166,7 @@ impl<B: Backend> Batcher<FSRSItem, FSRSBatch<B>> for FSRSBatcher<B> {
.to_device(&self.device); // [seq_len, batch_size]
let delta_ts = Tensor::cat(delta_ts, 0).to_device(&self.device);
let labels = Tensor::cat(labels, 0).to_device(&self.device);
let weights = Tensor::cat(weights, 0).to_device(&self.device);

// dbg!(&items[0].t_history);
// dbg!(&t_historys);
Expand All @@ -161,27 +176,28 @@ impl<B: Backend> Batcher<FSRSItem, FSRSBatch<B>> for FSRSBatcher<B> {
r_historys,
delta_ts,
labels,
weights,
}
}
}

pub(crate) struct FSRSDataset {
pub(crate) items: Vec<FSRSItem>,
pub(crate) items: Vec<WeightedFSRSItem>,
}

impl Dataset<FSRSItem> for FSRSDataset {
impl Dataset<WeightedFSRSItem> for FSRSDataset {
fn len(&self) -> usize {
self.items.len()
}

fn get(&self, index: usize) -> Option<FSRSItem> {
fn get(&self, index: usize) -> Option<WeightedFSRSItem> {
// info!("get {}", index);
self.items.get(index).cloned()
}
}

impl From<Vec<FSRSItem>> for FSRSDataset {
fn from(items: Vec<FSRSItem>) -> Self {
impl From<Vec<WeightedFSRSItem>> for FSRSDataset {
fn from(items: Vec<WeightedFSRSItem>) -> Self {
Self { items }
}
}
Expand Down Expand Up @@ -252,6 +268,33 @@ pub fn prepare_training_data(items: Vec<FSRSItem>) -> (Vec<FSRSItem>, Vec<FSRSIt
(pretrainset.clone(), [pretrainset, trainset].concat())
}

pub(crate) fn sort_items_by_review_length(
mut weighted_items: Vec<WeightedFSRSItem>,
) -> Vec<WeightedFSRSItem> {
weighted_items.sort_by_cached_key(|weighted_item| weighted_item.item.reviews.len());
weighted_items
}

pub(crate) fn constant_weighted_fsrs_items(items: Vec<FSRSItem>) -> Vec<WeightedFSRSItem> {
items
.into_iter()
.map(|item| WeightedFSRSItem { weight: 1.0, item })
.collect()
}

/// The input items should be sorted by the review timestamp.
pub(crate) fn recency_weighted_fsrs_items(items: Vec<FSRSItem>) -> Vec<WeightedFSRSItem> {
let length = items.len() as f32;
items
.into_iter()
.enumerate()
.map(|(idx, item)| WeightedFSRSItem {
weight: 0.25 + 0.75 * (idx as f32 / length).powi(3),
item,
})
.collect()
}

#[cfg(test)]
mod tests {
use super::*;
Expand All @@ -261,19 +304,21 @@ mod tests {
fn from_anki() {
use burn::data::dataloader::Dataset;

let dataset = FSRSDataset::from(anki21_sample_file_converted_to_fsrs());
let dataset = FSRSDataset::from(sort_items_by_review_length(constant_weighted_fsrs_items(
anki21_sample_file_converted_to_fsrs(),
)));
assert_eq!(
dataset.get(704).unwrap(),
dataset.get(704).unwrap().item,
FSRSItem {
reviews: vec![
FSRSReview {
rating: 3,
delta_t: 0,
rating: 4,
delta_t: 0
},
FSRSReview {
rating: 3,
delta_t: 1,
},
delta_t: 3
}
],
}
);
Expand Down Expand Up @@ -435,6 +480,10 @@ mod tests {
],
},
];
let items = items
.into_iter()
.map(|item| WeightedFSRSItem { weight: 1.0, item })
.collect();
let batch = batcher.batch(items);
assert_eq!(
batch.t_historys.to_data(),
Expand Down
Loading

0 comments on commit 29b010a

Please sign in to comment.