From a9aa00f65265e785b1aa0e60b46b19f0e850f319 Mon Sep 17 00:00:00 2001 From: Nikola Whallon Date: Fri, 4 Nov 2022 11:29:47 -0700 Subject: [PATCH 1/2] sst5 support instead of sst2 --- src/pipelines/sentiment.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/pipelines/sentiment.rs b/src/pipelines/sentiment.rs index 5e41f29ac..0055c3f0f 100644 --- a/src/pipelines/sentiment.rs +++ b/src/pipelines/sentiment.rs @@ -65,6 +65,7 @@ use serde::{Deserialize, Serialize}; pub enum SentimentPolarity { Positive, Negative, + Neutral, } #[derive(Debug, Serialize, Deserialize)] @@ -141,8 +142,10 @@ impl SentimentModel { let labels = self.sequence_classification_model.predict(input); let mut sentiments = Vec::with_capacity(labels.len()); for label in labels { - let polarity = if label.id == 1 { + let polarity = if label.id == 5 || label.id == 4 { SentimentPolarity::Positive + } else if label.id == 3 { + SentimentPolarity::Neutral } else { SentimentPolarity::Negative }; From cd81a838287949bd81368b4e33d873c386fe4806 Mon Sep 17 00:00:00 2001 From: Nikola Whallon Date: Fri, 11 Nov 2022 10:44:22 -0800 Subject: [PATCH 2/2] fixed off-by-one sentiment --- src/pipelines/sentiment.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pipelines/sentiment.rs b/src/pipelines/sentiment.rs index 0055c3f0f..feb3d8f9f 100644 --- a/src/pipelines/sentiment.rs +++ b/src/pipelines/sentiment.rs @@ -142,9 +142,9 @@ impl SentimentModel { let labels = self.sequence_classification_model.predict(input); let mut sentiments = Vec::with_capacity(labels.len()); for label in labels { - let polarity = if label.id == 5 || label.id == 4 { + let polarity = if label.id == 4 || label.id == 3 { SentimentPolarity::Positive - } else if label.id == 3 { + } else if label.id == 2 { SentimentPolarity::Neutral } else { SentimentPolarity::Negative