Skip to content

Commit

Permalink
fix parameter_clipper
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock committed Oct 28, 2024
1 parent f9c128c commit ae84cc2
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 64 deletions.
75 changes: 30 additions & 45 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 8 additions & 8 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,18 @@ description = "FSRS for Rust, including Optimizer and Scheduler"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies.burn]
version = "0.13.1"
# git = "https://github.com/tracel-ai/burn.git"
# rev = "6ae3926006872a204869e84ffc303417c54b6b7f"
# path = "../burn/burn"
# version = "0.13.1"
git = "https://github.com/open-spaced-repetition/burn.git"
rev = "4c29ed3f3ce1faf15822dada2db5cc58aa8c752e"
# path = "../burn/crates/burn"
default-features = false
features = ["std", "train", "ndarray"]

[dev-dependencies.burn]
version = "0.13.1"
# git = "https://github.com/tracel-ai/burn.git"
# rev = "6ae3926006872a204869e84ffc303417c54b6b7f"
# path = "../burn/burn"
# version = "0.13.1"
git = "https://github.com/open-spaced-repetition/burn.git"
rev = "4c29ed3f3ce1faf15822dada2db5cc58aa8c752e"
# path = "../burn/crates/burn"
default-features = false
features = ["std", "train", "ndarray", "sqlite-bundled"]

Expand Down
21 changes: 15 additions & 6 deletions src/parameter_clipper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,22 @@ use crate::{
inference::{Parameters, S_MIN},
pre_training::INIT_S_MAX,
};
use burn::tensor::{backend::Backend, Data, Tensor};
use burn::{
module::Param,
tensor::{backend::Backend, Data, Tensor},
};

pub(crate) fn parameter_clipper<B: Backend>(parameters: Tensor<B, 1>) -> Tensor<B, 1> {
pub(crate) fn parameter_clipper<B: Backend>(
parameters: Param<Tensor<B, 1>>,
) -> Param<Tensor<B, 1>> {
let val = clip_parameters(&parameters.to_data().convert().value);
Tensor::from_data(
Data::new(val, parameters.shape()).convert(),
&B::Device::default(),
Param::initialized(
parameters.clone().id,
Tensor::from_data(
Data::new(val, parameters.shape()).convert(),
&B::Device::default(),
)
.require_grad(),
)
}

Expand Down Expand Up @@ -58,7 +67,7 @@ mod tests {
&device,
);

let param: Tensor<1> = parameter_clipper(tensor);
let param = parameter_clipper(Param::from_tensor(tensor));
let values = &param.to_data().value;

assert_eq!(
Expand Down
9 changes: 4 additions & 5 deletions src/training.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use burn::tensor::backend::Backend;
use burn::tensor::{Int, Tensor};
use burn::train::renderer::{MetricState, MetricsRenderer, TrainingProgress};
use burn::train::TrainingInterrupter;
use burn::{config::Config, module::Param, tensor::backend::AutodiffBackend};
use burn::{config::Config, tensor::backend::AutodiffBackend};
use core::marker::PhantomData;
use log::info;

Expand Down Expand Up @@ -369,8 +369,7 @@ fn train<B: AutodiffBackend>(
}
let grads = GradientsParams::from_grads(gradients, &model);
model = optim.step(lr, model, grads);
// TODO: bug in https://github.com/tracel-ai/burn/issues/2428
model.w = Param::from_tensor(parameter_clipper(model.w.val()));
model.w = parameter_clipper(model.w);
// info!("epoch: {:?} iteration: {:?} lr: {:?}", epoch, iteration, lr);
renderer.render_train(TrainingProgress {
progress,
Expand Down Expand Up @@ -521,7 +520,7 @@ mod tests {
let lr = 0.04;
let grads = GradientsParams::from_grads(gradients, &model);
model = optim.step(lr, model, grads);
model.w = Param::from_tensor(parameter_clipper(model.w.val()));
model.w = parameter_clipper(model.w);
assert_eq!(
model.w.val().to_data(),
Data::from([
Expand Down Expand Up @@ -592,7 +591,7 @@ mod tests {
.assert_approx_eq(&w_grad.clone().into_data(), 5);
let grads = GradientsParams::from_grads(gradients, &model);
model = optim.step(lr, model, grads);
model.w = Param::from_tensor(parameter_clipper(model.w.val()));
model.w = parameter_clipper(model.w);
assert_eq!(
model.w.val().to_data(),
Data::from([
Expand Down

0 comments on commit ae84cc2

Please sign in to comment.