Skip to content

Commit

Permalink
refactor for rag
Browse files Browse the repository at this point in the history
  • Loading branch information
jtarchie committed Jan 29, 2025
1 parent 28ce7ab commit 20677ba
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 15 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

# Output of the go coverage tool, specifically when used with LiteIDE
*.out
*.out.*

# Dependency directories (remove the comment below to include it)
# vendor/
Expand Down
39 changes: 26 additions & 13 deletions rag.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,39 @@ import (
"context"
_ "embed"
"fmt"
"os"
"time"

"github.com/philippgille/chromem-go"
"github.com/sashabaranov/go-openai"
)

type OpenAIConfig struct {
EmbedModel string
Endpoint string
LLMModel string
Token string
}

type RAG struct {
db *chromem.DB

embedModel string
endpoint string
llmModel string
token string
config OpenAIConfig
embeddingFunc chromem.EmbeddingFunc
}

func NewRAG(filename string, embedModel string, llmModel string, endpoint string, token string) (*RAG, error) {
func NewRAG(filename string, config *OpenAIConfig) (*RAG, error) {
if config == nil {
config = &OpenAIConfig{
// https://platform.openai.com/docs/guides/embeddings#embedding-models
EmbedModel: "text-embedding-3-small",
Endpoint: "https://api.openai.com/v1",
// https://platform.openai.com/docs/model
LLMModel: "gpt-4o-mini",
Token: os.Getenv("OPENAI_API_KEY"),
}
}

db := chromem.NewDB()

if filename != ":memory:" {
Expand All @@ -35,11 +51,8 @@ func NewRAG(filename string, embedModel string, llmModel string, endpoint string
return &RAG{
db: db,

embedModel: embedModel,
endpoint: endpoint,
llmModel: llmModel,
token: token,
embeddingFunc: chromem.NewEmbeddingFuncOpenAICompat(endpoint, token, embedModel, nil),
config: *config,
embeddingFunc: chromem.NewEmbeddingFuncOpenAICompat(config.Endpoint, config.Token, config.EmbedModel, nil),
}, nil
}

Expand Down Expand Up @@ -89,8 +102,8 @@ func (r *RAG) Ask(query string) (string, error) {
return "", fmt.Errorf("failed to search: %w", err)
}

config := openai.DefaultConfig(r.token)
config.BaseURL = r.endpoint
config := openai.DefaultConfig(r.config.Token)
config.BaseURL = r.config.Endpoint
client := openai.NewClientWithConfig(config)

userPrompt := "Query: " + query + "\n\nDocuments:\n\n"
Expand All @@ -103,7 +116,7 @@ func (r *RAG) Ask(query string) (string, error) {
response, err := client.CreateChatCompletion(
context.Background(),
openai.ChatCompletionRequest{
Model: r.llmModel,
Model: r.config.LLMModel,
Messages: []openai.ChatCompletionMessage{
{
Role: "system",
Expand Down
10 changes: 8 additions & 2 deletions rag_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,13 @@ The Internet has transformed:

var _ = FDescribe("RAG", func() {
It("adds documents", func() {
rag, err := builder.NewRAG(":memory:", "nomic-embed-text", "llama3.2", "http://localhost:11434/v1", "")
config := &builder.OpenAIConfig{
EmbedModel: "nomic-embed-text",
Endpoint: "http://localhost:11434/v1",
LLMModel: "llama3.2",
Token: "",
}
rag, err := builder.NewRAG(":memory:", config)
Expect(err).NotTo(HaveOccurred())

for index, doc := range []string{doc1, doc2, doc3} {
Expand All @@ -86,6 +92,6 @@ var _ = FDescribe("RAG", func() {

answer, err := rag.Ask("What is the largest planet?")
Expect(err).NotTo(HaveOccurred())
fmt.Println(answer)
fmt.Println("answer:" + answer)
})
})

0 comments on commit 20677ba

Please sign in to comment.