Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor index writer for readability #150

Merged
merged 1 commit into from
Dec 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions rs/index_writer/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,17 @@ pub struct IvfConfig {
pub batch_size: usize,
}

#[derive(Debug, Clone, Default)]
pub struct HnswIvfConfig {
pub hnsw_config: HnswConfig,
pub ivf_config: IvfConfig,
}

#[derive(Debug, Clone)]
pub enum IndexWriterConfig {
Hnsw(HnswConfig),
Ivf(IvfConfig),
HnswIvf(HnswIvfConfig),
}

impl Default for IndexWriterConfig {
Expand Down
251 changes: 137 additions & 114 deletions rs/index_writer/src/index_writer.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use anyhow::Result;
use anyhow::{Ok, Result};
use index::hnsw::builder::HnswBuilder;
use index::hnsw::writer::HnswWriter;
use index::ivf::builder::{IvfBuilder, IvfBuilderConfig};
Expand All @@ -8,7 +8,7 @@ use quantization::pq::{ProductQuantizerConfig, ProductQuantizerWriter};
use quantization::pq_builder::{ProductQuantizerBuilder, ProductQuantizerBuilderConfig};
use rand::seq::SliceRandom;

use crate::config::{IndexWriterConfig, QuantizerType};
use crate::config::{HnswConfig, HnswIvfConfig, IndexWriterConfig, IvfConfig, QuantizerType};
use crate::input::Input;

pub struct IndexWriter {
Expand All @@ -28,122 +28,145 @@ impl IndexWriter {
ret
}

fn do_build_hnsw_index(
&mut self,
input: &mut impl Input,
hnsw_config: &HnswConfig,
) -> Result<()> {
info!("Start indexing (HNSW)");
let path = &hnsw_config.base_config.output_path;
let pg_temp_dir = format!("{}/pq_tmp", path);
std::fs::create_dir_all(&pg_temp_dir)?;

// First, train the product quantizer
let mut pq_builder = match hnsw_config.quantizer_type {
QuantizerType::ProductQuantizer => {
let pq_config = ProductQuantizerConfig {
dimension: hnsw_config.base_config.dimension,
subvector_dimension: hnsw_config.subvector_dimension,
num_bits: hnsw_config.num_bits,
};
let pq_builder_config = ProductQuantizerBuilderConfig {
max_iteration: hnsw_config.max_iteration,
batch_size: hnsw_config.batch_size,
};
ProductQuantizerBuilder::new(pq_config, pq_builder_config)
}
};

info!("Start training product quantizer");
let sorted_random_rows =
Self::get_sorted_random_rows(input.num_rows(), hnsw_config.num_training_rows);
for row_idx in sorted_random_rows {
input.skip_to(row_idx as usize);
pq_builder.add(input.next().data.to_vec());
}

let pq = pq_builder.build(pg_temp_dir.clone())?;

info!("Start writing product quantizer");
let pq_directory = format!("{}/quantizer", path);
std::fs::create_dir_all(&pq_directory)?;

let pq_writer = ProductQuantizerWriter::new(pq_directory);
pq_writer.write(&pq)?;

info!("Start building index");
let vector_directory = format!("{}/vectors", path);
std::fs::create_dir_all(&vector_directory)?;

let mut hnsw_builder = HnswBuilder::new(
hnsw_config.max_num_neighbors,
hnsw_config.num_layers,
hnsw_config.ef_construction,
hnsw_config.base_config.max_memory_size,
hnsw_config.base_config.file_size,
hnsw_config.base_config.dimension / hnsw_config.subvector_dimension,
Box::new(pq),
vector_directory.clone(),
);

input.reset();
while input.has_next() {
let row = input.next();
hnsw_builder.insert(row.id, row.data)?;
if row.id % 10000 == 0 {
debug!("Inserted {} rows", row.id);
}
}

let hnsw_directory = format!("{}/hnsw", path);
std::fs::create_dir_all(&hnsw_directory)?;

info!("Start writing index");
let hnsw_writer = HnswWriter::new(hnsw_directory);
hnsw_writer.write(&mut hnsw_builder, hnsw_config.reindex)?;

// Cleanup tmp directory. It's ok to fail
std::fs::remove_dir_all(&pg_temp_dir).unwrap_or_default();
std::fs::remove_dir_all(&vector_directory).unwrap_or_default();
Ok(())
}

fn do_build_ivf_index(&mut self, input: &mut impl Input, ivf_config: &IvfConfig) -> Result<()> {
info!("Start indexing (IVF)");
let path = &ivf_config.base_config.output_path;

let mut ivf_builder = IvfBuilder::new(IvfBuilderConfig {
max_iteration: ivf_config.max_iteration,
batch_size: ivf_config.batch_size,
num_clusters: ivf_config.num_clusters,
num_data_points: ivf_config.num_data_points,
max_clusters_per_vector: ivf_config.max_clusters_per_vector,
base_directory: path.to_string(),
memory_size: ivf_config.base_config.max_memory_size,
file_size: ivf_config.base_config.file_size,
num_features: ivf_config.base_config.dimension,
})?;

input.reset();
while input.has_next() {
let row = input.next();
ivf_builder.add_vector(row.data.to_vec())?;
if row.id % 10000 == 0 {
debug!("Inserted {} rows", row.id);
}
}

info!("Start building index");
ivf_builder.build()?;

let ivf_directory = format!("{}/ivf", path);
std::fs::create_dir_all(&ivf_directory)?;

info!("Start writing index");
let ivf_writer = IvfWriter::new(ivf_directory);
ivf_writer.write(&mut ivf_builder)?;

// Cleanup tmp directory. It's ok to fail
ivf_builder.cleanup()?;
Ok(())
}

#[allow(unused_variables)]
fn do_build_ivf_hnsw_index(
&mut self,
input: &mut impl Input,
hnsw_ivf_config: &HnswIvfConfig,
) -> Result<()> {
todo!()
}

// TODO(hicder): Support multiple inputs
pub fn process(&mut self, input: &mut impl Input) -> Result<()> {
match &self.config {
let cfg = self.config.clone();
match cfg {
IndexWriterConfig::Hnsw(hnsw_config) => {
info!("Start indexing (HNSW)");
let path = &hnsw_config.base_config.output_path;
let pg_temp_dir = format!("{}/pq_tmp", path);
std::fs::create_dir_all(&pg_temp_dir)?;

// First, train the product quantizer
let mut pq_builder = match hnsw_config.quantizer_type {
QuantizerType::ProductQuantizer => {
let pq_config = ProductQuantizerConfig {
dimension: hnsw_config.base_config.dimension,
subvector_dimension: hnsw_config.subvector_dimension,
num_bits: hnsw_config.num_bits,
};
let pq_builder_config = ProductQuantizerBuilderConfig {
max_iteration: hnsw_config.max_iteration,
batch_size: hnsw_config.batch_size,
};
ProductQuantizerBuilder::new(pq_config, pq_builder_config)
}
};

info!("Start training product quantizer");
let sorted_random_rows =
Self::get_sorted_random_rows(input.num_rows(), hnsw_config.num_training_rows);
for row_idx in sorted_random_rows {
input.skip_to(row_idx as usize);
pq_builder.add(input.next().data.to_vec());
}

let pq = pq_builder.build(pg_temp_dir.clone())?;

info!("Start writing product quantizer");
let pq_directory = format!("{}/quantizer", path);
std::fs::create_dir_all(&pq_directory)?;

let pq_writer = ProductQuantizerWriter::new(pq_directory);
pq_writer.write(&pq)?;

info!("Start building index");
let vector_directory = format!("{}/vectors", path);
std::fs::create_dir_all(&vector_directory)?;

let mut hnsw_builder = HnswBuilder::new(
hnsw_config.max_num_neighbors,
hnsw_config.num_layers,
hnsw_config.ef_construction,
hnsw_config.base_config.max_memory_size,
hnsw_config.base_config.file_size,
hnsw_config.base_config.dimension / hnsw_config.subvector_dimension,
Box::new(pq),
vector_directory.clone(),
);

input.reset();
while input.has_next() {
let row = input.next();
hnsw_builder.insert(row.id, row.data)?;
if row.id % 10000 == 0 {
debug!("Inserted {} rows", row.id);
}
}

let hnsw_directory = format!("{}/hnsw", path);
std::fs::create_dir_all(&hnsw_directory)?;

info!("Start writing index");
let hnsw_writer = HnswWriter::new(hnsw_directory);
hnsw_writer.write(&mut hnsw_builder, hnsw_config.reindex)?;

// Cleanup tmp directory. It's ok to fail
std::fs::remove_dir_all(&pg_temp_dir).unwrap_or_default();
std::fs::remove_dir_all(&vector_directory).unwrap_or_default();
Ok(())
Ok(self.do_build_hnsw_index(input, &hnsw_config)?)
}
IndexWriterConfig::Ivf(ivf_config) => {
info!("Start indexing (IVF)");
let path = &ivf_config.base_config.output_path;

let mut ivf_builder = IvfBuilder::new(IvfBuilderConfig {
max_iteration: ivf_config.max_iteration,
batch_size: ivf_config.batch_size,
num_clusters: ivf_config.num_clusters,
num_data_points: ivf_config.num_data_points,
max_clusters_per_vector: ivf_config.max_clusters_per_vector,
base_directory: path.to_string(),
memory_size: ivf_config.base_config.max_memory_size,
file_size: ivf_config.base_config.file_size,
num_features: ivf_config.base_config.dimension,
})?;

input.reset();
while input.has_next() {
let row = input.next();
ivf_builder.add_vector(row.data.to_vec())?;
if row.id % 10000 == 0 {
debug!("Inserted {} rows", row.id);
}
}

info!("Start building index");
ivf_builder.build()?;

let ivf_directory = format!("{}/ivf", path);
std::fs::create_dir_all(&ivf_directory)?;

info!("Start writing index");
let ivf_writer = IvfWriter::new(ivf_directory);
ivf_writer.write(&mut ivf_builder)?;

// Cleanup tmp directory. It's ok to fail
ivf_builder.cleanup()?;
Ok(())
IndexWriterConfig::Ivf(ivf_config) => Ok(self.do_build_ivf_index(input, &ivf_config)?),
IndexWriterConfig::HnswIvf(hnsw_ivf_config) => {
Ok(self.do_build_ivf_hnsw_index(input, &hnsw_ivf_config)?)
}
}
}
Expand Down