diff --git a/src/main/java/io/anserini/collection/SafeTensorsDenseVectorCollection.java b/src/main/java/io/anserini/collection/SafeTensorsDenseVectorCollection.java new file mode 100644 index 0000000000..dfbfc0d663 --- /dev/null +++ b/src/main/java/io/anserini/collection/SafeTensorsDenseVectorCollection.java @@ -0,0 +1,334 @@ +/* + * Anserini: A Lucene toolkit for reproducible information retrieval research + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.anserini.collection; + +import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +import java.io.BufferedReader; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.concurrent.ConcurrentHashMap; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +/** + * Collection class for managing SafeTensors dense vectors and corresponding document IDs. + * Extends the DocumentCollection class for handling documents. + */ +public class SafeTensorsDenseVectorCollection extends DocumentCollection { + private static final Logger LOG = LogManager.getLogger(SafeTensorsDenseVectorCollection.class); + private String vectorsFilePath; // Path to the vectors file + private String docidsFilePath; // Path to the document IDs file + public double[][] vectors; // Array to store vector data + public String[] docids; // Array to store document IDs + private static final ConcurrentHashMap processedDocuments = new ConcurrentHashMap<>(); // Track processed documents + + /** + * Constructor that initializes the collection by reading vector and doc ID data from the specified path. + * @param path the path to the directory containing the data files. + * @throws IOException if an I/O error occurs during file reading. + */ + public SafeTensorsDenseVectorCollection(Path path) throws IOException { + this.path = path; + generateFilePaths(path); + readData(); + } + + /** + * Default constructor. + */ + public SafeTensorsDenseVectorCollection() { + // Default constructor + } + + /** + * Creates a file segment for the specified path. + * @param p the path to the file segment. + * @return a FileSegment instance. + * @throws IOException if an I/O error occurs. + */ + @Override + public FileSegment createFileSegment(Path p) throws IOException { + return new SafeTensorsDenseVectorCollection.Segment(p, vectors, docids); + } + + /** + * Throws UnsupportedOperationException as BufferedReader is not supported for this collection. + * @param bufferedReader the BufferedReader instance. + * @throws UnsupportedOperationException indicating the method is not supported. + */ + @Override + public FileSegment createFileSegment(BufferedReader bufferedReader) throws IOException { + throw new UnsupportedOperationException("BufferedReader is not supported for SafeTensorsDenseVectorCollection."); + } + + /** + * Generates file paths for vectors and doc IDs files from the input folder. + * @param inputFolder the directory containing the data files. + * @throws IOException if an I/O error occurs or files are not found. + */ + private void generateFilePaths(Path inputFolder) throws IOException { + List files; + try (Stream stream = Files.list(inputFolder)) { + files = stream.collect(Collectors.toList()); + } + + vectorsFilePath = files.stream() + .filter(file -> file.toString().contains("_vectors.safetensors")) + .map(Path::toString) + .findFirst() + .orElseThrow(() -> new IOException("No vectors file found in the directory " + inputFolder)); + + docidsFilePath = files.stream() + .filter(file -> file.toString().contains("_docids.safetensors")) + .map(Path::toString) + .findFirst() + .orElseThrow(() -> new IOException("No docids file found in the directory " + inputFolder)); + } + + /** + * Reads the data from vectors and doc IDs files. + * @throws IOException if an I/O error occurs during file reading. + */ + private void readData() throws IOException { + vectors = readVectors(vectorsFilePath); + docids = readDocidAsciiValues(docidsFilePath); + } + + /** + * Reads vector data from the specified file path. + * @param filePath the path to the vectors file. + * @return a 2D array of vectors. + * @throws IOException if an I/O error occurs during file reading. + */ + private double[][] readVectors(String filePath) throws IOException { + byte[] data = Files.readAllBytes(Paths.get(filePath)); + Map header = parseHeader(data); + return extractVectors(data, header); + } + + /** + * Reads document ID ASCII values from the specified file path. + * @param filePath the path to the doc IDs file. + * @return an array of document IDs. + * @throws IOException if an I/O error occurs during file reading. + */ + private String[] readDocidAsciiValues(String filePath) throws IOException { + byte[] data = Files.readAllBytes(Paths.get(filePath)); + Map header = parseHeader(data); + return extractDocids(data, header); + } + + /** + * Parses the header from the byte data. + * @param data the byte data. + * @return a map representing the header. + * @throws IOException if an I/O error occurs during parsing. + */ + private Map parseHeader(byte[] data) throws IOException { + ByteBuffer buffer = ByteBuffer.wrap(data).order(ByteOrder.LITTLE_ENDIAN); + long headerSize = buffer.getLong(); + byte[] headerBytes = new byte[(int) headerSize]; + buffer.get(headerBytes); + String headerJson = new String(headerBytes, StandardCharsets.UTF_8).trim(); + ObjectMapper objectMapper = new ObjectMapper(); + return objectMapper.readValue(headerJson, Map.class); + } + + /** + * Extracts vectors from the byte data using the header information. + * @param data the byte data. + * @param header the header information. + * @return a 2D array of vectors. + */ + private double[][] extractVectors(byte[] data, Map header) { + Map vectorsInfo = (Map) header.get("vectors"); + String dtype = (String) vectorsInfo.get("dtype"); + + List shapeList = (List) vectorsInfo.get("shape"); + int rows = shapeList.get(0); + int cols = shapeList.get(1); + List dataOffsets = (List) vectorsInfo.get("data_offsets"); + long begin = dataOffsets.get(0).longValue(); + long end = dataOffsets.get(1).longValue(); + + ByteBuffer buffer = ByteBuffer.wrap(data).order(ByteOrder.LITTLE_ENDIAN); + buffer.position((int) (begin + buffer.getLong(0) + 8)); + + double[][] vectors = new double[rows][cols]; + if (dtype.equals("F64")) { + for (int i = 0; i < rows; i++) { + for (int j = 0; j < cols; j++) { + vectors[i][j] = buffer.getDouble(); + } + } + } else { + throw new UnsupportedOperationException("Unsupported data type: " + dtype); + } + + return vectors; + } + + /** + * Extracts document IDs from the byte data using the header information. + * @param data the byte data. + * @param header the header information. + * @return an array of document IDs. + */ + private String[] extractDocids(byte[] data, Map header) { + Map docidsInfo = (Map) header.get("docids"); + String dtype = (String) docidsInfo.get("dtype"); + + List shapeList = (List) docidsInfo.get("shape"); + int length = shapeList.get(0); + int maxCols = shapeList.get(1); + + List dataOffsets = (List) docidsInfo.get("data_offsets"); + long begin = dataOffsets.get(0).longValue(); + long end = dataOffsets.get(1).longValue(); + + ByteBuffer buffer = ByteBuffer.wrap(data).order(ByteOrder.LITTLE_ENDIAN); + buffer.position((int) (begin + buffer.getLong(0) + 8)); + + String[] docids = new String[length]; + StringBuilder sb = new StringBuilder(); + if (dtype.equals("I64")) { + for (int i = 0; i < length; i++) { + sb.setLength(0); + for (int j = 0; j < maxCols; j++) { + char c = (char) buffer.getLong(); + if (c != 0) + sb.append(c); + } + docids[i] = sb.toString(); + } + } else { + throw new UnsupportedOperationException("Unsupported data type: " + dtype); + } + + return docids; + } + + /** + * Inner class representing a file segment for SafeTensorsDenseVectorCollection. + */ + public static class Segment extends FileSegment { + private double[][] vectors; + private String[] docids; + private int currentIndex; + + /** + * Constructor for the Segment class. + * @param path the path to the file segment. + * @param vectors the vectors data. + * @param docids the document IDs data. + * @throws IOException if an I/O error occurs during file reading. + */ + public Segment(Path path, double[][] vectors, String[] docids) throws IOException { + super(path); + this.vectors = vectors; + this.docids = docids; + this.currentIndex = 0; + } + + /** + * Reads the next document in the segment. + * @throws IOException if an I/O error occurs during file reading. + * @throws NoSuchElementException if end of file is reached. + */ + @Override + protected synchronized void readNext() throws IOException, NoSuchElementException { + if (currentIndex >= docids.length) { + atEOF = true; + throw new NoSuchElementException("End of file reached"); + } + + String id = docids[currentIndex]; + double[] vector = vectors[currentIndex]; + bufferedRecord = new SafeTensorsDenseVectorCollection.Document(id, vector, ""); + currentIndex++; + } + } + + /** + * Inner class representing a document in the SafeTensorsDenseVectorCollection. + */ + public static class Document implements SourceDocument { + private final String id; // Document ID + private final double[] vector; // Vector data + private final String raw; // Raw data + + /** + * Constructor for the Document class. + * @param id the document ID. + * @param vector the vector data. + * @param raw the raw data. + */ + public Document(String id, double[] vector, String raw) { + this.id = id; + this.vector = vector; + this.raw = raw; + } + + /** + * Returns the document ID. + * @return the document ID. + */ + @Override + public String id() { + return id; + } + + /** + * Returns the vector contents as a string. + * @return the vector contents. + */ + @Override + public String contents() { + return Arrays.toString(vector); + } + + /** + * Returns the raw data. + * @return the raw data. + */ + @Override + public String raw() { + return raw; + } + + /** + * Indicates whether the document is indexable. + * @return true if the document is indexable, false otherwise. + */ + @Override + public boolean indexable() { + return true; + } + } +} diff --git a/src/main/java/io/anserini/index/generator/SafeTensorsDenseVectorDocumentGenerator.java b/src/main/java/io/anserini/index/generator/SafeTensorsDenseVectorDocumentGenerator.java new file mode 100644 index 0000000000..495205c49b --- /dev/null +++ b/src/main/java/io/anserini/index/generator/SafeTensorsDenseVectorDocumentGenerator.java @@ -0,0 +1,114 @@ +/* + * Anserini: A Lucene toolkit for reproducible information retrieval research + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.anserini.index.generator; + +import io.anserini.collection.SourceDocument; +import io.anserini.index.Constants; +import org.apache.lucene.document.BinaryDocValuesField; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.document.StringField; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.util.BytesRef; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicBoolean; + +/** + * A document generator for creating Lucene documents with SafeTensors dense vector data. + * Implements the LuceneDocumentGenerator interface. + * + * @param the type of SourceDocument + */ +public class SafeTensorsDenseVectorDocumentGenerator implements LuceneDocumentGenerator { + private static final Logger LOG = LogManager.getLogger(SafeTensorsDenseVectorDocumentGenerator.class); + private static final ConcurrentHashMap processedDocuments = new ConcurrentHashMap<>(); // Track processed documents + + /** + * Creates a Lucene document from the source document. + * + * @param src the source document + * @return the created Lucene document + * @throws InvalidDocumentException if the document is invalid + */ + @SuppressWarnings("unused") + @Override + public Document createDocument(T src) throws InvalidDocumentException { + String docId = src.id(); + + try { + LOG.info("Processing document ID: " + src.id() + " with thread: " + Thread.currentThread().getName()); + + // Parse vector data from document contents + float[] contents = parseVectorFromContents(src.contents()); + if (contents == null || contents.length == 0) { + LOG.error("Vector data is null or empty for document ID: " + src.id()); + throw new InvalidDocumentException(); + } + + LOG.info("Vector length: " + contents.length + " for document ID: " + src.id()); + + // Create and populate the Lucene document + final Document document = new Document(); + document.add(new StringField(Constants.ID, src.id(), Field.Store.YES)); + document.add(new BinaryDocValuesField(Constants.ID, new BytesRef(src.id()))); + document.add(new KnnFloatVectorField(Constants.VECTOR, contents, VectorSimilarityFunction.DOT_PRODUCT)); + + LOG.info("Document created for ID: " + src.id()); + return document; + + } catch (Exception e) { + LOG.error("Error creating document for ID: " + src.id(), e); + throw new InvalidDocumentException(); + + } finally { + // Ensure the processed flag is reset if needed + AtomicBoolean processed = processedDocuments.get(docId); + if (processed != null) { + processed.set(false); + } + } + } + + /** + * Parses the vector data from the document contents. + * + * @param contents the contents of the document + * @return the parsed vector as an array of floats + */ + private float[] parseVectorFromContents(String contents) { + if (contents == null || contents.isEmpty()) { + LOG.error("Contents are null or empty, cannot parse vectors."); + return null; + } + + try { + String[] parts = contents.replace("[", "").replace("]", "").split(","); + float[] vector = new float[parts.length]; + for (int i = 0; i < parts.length; i++) { + vector[i] = Float.parseFloat(parts[i].trim()); + } + return vector; + } catch (NumberFormatException e) { + LOG.error("Error parsing vector contents: " + contents, e); + return null; + } + } +} \ No newline at end of file diff --git a/src/main/python/safetensors/compare_jsonl.py b/src/main/python/safetensors/compare_jsonl.py new file mode 100644 index 0000000000..9b035d7e00 --- /dev/null +++ b/src/main/python/safetensors/compare_jsonl.py @@ -0,0 +1,59 @@ +import json +import torch +import sys +from safetensors.torch import load_file + +def convert_safetensors_to_dicts(vectors_path, docids_path): + # Load vectors and docids + vectors_tensor = load_file(vectors_path)['vectors'] + docids_tensor = load_file(docids_path)['docids'] + + # Convert docids_tensor to a list of docid strings + docids = ["".join([chr(int(c)) for c in row if c != 0]) for row in docids_tensor.tolist()] + + vectors_dict = {docids[i]: vectors_tensor[i].tolist() for i in range(len(docids))} + contents_dict = {docid: {"docid": docid, "contents": "Dummy contents for docid: " + docid, "vector": vectors_dict[docid]} for docid in docids} + + return vectors_dict, contents_dict + +def compare_dicts(vectors_dict, contents_dict): + all_docids = set(vectors_dict.keys()).union(contents_dict.keys()) + + differences = [] + + for docid in sorted(all_docids): + vector = vectors_dict.get(docid) + content_entry = contents_dict.get(docid) + + if not vector or not content_entry: + differences.append(f"Missing entry for docid: {docid}") + continue + + content_docid = content_entry.get('docid') + + if docid != content_docid: + differences.append(f"Docid mismatch for docid: {docid}, content docid: {content_docid}") + else: + if not vector == content_entry.get('vector'): + differences.append(f"Vector mismatch for docid: {docid}") + + if differences: + print("Differences found:") + for difference in differences: + print(difference) + else: + print("No differences found. The files are identical.") + +if __name__ == "__main__": + if len(sys.argv) != 3: + print("Usage: python compare_safetensors.py ") + sys.exit(1) + + vectors_path = sys.argv[1] + docids_path = sys.argv[2] + + # Convert SafeTensors to dictionaries + vectors_dict, contents_dict = convert_safetensors_to_dicts(vectors_path, docids_path) + + # Compare the dictionaries + compare_dicts(vectors_dict, contents_dict) diff --git a/src/main/python/safetensors/json_to_bin.py b/src/main/python/safetensors/json_to_bin.py new file mode 100644 index 0000000000..57a233128f --- /dev/null +++ b/src/main/python/safetensors/json_to_bin.py @@ -0,0 +1,159 @@ +import os +import json +import gzip +import torch +import argparse +import shutil +import logging +from safetensors.torch import save_file, load_file +from concurrent.futures import ThreadPoolExecutor, as_completed +from tqdm import tqdm + + +def setup_logging(): + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(message)s", + level=logging.ERROR, + handlers=[ + logging.StreamHandler() # Logs to the terminal + ] + ) + + +def read_jsonl_file(file_path: str) -> list[dict]: + data = [] + try: + if file_path.endswith(".gz"): + with gzip.open(file_path, "rt", encoding="utf-8") as f: + for line in f: + data.append(json.loads(line)) + else: + with open(file_path, "r", encoding="utf-8") as f: + for line in f: + data.append(json.loads(line)) + except Exception as e: + logging.error(f"Failed to read file {file_path}: {e}") + raise RuntimeError(f"Failed to read file {file_path}: {e}") + return data + + +def convert_file_to_safetensors(input_file_path: str, vectors_path: str, docids_path: str) -> int: + try: + data = read_jsonl_file(input_file_path) + vectors = [] + docids = [] + + for entry in data: + if isinstance(entry.get('vector', [None])[0], float): + vectors.append(entry['vector']) + docid = entry['docid'] + docid_ascii = [ord(char) for char in docid] # Convert docid to ASCII values + docids.append(docid_ascii) + else: + logging.warning(f"Skipped invalid vector entry with docid: {entry.get('docid', 'N/A')}") + + # Convert to tensors + vectors_tensor = torch.tensor(vectors, dtype=torch.float64) + docids_tensor = torch.nn.utils.rnn.pad_sequence([torch.tensor(d, dtype=torch.int64) for d in docids], batch_first=True) + + # Save as Safetensors + save_file({'vectors': vectors_tensor}, vectors_path) + save_file({'docids': docids_tensor}, docids_path) + + return len(vectors) # Return number of processed entries + + except Exception as e: + logging.error(f"Error converting {input_file_path} to Safetensors: {e}") + raise RuntimeError(f"Error converting {input_file_path} to Safetensors: {e}") + + +def validate_safetensor_conversion(vectors_path: str, docids_path: str, original_data: list[dict]) -> bool: + try: + loaded_vectors = load_file(vectors_path)['vectors'] + loaded_docids = load_file(docids_path)['docids'] + + # Validate the sizes + if loaded_vectors.size(0) != len(original_data): + raise ValueError(f"Validation failed for {vectors_path}: number of vectors does not match the original data") + + logging.info(f"Validation successful for {vectors_path} and {docids_path}") + return True + + except Exception as e: + logging.error(f"Validation failed for {vectors_path} or {docids_path}: {e}") + raise e + + +def convert_and_validate_file(input_file_path: str, vectors_path: str, docids_path: str) -> int: + row_count = convert_file_to_safetensors(input_file_path, vectors_path, docids_path) + original_data = read_jsonl_file(input_file_path) + validate_safetensor_conversion(vectors_path, docids_path, original_data) + logging.info(f"Converted {input_file_path} to {vectors_path} and {docids_path}") + return row_count + + +def convert_jsonl_to_safetensors(input_dir: str, output_dir: str, overwrite=False) -> None: + if overwrite and os.path.exists(output_dir): + shutil.rmtree(output_dir) + + os.makedirs(output_dir, exist_ok=True) + + seen_basenames = set() + total_files = 0 + total_rows = 0 + + files_to_process = [] + for file_name in os.listdir(input_dir): + input_file_path = os.path.join(input_dir, file_name) + + if file_name.endswith(".jsonl"): + basename = file_name[:-6] + elif file_name.endswith(".jsonl.gz"): + basename = file_name[:-9] + else: + continue + + if basename in seen_basenames: + continue + + seen_basenames.add(basename) + vectors_path = os.path.join(output_dir, f"{basename}_vectors.safetensors") + docids_path = os.path.join(output_dir, f"{basename}_docids.safetensors") + files_to_process.append((input_file_path, vectors_path, docids_path)) + + with tqdm(total=len(files_to_process), desc="Processing Files") as pbar: + for input_path, vectors_path, docids_path in files_to_process: + try: + logging.info(f"Processing file: {input_path}") + row_count = convert_and_validate_file(input_path, vectors_path, docids_path) + total_files += 1 + total_rows += row_count + except Exception as e: + logging.error(f"Failed to process {input_path}: {e}") + finally: + pbar.update(1) + + logging.info(f"Total files processed: {total_files}") + logging.info(f"Total rows processed: {total_rows}") + +if __name__ == "__main__": + setup_logging() + + parser = argparse.ArgumentParser( + description="Convert JSONL files to Safetensor format and validate." + ) + parser.add_argument( + "--input", required=True, help="Input directory containing JSONL files." + ) + parser.add_argument( + "--output", required=True, help="Output directory for Safetensor files." + ) + parser.add_argument( + "--overwrite", + action="store_true", + default=False, + help="Overwrite the output directory.", + ) + args = parser.parse_args() + + convert_jsonl_to_safetensors(args.input, args.output, args.overwrite) diff --git a/src/main/python/safetensors/requirements.txt b/src/main/python/safetensors/requirements.txt new file mode 100644 index 0000000000..64a61aaa8a --- /dev/null +++ b/src/main/python/safetensors/requirements.txt @@ -0,0 +1,27 @@ +certifi==2024.2.2 +charset-normalizer==3.3.2 +contourpy==1.2.1 +cycler==0.12.1 +filelock==3.14.0 +fonttools==4.52.1 +fsspec==2024.5.0 +idna==3.7 +Jinja2==3.1.4 +kiwisolver==1.4.5 +MarkupSafe==2.1.5 +matplotlib==3.9.0 +mpmath==1.3.0 +networkx==3.3 +numpy==1.26.4 +packaging==24.0 +pillow==10.3.0 +pyparsing==3.1.2 +python-dateutil==2.9.0.post0 +requests==2.32.2 +safetensors==0.4.3 +six==1.16.0 +sympy==1.12 +torch==2.3.0 +typing_extensions==4.11.0 +urllib3==2.2.1 +tqdm==4.66.5 \ No newline at end of file