Skip to content

Commit

Permalink
Correctly track progress with n_splits > 1
Browse files Browse the repository at this point in the history
  • Loading branch information
dae committed Oct 14, 2023
1 parent 5089470 commit 64c6341
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 16 deletions.
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@ pub use inference::{
};
pub use model::FSRS;
pub use optimal_retention::SimulatorConfig;
pub use training::ProgressState;
pub use training::CombinedProgressState;
52 changes: 37 additions & 15 deletions src/training.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,35 +113,53 @@ impl<B: Backend> ValidStep<FSRSBatch<B>, ClassificationOutput<B>> for Model<B> {
}
}

#[derive(Debug, Default)]
#[derive(Debug, Default, Clone)]
pub struct ProgressState {
pub epoch: usize,
pub epoch_total: usize,
pub items_processed: usize,
pub items_total: usize,
}

#[derive(Default)]
pub struct CombinedProgressState {
pub want_abort: bool,
pub splits: Vec<ProgressState>,
}

impl CombinedProgressState {
pub fn new_shared() -> Arc<Mutex<Self>> {
Default::default()
}

pub fn current(&self) -> usize {
self.splits.iter().map(|s| s.current()).sum()
}

pub fn total(&self) -> usize {
self.splits.iter().map(|s| s.total()).sum()
}
}

#[derive(Clone, Default)]
#[derive(Clone)]
pub struct ProgressCollector {
pub state: Arc<Mutex<ProgressState>>,
pub state: Arc<Mutex<CombinedProgressState>>,
pub interrupter: TrainingInterrupter,
/// The index of the split we should update.
pub index: usize,
}

impl ProgressCollector {
pub fn new(state: Arc<Mutex<ProgressState>>) -> Self {
pub fn new(state: Arc<Mutex<CombinedProgressState>>, index: usize) -> Self {
Self {
state,
..Default::default()
interrupter: Default::default(),
index,
}
}
}

impl ProgressState {
pub fn new_shared() -> Arc<Mutex<Self>> {
Default::default()
}

pub fn current(&self) -> usize {
self.epoch.saturating_sub(1) * self.items_total + self.items_processed
}
Expand All @@ -158,10 +176,11 @@ impl DashboardRenderer for ProgressCollector {

fn render_train(&mut self, item: TrainingProgress) {
let mut info = self.state.lock().unwrap();
info.epoch = item.epoch;
info.epoch_total = item.epoch_total;
info.items_processed = item.progress.items_processed;
info.items_total = item.progress.items_total;
let split = &mut info.splits[self.index];
split.epoch = item.epoch;
split.epoch_total = item.epoch_total;
split.items_processed = item.progress.items_processed;
split.items_total = item.progress.items_total;
if info.want_abort {
self.interrupter.stop();
}
Expand Down Expand Up @@ -206,9 +225,12 @@ impl<B: Backend> FSRS<B> {
pub fn compute_weights(
&self,
items: Vec<FSRSItem>,
progress: Option<Arc<Mutex<ProgressState>>>,
mut progress: Option<Arc<Mutex<CombinedProgressState>>>,
) -> Result<Vec<f32>> {
let n_splits = 5;
if let Some(progress) = &mut progress {
progress.lock().unwrap().splits = vec![ProgressState::default(); n_splits];
}
let average_recall = calculate_average_recall(&items);
let (pre_trainset, trainsets) = split_data(items, n_splits);
let initial_stability = pretrain(pre_trainset, average_recall)?;
Expand All @@ -234,7 +256,7 @@ impl<B: Backend> FSRS<B> {
trainset,
&config,
self.device(),
progress.clone().map(ProgressCollector::new),
progress.clone().map(|p| ProgressCollector::new(p, i)),
);
model.unwrap().w.val().to_data().convert().value
})
Expand Down

0 comments on commit 64c6341

Please sign in to comment.