Skip to content

Commit

Permalink
feat: enhance API, expose more parameters (mudler#24)
Browse files Browse the repository at this point in the history
Signed-off-by: mudler <[email protected]>
  • Loading branch information
mudler authored Apr 16, 2023
1 parent c371752 commit b062f31
Showing 1 changed file with 59 additions and 23 deletions.
82 changes: 59 additions & 23 deletions api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ type OpenAIResponse struct {
}

type Choice struct {
Index int `json:"index,omitempty"`
FinishReason string `json:"finish_reason,omitempty"`
Message Message `json:"message,omitempty"`
Text string `json:"text,omitempty"`
Index int `json:"index,omitempty"`
FinishReason string `json:"finish_reason,omitempty"`
Message *Message `json:"message,omitempty"`
Text string `json:"text,omitempty"`
}

type Message struct {
Expand All @@ -47,20 +47,29 @@ type OpenAIRequest struct {

// Prompt is read only by completion API calls
Prompt string `json:"prompt"`

// Messages is read only by chat/completion API calls
Messages []Message `json:"messages"`

Echo bool `json:"echo"`
// Common options between all the API calls
TopP float64 `json:"top_p"`
TopK int `json:"top_k"`
Temperature float64 `json:"temperature"`
Maxtokens int `json:"max_tokens"`

N int `json:"n"`

// Custom parameters - not present in the OpenAI API
Batch int `json:"batch"`
F16 bool `json:"f16kv"`
IgnoreEOS bool `json:"ignore_eos"`
}

//go:embed index.html
var indexHTML embed.FS

// https://platform.openai.com/docs/api-reference/completions
func openAIEndpoint(chat bool, defaultModel *llama.LLama, loader *model.ModelLoader, threads int, defaultMutex *sync.Mutex, mutexMap *sync.Mutex, mutexes map[string]*sync.Mutex) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
var err error
Expand Down Expand Up @@ -139,31 +148,58 @@ func openAIEndpoint(chat bool, defaultModel *llama.LLama, loader *model.ModelLoa
predInput = templatedInput
}

// Generate the prediction using the language model
prediction, err := model.Predict(
predInput,
llama.SetTemperature(temperature),
llama.SetTopP(topP),
llama.SetTopK(topK),
llama.SetTokens(tokens),
llama.SetThreads(threads),
)
if err != nil {
return err
result := []Choice{}

n := input.N

if input.N == 0 {
n = 1
}

if chat {
// Return the chat prediction in the response body
return c.JSON(OpenAIResponse{
Model: input.Model,
Choices: []Choice{{Message: Message{Role: "assistant", Content: prediction}}},
})
for i := 0; i < n; i++ {
// Generate the prediction using the language model
predictOptions := []llama.PredictOption{
llama.SetTemperature(temperature),
llama.SetTopP(topP),
llama.SetTopK(topK),
llama.SetTokens(tokens),
llama.SetThreads(threads),
}

if input.Batch != 0 {
predictOptions = append(predictOptions, llama.SetBatch(input.Batch))
}

if input.F16 {
predictOptions = append(predictOptions, llama.EnableF16KV)
}

if input.IgnoreEOS {
predictOptions = append(predictOptions, llama.IgnoreEOS)
}

prediction, err := model.Predict(
predInput,
predictOptions...,
)
if err != nil {
return err
}

if input.Echo {
prediction = predInput + prediction
}
if chat {
result = append(result, Choice{Message: &Message{Role: "assistant", Content: prediction}})
} else {
result = append(result, Choice{Text: prediction})
}
}

// Return the prediction in the response body
return c.JSON(OpenAIResponse{
Model: input.Model,
Choices: []Choice{{Text: prediction}},
Choices: result,
})
}
}
Expand Down

0 comments on commit b062f31

Please sign in to comment.