From 55a17c8a48314f78c3dd25cae5e693d917afbe49 Mon Sep 17 00:00:00 2001 From: Damien Elmes Date: Mon, 21 Aug 2023 14:21:20 +1000 Subject: [PATCH] Run cargo fmt --- src/lib.rs | 6 +++--- src/model.rs | 10 ++++----- src/training.rs | 48 +++++++++++++++++++++++++------------------ src/weight_clipper.rs | 23 +++++++++++---------- 4 files changed, 47 insertions(+), 40 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index db6dd508..9b4d539b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,5 @@ -pub mod model; +pub mod convertor; pub mod dataset; +pub mod model; pub mod training; -pub mod convertor; -mod weight_clipper; \ No newline at end of file +mod weight_clipper; diff --git a/src/model.rs b/src/model.rs index 95e73d1d..cf09454d 100644 --- a/src/model.rs +++ b/src/model.rs @@ -1,9 +1,9 @@ use burn::{ - module::{Param, Module}, - tensor::{backend::Backend, Float, Tensor}, config::Config, + config::Config, + module::{Module, Param}, + tensor::{backend::Backend, Float, Tensor}, }; - #[derive(Module, Debug)] pub struct Model { pub w: Param>, @@ -119,10 +119,8 @@ impl> Model { } } - #[derive(Config, Debug)] -pub struct ModelConfig { -} +pub struct ModelConfig {} impl ModelConfig { pub fn init>(&self) -> Model { diff --git a/src/training.rs b/src/training.rs index c57f03a5..2fb21203 100644 --- a/src/training.rs +++ b/src/training.rs @@ -1,22 +1,16 @@ -use crate::dataset::{FSRSBatcher, FSRSDataset, FSRSBatch}; -use crate::model::{ModelConfig, Model}; +use crate::dataset::{FSRSBatch, FSRSBatcher, FSRSDataset}; +use crate::model::{Model, ModelConfig}; use crate::weight_clipper::weight_clipper; use burn::module::Module; use burn::nn::loss::CrossEntropyLoss; use burn::optim::AdamConfig; -use burn::record::{PrettyJsonFileRecorder, FullPrecisionSettings, Recorder}; -use burn::tensor::{Tensor, Int}; +use burn::record::{FullPrecisionSettings, PrettyJsonFileRecorder, Recorder}; use burn::tensor::backend::Backend; -use burn::train::{TrainStep, TrainOutput, ValidStep, ClassificationOutput}; +use burn::tensor::{Int, Tensor}; +use burn::train::{ClassificationOutput, TrainOutput, TrainStep, ValidStep}; use burn::{ - config::Config, - data::dataloader::DataLoaderBuilder, - tensor::backend::ADBackend, - train::{ - // metric::{AccuracyMetric, LossMetric}, - LearnerBuilder, - }, - module::Param, + config::Config, data::dataloader::DataLoaderBuilder, module::Param, tensor::backend::ADBackend, + train::LearnerBuilder, }; use log::info; @@ -33,7 +27,8 @@ impl> Model { let (stability, _difficulty) = self.forward(t_historys, r_historys); let retention = self.power_forgetting_curve(delta_ts.clone(), stability.clone()); // dbg!(&retention); - let logits = Tensor::cat(vec![retention.clone(), -retention.clone() + 1], 0).reshape([1, -1]); + let logits = + Tensor::cat(vec![retention.clone(), -retention.clone() + 1], 0).reshape([1, -1]); info!("stability: {}", &stability); info!("delta_ts: {}", &delta_ts); info!("retention: {}", &retention); @@ -46,7 +41,12 @@ impl> Model { impl> TrainStep, ClassificationOutput> for Model { fn step(&self, batch: FSRSBatch) -> TrainOutput> { - let item = self.forward_classification(batch.t_historys, batch.r_historys, batch.delta_ts, batch.labels); + let item = self.forward_classification( + batch.t_historys, + batch.r_historys, + batch.delta_ts, + batch.labels, + ); TrainOutput::new(self, item.loss.backward(), item) } @@ -54,7 +54,12 @@ impl> TrainStep, ClassificationOutput impl> ValidStep, ClassificationOutput> for Model { fn step(&self, batch: FSRSBatch) -> ClassificationOutput { - self.forward_classification(batch.t_historys, batch.r_historys, batch.delta_ts, batch.labels) + self.forward_classification( + batch.t_historys, + batch.r_historys, + batch.delta_ts, + batch.labels, + ) } } @@ -76,7 +81,11 @@ pub struct TrainingConfig { pub learning_rate: f64, } -pub fn train>(artifact_dir: &str, config: TrainingConfig, device: B::Device) { +pub fn train>( + artifact_dir: &str, + config: TrainingConfig, + device: B::Device, +) { std::fs::create_dir_all(artifact_dir).ok(); config .save(&format!("{artifact_dir}/config.json")) @@ -121,7 +130,7 @@ pub fn train>(artifact_dir: &str, config: Training .save(format!("{ARTIFACT_DIR}/config.json").as_str()) .unwrap(); - PrettyJsonFileRecorder::::new() + PrettyJsonFileRecorder::::new() .record( model_trained.clone().into_record(), format!("{ARTIFACT_DIR}/model").into(), @@ -131,7 +140,6 @@ pub fn train>(artifact_dir: &str, config: Training info!("trained weights: {}", &model_trained.w.val()); } - #[test] fn test() { use burn_ndarray::NdArrayBackend; @@ -146,4 +154,4 @@ fn test() { TrainingConfig::new(ModelConfig::new(), AdamConfig::new()), device.clone(), ); -} \ No newline at end of file +} diff --git a/src/weight_clipper.rs b/src/weight_clipper.rs index b9ce9430..1eac7e90 100644 --- a/src/weight_clipper.rs +++ b/src/weight_clipper.rs @@ -1,7 +1,6 @@ -use burn::tensor::{backend::Backend, Tensor, Data}; - -pub fn weight_clipper>(weights:Tensor) -> Tensor { +use burn::tensor::{backend::Backend, Data, Tensor}; +pub fn weight_clipper>(weights: Tensor) -> Tensor { const CLAMPS: [(f32, f32); 13] = [ (1.0, 10.0), (0.1, 5.0), @@ -23,7 +22,7 @@ pub fn weight_clipper>(weights:Tensor) -> Tens for (i, w) in val.iter_mut().skip(4).enumerate() { *w = w.clamp(CLAMPS[i].0.into(), CLAMPS[i].1.into()); - } + } Tensor::from_data(Data::new(val.clone(), weights.shape())) } @@ -33,14 +32,16 @@ fn weight_clipper_test() { type Backend = burn_ndarray::NdArrayBackend; //type AutodiffBackend = burn_autodiff::ADBackendDecorator; - let tensor = Tensor::from_floats( - [0.0, -1000.0, 1000.0, 0.0, // Ignored - 1000.0, -1000.0, 1.0, 0.25, -0.1]); // Clamped (1.0, 10.0),(0.1, 5.0),(0.1, 5.0),(0.0, 0.5), + let tensor = Tensor::from_floats([ + 0.0, -1000.0, 1000.0, 0.0, // Ignored + 1000.0, -1000.0, 1.0, 0.25, -0.1, + ]); // Clamped (1.0, 10.0),(0.1, 5.0),(0.1, 5.0),(0.0, 0.5), let param: Tensor = weight_clipper(tensor); let values = ¶m.to_data().value; - assert_eq!(*values, vec! - [0.0, -1000.0, 1000.0, 0.0, - 10.0, 0.1, 1.0, 0.25, 0.0]); -} \ No newline at end of file + assert_eq!( + *values, + vec![0.0, -1000.0, 1000.0, 0.0, 10.0, 0.1, 1.0, 0.25, 0.0] + ); +}