Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/add pub(crate) back #170

Merged
merged 1 commit into from
Mar 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/batch_shuffle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use std::{

use crate::{dataset::FSRSDataset, FSRSItem};

pub struct BatchShuffledDataset<I> {
pub(crate) struct BatchShuffledDataset<I> {
dataset: Arc<FSRSDataset>,
indices: Vec<usize>,
input: PhantomData<I>,
Expand Down
2 changes: 1 addition & 1 deletion src/cosine_annealing.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use burn::{lr_scheduler::LrScheduler, tensor::backend::Backend, LearningRate};
#[derive(Clone, Debug)]
pub struct CosineAnnealingLR {
pub(crate) struct CosineAnnealingLR {
t_max: f64,
eta_min: f64,
init_lr: LearningRate,
Expand Down
6 changes: 3 additions & 3 deletions src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ impl FSRSItem {
}
}

pub struct FSRSBatcher<B: Backend> {
pub(crate) struct FSRSBatcher<B: Backend> {
device: B::Device,
}

Expand All @@ -61,7 +61,7 @@ impl<B: Backend> FSRSBatcher<B> {
}

#[derive(Debug, Clone)]
pub struct FSRSBatch<B: Backend> {
pub(crate) struct FSRSBatch<B: Backend> {
pub t_historys: Tensor<B, 2, Float>,
pub r_historys: Tensor<B, 2, Float>,
pub delta_ts: Tensor<B, 1, Float>,
Expand Down Expand Up @@ -133,7 +133,7 @@ impl<B: Backend> Batcher<FSRSItem, FSRSBatch<B>> for FSRSBatcher<B> {
}
}

pub struct FSRSDataset {
pub(crate) struct FSRSDataset {
items: Vec<FSRSItem>,
}

Expand Down
6 changes: 3 additions & 3 deletions src/inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ use crate::model::Model;
use crate::training::BCELoss;
use crate::{FSRSError, FSRSItem};
use burn::tensor::ElementConversion;
pub const DECAY: f64 = -0.5;
pub(crate) const DECAY: f64 = -0.5;
/// (9/10) ^ (1 / DECAY) - 1
pub const FACTOR: f64 = 19f64 / 81f64;
pub const S_MIN: f32 = 0.01;
pub(crate) const FACTOR: f64 = 19f64 / 81f64;
pub(crate) const S_MIN: f32 = 0.01;
/// This is a slice for efficiency, but should always be 17 in length.
pub type Parameters = [f32];
use itertools::izip;
Expand Down
6 changes: 3 additions & 3 deletions src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ pub struct Model<B: Backend> {
pub config: ModelConfig,
}

pub trait Get<B: Backend, const N: usize> {
pub(crate) trait Get<B: Backend, const N: usize> {
fn get(&self, n: usize) -> Tensor<B, N>;
}

Expand Down Expand Up @@ -177,7 +177,7 @@ impl<B: Backend> Model<B> {
}

#[derive(Debug, Clone)]
pub struct MemoryStateTensors<B: Backend> {
pub(crate) struct MemoryStateTensors<B: Backend> {
pub stability: Tensor<B, 1>,
pub difficulty: Tensor<B, 1>,
}
Expand Down Expand Up @@ -240,7 +240,7 @@ impl<B: Backend> FSRS<B> {
}
}

pub fn parameters_to_model<B: Backend>(parameters: &Parameters) -> Model<B> {
pub(crate) fn parameters_to_model<B: Backend>(parameters: &Parameters) -> Model<B> {
let config = ModelConfig::default();
let mut model = Model::new(config);
model.w = Param::from(Tensor::from_floats(
Expand Down
2 changes: 1 addition & 1 deletion src/pre_training.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ fn loss(
logloss + l1
}

pub const INIT_S_MAX: f32 = 100.0;
pub(crate) const INIT_S_MAX: f32 = 100.0;

fn search_parameters(
mut pretrainset: HashMap<FirstRating, Vec<AverageRecall>>,
Expand Down
2 changes: 1 addition & 1 deletion src/training.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ impl MetricsRenderer for ProgressCollector {
}

#[derive(Config)]
pub struct TrainingConfig {
pub(crate) struct TrainingConfig {
pub model: ModelConfig,
pub optimizer: AdamConfig,
#[config(default = 5)]
Expand Down
4 changes: 2 additions & 2 deletions src/weight_clipper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@ use crate::{
};
use burn::tensor::{backend::Backend, Data, Tensor};

pub fn weight_clipper<B: Backend>(parameters: Tensor<B, 1>) -> Tensor<B, 1> {
pub(crate) fn weight_clipper<B: Backend>(parameters: Tensor<B, 1>) -> Tensor<B, 1> {
let val = clip_parameters(&parameters.to_data().convert().value);
Tensor::from_data(
Data::new(val, parameters.shape()).convert(),
&B::Device::default(),
)
}

pub fn clip_parameters(parameters: &Parameters) -> Vec<f32> {
pub(crate) fn clip_parameters(parameters: &Parameters) -> Vec<f32> {
// https://regex101.com/r/21mXNI/1
const CLAMPS: [(f32, f32); 17] = [
(S_MIN, INIT_S_MAX),
Expand Down
Loading