Skip to content

Commit

Permalink
fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
Panizghi committed Jul 9, 2024
1 parent 670623c commit ff75047
Show file tree
Hide file tree
Showing 4 changed files with 207 additions and 136 deletions.
8 changes: 8 additions & 0 deletions src/main/java/io/anserini/index/AbstractIndexer.java
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,10 @@ public AbstractIndexer(Args args) {
Class.forName("io.anserini.collection." + args.collectionClass);
this.collection = collectionClass.getConstructor(Path.class).newInstance(collectionPath);
} catch (Exception e) {
LOG.error(e);
// print more error detail
LOG.error("Error loading collection class: " + args.collectionClass);
LOG.error("Collection path: " + collectionPath);
throw new IllegalArgumentException(String.format("Unable to load collection class \"%s\".", args.collectionClass));
}
}
Expand Down Expand Up @@ -329,6 +333,10 @@ protected void processSegments(ThreadPoolExecutor executor, List<Path> segmentPa

executor.execute(new IndexerThread(segmentPath, generator));
} catch (InstantiationException | IllegalAccessException | InvocationTargetException | NoSuchMethodException e) {
LOG.error(e);
LOG.error("Error instantiating generator class: " + generatorClass);
LOG.error("Segment path: " + segmentPath);
LOG.error("Args: " + args);
throw new IllegalArgumentException(String.format("Unable to load LuceneDocumentGenerator \"%s\".", generatorClass.getSimpleName()));
}
});
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
package io.anserini.index.generator;

import io.anserini.collection.SourceDocument;
import io.anserini.index.AbstractIndexer;
import io.anserini.index.Constants;
import io.anserini.index.IndexHnswDenseVectors;
import io.anserini.index.IndexHnswDenseVectors.Args;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.document.BinaryDocValuesField;
Expand All @@ -13,32 +12,31 @@
import org.apache.lucene.document.StringField;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.BytesRef;
import com.fasterxml.jackson.databind.ObjectMapper;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Map;
import java.util.List;
import java.util.stream.Stream;

