Skip to content

Commit

Permalink
Reverted to use vectorized_update_gradients() instead of self.ftrl.pa…
Browse files Browse the repository at this point in the history
…r_update_gradients() as len(user_items) is 20 or less and no enough room to utilize parallel processing
  • Loading branch information
myui committed Nov 17, 2024
1 parent bee9fe3 commit cb81476
Showing 1 changed file with 16 additions and 16 deletions.
32 changes: 16 additions & 16 deletions src/slim.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,20 @@ impl SlimMSE {
}
}

pub fn fit_identified(&mut self, user_interactions: Vec<(i32, i32, f32, f32)>, add_interaction: Option<bool>, update_interaction: Option<bool>) {
for (user_id, item_id, tstamp, rating) in user_interactions {
if let Err(e) = {
if add_interaction.unwrap_or(true) {
self.interactions.add_interaction(user_id, item_id, tstamp, rating, update_interaction.unwrap_or(false));
}
self.update_weights(user_id, item_id);
Ok::<(), Box<dyn std::error::Error>>(())
} {
warn!("Failed to fit interaction: {}", e);
}
}
}

#[pyo3(signature = (pydf, epochs = 1, random_seed = None))]
pub fn bulk_fit(&mut self, pydf: PyDataFrame, epochs: usize, random_seed: Option<u64>) -> Result<()> {
let df: DataFrame = pydf.into();
Expand Down Expand Up @@ -135,20 +149,6 @@ impl SlimMSE {
Ok(())
}

pub fn fit_identified(&mut self, user_interactions: Vec<(i32, i32, f32, f32)>, add_interaction: Option<bool>, update_interaction: Option<bool>) {
for (user_id, item_id, tstamp, rating) in user_interactions {
if let Err(e) = {
if add_interaction.unwrap_or(true) {
self.interactions.add_interaction(user_id, item_id, tstamp, rating, update_interaction.unwrap_or(false));
}
self.update_weights(user_id, item_id);
Ok::<(), Box<dyn std::error::Error>>(())
} {
warn!("Failed to fit interaction: {}", e);
}
}
}

/// Bulk identify users and items from the provided interactions.
#[inline]
pub fn bulk_identify(&mut self, user_interactions: Vec<(SerializableValue, SerializableValue)>) -> Vec<(i32, i32)> {
Expand Down Expand Up @@ -193,7 +193,7 @@ impl SlimMSE {
//
// Vectorized updates for better performance.
let updates: Vec<_> = user_items
.par_iter()
.iter()
.filter(|&&ui| ui != item_id) // Exclude the target item_id
.map(|&ui| {
let rating = self.interactions.get_user_item_rating(user_id, ui, 0.0);
Expand All @@ -203,7 +203,7 @@ impl SlimMSE {
.filter(|(_, grad)| grad.abs() > 1e-6) // Skip very small gradients
.collect();

self.ftrl.par_update_gradients(item_id, &updates);
self.ftrl.vectorized_update_gradients(item_id, &updates);
}

pub fn predict_rating(&self, user: SerializableValue, item: SerializableValue) -> f32 {
Expand Down

0 comments on commit cb81476

Please sign in to comment.