Skip to content

Commit

Permalink
Clippy nursery (#112)
Browse files Browse the repository at this point in the history
* const fn && Self && less clone && more precise

* mul_add

---------

Co-authored-by: Jarrett Ye <[email protected]>
  • Loading branch information
asukaminato0721 and L-M-Sherlock authored Oct 27, 2023
1 parent a113052 commit aa797b9
Show file tree
Hide file tree
Showing 8 changed files with 46 additions and 47 deletions.
8 changes: 4 additions & 4 deletions src/batch_shuffle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ where

for (dataset, rng) in datasets.into_iter().zip(rngs) {
let strategy = strategy.new_like();
let dataloader = BatchShuffledDataLoader::new(
let dataloader = Self::new(
strategy,
Arc::new(dataset),
batcher.clone(),
Expand Down Expand Up @@ -225,7 +225,7 @@ impl<I, O> BatchShuffledDataloaderIterator<I, O> {
dataset: Arc<dyn Dataset<I>>,
batcher: Arc<dyn Batcher<I, O>>,
) -> Self {
BatchShuffledDataloaderIterator {
Self {
current_index: 0,
strategy,
dataset,
Expand Down Expand Up @@ -324,7 +324,7 @@ where
/// # Returns
///
/// The data loader builder.
pub fn shuffle(mut self, seed: u64) -> Self {
pub const fn shuffle(mut self, seed: u64) -> Self {
self.shuffle = Some(seed);
self
}
Expand All @@ -338,7 +338,7 @@ where
/// # Returns
///
/// The data loader builder.
pub fn num_workers(mut self, num_workers: usize) -> Self {
pub const fn num_workers(mut self, num_workers: usize) -> Self {
self.num_threads = Some(num_workers);
self
}
Expand Down
9 changes: 4 additions & 5 deletions src/cosine_annealing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ pub(crate) struct CosineAnnealingLR {
}

impl CosineAnnealingLR {
pub fn init(t_max: f64, init_lr: LearningRate) -> CosineAnnealingLR {
CosineAnnealingLR {
pub const fn init(t_max: f64, init_lr: LearningRate) -> Self {
Self {
t_max,
eta_min: 0.0,
init_lr,
Expand All @@ -38,9 +38,8 @@ impl LrScheduler for CosineAnnealingLR {
if (step_count - 1.0 - t_max) % (2.0 * t_max) == 0.0 {
(init_lr - eta_min) * (1.0 - f64::cos(PI / t_max)) / 2.0
} else {
(1.0 + f64::cos(cosine_arg)) / (1.0 + f64::cos(PI * (step_count - 1.0) / t_max))
* (lr - eta_min)
+ eta_min
((1.0 + f64::cos(cosine_arg)) / (1.0 + f64::cos(PI * (step_count - 1.0) / t_max)))
.mul_add(lr - eta_min, eta_min)
}
}
self.current_lr = cosine_annealing_lr(
Expand Down
6 changes: 3 additions & 3 deletions src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@ use serde::{Deserialize, Serialize};
/// first one.
/// When used during review, the last item should include the correct delta_t, but
/// the provided rating is ignored as all four ratings are returned by .next_states()
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
pub struct FSRSItem {
pub reviews: Vec<FSRSReview>,
}

#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
pub struct FSRSReview {
/// 1-4
pub rating: u32,
Expand All @@ -41,7 +41,7 @@ pub(crate) struct FSRSBatcher<B: Backend> {
}

impl<B: Backend> FSRSBatcher<B> {
pub fn new(device: B::Device) -> Self {
pub const fn new(device: B::Device) -> Self {
Self { device }
}
}
Expand Down
14 changes: 7 additions & 7 deletions src/inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ fn infer<B: Backend>(
batch: FSRSBatch<B>,
) -> (MemoryStateTensors<B>, Tensor<B, 1>) {
let state = model.forward(batch.t_historys, batch.r_historys, None);
let retention = model.power_forgetting_curve(batch.delta_ts.clone(), state.stability.clone());
let retention = model.power_forgetting_curve(batch.delta_ts, state.stability.clone());
(state, retention)
}

Expand All @@ -38,7 +38,7 @@ pub struct MemoryState {

impl<B: Backend> From<MemoryStateTensors<B>> for MemoryState {
fn from(m: MemoryStateTensors<B>) -> Self {
MemoryState {
Self {
stability: m.stability.to_data().value[0].elem(),
difficulty: m.difficulty.to_data().value[0].elem(),
}
Expand All @@ -47,7 +47,7 @@ impl<B: Backend> From<MemoryStateTensors<B>> for MemoryState {

impl<B: Backend> From<MemoryState> for MemoryStateTensors<B> {
fn from(m: MemoryState) -> Self {
MemoryStateTensors {
Self {
stability: Tensor::from_data(Data::new(vec![m.stability.elem()], Shape { dims: [1] })),
difficulty: Tensor::from_data(Data::new(
vec![m.difficulty.elem()],
Expand Down Expand Up @@ -102,7 +102,7 @@ impl<B: Backend> FSRS<B> {
let w10: f32 = w.get(10).into_scalar().elem();
let difficulty = 11.0
- (ease_factor - 1.0)
/ (w8.exp() * stability.powf(-w9) * (((1.0 - sm2_retention) * w10).exp() - 1.0));
/ (w8.exp() * stability.powf(-w9) * ((1.0 - sm2_retention) * w10).exp_m1());
MemoryState {
stability,
difficulty: difficulty.clamp(1.0, 10.0),
Expand All @@ -122,7 +122,7 @@ impl<B: Backend> FSRS<B> {
// get initial stability for new card
let rating = Tensor::from_data(Data::new(vec![rating.elem()], Shape { dims: [1] }));
let model = self.model();
model.init_stability(rating.clone()).into_scalar().elem()
model.init_stability(rating).into_scalar().elem()
});
next_interval(stability, desired_retention)
}
Expand Down Expand Up @@ -211,7 +211,7 @@ impl<B: Backend> FSRS<B> {
/// How well the user is likely to remember the item after `days_elapsed` since the previous
/// review.
pub fn current_retrievability(&self, state: MemoryState, days_elapsed: u32) -> f32 {
(days_elapsed as f32 / (state.stability * 9.0) + 1.0).powf(-1.0)
(days_elapsed as f32 / (state.stability * 9.0) + 1.0).powi(-1)
}
}

Expand Down Expand Up @@ -480,7 +480,7 @@ mod tests {
fsrs.memory_state_from_sm2(2.5, 10.0, 0.9),
MemoryState {
stability: 9.999995,
difficulty: 6.265295
difficulty: 6.2652965
}
);
assert_eq!(
Expand Down
10 changes: 5 additions & 5 deletions src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ pub(crate) trait Get<B: Backend, const N: usize> {
}

impl<B: Backend, const N: usize> Get<B, N> for Tensor<B, N> {
fn get(&self, n: usize) -> Tensor<B, N> {
fn get(&self, n: usize) -> Self {
self.clone().slice([n..(n + 1)])
}
}
Expand All @@ -32,7 +32,7 @@ trait Pow<B: Backend, const N: usize> {
}

impl<B: Backend, const N: usize> Pow<B, N> for Tensor<B, N> {
fn pow(&self, other: Tensor<B, N>) -> Tensor<B, N> {
fn pow(&self, other: Self) -> Self {
// a ^ b => exp(ln(a^b)) => exp(b ln (a))
(self.clone().log() * other).exp()
}
Expand All @@ -43,7 +43,7 @@ impl<B: Backend> Model<B> {
pub fn new(config: ModelConfig) -> Self {
let initial_params = config
.initial_stability
.unwrap_or(DEFAULT_WEIGHTS[0..4].try_into().unwrap())
.unwrap_or_else(|| DEFAULT_WEIGHTS[0..4].try_into().unwrap())
.into_iter()
.chain([
4.93, 0.94, 0.86, 0.01, // difficulty
Expand Down Expand Up @@ -106,7 +106,7 @@ impl<B: Backend> Model<B> {
}

pub(crate) fn init_stability(&self, rating: Tensor<B, 1>) -> Tensor<B, 1> {
self.w.val().select(0, rating.clone().int() - 1)
self.w.val().select(0, rating.int() - 1)
}

fn init_difficulty(&self, rating: Tensor<B, 1>) -> Tensor<B, 1> {
Expand Down Expand Up @@ -134,7 +134,7 @@ impl<B: Backend> Model<B> {
let stability_after_failure = self.stability_after_failure(
state.stability.clone(),
state.difficulty.clone(),
retention.clone(),
retention,
);
let mut new_stability = stability_after_success
.mask_where(rating.clone().equal_elem(1), stability_after_failure);
Expand Down
25 changes: 12 additions & 13 deletions src/optimal_retention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ impl ndarray::SliceNextDim for Column {

impl From<Column> for SliceInfoElem {
fn from(value: Column) -> Self {
SliceInfoElem::Index(value as isize)
Self::Index(value as isize)
}
}

Expand All @@ -59,8 +59,8 @@ pub struct SimulatorConfig {
}

impl Default for SimulatorConfig {
fn default() -> SimulatorConfig {
SimulatorConfig {
fn default() -> Self {
Self {
deck_size: 10000,
learn_span: 365,
max_cost_perday: 1800.0,
Expand All @@ -78,13 +78,12 @@ impl Default for SimulatorConfig {
fn stability_after_success(w: &[f64], s: f64, r: f64, d: f64, response: usize) -> f64 {
let hard_penalty = if response == 2 { w[15] } else { 1.0 };
let easy_bonus = if response == 4 { w[16] } else { 1.0 };
s * (1.0
+ f64::exp(w[8])
* (11.0 - d)
* s.powf(-w[9])
* (f64::exp((1.0 - r) * w[10]) - 1.0)
* hard_penalty
* easy_bonus)
s * (f64::exp(w[8])
* (11.0 - d)
* s.powf(-w[9])
* (f64::exp((1.0 - r) * w[10]) - 1.0)
* hard_penalty)
.mul_add(easy_bonus, 1.0)
}

fn stability_after_failure(w: &[f64], s: f64, r: f64, d: f64) -> f64 {
Expand Down Expand Up @@ -146,7 +145,7 @@ fn simulate(config: &SimulatorConfig, w: &[f64], request_retention: f64, seed: O
izip!(&mut retrievability, &delta_t, &old_stability, &has_learned)
.filter(|(.., &has_learned_flag)| has_learned_flag)
.for_each(|(retrievability, &delta_t, &stability, ..)| {
*retrievability = (1.0 + delta_t / (9.0 * stability)).powf(-1.0)
*retrievability = (1.0 + delta_t / (9.0 * stability)).powi(-1)
});

// Set 'cost' column to 0
Expand Down Expand Up @@ -278,7 +277,7 @@ fn simulate(config: &SimulatorConfig, w: &[f64], request_retention: f64, seed: O
izip!(&mut new_difficulty, &old_difficulty, &true_review, &forget)
.filter(|(.., &true_rev, &frgt)| true_rev && frgt)
.for_each(|(new_diff, &old_diff, ..)| {
*new_diff = (old_diff + 2.0 * w[6]).clamp(1.0, 10.0);
*new_diff = (2.0f64.mul_add(w[6], old_diff)).clamp(1.0, 10.0);
});

// Update the difficulty values based on the condition 'true_review & !forget'
Expand Down Expand Up @@ -310,7 +309,7 @@ fn simulate(config: &SimulatorConfig, w: &[f64], request_retention: f64, seed: O
.filter(|(.., &true_learn_flag)| true_learn_flag)
.for_each(|(new_stab, new_diff, &rating, _)| {
*new_stab = w[rating - 1];
*new_diff = (w[4] - w[5] * (rating as f64 - 3.0)).clamp(1.0, 10.0);
*new_diff = (w[5].mul_add(-(rating as f64 - 3.0), w[4])).clamp(1.0, 10.0);
});
let old_interval = card_table.slice(s![Column::Interval, ..]);
let mut new_interval = old_interval.to_owned();
Expand Down
13 changes: 7 additions & 6 deletions src/pre_training.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ static R_S0_DEFAULT_ARRAY: &[(u32, f32); 4] = &[(1, 0.4), (2, 0.9), (3, 2.3), (4
pub fn pretrain(fsrs_items: Vec<FSRSItem>, average_recall: f32) -> Result<[f32; 4]> {
let pretrainset = create_pretrain_data(fsrs_items);
let rating_count = total_rating_count(&pretrainset);
let rating_stability = search_parameters(pretrainset, average_recall);
smooth_and_fill(&mut rating_stability.clone(), &rating_count)
let mut rating_stability = search_parameters(pretrainset, average_recall);
smooth_and_fill(&mut rating_stability, &rating_count)
}

type FirstRating = u32;
Expand Down Expand Up @@ -130,7 +130,7 @@ fn search_parameters(
// https://github.com/open-spaced-repetition/fsrs4anki/pull/358/files#diff-35b13c8e3466e8bd1231a51c71524fc31a945a8f332290726214d3a6fa7f442aR491
let real_recall = Array1::from_iter(data.iter().map(|d| d.recall));
let n = data.iter().map(|d| d.count).sum::<f32>();
(real_recall * n + average_recall * 1.0) / (n + 1.0)
(real_recall * n + average_recall) / (n + 1.0)
};

let count = Array1::from_iter(data.iter().map(|d| d.count));
Expand Down Expand Up @@ -221,12 +221,13 @@ fn smooth_and_fill(
rating_stability.insert(1, r2.powf(1.0 / w1) * r3.powf(1.0 - 1.0 / w1));
}
(Some(&r1), None, None, Some(&r4)) => {
let r2 =
r1.powf(w1 / (w1 + w2 - w1 * w2)) * r4.powf(1.0 - w1 / (w1 + w2 - w1 * w2));
let r2 = r1.powf(w1 / (w1.mul_add(-w2, w1 + w2)))
* r4.powf(1.0 - w1 / (w1.mul_add(-w2, w1 + w2)));
rating_stability.insert(2, r2);
rating_stability.insert(
3,
r1.powf(1.0 - w2 / (w1 + w2 - w1 * w2)) * r4.powf(w2 / (w1 + w2 - w1 * w2)),
r1.powf(1.0 - w2 / (w1.mul_add(-w2, w1 + w2)))
* r4.powf(w2 / (w1.mul_add(-w2, w1 + w2))),
);
}
(Some(&r1), None, Some(&r3), None) => {
Expand Down
8 changes: 4 additions & 4 deletions src/training.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ pub struct BCELoss<B: Backend> {
}

impl<B: Backend> BCELoss<B> {
pub fn new() -> Self {
pub const fn new() -> Self {
Self {
backend: PhantomData,
}
Expand All @@ -60,7 +60,7 @@ impl<B: Backend> Model<B> {
// info!("t_historys: {}", &t_historys);
// info!("r_historys: {}", &r_historys);
let state = self.forward(t_historys, r_historys, None);
let retention = self.power_forgetting_curve(delta_ts.clone(), state.stability);
let retention = self.power_forgetting_curve(delta_ts, state.stability);
let logits =
Tensor::cat(vec![-retention.clone() + 1, retention.clone()], 0).unsqueeze::<2>();
let loss = BCELoss::new().forward(retention, labels.clone().float());
Expand Down Expand Up @@ -167,11 +167,11 @@ impl ProgressCollector {
}

impl ProgressState {
pub fn current(&self) -> usize {
pub const fn current(&self) -> usize {
self.epoch.saturating_sub(1) * self.items_total + self.items_processed
}

pub fn total(&self) -> usize {
pub const fn total(&self) -> usize {
self.epoch_total * self.items_total
}
}
Expand Down

0 comments on commit aa797b9

Please sign in to comment.