-
Notifications
You must be signed in to change notification settings - Fork 20
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* #24: Refactor OpenAI provider config struct * #24: Update OpenAiProviderConfig Messages field validation --------- Co-authored-by: Max <[email protected]>
- Loading branch information
1 parent
dd17fc6
commit 86c624c
Showing
1 changed file
with
45 additions
and
95 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,103 +1,53 @@ | ||
package openai | ||
|
||
type ProviderConfig struct { | ||
Model ConfigItem `json:"model" validate:"required,lowercase"` | ||
Messages ConfigItem `json:"messages" validate:"required"` | ||
MaxTokens ConfigItem `json:"max_tokens" validate:"omitempty,gte=0"` | ||
Temperature ConfigItem `json:"temperature" validate:"omitempty,gte=0,lte=2"` | ||
TopP ConfigItem `json:"top_p" validate:"omitempty,gte=0,lte=1"` | ||
N ConfigItem `json:"n" validate:"omitempty,gte=1"` | ||
Stream ConfigItem `json:"stream" validate:"omitempty, boolean"` | ||
Stop ConfigItem `json:"stop"` | ||
PresencePenalty ConfigItem `json:"presence_penalty" validate:"omitempty,gte=-2,lte=2"` | ||
FrequencyPenalty ConfigItem `json:"frequency_penalty" validate:"omitempty,gte=-2,lte=2"` | ||
LogitBias ConfigItem `json:"logit_bias" validate:"omitempty"` | ||
User ConfigItem `json:"user"` | ||
Seed ConfigItem `json:"seed" validate:"omitempty,gte=0"` | ||
Tools ConfigItem `json:"tools"` | ||
ToolChoice ConfigItem `json:"tool_choice"` | ||
ResponseFormat ConfigItem `json:"response_format"` | ||
type OpenAiProviderConfig struct { | ||
Model string `json:"model" validate:"required,lowercase"` | ||
Messages string `json:"messages" validate:"required"` // does this need to be updated to []string? | ||
MaxTokens int `json:"max_tokens" validate:"omitempty,gte=0"` | ||
Temperature int `json:"temperature" validate:"omitempty,gte=0,lte=2"` | ||
TopP int `json:"top_p" validate:"omitempty,gte=0,lte=1"` | ||
N int `json:"n" validate:"omitempty,gte=1"` | ||
Stream bool `json:"stream" validate:"omitempty, boolean"` | ||
Stop interface{} `json:"stop"` | ||
PresencePenalty int `json:"presence_penalty" validate:"omitempty,gte=-2,lte=2"` | ||
FrequencyPenalty int `json:"frequency_penalty" validate:"omitempty,gte=-2,lte=2"` | ||
LogitBias *map[int]float64 `json:"logit_bias" validate:"omitempty"` | ||
User interface{} `json:"user"` | ||
Seed interface{} `json:"seed" validate:"omitempty,gte=0"` | ||
Tools []string `json:"tools"` | ||
ToolChoice interface{} `json:"tool_choice"` | ||
ResponseFormat interface{} `json:"response_format"` | ||
} | ||
|
||
type ConfigItem struct { | ||
Param string `json:"param" validate:"required"` | ||
Required bool `json:"required" validate:"omitempty,boolean"` | ||
Default interface{} `json:"default"` | ||
} | ||
var defaultMessage = `[ | ||
{ | ||
"role": "system", | ||
"content": "You are a helpful assistant." | ||
}, | ||
{ | ||
"role": "user", | ||
"content": "Hello!" | ||
} | ||
]` | ||
|
||
// Provide the request body for OpenAI's ChatCompletion API | ||
func OpenAiChatDefaultConfig() ProviderConfig { | ||
return ProviderConfig{ | ||
Model: ConfigItem{ | ||
Param: "model", | ||
Required: true, | ||
Default: "gpt-3.5-turbo", | ||
}, | ||
Messages: ConfigItem{ | ||
Param: "messages", | ||
Required: true, | ||
Default: "", | ||
}, | ||
MaxTokens: ConfigItem{ | ||
Param: "max_tokens", | ||
Required: false, | ||
Default: 100, | ||
}, | ||
Temperature: ConfigItem{ | ||
Param: "temperature", | ||
Required: false, | ||
Default: 1, | ||
}, | ||
TopP: ConfigItem{ | ||
Param: "top_p", | ||
Required: false, | ||
Default: 1, | ||
}, | ||
N: ConfigItem{ | ||
Param: "n", | ||
Required: false, | ||
Default: 1, | ||
}, | ||
Stream: ConfigItem{ | ||
Param: "stream", | ||
Required: false, | ||
Default: false, | ||
}, | ||
Stop: ConfigItem{ | ||
Param: "stop", | ||
Required: false, | ||
}, | ||
PresencePenalty: ConfigItem{ | ||
Param: "presence_penalty", | ||
Required: false, | ||
}, | ||
FrequencyPenalty: ConfigItem{ | ||
Param: "frequency_penalty", | ||
Required: false, | ||
}, | ||
LogitBias: ConfigItem{ | ||
Param: "logit_bias", | ||
Required: false, | ||
}, | ||
User: ConfigItem{ | ||
Param: "user", | ||
Required: false, | ||
}, | ||
Seed: ConfigItem{ | ||
Param: "seed", | ||
Required: false, | ||
}, | ||
Tools: ConfigItem{ | ||
Param: "tools", | ||
Required: false, | ||
}, | ||
ToolChoice: ConfigItem{ | ||
Param: "tool_choice", | ||
Required: false, | ||
}, | ||
ResponseFormat: ConfigItem{ | ||
Param: "response_format", | ||
Required: false, | ||
}, | ||
func OpenAiChatDefaultConfig() OpenAiProviderConfig { | ||
return OpenAiProviderConfig{ | ||
Model: "gpt-3.5-turbo", | ||
Messages: defaultMessage, | ||
MaxTokens: 100, | ||
Temperature: 1, | ||
TopP: 1, | ||
N: 1, | ||
Stream: false, | ||
Stop: nil, | ||
PresencePenalty: 0, | ||
FrequencyPenalty: 0, | ||
LogitBias: nil, | ||
User: nil, | ||
Seed: nil, | ||
Tools: nil, | ||
ToolChoice: nil, | ||
ResponseFormat: nil, | ||
} | ||
} |