diff --git a/frontend/src/js/ui/views/settings.ts b/frontend/src/js/ui/views/settings.ts index 806b6c2..99836e3 100644 --- a/frontend/src/js/ui/views/settings.ts +++ b/frontend/src/js/ui/views/settings.ts @@ -88,7 +88,7 @@ export default (): m.Component => { m.redraw(); }) .catch((err) => { - error('Could not fetch AI models... retrying...'); + error('Could not fetch AI models. retrying... (' + err + ')'); setTimeout(fetchAiModels, 3000); }); }; @@ -293,7 +293,8 @@ export default (): m.Component => { }, aiContextWindow: { label: 'Context Window', - description: 'The context window for the AI service', + description: + 'The context window for the AI service. This window is used to provide the AI with examples in case of the entry generator. Higher values should provide better results.', }, aiMaxTokens: { label: 'Max Tokens', diff --git a/rpc/ai.go b/rpc/ai.go index 1d61ddf..d5741e2 100644 --- a/rpc/ai.go +++ b/rpc/ai.go @@ -222,10 +222,7 @@ func RegisterAI(route *echo.Group, db database.Database) { }) bind.MustBind(route, "/aiModels", func(provider string) ([]string, error) { - // TODO: dynamically fetch models switch provider { - case "OpenAI": - return []string{"gpt-3.5-turbo", "gpt-3.5-turbo-1106", "gpt-3.5-turbo-16k", "gpt-4-1106-preview", "gpt-4", "gpt-4-32k"}, nil case "Custom (e.g. Local)": return []string{"Custom"}, nil } @@ -235,7 +232,25 @@ func RegisterAI(route *echo.Group, db database.Database) { return nil, err } - resp, err := http.Get(endpoint + "/v1/models") + req, err := http.NewRequest("GET", endpoint+"/v1/models", nil) + if err != nil { + return nil, err + } + + if provider == "OpenAI" { + settings, err := db.GetSettings() + if err != nil { + return nil, err + } + + if settings.AIApiKey == "" { + return nil, errors.New("OpenAI API key not set") + } + + req.Header.Set("Authorization", "Bearer "+settings.AIApiKey) + } + + resp, err := http.DefaultClient.Do(req) if err != nil { return nil, err } @@ -245,6 +260,10 @@ func RegisterAI(route *echo.Group, db database.Database) { return nil, err } + if strings.Contains(string(res), "invalid_api_key") { + return nil, errors.New("invalid API key") + } + var models ModelsList err = json.Unmarshal(res, &models) if err != nil {