Skip to content

Commit

Permalink
update lancedb fix embeddings database
Browse files Browse the repository at this point in the history
  • Loading branch information
rjmacarthy committed Nov 12, 2024
1 parent f4229a4 commit 411f225
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 123 deletions.
89 changes: 29 additions & 60 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@
"web-tree-sitter": "^0.22.1"
},
"dependencies": {
"@lancedb/lancedb": "^0.9.0",
"@lancedb/lancedb": "^0.12.0",
"@tiptap/extension-mention": "^2.5.9",
"@tiptap/extension-placeholder": "^2.5.9",
"@tiptap/pm": "^2.5.9",
Expand Down
4 changes: 0 additions & 4 deletions src/common/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,6 @@ export const EXTENSION_CONTEXT_NAME = {
twinnyOverlapSize: "twinnyOverlapSize",
twinnyRelevantFilePaths: "twinnyRelevantFilePaths",
twinnyRelevantCodeSnippets: "twinnyRelevantCodeSnippets",
twinnyVectorSearchMetric: "twinnyVectorSearchMetric",
twinnySymmetryTab: "twinnySymmetryTab",
twinnyEnableRag: "twinnyEnableRag",
}
Expand Down Expand Up @@ -338,9 +337,6 @@ export const WASM_LANGUAGES: { [key: string]: string } = {

export const DEFAULT_RELEVANT_FILE_COUNT = 10
export const DEFAULT_RELEVANT_CODE_COUNT = 5
export const DEFAULT_VECTOR_SEARCH_METRIC = "l2"

export const EMBEDDING_METRICS = ["cosine", "l2", "dot"]

export const MULTILINE_OUTSIDE = [
"class_body",
Expand Down
2 changes: 1 addition & 1 deletion src/common/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ export interface ChunkOptions {
}

export type Embedding = {
embeddings: number[]
embeddings: [number[]]
}

export type EmbeddedDocument = {
Expand Down
15 changes: 0 additions & 15 deletions src/extension/chat-service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ import {
DEFAULT_RELEVANT_CODE_COUNT,
DEFAULT_RELEVANT_FILE_COUNT,
DEFAULT_RERANK_THRESHOLD,
DEFAULT_VECTOR_SEARCH_METRIC,
EVENT_NAME,
EXTENSION_CONTEXT_NAME,
EXTENSION_SESSION_NAME,
Expand Down Expand Up @@ -118,18 +117,11 @@ export class ChatService extends Base {
) as number
const relevantFileCount = Number(stored) || DEFAULT_RELEVANT_FILE_COUNT

const storedMetric = this._context?.globalState.get(
`${EVENT_NAME.twinnyGlobalContext}-${EXTENSION_CONTEXT_NAME.twinnyVectorSearchMetric}`
) as number

const metric = storedMetric || DEFAULT_VECTOR_SEARCH_METRIC

const filePaths =
(await this._db.getDocuments(
embedding,
relevantFileCount,
table,
metric as "cosine" | "l2" | "dot"
)) || []

if (!filePaths.length) return []
Expand Down Expand Up @@ -226,11 +218,6 @@ export class ChatService extends Base {

if (!embedding) return ""

const storedMetric = this._context?.globalState.get(
`${EVENT_NAME.twinnyGlobalContext}-${EXTENSION_CONTEXT_NAME.twinnyVectorSearchMetric}`
) as number
const metric = storedMetric || DEFAULT_VECTOR_SEARCH_METRIC

const query = relevantFiles?.length
? `file IN ("${relevantFiles.map((file) => file[0]).join("\",\"")}")`
: ""
Expand All @@ -240,7 +227,6 @@ export class ChatService extends Base {
embedding,
Math.round(relevantCodeCount / 2),
table,
metric as "cosine" | "l2" | "dot",
query
)) || []

Expand All @@ -249,7 +235,6 @@ export class ChatService extends Base {
embedding,
Math.round(relevantCodeCount / 2),
table,
metric as "cosine" | "l2" | "dot"
)) || []

const documents = [...embeddedDocuments, ...queryEmbeddedDocuments]
Expand Down
40 changes: 25 additions & 15 deletions src/extension/embeddings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,7 @@ import {
import { fetchEmbedding } from "./api"
import { Base } from "./base"
import { TwinnyProvider } from "./provider-manager"
import {
getDocumentSplitChunks,
readGitSubmodulesFile
} from "./utils"
import { getDocumentSplitChunks, readGitSubmodulesFile } from "./utils"

export class EmbeddingDatabase extends Base {
private _documents: EmbeddedDocument[] = []
Expand Down Expand Up @@ -110,7 +107,7 @@ export class EmbeddingDatabase extends Base {

ig.add(embeddingIgnoredGlobs)
ig.add([".git", ".gitignore"])

for (const dirent of dirents) {
const fullPath = path.join(dirPath, dirent.name)
const relativePath = path.relative(rootPath, fullPath)
Expand Down Expand Up @@ -150,26 +147,28 @@ export class EmbeddingDatabase extends Base {
if (!this._extensionContext) return
const promises = filePaths.map(async (filePath) => {
const content = await fs.promises.readFile(filePath, "utf-8")

const chunks = await getDocumentSplitChunks(
content,
filePath,
this._extensionContext
)
const filePathEmbedding = await this.fetchModelEmbedding(filePath)

const fileNameEmbedding = await this.fetchModelEmbedding(filePath)

this._filePaths.push({
content: filePath,
vector: filePathEmbedding,
vector: fileNameEmbedding,
file: filePath
})

for (const chunk of chunks) {
const vector = await this.fetchModelEmbedding(filePath)
const chunkEmbedding = await this.fetchModelEmbedding(chunk)
if (this.getIsDuplicateItem(chunk, chunks)) return
this._documents.push({
content: chunk,
file: filePath,
vector: vector
vector: chunkEmbedding,
file: filePath
})
}

Expand All @@ -196,11 +195,23 @@ export class EmbeddingDatabase extends Base {
try {
const tableNames = await this._db?.tableNames()
if (!tableNames?.includes(`${this._workspaceName}-documents`)) {
await this._db?.createTable(this._documentTableName, this._documents)
await this._db?.createTable(
this._documentTableName,
this._documents,
{
mode: "overwrite"
}
)
}

if (!tableNames?.includes(`${this._workspaceName}-file-paths`)) {
await this._db?.createTable(this._filePathTableName, this._filePaths)
await this._db?.createTable(
this._filePathTableName,
this._filePaths,
{
mode: "overwrite"
}
)
return
}

Expand All @@ -224,12 +235,11 @@ export class EmbeddingDatabase extends Base {
vector: IntoVector,
limit: number,
tableName: string,
metric: "cosine" | "l2" | "dot" = "cosine",
where?: string
): Promise<EmbeddedDocument[] | undefined> {
try {
const table = await this._db?.openTable(tableName)
const query = table?.search(vector).limit(limit).distanceType(metric) // add type assertion
const query = table?.vectorSearch(vector).select("content").limit(limit)
if (where) query?.where(where)
return query?.toArray()
} catch (e) {
Expand All @@ -255,6 +265,6 @@ export class EmbeddingDatabase extends Base {
return (response as LMStudioEmbedding).data?.[0].embedding
}

return (response as Embedding).embeddings
return (response as Embedding).embeddings[0]
}
}
Loading

0 comments on commit 411f225

Please sign in to comment.