Skip to content

Commit

Permalink
Added Int8 and Bool vectors
Browse files Browse the repository at this point in the history
  • Loading branch information
jkrukowski committed Nov 8, 2024
1 parent 046a3be commit 84551eb
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 5 deletions.
19 changes: 16 additions & 3 deletions Sources/SQLiteVec/Database.swift
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ public actor Database {
params: [any Sendable] = []
) throws -> [[String: any Sendable]] {
let stmt = try prepare(sql, params: params)
return query(stmt)
return try query(stmt)
}

private func prepare(_ sql: String, params: [any Sendable]) throws -> OpaquePointer {
Expand Down Expand Up @@ -278,6 +278,14 @@ public actor Database {
result = sqlite3_bind_blob(
stmt, Int32(index + 1), value, Int32(MemoryLayout<Float>.stride * value.count),
SQLITE_STATIC)
case let value as [Int8]:
result = sqlite3_bind_blob(
stmt, Int32(index + 1), value, Int32(MemoryLayout<Int8>.stride * value.count),
SQLITE_STATIC)
case let value as [Bool]:
result = sqlite3_bind_blob(
stmt, Int32(index + 1), value, Int32(MemoryLayout<Bool>.stride * value.count),
SQLITE_STATIC)
default:
result = sqlite3_bind_null(stmt, Int32(index + 1))
}
Expand All @@ -291,11 +299,16 @@ public actor Database {
try SQLiteVecError.check(sqlite3_step(stmt))
}

private func query(_ stmt: OpaquePointer) -> [[String: any Sendable]] {
private func query(_ stmt: OpaquePointer) throws -> [[String: any Sendable]] {
defer { sqlite3_finalize(stmt) }
var rows = [[String: any Sendable]]()
var columnInfo: (names: [String], types: [Int32])?
while sqlite3_step(stmt) == SQLITE_ROW {
while true {
let result = sqlite3_step(stmt)
try SQLiteVecError.check(result, handler.handle)
if result != SQLITE_ROW {
break
}
if columnInfo == nil {
let columnCount = sqlite3_column_count(stmt)
var names = [String]()
Expand Down
10 changes: 10 additions & 0 deletions Tests/SQLiteVecTests/CoreTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,14 @@ final class CoreTests: XCTestCase {
func testInitialize() throws {
try XCTAssertNoThrow(SQLiteVec.initialize(), "Initializing should not throw")
}

func testLoadedExtensions() async throws {
try SQLiteVec.initialize()
let db = try Database(.inMemory)
let result = try await db.query("PRAGMA module_list")
let extensionNames = result.compactMap { $0["name"] as? String }

XCTAssertTrue(extensionNames.contains("vec_each"), "vec_each should be loaded")
XCTAssertTrue(extensionNames.contains("vec0"), "vec0 should be loaded")
}
}
77 changes: 75 additions & 2 deletions Tests/SQLiteVecTests/DatabaseTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,46 @@ final class DatabaseTests: XCTestCase {
XCTAssertEqual(data.bytes, [212])
}

func testVectorAdd() async throws {
func testVectorInit() async throws {
let db = try Database(.inMemory)
let result1 = try await db.query(
"""
SELECT vec_int8(?) as result
""",
params: [
[0, 1, 2, 3, 4] as [Int8]
]
)
let data1 = try XCTUnwrap(result1[0]["result"] as? Data)
let array1: [Int8] = data1.toArray()
XCTAssertEqual(array1, [0, 1, 2, 3, 4])

let result2 = try await db.query(
"""
SELECT vec_bit(?) as result
""",
params: [
[false, false, false, true, true] as [Bool]
]
)
let data2 = try XCTUnwrap(result2[0]["result"] as? Data)
let array2: [Bool] = data2.toArray()
XCTAssertEqual(array2, [false, false, false, true, true])

let result3 = try await db.query(
"""
SELECT vec_f32(?) as result
""",
params: [
[0, 1, 2, 3, 4] as [Float]
]
)
let data3 = try XCTUnwrap(result3[0]["result"] as? Data)
let array3: [Float] = data3.toArray()
XCTAssertEqual(array3, [0, 1, 2, 3, 4], accuracy: Float(accuracy))
}

func testVectorAddFloat() async throws {
let db = try Database(.inMemory)
let result = try await db.query(
"""
Expand All @@ -99,7 +138,24 @@ final class DatabaseTests: XCTestCase {
)
}

func testVectorSub() async throws {
func testVectorAddInt8() async throws {
let db = try Database(.inMemory)
let result = try await db.query(
"""
SELECT vec_add(vec_int8(?), vec_int8(?)) as result
""",
params: [
[0, 1, 2, 3] as [Int8],
[5, 6, 7, 8] as [Int8],
]
)
XCTAssertEqual(result.count, 1)
let data = try XCTUnwrap(result[0]["result"] as? Data)
let array: [Int8] = data.toArray()
XCTAssertEqual(array, [5, 7, 9, 11])
}

func testVectorSubFloat() async throws {
let db = try Database(.inMemory)
let result = try await db.query(
"""
Expand All @@ -119,6 +175,23 @@ final class DatabaseTests: XCTestCase {
)
}

func testVectorSubInt8() async throws {
let db = try Database(.inMemory)
let result = try await db.query(
"""
SELECT vec_sub(vec_int8(?), vec_int8(?)) as result
""",
params: [
[0, 1, 2, 3] as [Int8],
[9, 3, 8, 0] as [Int8],
]
)
XCTAssertEqual(result.count, 1)
let data = try XCTUnwrap(result[0]["result"] as? Data)
let array: [Int8] = data.toArray()
XCTAssertEqual(array, [-9, -2, -6, 3])
}

func testEmbeddingDistanceQuery() async throws {
let data: [(index: Int, vector: [Float])] = [
(1, [0.1, 0.1, 0.1, 0.1]),
Expand Down

0 comments on commit 84551eb

Please sign in to comment.