public class HnswJsonWithSafeTensorsDenseVectorDocumentGenerator<T extends SourceDocument>
implements LuceneDocumentGenerator<T> {
private static final Logger LOG = LogManager.getLogger(HnswJsonWithSafeTensorsDenseVectorDocumentGenerator.class);
protected Args args;
protected AbstractIndexer.Args args;
private HashSet<String> allowedFileSuffix;

public HnswJsonWithSafeTensorsDenseVectorDocumentGenerator() {
this.allowedFileSuffix = new HashSet<>(Arrays.asList(".json", ".jsonl", ".gz"));
LOG.info("V1 Initializing HnswJsonWithSafeTensorsDenseVectorDocumentGenerator...");
}

public void setArgs(IndexHnswDenseVectors.Args args) {
public HnswJsonWithSafeTensorsDenseVectorDocumentGenerator(AbstractIndexer.Args args) {
this.args = args;
LOG.info("Args set via setter method:");
LOG.info(" - Input path: " + this.args.input);
this.allowedFileSuffix = new HashSet<>(Arrays.asList(".json", ".jsonl", ".gz"));
LOG.info("Initializing HnswJsonWithSafeTensorsDenseVectorDocumentGenerator with Args...");
}

@Override
Expand Down Expand Up @@ -69,29 +67,18 @@ public Document createDocument(T src) throws InvalidDocumentException {
throw new InvalidDocumentException();
}

// Read and deserialize the SafeTensors files
byte[] vectorsData = Files.readAllBytes(Paths.get(filePaths.vectorsFilePath));
byte[] docidsData = Files.readAllBytes(Paths.get(filePaths.docidsFilePath));
// Read vectors and docids from safetensors
double[][] vectors = readVectors(filePaths.vectorsFilePath);
String[] docids = readDocidAsciiValues(filePaths.docidsFilePath);

// Deserialize vectors and docid ASCII values
double[][] vectors = extractVectors(vectorsData);
int[][] docidAsciiValues = extractDocidAsciiValues(docidsData);

// Create the Lucene document
String id = src.id();
LOG.info("Processing document ID: " + id);
int[] docidAscii = id.chars().toArray();

Integer index = null;
for (int i = 0; i < docidAsciiValues.length; i++) {
if (Arrays.equals(docidAscii, docidAsciiValues[i])) {
index = i;
break;
}
}
int index = Arrays.asList(docids).indexOf(id);

if (index == null) {
if (index == -1) {
LOG.error("Error finding index for document ID: " + id);
LOG.error("Document ID ASCII: " + Arrays.toString(id.chars().toArray()));
LOG.error("Available IDs ASCII: " + Arrays.deepToString(docids));
throw new InvalidDocumentException();
}

Expand Down Expand Up @@ -148,31 +135,117 @@ public FilePaths(String vectorsFilePath, String docidsFilePath) {
}
}

private double[][] extractVectors(byte[] data) {
private double[][] readVectors(String filePath) throws IOException {
byte[] data = Files.readAllBytes(Paths.get(filePath));
Map<String, Object> header = parseHeader(data);
return extractVectors(data, header);
}

private String[] readDocidAsciiValues(String filePath) throws IOException {
byte[] data = Files.readAllBytes(Paths.get(filePath));
Map<String, Object> header = parseHeader(data);
return extractDocids(data, header);
}

@SuppressWarnings("unchecked")
private static Map<String, Object> parseHeader(byte[] data) throws IOException {
ByteBuffer buffer = ByteBuffer.wrap(data).order(ByteOrder.LITTLE_ENDIAN);
int rows = buffer.getInt();
int cols = buffer.getInt();
long headerSize = buffer.getLong();
byte[] headerBytes = new byte[(int) headerSize];
buffer.get(headerBytes);
String headerJson = new String(headerBytes, StandardCharsets.UTF_8).trim();
System.out.println("Header JSON: " + headerJson);
ObjectMapper objectMapper = new ObjectMapper();
return objectMapper.readValue(headerJson, Map.class);
}

private static double[][] extractVectors(byte[] data, Map<String, Object> header) {
@SuppressWarnings("unchecked")
Map<String, Object> vectorsInfo = (Map<String, Object>) header.get("vectors");
String dtype = (String) vectorsInfo.get("dtype");

@SuppressWarnings("unchecked")
List<Integer> shapeList = (List<Integer>) vectorsInfo.get("shape");
int rows = shapeList.get(0);
int cols = shapeList.get(1);
@SuppressWarnings("unchecked")
List<Number> dataOffsets = (List<Number>) vectorsInfo.get("data_offsets");
long begin = dataOffsets.get(0).longValue();
long end = dataOffsets.get(1).longValue();

System.out.println("Vectors shape: " + rows + "x" + cols);
System.out.println("Data offsets: " + begin + " to " + end);
System.out.println("Data type: " + dtype);

ByteBuffer buffer = ByteBuffer.wrap(data).order(ByteOrder.LITTLE_ENDIAN);
// Correctly position the buffer to start reading after the header
buffer.position((int) (begin + buffer.getLong(0) + 8));

double[][] vectors = new double[rows][cols];
for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) {
vectors[i][j] = buffer.getDouble();
if (dtype.equals("F64")) {
for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) {
vectors[i][j] = buffer.getDouble();
}
}
} else {
throw new UnsupportedOperationException("Unsupported data type: " + dtype);
}

// Log the first few rows and columns to verify the content
System.out.println("First few vectors:");
for (int i = 0; i < Math.min(5, rows); i++) {
for (int j = 0; j < Math.min(10, cols); j++) {
System.out.print(vectors[i][j] + " ");
}
System.out.println();
}

return vectors;
}

private int[][] extractDocidAsciiValues(byte[] data) {
ByteBuffer buffer = ByteBuffer.wrap(data).order(ByteOrder.LITTLE_ENDIAN);
int rows = buffer.getInt();
int maxCols = buffer.getInt();
@SuppressWarnings("unchecked")
private static String[] extractDocids(byte[] data, Map<String, Object> header) {
Map<String, Object> docidsInfo = (Map<String, Object>) header.get("docids");
String dtype = (String) docidsInfo.get("dtype");

List<Integer> shapeList = (List<Integer>) docidsInfo.get("shape");
int length = shapeList.get(0);
int maxCols = shapeList.get(1);

int[][] docidAsciiValues = new int[rows][maxCols];
for (int i = 0; i < rows; i++) {
for (int j = 0; j < maxCols; j++) {
docidAsciiValues[i][j] = buffer.getInt();
List<Number> dataOffsets = (List<Number>) docidsInfo.get("data_offsets");
long begin = dataOffsets.get(0).longValue();
long end = dataOffsets.get(1).longValue();

System.out.println("Docids shape: " + length + "x" + maxCols);
System.out.println("Data offsets: " + begin + " to " + end);
System.out.println("Data type: " + dtype);

ByteBuffer buffer = ByteBuffer.wrap(data).order(ByteOrder.LITTLE_ENDIAN);
// Correctly position the buffer to start reading after the header
buffer.position((int) (begin + buffer.getLong(0) + 8));

String[] docids = new String[length];
StringBuilder sb = new StringBuilder();
if (dtype.equals("I64")) {
for (int i = 0; i < length; i++) {
sb.setLength(0);
for (int j = 0; j < maxCols; j++) {
char c = (char) buffer.getLong();
if (c != 0) sb.append(c);
}
docids[i] = sb.toString();
}
} else {
throw new UnsupportedOperationException("Unsupported data type: " + dtype);
}
return docidAsciiValues;

// Log the first few docid indices to verify the content
System.out.println("First few docids:");
for (int i = 0; i < Math.min(10, docids.length); i++) {
System.out.println(docids[i]);
}

return docids;
}
}
9 changes: 8 additions & 1 deletion src/main/python/safetensors/json_to_bin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import argparse
import gzip
from safetensors.torch import save_file
from safetensors.torch import save_file, load_file

# Set up argument parser
parser = argparse.ArgumentParser(description='Process vectors and docids from JSONL or GZ files.')
Expand Down Expand Up @@ -66,3 +66,10 @@

print(f"Saved vectors to {vectors_path}")
print(f"Saved docids to {docids_path}")

# Load vectors and docids
loaded_vectors = load_file(vectors_path)['vectors']
loaded_docids = load_file(docids_path)['docids']

print(f"Loaded vectors: {loaded_vectors}")
print(f"Loaded document IDs (ASCII): {loaded_docids}")
Loading

0 comments on commit ff75047

Please sign in to comment.