diff --git a/sqlconnect/internal/snowflake/authentication_test.go b/sqlconnect/internal/snowflake/authentication_test.go index 2dd71ad..8f9a4cc 100644 --- a/sqlconnect/internal/snowflake/authentication_test.go +++ b/sqlconnect/internal/snowflake/authentication_test.go @@ -1,7 +1,13 @@ package snowflake_test import ( + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" "os" + "strings" "testing" "github.com/stretchr/testify/require" @@ -38,4 +44,59 @@ func TestSnowflakeAuthentication(t *testing.T) { defer func() { _ = db.Close() }() require.NoError(t, db.Ping(), "it should be able to ping the database") }) + t.Run("oauth", func(t *testing.T) { + authCode, ok := os.LookupEnv("SNOWFLAKE_TEST_AUTH_OAUTH_CODE") + if !ok { + t.Skip("skipping test due to lack of a test environment") + } + + configJSON, ok := os.LookupEnv("SNOWFLAKE_TEST_ENVIRONMENT_CREDENTIALS") + require.True(t, ok, "it should be able to get the environment credentials") + var conf snowflake.Config + require.NoError(t, json.Unmarshal([]byte(configJSON), &conf), "it should be able to unmarshal the config") + // reset username and password + conf.User = "" + conf.Password = "" + + // Issue a token + var accessToken string + { + var oauthCreds struct { + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret"` + } + oauthCredsJSON, ok := os.LookupEnv("SNOWFLAKE_TEST_AUTH_OAUTH_CREDENTIALS") + require.True(t, ok, "it should be able to get the oauth creds") + require.NoError(t, json.Unmarshal([]byte(oauthCredsJSON), &oauthCreds), "it should be able to unmarshal the oauth creds") + body := url.Values{} + body.Add("redirect_uri", "https://localhost.com") + body.Add("code", authCode) + body.Add("grant_type", "authorization_code") + body.Add("scope", fmt.Sprintf("session:role:%s", conf.Role)) + r, _ := http.NewRequest(http.MethodPost, fmt.Sprintf("https://%s.snowflakecomputing.com/oauth/token-request", conf.Account), strings.NewReader(body.Encode())) + r.Header.Add("Content-Type", "application/x-www-form-urlencoded;charset=UTF-8") + r.SetBasicAuth(oauthCreds.ClientID, oauthCreds.ClientSecret) + resp, err := http.DefaultClient.Do(r) + require.NoError(t, err, "it should be able to issue a token") + defer func() { _ = resp.Body.Close() }() + respBody, err := io.ReadAll(resp.Body) + require.NoError(t, err, "it should be able to read the response body") + require.Equalf(t, http.StatusOK, resp.StatusCode, "it should be able to issue a token: %s", string(respBody)) + var token struct { + AccessToken string `json:"access_token"` + } + require.NoError(t, json.Unmarshal(respBody, &token), "it should be able to decode the token") + accessToken = token.AccessToken + } + + conf.UseOAuth = true + conf.OAuthToken = accessToken + oauthConfigJSON, err := json.Marshal(conf) + require.NoError(t, err, "it should be able to marshal the config") + db, err := sqlconnect.NewDB(snowflake.DatabaseType, oauthConfigJSON) + require.NoError(t, err, "it should be able to create a new DB") + defer func() { _ = db.Close() }() + require.NoError(t, db.Ping(), "it should be able to ping the database") + require.NoError(t, db.QueryRow("SELECT 1").Err()) + }) } diff --git a/sqlconnect/internal/snowflake/config.go b/sqlconnect/internal/snowflake/config.go index 4485687..5dd0ab9 100644 --- a/sqlconnect/internal/snowflake/config.go +++ b/sqlconnect/internal/snowflake/config.go @@ -20,6 +20,11 @@ type Config struct { User string `json:"user"` Schema string `json:"schema"` Role string `json:"role"` + Region string `json:"region"` + + Protocol string `json:"protocol"` // http or https (optional) + Host string `json:"host"` // hostname (optional) + Port int `json:"port"` // port (optional) Password string `json:"password"` @@ -27,6 +32,9 @@ type Config struct { PrivateKey string `json:"privateKey"` PrivateKeyPassphrase string `json:"privateKeyPassphrase"` + UseOAuth bool `json:"useOAuth"` + OAuthToken string `json:"oauthToken"` + Application string `json:"application"` LoginTimeout time.Duration `json:"loginTimeout"` // default: 5m @@ -46,6 +54,10 @@ func (c Config) ConnectionString() (dsn string, err error) { Warehouse: c.Warehouse, Schema: c.Schema, Role: c.Role, + Region: c.Region, + Protocol: c.Protocol, + Host: c.Host, + Port: c.Port, Application: c.Application, LoginTimeout: c.LoginTimeout, Params: make(map[string]*string), @@ -58,6 +70,9 @@ func (c Config) ConnectionString() (dsn string, err error) { return "", fmt.Errorf("parsing private key: %w", err) } sc.PrivateKey = privateKey + } else if c.UseOAuth { + sc.Authenticator = gosnowflake.AuthTypeOAuth + sc.Token = c.OAuthToken } if c.KeepSessionAlive {