Skip to content

Commit

Permalink
Run cargo fmt (#11)
Browse files Browse the repository at this point in the history
  • Loading branch information
dae authored Aug 21, 2023
1 parent 14f7b07 commit 306e924
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 40 deletions.
6 changes: 3 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
pub mod model;
pub mod convertor;
pub mod dataset;
pub mod model;
pub mod training;
pub mod convertor;
mod weight_clipper;
mod weight_clipper;
10 changes: 4 additions & 6 deletions src/model.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use burn::{
module::{Param, Module},
tensor::{backend::Backend, Float, Tensor}, config::Config,
config::Config,
module::{Module, Param},
tensor::{backend::Backend, Float, Tensor},
};


#[derive(Module, Debug)]
pub struct Model<B: Backend> {
pub w: Param<Tensor<B, 1>>,
Expand Down Expand Up @@ -119,10 +119,8 @@ impl<B: Backend<FloatElem = f32>> Model<B> {
}
}


#[derive(Config, Debug)]
pub struct ModelConfig {
}
pub struct ModelConfig {}

impl ModelConfig {
pub fn init<B: Backend<FloatElem = f32>>(&self) -> Model<B> {
Expand Down
48 changes: 28 additions & 20 deletions src/training.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,16 @@
use crate::dataset::{FSRSBatcher, FSRSDataset, FSRSBatch};
use crate::model::{ModelConfig, Model};
use crate::dataset::{FSRSBatch, FSRSBatcher, FSRSDataset};
use crate::model::{Model, ModelConfig};
use crate::weight_clipper::weight_clipper;
use burn::module::Module;
use burn::nn::loss::CrossEntropyLoss;
use burn::optim::AdamConfig;
use burn::record::{PrettyJsonFileRecorder, FullPrecisionSettings, Recorder};
use burn::tensor::{Tensor, Int};
use burn::record::{FullPrecisionSettings, PrettyJsonFileRecorder, Recorder};
use burn::tensor::backend::Backend;
use burn::train::{TrainStep, TrainOutput, ValidStep, ClassificationOutput};
use burn::tensor::{Int, Tensor};
use burn::train::{ClassificationOutput, TrainOutput, TrainStep, ValidStep};
use burn::{
config::Config,
data::dataloader::DataLoaderBuilder,
tensor::backend::ADBackend,
train::{
// metric::{AccuracyMetric, LossMetric},
LearnerBuilder,
},
module::Param,
config::Config, data::dataloader::DataLoaderBuilder, module::Param, tensor::backend::ADBackend,
train::LearnerBuilder,
};
use log::info;

Expand All @@ -33,7 +27,8 @@ impl<B: Backend<FloatElem = f32>> Model<B> {
let (stability, _difficulty) = self.forward(t_historys, r_historys);
let retention = self.power_forgetting_curve(delta_ts.clone(), stability.clone());
// dbg!(&retention);
let logits = Tensor::cat(vec![retention.clone(), -retention.clone() + 1], 0).reshape([1, -1]);
let logits =
Tensor::cat(vec![retention.clone(), -retention.clone() + 1], 0).reshape([1, -1]);
info!("stability: {}", &stability);
info!("delta_ts: {}", &delta_ts);
info!("retention: {}", &retention);
Expand All @@ -46,15 +41,25 @@ impl<B: Backend<FloatElem = f32>> Model<B> {

impl<B: ADBackend<FloatElem = f32>> TrainStep<FSRSBatch<B>, ClassificationOutput<B>> for Model<B> {
fn step(&self, batch: FSRSBatch<B>) -> TrainOutput<ClassificationOutput<B>> {
let item = self.forward_classification(batch.t_historys, batch.r_historys, batch.delta_ts, batch.labels);
let item = self.forward_classification(
batch.t_historys,
batch.r_historys,
batch.delta_ts,
batch.labels,
);

TrainOutput::new(self, item.loss.backward(), item)
}
}

impl<B: Backend<FloatElem = f32>> ValidStep<FSRSBatch<B>, ClassificationOutput<B>> for Model<B> {
fn step(&self, batch: FSRSBatch<B>) -> ClassificationOutput<B> {
self.forward_classification(batch.t_historys, batch.r_historys, batch.delta_ts, batch.labels)
self.forward_classification(
batch.t_historys,
batch.r_historys,
batch.delta_ts,
batch.labels,
)
}
}

Expand All @@ -76,7 +81,11 @@ pub struct TrainingConfig {
pub learning_rate: f64,
}

pub fn train<B: ADBackend<FloatElem = f32>>(artifact_dir: &str, config: TrainingConfig, device: B::Device) {
pub fn train<B: ADBackend<FloatElem = f32>>(
artifact_dir: &str,
config: TrainingConfig,
device: B::Device,
) {
std::fs::create_dir_all(artifact_dir).ok();
config
.save(&format!("{artifact_dir}/config.json"))
Expand Down Expand Up @@ -121,7 +130,7 @@ pub fn train<B: ADBackend<FloatElem = f32>>(artifact_dir: &str, config: Training
.save(format!("{ARTIFACT_DIR}/config.json").as_str())
.unwrap();

PrettyJsonFileRecorder::<FullPrecisionSettings>::new()
PrettyJsonFileRecorder::<FullPrecisionSettings>::new()
.record(
model_trained.clone().into_record(),
format!("{ARTIFACT_DIR}/model").into(),
Expand All @@ -131,7 +140,6 @@ pub fn train<B: ADBackend<FloatElem = f32>>(artifact_dir: &str, config: Training
info!("trained weights: {}", &model_trained.w.val());
}


#[test]
fn test() {
use burn_ndarray::NdArrayBackend;
Expand All @@ -146,4 +154,4 @@ fn test() {
TrainingConfig::new(ModelConfig::new(), AdamConfig::new()),
device.clone(),
);
}
}
23 changes: 12 additions & 11 deletions src/weight_clipper.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use burn::tensor::{backend::Backend, Tensor, Data};

pub fn weight_clipper<B: Backend<FloatElem = f32>>(weights:Tensor<B, 1>) -> Tensor<B, 1> {
use burn::tensor::{backend::Backend, Data, Tensor};

pub fn weight_clipper<B: Backend<FloatElem = f32>>(weights: Tensor<B, 1>) -> Tensor<B, 1> {
const CLAMPS: [(f32, f32); 13] = [
(1.0, 10.0),
(0.1, 5.0),
Expand All @@ -23,7 +22,7 @@ pub fn weight_clipper<B: Backend<FloatElem = f32>>(weights:Tensor<B, 1>) -> Tens

for (i, w) in val.iter_mut().skip(4).enumerate() {
*w = w.clamp(CLAMPS[i].0.into(), CLAMPS[i].1.into());
}
}

Tensor::from_data(Data::new(val.clone(), weights.shape()))
}
Expand All @@ -33,14 +32,16 @@ fn weight_clipper_test() {
type Backend = burn_ndarray::NdArrayBackend<f32>;
//type AutodiffBackend = burn_autodiff::ADBackendDecorator<Backend>;

let tensor = Tensor::from_floats(
[0.0, -1000.0, 1000.0, 0.0, // Ignored
1000.0, -1000.0, 1.0, 0.25, -0.1]); // Clamped (1.0, 10.0),(0.1, 5.0),(0.1, 5.0),(0.0, 0.5),
let tensor = Tensor::from_floats([
0.0, -1000.0, 1000.0, 0.0, // Ignored
1000.0, -1000.0, 1.0, 0.25, -0.1,
]); // Clamped (1.0, 10.0),(0.1, 5.0),(0.1, 5.0),(0.0, 0.5),

let param: Tensor<Backend, 1> = weight_clipper(tensor);
let values = &param.to_data().value;

assert_eq!(*values, vec!
[0.0, -1000.0, 1000.0, 0.0,
10.0, 0.1, 1.0, 0.25, 0.0]);
}
assert_eq!(
*values,
vec![0.0, -1000.0, 1000.0, 0.0, 10.0, 0.1, 1.0, 0.25, 0.0]
);
}

0 comments on commit 306e924

Please sign in to comment.