From aa797b950375d5f5be1cddacd5f9b85416fd8e0f Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Fri, 27 Oct 2023 21:56:40 +0900 Subject: [PATCH] Clippy nursery (#112) * const fn && Self && less clone && more precise * mul_add --------- Co-authored-by: Jarrett Ye --- src/batch_shuffle.rs | 8 ++++---- src/cosine_annealing.rs | 9 ++++----- src/dataset.rs | 6 +++--- src/inference.rs | 14 +++++++------- src/model.rs | 10 +++++----- src/optimal_retention.rs | 25 ++++++++++++------------- src/pre_training.rs | 13 +++++++------ src/training.rs | 8 ++++---- 8 files changed, 46 insertions(+), 47 deletions(-) diff --git a/src/batch_shuffle.rs b/src/batch_shuffle.rs index d59a1e13..d9502c1e 100644 --- a/src/batch_shuffle.rs +++ b/src/batch_shuffle.rs @@ -167,7 +167,7 @@ where for (dataset, rng) in datasets.into_iter().zip(rngs) { let strategy = strategy.new_like(); - let dataloader = BatchShuffledDataLoader::new( + let dataloader = Self::new( strategy, Arc::new(dataset), batcher.clone(), @@ -225,7 +225,7 @@ impl BatchShuffledDataloaderIterator { dataset: Arc>, batcher: Arc>, ) -> Self { - BatchShuffledDataloaderIterator { + Self { current_index: 0, strategy, dataset, @@ -324,7 +324,7 @@ where /// # Returns /// /// The data loader builder. - pub fn shuffle(mut self, seed: u64) -> Self { + pub const fn shuffle(mut self, seed: u64) -> Self { self.shuffle = Some(seed); self } @@ -338,7 +338,7 @@ where /// # Returns /// /// The data loader builder. - pub fn num_workers(mut self, num_workers: usize) -> Self { + pub const fn num_workers(mut self, num_workers: usize) -> Self { self.num_threads = Some(num_workers); self } diff --git a/src/cosine_annealing.rs b/src/cosine_annealing.rs index 33a16ffa..59140dd2 100644 --- a/src/cosine_annealing.rs +++ b/src/cosine_annealing.rs @@ -10,8 +10,8 @@ pub(crate) struct CosineAnnealingLR { } impl CosineAnnealingLR { - pub fn init(t_max: f64, init_lr: LearningRate) -> CosineAnnealingLR { - CosineAnnealingLR { + pub const fn init(t_max: f64, init_lr: LearningRate) -> Self { + Self { t_max, eta_min: 0.0, init_lr, @@ -38,9 +38,8 @@ impl LrScheduler for CosineAnnealingLR { if (step_count - 1.0 - t_max) % (2.0 * t_max) == 0.0 { (init_lr - eta_min) * (1.0 - f64::cos(PI / t_max)) / 2.0 } else { - (1.0 + f64::cos(cosine_arg)) / (1.0 + f64::cos(PI * (step_count - 1.0) / t_max)) - * (lr - eta_min) - + eta_min + ((1.0 + f64::cos(cosine_arg)) / (1.0 + f64::cos(PI * (step_count - 1.0) / t_max))) + .mul_add(lr - eta_min, eta_min) } } self.current_lr = cosine_annealing_lr( diff --git a/src/dataset.rs b/src/dataset.rs index 0d2facad..715681d8 100644 --- a/src/dataset.rs +++ b/src/dataset.rs @@ -12,12 +12,12 @@ use serde::{Deserialize, Serialize}; /// first one. /// When used during review, the last item should include the correct delta_t, but /// the provided rating is ignored as all four ratings are returned by .next_states() -#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)] +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)] pub struct FSRSItem { pub reviews: Vec, } -#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)] +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)] pub struct FSRSReview { /// 1-4 pub rating: u32, @@ -41,7 +41,7 @@ pub(crate) struct FSRSBatcher { } impl FSRSBatcher { - pub fn new(device: B::Device) -> Self { + pub const fn new(device: B::Device) -> Self { Self { device } } } diff --git a/src/inference.rs b/src/inference.rs index 3e475377..e548d73f 100644 --- a/src/inference.rs +++ b/src/inference.rs @@ -26,7 +26,7 @@ fn infer( batch: FSRSBatch, ) -> (MemoryStateTensors, Tensor) { let state = model.forward(batch.t_historys, batch.r_historys, None); - let retention = model.power_forgetting_curve(batch.delta_ts.clone(), state.stability.clone()); + let retention = model.power_forgetting_curve(batch.delta_ts, state.stability.clone()); (state, retention) } @@ -38,7 +38,7 @@ pub struct MemoryState { impl From> for MemoryState { fn from(m: MemoryStateTensors) -> Self { - MemoryState { + Self { stability: m.stability.to_data().value[0].elem(), difficulty: m.difficulty.to_data().value[0].elem(), } @@ -47,7 +47,7 @@ impl From> for MemoryState { impl From for MemoryStateTensors { fn from(m: MemoryState) -> Self { - MemoryStateTensors { + Self { stability: Tensor::from_data(Data::new(vec![m.stability.elem()], Shape { dims: [1] })), difficulty: Tensor::from_data(Data::new( vec![m.difficulty.elem()], @@ -102,7 +102,7 @@ impl FSRS { let w10: f32 = w.get(10).into_scalar().elem(); let difficulty = 11.0 - (ease_factor - 1.0) - / (w8.exp() * stability.powf(-w9) * (((1.0 - sm2_retention) * w10).exp() - 1.0)); + / (w8.exp() * stability.powf(-w9) * ((1.0 - sm2_retention) * w10).exp_m1()); MemoryState { stability, difficulty: difficulty.clamp(1.0, 10.0), @@ -122,7 +122,7 @@ impl FSRS { // get initial stability for new card let rating = Tensor::from_data(Data::new(vec![rating.elem()], Shape { dims: [1] })); let model = self.model(); - model.init_stability(rating.clone()).into_scalar().elem() + model.init_stability(rating).into_scalar().elem() }); next_interval(stability, desired_retention) } @@ -211,7 +211,7 @@ impl FSRS { /// How well the user is likely to remember the item after `days_elapsed` since the previous /// review. pub fn current_retrievability(&self, state: MemoryState, days_elapsed: u32) -> f32 { - (days_elapsed as f32 / (state.stability * 9.0) + 1.0).powf(-1.0) + (days_elapsed as f32 / (state.stability * 9.0) + 1.0).powi(-1) } } @@ -480,7 +480,7 @@ mod tests { fsrs.memory_state_from_sm2(2.5, 10.0, 0.9), MemoryState { stability: 9.999995, - difficulty: 6.265295 + difficulty: 6.2652965 } ); assert_eq!( diff --git a/src/model.rs b/src/model.rs index 8674ac5e..78a0a7ec 100644 --- a/src/model.rs +++ b/src/model.rs @@ -21,7 +21,7 @@ pub(crate) trait Get { } impl Get for Tensor { - fn get(&self, n: usize) -> Tensor { + fn get(&self, n: usize) -> Self { self.clone().slice([n..(n + 1)]) } } @@ -32,7 +32,7 @@ trait Pow { } impl Pow for Tensor { - fn pow(&self, other: Tensor) -> Tensor { + fn pow(&self, other: Self) -> Self { // a ^ b => exp(ln(a^b)) => exp(b ln (a)) (self.clone().log() * other).exp() } @@ -43,7 +43,7 @@ impl Model { pub fn new(config: ModelConfig) -> Self { let initial_params = config .initial_stability - .unwrap_or(DEFAULT_WEIGHTS[0..4].try_into().unwrap()) + .unwrap_or_else(|| DEFAULT_WEIGHTS[0..4].try_into().unwrap()) .into_iter() .chain([ 4.93, 0.94, 0.86, 0.01, // difficulty @@ -106,7 +106,7 @@ impl Model { } pub(crate) fn init_stability(&self, rating: Tensor) -> Tensor { - self.w.val().select(0, rating.clone().int() - 1) + self.w.val().select(0, rating.int() - 1) } fn init_difficulty(&self, rating: Tensor) -> Tensor { @@ -134,7 +134,7 @@ impl Model { let stability_after_failure = self.stability_after_failure( state.stability.clone(), state.difficulty.clone(), - retention.clone(), + retention, ); let mut new_stability = stability_after_success .mask_where(rating.clone().equal_elem(1), stability_after_failure); diff --git a/src/optimal_retention.rs b/src/optimal_retention.rs index 1df7ee17..57c4836b 100644 --- a/src/optimal_retention.rs +++ b/src/optimal_retention.rs @@ -40,7 +40,7 @@ impl ndarray::SliceNextDim for Column { impl From for SliceInfoElem { fn from(value: Column) -> Self { - SliceInfoElem::Index(value as isize) + Self::Index(value as isize) } } @@ -59,8 +59,8 @@ pub struct SimulatorConfig { } impl Default for SimulatorConfig { - fn default() -> SimulatorConfig { - SimulatorConfig { + fn default() -> Self { + Self { deck_size: 10000, learn_span: 365, max_cost_perday: 1800.0, @@ -78,13 +78,12 @@ impl Default for SimulatorConfig { fn stability_after_success(w: &[f64], s: f64, r: f64, d: f64, response: usize) -> f64 { let hard_penalty = if response == 2 { w[15] } else { 1.0 }; let easy_bonus = if response == 4 { w[16] } else { 1.0 }; - s * (1.0 - + f64::exp(w[8]) - * (11.0 - d) - * s.powf(-w[9]) - * (f64::exp((1.0 - r) * w[10]) - 1.0) - * hard_penalty - * easy_bonus) + s * (f64::exp(w[8]) + * (11.0 - d) + * s.powf(-w[9]) + * (f64::exp((1.0 - r) * w[10]) - 1.0) + * hard_penalty) + .mul_add(easy_bonus, 1.0) } fn stability_after_failure(w: &[f64], s: f64, r: f64, d: f64) -> f64 { @@ -146,7 +145,7 @@ fn simulate(config: &SimulatorConfig, w: &[f64], request_retention: f64, seed: O izip!(&mut retrievability, &delta_t, &old_stability, &has_learned) .filter(|(.., &has_learned_flag)| has_learned_flag) .for_each(|(retrievability, &delta_t, &stability, ..)| { - *retrievability = (1.0 + delta_t / (9.0 * stability)).powf(-1.0) + *retrievability = (1.0 + delta_t / (9.0 * stability)).powi(-1) }); // Set 'cost' column to 0 @@ -278,7 +277,7 @@ fn simulate(config: &SimulatorConfig, w: &[f64], request_retention: f64, seed: O izip!(&mut new_difficulty, &old_difficulty, &true_review, &forget) .filter(|(.., &true_rev, &frgt)| true_rev && frgt) .for_each(|(new_diff, &old_diff, ..)| { - *new_diff = (old_diff + 2.0 * w[6]).clamp(1.0, 10.0); + *new_diff = (2.0f64.mul_add(w[6], old_diff)).clamp(1.0, 10.0); }); // Update the difficulty values based on the condition 'true_review & !forget' @@ -310,7 +309,7 @@ fn simulate(config: &SimulatorConfig, w: &[f64], request_retention: f64, seed: O .filter(|(.., &true_learn_flag)| true_learn_flag) .for_each(|(new_stab, new_diff, &rating, _)| { *new_stab = w[rating - 1]; - *new_diff = (w[4] - w[5] * (rating as f64 - 3.0)).clamp(1.0, 10.0); + *new_diff = (w[5].mul_add(-(rating as f64 - 3.0), w[4])).clamp(1.0, 10.0); }); let old_interval = card_table.slice(s![Column::Interval, ..]); let mut new_interval = old_interval.to_owned(); diff --git a/src/pre_training.rs b/src/pre_training.rs index fea6e72b..04f8a8cc 100644 --- a/src/pre_training.rs +++ b/src/pre_training.rs @@ -9,8 +9,8 @@ static R_S0_DEFAULT_ARRAY: &[(u32, f32); 4] = &[(1, 0.4), (2, 0.9), (3, 2.3), (4 pub fn pretrain(fsrs_items: Vec, average_recall: f32) -> Result<[f32; 4]> { let pretrainset = create_pretrain_data(fsrs_items); let rating_count = total_rating_count(&pretrainset); - let rating_stability = search_parameters(pretrainset, average_recall); - smooth_and_fill(&mut rating_stability.clone(), &rating_count) + let mut rating_stability = search_parameters(pretrainset, average_recall); + smooth_and_fill(&mut rating_stability, &rating_count) } type FirstRating = u32; @@ -130,7 +130,7 @@ fn search_parameters( // https://github.com/open-spaced-repetition/fsrs4anki/pull/358/files#diff-35b13c8e3466e8bd1231a51c71524fc31a945a8f332290726214d3a6fa7f442aR491 let real_recall = Array1::from_iter(data.iter().map(|d| d.recall)); let n = data.iter().map(|d| d.count).sum::(); - (real_recall * n + average_recall * 1.0) / (n + 1.0) + (real_recall * n + average_recall) / (n + 1.0) }; let count = Array1::from_iter(data.iter().map(|d| d.count)); @@ -221,12 +221,13 @@ fn smooth_and_fill( rating_stability.insert(1, r2.powf(1.0 / w1) * r3.powf(1.0 - 1.0 / w1)); } (Some(&r1), None, None, Some(&r4)) => { - let r2 = - r1.powf(w1 / (w1 + w2 - w1 * w2)) * r4.powf(1.0 - w1 / (w1 + w2 - w1 * w2)); + let r2 = r1.powf(w1 / (w1.mul_add(-w2, w1 + w2))) + * r4.powf(1.0 - w1 / (w1.mul_add(-w2, w1 + w2))); rating_stability.insert(2, r2); rating_stability.insert( 3, - r1.powf(1.0 - w2 / (w1 + w2 - w1 * w2)) * r4.powf(w2 / (w1 + w2 - w1 * w2)), + r1.powf(1.0 - w2 / (w1.mul_add(-w2, w1 + w2))) + * r4.powf(w2 / (w1.mul_add(-w2, w1 + w2))), ); } (Some(&r1), None, Some(&r3), None) => { diff --git a/src/training.rs b/src/training.rs index 945d4b42..dc15513b 100644 --- a/src/training.rs +++ b/src/training.rs @@ -36,7 +36,7 @@ pub struct BCELoss { } impl BCELoss { - pub fn new() -> Self { + pub const fn new() -> Self { Self { backend: PhantomData, } @@ -60,7 +60,7 @@ impl Model { // info!("t_historys: {}", &t_historys); // info!("r_historys: {}", &r_historys); let state = self.forward(t_historys, r_historys, None); - let retention = self.power_forgetting_curve(delta_ts.clone(), state.stability); + let retention = self.power_forgetting_curve(delta_ts, state.stability); let logits = Tensor::cat(vec![-retention.clone() + 1, retention.clone()], 0).unsqueeze::<2>(); let loss = BCELoss::new().forward(retention, labels.clone().float()); @@ -167,11 +167,11 @@ impl ProgressCollector { } impl ProgressState { - pub fn current(&self) -> usize { + pub const fn current(&self) -> usize { self.epoch.saturating_sub(1) * self.items_total + self.items_processed } - pub fn total(&self) -> usize { + pub const fn total(&self) -> usize { self.epoch_total * self.items_total } }