Skip to content

Commit

Permalink
Use L2DistanceCalculator for PQ (hicder#24)
Browse files Browse the repository at this point in the history
* Add new rustfmt.toml and format (hicder#25)

* Use L2DistanceCalculator for PQ

---------

Co-authored-by: Hieu Pham <[email protected]>
Co-authored-by: BuildKite <[email protected]>
  • Loading branch information
3 people authored Oct 21, 2024
1 parent 8328298 commit 8cd23e2
Show file tree
Hide file tree
Showing 14 changed files with 65 additions and 56 deletions.
3 changes: 2 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion rs/aggregator/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
mod aggregator;

use crate::aggregator::AggregatorServerImpl;
use clap::Parser;
use log::info;
use proto::muopdb::aggregator_server::AggregatorServer;
use tonic::transport::Server;

use crate::aggregator::AggregatorServerImpl;

#[derive(Parser, Debug)]
#[command(version, about, long_about = None)]
struct Args {
Expand Down
10 changes: 5 additions & 5 deletions rs/index/src/hnsw/builder.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
use std::cmp::min;
use std::collections::{BinaryHeap, HashMap, HashSet};

use ordered_float::NotNan;
use quantization::quantization::Quantizer;
use rand::Rng;
use std::{
cmp::min,
collections::{BinaryHeap, HashMap, HashSet},
};

use super::utils::{GraphTraversal, PointAndDistance, SearchContext};

Expand Down Expand Up @@ -273,9 +272,10 @@ impl GraphTraversal for HnswBuilder {
// Test
#[cfg(test)]
mod tests {
use super::*;
use quantization::pq::ProductQuantizer;

use super::*;

fn generate_random_vector(dimension: usize) -> Vec<f32> {
let mut rng = rand::thread_rng();
let mut vector = vec![];
Expand Down
10 changes: 6 additions & 4 deletions rs/index/src/hnsw/index.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
use std::fs::File;
use std::vec;

use memmap2::Mmap;
use quantization::{pq::ProductQuantizerReader, quantization::Quantizer};
use quantization::pq::ProductQuantizerReader;
use quantization::quantization::Quantizer;
use rand::Rng;
use std::{fs::File, vec};

use crate::hnsw::writer::Header;

use super::utils::{GraphTraversal, SearchContext};
use crate::hnsw::writer::Header;

pub struct Hnsw {
// Need this for mmap
Expand Down
20 changes: 9 additions & 11 deletions rs/index/src/hnsw/reader.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
use std::fs::File;

use byteorder::{ByteOrder, LittleEndian};
use memmap2::Mmap;
use std::fs::File;

use crate::hnsw::{
index::Hnsw,
writer::{Header, Version},
};
use crate::hnsw::index::Hnsw;
use crate::hnsw::writer::{Header, Version};

pub struct HnswReader {
base_directory: String,
Expand Down Expand Up @@ -79,14 +78,13 @@ impl HnswReader {
// Test
#[cfg(test)]
mod tests {
use crate::hnsw::{builder::HnswBuilder, writer::HnswWriter};
use quantization::pq::{ProductQuantizerConfig, ProductQuantizerWriter};
use quantization::pq_builder::{ProductQuantizerBuilder, ProductQuantizerBuilderConfig};
use utils::test_utils::generate_random_vector;

use super::*;
use quantization::{
pq::{ProductQuantizerConfig, ProductQuantizerWriter},
pq_builder::{ProductQuantizerBuilder, ProductQuantizerBuilderConfig},
};
use utils::test_utils::generate_random_vector;
use crate::hnsw::builder::HnswBuilder;
use crate::hnsw::writer::HnswWriter;

#[test]
fn test_read_header() {
Expand Down
21 changes: 7 additions & 14 deletions rs/index/src/hnsw/writer.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
use std::{
fs::{self, File},
io::{BufReader, BufWriter, Read, Write},
};
use std::fs::{self, File};
use std::io::{BufReader, BufWriter, Read, Write};

use utils::io::wrap_write;

Expand Down Expand Up @@ -242,19 +240,14 @@ mod tests {
use std::collections::HashMap;

use ordered_float::NotNan;
use quantization::{
pq::{ProductQuantizerConfig, ProductQuantizerWriter},
pq_builder::{ProductQuantizerBuilder, ProductQuantizerBuilderConfig},
};
use quantization::pq::{ProductQuantizerConfig, ProductQuantizerWriter};
use quantization::pq_builder::{ProductQuantizerBuilder, ProductQuantizerBuilderConfig};
use utils::test_utils::generate_random_vector;

use crate::hnsw::{
builder::Layer,
reader::HnswReader,
utils::{GraphTraversal, PointAndDistance},
};

use super::*;
use crate::hnsw::builder::Layer;
use crate::hnsw::reader::HnswReader;
use crate::hnsw::utils::{GraphTraversal, PointAndDistance};

fn construct_layers(hnsw_builder: &mut HnswBuilder) {
// Prepare all layers
Expand Down
3 changes: 2 additions & 1 deletion rs/index/src/vector/file.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::{fs::OpenOptions, vec};
use std::fs::OpenOptions;
use std::vec;

use num_traits::ToBytes;

Expand Down
1 change: 1 addition & 0 deletions rs/quantization/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ serde.workspace = true
tempdir.workspace = true
kmeans.workspace = true
rand.workspace = true
utils.workspace = true
26 changes: 13 additions & 13 deletions rs/quantization/src/pq.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
use crate::quantization::Quantizer;
use core::result::Result;
use std::fs::File;
use std::io::Write;
use std::path::Path;

use serde::{Deserialize, Serialize};
use std::{fs::File, io::Write, path::Path};
use utils::l2::L2DistanceCalculator;
use utils::DistanceCalculator;

use crate::quantization::Quantizer;

pub struct ProductQuantizer {
pub dimension: usize,
Expand All @@ -10,6 +16,7 @@ pub struct ProductQuantizer {
pub codebook: Vec<f32>,
pub base_directory: String,
pub codebook_name: String,
pub distance_calculator: L2DistanceCalculator,
}

#[derive(Serialize, Deserialize, Debug)]
Expand Down Expand Up @@ -138,6 +145,7 @@ impl ProductQuantizer {
codebook,
base_directory,
codebook_name,
distance_calculator: L2DistanceCalculator::new(),
}
}

Expand Down Expand Up @@ -166,6 +174,7 @@ impl ProductQuantizer {
codebook,
base_directory: config.base_directory,
codebook_name: config.codebook_name,
distance_calculator: L2DistanceCalculator::new(),
}
}

Expand All @@ -191,15 +200,6 @@ impl ProductQuantizer {
}
}

/// Compute L2 distance between two vectors
/// TODO: Move this to a separate file
fn compute_l2_distance(a: &[f32], b: &[f32]) -> f32 {
a.iter()
.zip(b.iter())
.map(|(a, b)| (a - b).powi(2))
.sum::<f32>()
}

/// TODO(hicder): Make this faster
/// TODO(hicder): Support multiple distance type
impl Quantizer for ProductQuantizer {
Expand All @@ -218,7 +218,7 @@ impl Quantizer for ProductQuantizer {
for i in 0..num_centroids {
let offset = subspace_offset + i * self.subvector_dimension;
let centroid = &self.codebook[offset..offset + self.subvector_dimension];
let distance = compute_l2_distance(subvector, centroid);
let distance = self.distance_calculator.calculate(&subvector, &centroid);
if distance < min_distance {
min_distance = distance;
min_centroid_id = i;
Expand Down Expand Up @@ -271,7 +271,7 @@ impl Quantizer for ProductQuantizer {
&self.codebook[a_centroid_offset..a_centroid_offset + self.subvector_dimension];
let b_vec =
&self.codebook[b_centroid_offset..b_centroid_offset + self.subvector_dimension];
compute_l2_distance(a_vec, b_vec)
self.distance_calculator.calculate(&a_vec, &b_vec)
})
.sum::<f32>()
.sqrt()
Expand Down
3 changes: 2 additions & 1 deletion rs/quantization/src/pq_builder.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::pq::{ProductQuantizer, ProductQuantizerConfig};
use kmeans::*;

use crate::pq::{ProductQuantizer, ProductQuantizerConfig};

pub struct ProductQuantizerBuilderConfig {
pub max_iteration: usize,
pub batch_size: usize,
Expand Down
3 changes: 2 additions & 1 deletion rs/utils/src/hdf5_reader.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use hdf5::File;
use std::path::Path;

use hdf5::File;

/// Sample function to read a HDF5 file
/// TODO(hicder): Fix this function to make it generic
pub fn read_hdf5_sift_128(path: &str) -> Result<(), Box<dyn std::error::Error>> {
Expand Down
6 changes: 2 additions & 4 deletions rs/utils/src/io.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
use std::{
fs::File,
io::{BufWriter, Write},
};
use std::fs::File;
use std::io::{BufWriter, Write};

/// Convenient wrapper for going from io::Result<usize> to Result<usize, String>
pub fn wrap_write(writer: &mut BufWriter<&mut File>, buf: &[u8]) -> Result<usize, String> {
Expand Down
7 changes: 7 additions & 0 deletions rs/utils/src/l2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,14 @@ use crate::DistanceCalculator;

pub struct L2DistanceCalculator {}

impl L2DistanceCalculator {
pub fn new() -> Self {
Self {}
}
}

impl DistanceCalculator for L2DistanceCalculator {
/// Compute L2 distance between two vectors
fn calculate(&self, a: &[f32], b: &[f32]) -> f32 {
let mut dist = 0.0;
for i in 0..a.len() {
Expand Down
5 changes: 5 additions & 0 deletions rustfmt.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Check https://rust-lang.github.io/rustfmt for more options

reorder_imports = true
imports_granularity = "Module"
group_imports = "StdExternalCrate"

0 comments on commit 8cd23e2

Please sign in to comment.