diff --git a/rig-core/src/embeddings/embeddable.rs b/rig-core/src/embeddings/embeddable.rs index accdf402..f5a69fd6 100644 --- a/rig-core/src/embeddings/embeddable.rs +++ b/rig-core/src/embeddings/embeddable.rs @@ -29,7 +29,7 @@ //! } //! ``` -use crate::vec_utils::OneOrMany; +use crate::one_or_many::OneOrMany; /// Error type used for when the `embeddable` method fails. /// Used by default implementations of `Embeddable` for common types. diff --git a/rig-core/src/embeddings/embedding.rs b/rig-core/src/embeddings/embedding.rs index 81a980ef..735284a8 100644 --- a/rig-core/src/embeddings/embedding.rs +++ b/rig-core/src/embeddings/embedding.rs @@ -1,5 +1,5 @@ //! The module defines the [EmbeddingModel] trait, which represents an embedding model that can -//! generate embeddings for documents. It also provides an implementation of the [embeddings::EmbeddingsBuilder] +//! generate embeddings for documents. It also provides an implementation of the [crate::embeddings::EmbeddingsBuilder] //! struct, which allows users to build collections of document embeddings using different embedding //! models and document sources. //! diff --git a/rig-core/src/lib.rs b/rig-core/src/lib.rs index 6997c6ff..a4850791 100644 --- a/rig-core/src/lib.rs +++ b/rig-core/src/lib.rs @@ -52,8 +52,8 @@ //! //! ## Vector stores and indexes //! Rig provides a common interface for working with vector stores and indexes. Specifically, the library -//! provides the [VectorStore](crate::vector_store::VectorStore) and [VectorStoreIndex](crate::vector_store::VectorStoreIndex) -//! traits, which can be implemented to define vector stores and indices respectively. +//! provides the [VectorStoreIndex](crate::vector_store::VectorStoreIndex) +//! trait, which can be implemented to define vector stores and indices. //! Those can then be used as the knowledgebase for a [RagAgent](crate::rag::RagAgent), or //! as a source of context documents in a custom architecture that use multiple LLMs or agents. //! @@ -72,14 +72,14 @@ pub mod completion; pub mod embeddings; pub mod extractor; pub mod json_utils; +pub mod one_or_many; pub mod providers; pub mod tool; -pub mod vec_utils; pub mod vector_store; // Re-export commonly used types and traits pub use embeddings::embeddable::Embeddable; -pub use vec_utils::OneOrMany; +pub use one_or_many::OneOrMany; #[cfg(feature = "derive")] pub use rig_derive::Embeddable; diff --git a/rig-core/src/vec_utils.rs b/rig-core/src/one_or_many.rs similarity index 59% rename from rig-core/src/vec_utils.rs rename to rig-core/src/one_or_many.rs index a487e71c..23ece94f 100644 --- a/rig-core/src/vec_utils.rs +++ b/rig-core/src/one_or_many.rs @@ -33,14 +33,6 @@ impl OneOrMany { self.rest.clone() } - /// Use the Iterator trait on OneOrMany - pub fn iter(&self) -> OneOrManyIterator { - OneOrManyIterator { - one_or_many: self, - index: 0, - } - } - /// Create a OneOrMany object with a single item of any type. pub fn one(item: T) -> Self { OneOrMany { @@ -70,41 +62,102 @@ impl OneOrMany { OneOrMany::many(items) } + + pub fn iter(&self) -> Iter { + Iter { + first: Some(&self.first), + rest: self.rest.iter(), + } + } + + pub fn iter_mut(&mut self) -> IterMut<'_, T> { + IterMut { + first: Some(&mut self.first), + rest: self.rest.iter_mut(), + } + } } -/// Implement Iterator for OneOrMany. -/// Iterates over all items in both `first` and `rest`. -/// Borrows the OneOrMany object that is being iterator over. -pub struct OneOrManyIterator<'a, T> { - one_or_many: &'a OneOrMany, - index: usize, +// ================================================================ +// Implementations of Iterator for OneOrMany +// - OneOrMany::iter() -> iterate over references of T objects +// - OneOrMany::into_iter() -> iterate over owned T objects +// - OneOrMany::iter_mut() -> iterate over mutable references of T objects +// ================================================================ + +/// Struct returned by call to `OneOrMany::iter()`. +pub struct Iter<'a, T> { + // References. + first: Option<&'a T>, + rest: std::slice::Iter<'a, T>, } -impl<'a, T> Iterator for OneOrManyIterator<'a, T> { +/// Implement `Iterator` for `Iter`. +/// The Item type of the `Iterator` trait is a reference of `T`. +impl<'a, T> Iterator for Iter<'a, T> { type Item = &'a T; fn next(&mut self) -> Option { - let mut item = None; - if self.index == 0 { - item = Some(&self.one_or_many.first) - } else if self.index - 1 < self.one_or_many.rest.len() { - item = Some(&self.one_or_many.rest[self.index - 1]); - }; - - self.index += 1; - item + if let Some(first) = self.first.take() { + Some(first) + } else { + self.rest.next() + } } } -/// Implement IntoIterator for OneOrMany. -/// Iterates over all items in both `first` and `rest`. -/// Takes ownership the OneOrMany object that is being iterator over. +/// Struct returned by call to `OneOrMany::into_iter()`. +pub struct IntoIter { + // Owned. + first: Option, + rest: std::vec::IntoIter, +} + +/// Implement `Iterator` for `IntoIter`. impl IntoIterator for OneOrMany { type Item = T; - type IntoIter = std::iter::Chain, std::vec::IntoIter>; + type IntoIter = IntoIter; fn into_iter(self) -> Self::IntoIter { - std::iter::once(self.first).chain(self.rest) + IntoIter { + first: Some(self.first), + rest: self.rest.into_iter(), + } + } +} + +/// Implement `Iterator` for `IntoIter`. +/// The Item type of the `Iterator` trait is an owned `T`. +impl Iterator for IntoIter { + type Item = T; + + fn next(&mut self) -> Option { + if let Some(first) = self.first.take() { + Some(first) + } else { + self.rest.next() + } + } +} + +/// Struct returned by call to `OneOrMany::iter_mut()`. +pub struct IterMut<'a, T> { + // Mutable references. + first: Option<&'a mut T>, + rest: std::slice::IterMut<'a, T>, +} + +// Implement `Iterator` for `IterMut`. +// The Item type of the `Iterator` trait is a mutable reference of `OneOrMany`. +impl<'a, T> Iterator for IterMut<'a, T> { + type Item = &'a mut T; + + fn next(&mut self) -> Option { + if let Some(first) = self.first.take() { + Some(first) + } else { + self.rest.next() + } } } @@ -113,7 +166,7 @@ mod test { use super::OneOrMany; #[test] - fn test_one_or_many_iter_single() { + fn test_single() { let one_or_many = OneOrMany::one("hello".to_string()); assert_eq!(one_or_many.iter().count(), 1); @@ -124,7 +177,7 @@ mod test { } #[test] - fn test_one_or_many_iter() { + fn test() { let one_or_many = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap(); assert_eq!(one_or_many.iter().count(), 2); @@ -189,6 +242,34 @@ mod test { }); } + #[test] + fn test_mut_single() { + let mut one_or_many = OneOrMany::one("hello".to_string()); + + assert_eq!(one_or_many.iter_mut().count(), 1); + + one_or_many.iter_mut().for_each(|i| { + assert_eq!(i, "hello"); + }); + } + + #[test] + fn test_mut() { + let mut one_or_many = + OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap(); + + assert_eq!(one_or_many.iter_mut().count(), 2); + + one_or_many.iter_mut().enumerate().for_each(|(i, item)| { + if i == 0 { + assert_eq!(item, "hello"); + } + if i == 1 { + assert_eq!(item, "word"); + } + }); + } + #[test] fn test_one_or_many_error() { assert!(OneOrMany::::many(vec![]).is_err())