Skip to content

Commit

Permalink
use path join (#30)
Browse files Browse the repository at this point in the history
  • Loading branch information
asukaminato0721 authored Aug 27, 2023
1 parent b7f13e9 commit ed5b7b4
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 6 deletions.
2 changes: 1 addition & 1 deletion src/convertor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ fn read_collection() -> Result<Vec<RevlogEntry>> {
}

fn group_by_cid(revlogs: Vec<RevlogEntry>) -> Vec<Vec<RevlogEntry>> {
let mut grouped: HashMap<i64, Vec<RevlogEntry>> = HashMap::new();
let mut grouped = HashMap::new();
for revlog in revlogs {
grouped
.entry(revlog.cid)
Expand Down
4 changes: 2 additions & 2 deletions src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ impl<B: Backend> Batcher<FSRSItem, FSRSBatch<B>> for FSRSBatcher<B> {
let (time_histories, rating_histories) = items
.iter()
.map(|item| {
let (mut delta_t, mut rating): (Vec<i32>, Vec<i32>) =
let (mut delta_t, mut rating): (Vec<_>, Vec<_>) =
item.history().map(|r| (r.delta_t, r.rating)).unzip();
delta_t.resize(pad_size, 0);
rating.resize(pad_size, 0);
Expand Down Expand Up @@ -187,7 +187,7 @@ fn test_batcher() {
use burn_ndarray::NdArrayDevice;
type Backend = NdArrayBackend<f32>;
let device = NdArrayDevice::Cpu;
let batcher: FSRSBatcher<Backend> = FSRSBatcher::<Backend>::new(device);
let batcher = FSRSBatcher::<Backend>::new(device);
let items = vec![
FSRSItem {
reviews: vec![
Expand Down
18 changes: 15 additions & 3 deletions src/training.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::path::Path;

use crate::cosine_annealing::CosineAnnealingLR;
use crate::dataset::{FSRSBatch, FSRSBatcher, FSRSDataset};
use crate::model::{Model, ModelConfig};
Expand Down Expand Up @@ -107,7 +109,12 @@ pub fn train<B: ADBackend<FloatElem = f32>>(
) {
std::fs::create_dir_all(artifact_dir).ok();
config
.save(&format!("{artifact_dir}/config.json"))
.save(
Path::new(artifact_dir)
.join("config.json")
.to_str()
.unwrap(),
)
.expect("Save without error");

B::seed(config.seed);
Expand Down Expand Up @@ -153,13 +160,18 @@ pub fn train<B: ADBackend<FloatElem = f32>>(
info!("clipped weights: {}", &model_trained.w.val());

config
.save(format!("{ARTIFACT_DIR}/config.json").as_str())
.save(
Path::new(ARTIFACT_DIR)
.join("config.json")
.to_str()
.unwrap(),
)
.unwrap();

PrettyJsonFileRecorder::<FullPrecisionSettings>::new()
.record(
model_trained.into_record(),
format!("{ARTIFACT_DIR}/model").into(),
Path::new(ARTIFACT_DIR).join("model"),
)
.expect("Failed to save trained model");
}
Expand Down

0 comments on commit ed5b7b4

Please sign in to comment.