diff --git a/src/slim.rs b/src/slim.rs index 9f148de..e0b2311 100644 --- a/src/slim.rs +++ b/src/slim.rs @@ -69,6 +69,20 @@ impl SlimMSE { } } + pub fn fit_identified(&mut self, user_interactions: Vec<(i32, i32, f32, f32)>, add_interaction: Option, update_interaction: Option) { + for (user_id, item_id, tstamp, rating) in user_interactions { + if let Err(e) = { + if add_interaction.unwrap_or(true) { + self.interactions.add_interaction(user_id, item_id, tstamp, rating, update_interaction.unwrap_or(false)); + } + self.update_weights(user_id, item_id); + Ok::<(), Box>(()) + } { + warn!("Failed to fit interaction: {}", e); + } + } + } + #[pyo3(signature = (pydf, epochs = 1, random_seed = None))] pub fn bulk_fit(&mut self, pydf: PyDataFrame, epochs: usize, random_seed: Option) -> Result<()> { let df: DataFrame = pydf.into(); @@ -135,20 +149,6 @@ impl SlimMSE { Ok(()) } - pub fn fit_identified(&mut self, user_interactions: Vec<(i32, i32, f32, f32)>, add_interaction: Option, update_interaction: Option) { - for (user_id, item_id, tstamp, rating) in user_interactions { - if let Err(e) = { - if add_interaction.unwrap_or(true) { - self.interactions.add_interaction(user_id, item_id, tstamp, rating, update_interaction.unwrap_or(false)); - } - self.update_weights(user_id, item_id); - Ok::<(), Box>(()) - } { - warn!("Failed to fit interaction: {}", e); - } - } - } - /// Bulk identify users and items from the provided interactions. #[inline] pub fn bulk_identify(&mut self, user_interactions: Vec<(SerializableValue, SerializableValue)>) -> Vec<(i32, i32)> { @@ -193,7 +193,7 @@ impl SlimMSE { // // Vectorized updates for better performance. let updates: Vec<_> = user_items - .par_iter() + .iter() .filter(|&&ui| ui != item_id) // Exclude the target item_id .map(|&ui| { let rating = self.interactions.get_user_item_rating(user_id, ui, 0.0); @@ -203,7 +203,7 @@ impl SlimMSE { .filter(|(_, grad)| grad.abs() > 1e-6) // Skip very small gradients .collect(); - self.ftrl.par_update_gradients(item_id, &updates); + self.ftrl.vectorized_update_gradients(item_id, &updates); } pub fn predict_rating(&self, user: SerializableValue, item: SerializableValue) -> f32 {