Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Run cargo fmt #11

Merged
merged 1 commit into from
Aug 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
pub mod model;
pub mod convertor;
pub mod dataset;
pub mod model;
pub mod training;
pub mod convertor;
mod weight_clipper;
mod weight_clipper;
10 changes: 4 additions & 6 deletions src/model.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use burn::{
module::{Param, Module},
tensor::{backend::Backend, Float, Tensor}, config::Config,
config::Config,
module::{Module, Param},
tensor::{backend::Backend, Float, Tensor},
};


#[derive(Module, Debug)]
pub struct Model<B: Backend> {
pub w: Param<Tensor<B, 1>>,
Expand Down Expand Up @@ -119,10 +119,8 @@ impl<B: Backend<FloatElem = f32>> Model<B> {
}
}


#[derive(Config, Debug)]
pub struct ModelConfig {
}
pub struct ModelConfig {}

impl ModelConfig {
pub fn init<B: Backend<FloatElem = f32>>(&self) -> Model<B> {
Expand Down
48 changes: 28 additions & 20 deletions src/training.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,16 @@
use crate::dataset::{FSRSBatcher, FSRSDataset, FSRSBatch};
use crate::model::{ModelConfig, Model};
use crate::dataset::{FSRSBatch, FSRSBatcher, FSRSDataset};
use crate::model::{Model, ModelConfig};
use crate::weight_clipper::weight_clipper;
use burn::module::Module;
use burn::nn::loss::CrossEntropyLoss;
use burn::optim::AdamConfig;
use burn::record::{PrettyJsonFileRecorder, FullPrecisionSettings, Recorder};
use burn::tensor::{Tensor, Int};
use burn::record::{FullPrecisionSettings, PrettyJsonFileRecorder, Recorder};
use burn::tensor::backend::Backend;
use burn::train::{TrainStep, TrainOutput, ValidStep, ClassificationOutput};
use burn::tensor::{Int, Tensor};
use burn::train::{ClassificationOutput, TrainOutput, TrainStep, ValidStep};
use burn::{
config::Config,
data::dataloader::DataLoaderBuilder,
tensor::backend::ADBackend,
train::{
// metric::{AccuracyMetric, LossMetric},
LearnerBuilder,
},
module::Param,
config::Config, data::dataloader::DataLoaderBuilder, module::Param, tensor::backend::ADBackend,
train::LearnerBuilder,
};
use log::info;

Expand All @@ -33,7 +27,8 @@ impl<B: Backend<FloatElem = f32>> Model<B> {
let (stability, _difficulty) = self.forward(t_historys, r_historys);
let retention = self.power_forgetting_curve(delta_ts.clone(), stability.clone());
// dbg!(&retention);
let logits = Tensor::cat(vec![retention.clone(), -retention.clone() + 1], 0).reshape([1, -1]);
let logits =
Tensor::cat(vec![retention.clone(), -retention.clone() + 1], 0).reshape([1, -1]);
info!("stability: {}", &stability);
info!("delta_ts: {}", &delta_ts);
info!("retention: {}", &retention);
Expand All @@ -46,15 +41,25 @@ impl<B: Backend<FloatElem = f32>> Model<B> {

impl<B: ADBackend<FloatElem = f32>> TrainStep<FSRSBatch<B>, ClassificationOutput<B>> for Model<B> {
fn step(&self, batch: FSRSBatch<B>) -> TrainOutput<ClassificationOutput<B>> {
let item = self.forward_classification(batch.t_historys, batch.r_historys, batch.delta_ts, batch.labels);
let item = self.forward_classification(
batch.t_historys,
batch.r_historys,
batch.delta_ts,
batch.labels,
);

TrainOutput::new(self, item.loss.backward(), item)
}
}

impl<B: Backend<FloatElem = f32>> ValidStep<FSRSBatch<B>, ClassificationOutput<B>> for Model<B> {
fn step(&self, batch: FSRSBatch<B>) -> ClassificationOutput<B> {
self.forward_classification(batch.t_historys, batch.r_historys, batch.delta_ts, batch.labels)
self.forward_classification(
batch.t_historys,
batch.r_historys,
batch.delta_ts,
batch.labels,
)
}
}

Expand All @@ -76,7 +81,11 @@ pub struct TrainingConfig {
pub learning_rate: f64,
}

pub fn train<B: ADBackend<FloatElem = f32>>(artifact_dir: &str, config: TrainingConfig, device: B::Device) {
pub fn train<B: ADBackend<FloatElem = f32>>(
artifact_dir: &str,
config: TrainingConfig,
device: B::Device,
) {
std::fs::create_dir_all(artifact_dir).ok();
config
.save(&format!("{artifact_dir}/config.json"))
Expand Down Expand Up @@ -121,7 +130,7 @@ pub fn train<B: ADBackend<FloatElem = f32>>(artifact_dir: &str, config: Training
.save(format!("{ARTIFACT_DIR}/config.json").as_str())
.unwrap();

PrettyJsonFileRecorder::<FullPrecisionSettings>::new()
PrettyJsonFileRecorder::<FullPrecisionSettings>::new()
.record(
model_trained.clone().into_record(),
format!("{ARTIFACT_DIR}/model").into(),
Expand All @@ -131,7 +140,6 @@ pub fn train<B: ADBackend<FloatElem = f32>>(artifact_dir: &str, config: Training
info!("trained weights: {}", &model_trained.w.val());
}


#[test]
fn test() {
use burn_ndarray::NdArrayBackend;
Expand All @@ -146,4 +154,4 @@ fn test() {
TrainingConfig::new(ModelConfig::new(), AdamConfig::new()),
device.clone(),
);
}
}
23 changes: 12 additions & 11 deletions src/weight_clipper.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use burn::tensor::{backend::Backend, Tensor, Data};

pub fn weight_clipper<B: Backend<FloatElem = f32>>(weights:Tensor<B, 1>) -> Tensor<B, 1> {
use burn::tensor::{backend::Backend, Data, Tensor};

pub fn weight_clipper<B: Backend<FloatElem = f32>>(weights: Tensor<B, 1>) -> Tensor<B, 1> {
const CLAMPS: [(f32, f32); 13] = [
(1.0, 10.0),
(0.1, 5.0),
Expand All @@ -23,7 +22,7 @@ pub fn weight_clipper<B: Backend<FloatElem = f32>>(weights:Tensor<B, 1>) -> Tens

for (i, w) in val.iter_mut().skip(4).enumerate() {
*w = w.clamp(CLAMPS[i].0.into(), CLAMPS[i].1.into());
}
}

Tensor::from_data(Data::new(val.clone(), weights.shape()))
}
Expand All @@ -33,14 +32,16 @@ fn weight_clipper_test() {
type Backend = burn_ndarray::NdArrayBackend<f32>;
//type AutodiffBackend = burn_autodiff::ADBackendDecorator<Backend>;

let tensor = Tensor::from_floats(
[0.0, -1000.0, 1000.0, 0.0, // Ignored
1000.0, -1000.0, 1.0, 0.25, -0.1]); // Clamped (1.0, 10.0),(0.1, 5.0),(0.1, 5.0),(0.0, 0.5),
let tensor = Tensor::from_floats([
0.0, -1000.0, 1000.0, 0.0, // Ignored
1000.0, -1000.0, 1.0, 0.25, -0.1,
]); // Clamped (1.0, 10.0),(0.1, 5.0),(0.1, 5.0),(0.0, 0.5),

let param: Tensor<Backend, 1> = weight_clipper(tensor);
let values = &param.to_data().value;

assert_eq!(*values, vec!
[0.0, -1000.0, 1000.0, 0.0,
10.0, 0.1, 1.0, 0.25, 0.0]);
}
assert_eq!(
*values,
vec![0.0, -1000.0, 1000.0, 0.0, 10.0, 0.1, 1.0, 0.25, 0.0]
);
}
Loading