Skip to content

Commit

Permalink
#29: Add file path validation and error handling in openaiclient.go
Browse files Browse the repository at this point in the history
  • Loading branch information
mkrueger12 committed Dec 22, 2023
1 parent 3f8bf68 commit 0065cde
Showing 1 changed file with 31 additions and 2 deletions.
33 changes: 31 additions & 2 deletions pkg/providers/openai/openaiclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"log/slog"
"net/http"
"os"
"path/filepath"
"time"

"gopkg.in/yaml.v2"
Expand Down Expand Up @@ -142,7 +143,21 @@ func findProviderByModel(providers []providers.Provider, providerName string, mo
}

func readProviderVars(filePath string) ([]providers.ProviderVars, error) {
data, err := os.ReadFile(filePath)
absPath, err := filepath.Abs(filePath)
if err != nil {
return nil, fmt.Errorf("failed to get absolute file path: %w", err)
}

// Validate that the absolute path is a file
fileInfo, err := os.Stat(absPath)
if err != nil {
return nil, fmt.Errorf("failed to get file info: %w", err)
}
if fileInfo.IsDir() {
return nil, fmt.Errorf("provided path is a directory, not a file")
}

data, err := os.ReadFile(absPath)
if err != nil {
return nil, fmt.Errorf("failed to read provider vars file: %w", err)
}
Expand Down Expand Up @@ -170,7 +185,21 @@ func getDefaultBaseURL(provVars []providers.ProviderVars, providerName string) (
}

func readConfig(filePath string) (providers.GatewayConfig, error) {
data, err := os.ReadFile(filePath)
absPath, err := filepath.Abs(filePath)
if err != nil {
return providers.GatewayConfig{}, fmt.Errorf("failed to get absolute file path: %w", err)
}

// Validate that the absolute path is a file
fileInfo, err := os.Stat(absPath)
if err != nil {
return providers.GatewayConfig{}, fmt.Errorf("failed to get file info: %w", err)
}
if fileInfo.IsDir() {
return providers.GatewayConfig{}, fmt.Errorf("provided path is a directory, not a file")
}

data, err := os.ReadFile(absPath)
if err != nil {
slog.Error("Error:", err)
return providers.GatewayConfig{}, fmt.Errorf("failed to read config file: %w", err)
Expand Down

0 comments on commit 0065cde

Please sign in to comment.