Skip to content

Commit

Permalink
feat(snowflake): support oauth authentication (#236)
Browse files Browse the repository at this point in the history
# Description

This PR adds support for oauth authentication is snowflake.
- For backward compatibility, the default value of `Authenticator` is
`AuthTypeSnowflake`. If the client provides its value, it will be
overridden
- Added `Host` and `Token` fields in snowflake config

## Linear Ticket

< Replace_with_Linear_Link >

## Security

- [ ] The code changed/added as part of this pull request won't create
any security issues with how the software is being used.

---------

Co-authored-by: skShekhar <[email protected]>
Co-authored-by: Arnab Pal <[email protected]>
Co-authored-by: Aris Tzoumas <[email protected]>
  • Loading branch information
4 people authored Jan 15, 2025
1 parent c9a6308 commit 9bc7188
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 0 deletions.
61 changes: 61 additions & 0 deletions sqlconnect/internal/snowflake/authentication_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
package snowflake_test

import (
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"os"
"strings"
"testing"

"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -38,4 +44,59 @@ func TestSnowflakeAuthentication(t *testing.T) {
defer func() { _ = db.Close() }()
require.NoError(t, db.Ping(), "it should be able to ping the database")
})
t.Run("oauth", func(t *testing.T) {
authCode, ok := os.LookupEnv("SNOWFLAKE_TEST_AUTH_OAUTH_CODE")
if !ok {
t.Skip("skipping test due to lack of a test environment")
}

configJSON, ok := os.LookupEnv("SNOWFLAKE_TEST_ENVIRONMENT_CREDENTIALS")
require.True(t, ok, "it should be able to get the environment credentials")
var conf snowflake.Config
require.NoError(t, json.Unmarshal([]byte(configJSON), &conf), "it should be able to unmarshal the config")
// reset username and password
conf.User = ""
conf.Password = ""

// Issue a token
var accessToken string
{
var oauthCreds struct {
ClientID string `json:"client_id"`
ClientSecret string `json:"client_secret"`
}
oauthCredsJSON, ok := os.LookupEnv("SNOWFLAKE_TEST_AUTH_OAUTH_CREDENTIALS")
require.True(t, ok, "it should be able to get the oauth creds")
require.NoError(t, json.Unmarshal([]byte(oauthCredsJSON), &oauthCreds), "it should be able to unmarshal the oauth creds")
body := url.Values{}
body.Add("redirect_uri", "https://localhost.com")
body.Add("code", authCode)
body.Add("grant_type", "authorization_code")
body.Add("scope", fmt.Sprintf("session:role:%s", conf.Role))
r, _ := http.NewRequest(http.MethodPost, fmt.Sprintf("https://%s.snowflakecomputing.com/oauth/token-request", conf.Account), strings.NewReader(body.Encode()))
r.Header.Add("Content-Type", "application/x-www-form-urlencoded;charset=UTF-8")
r.SetBasicAuth(oauthCreds.ClientID, oauthCreds.ClientSecret)
resp, err := http.DefaultClient.Do(r)
require.NoError(t, err, "it should be able to issue a token")
defer func() { _ = resp.Body.Close() }()
respBody, err := io.ReadAll(resp.Body)
require.NoError(t, err, "it should be able to read the response body")
require.Equalf(t, http.StatusOK, resp.StatusCode, "it should be able to issue a token: %s", string(respBody))
var token struct {
AccessToken string `json:"access_token"`
}
require.NoError(t, json.Unmarshal(respBody, &token), "it should be able to decode the token")
accessToken = token.AccessToken
}

conf.UseOAuth = true
conf.OAuthToken = accessToken
oauthConfigJSON, err := json.Marshal(conf)
require.NoError(t, err, "it should be able to marshal the config")
db, err := sqlconnect.NewDB(snowflake.DatabaseType, oauthConfigJSON)
require.NoError(t, err, "it should be able to create a new DB")
defer func() { _ = db.Close() }()
require.NoError(t, db.Ping(), "it should be able to ping the database")
require.NoError(t, db.QueryRow("SELECT 1").Err())
})
}
15 changes: 15 additions & 0 deletions sqlconnect/internal/snowflake/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,21 @@ type Config struct {
User string `json:"user"`
Schema string `json:"schema"`
Role string `json:"role"`
Region string `json:"region"`

Protocol string `json:"protocol"` // http or https (optional)
Host string `json:"host"` // hostname (optional)
Port int `json:"port"` // port (optional)

Password string `json:"password"`

UseKeyPairAuth bool `json:"useKeyPairAuth"`
PrivateKey string `json:"privateKey"`
PrivateKeyPassphrase string `json:"privateKeyPassphrase"`

UseOAuth bool `json:"useOAuth"`
OAuthToken string `json:"oauthToken"`

Application string `json:"application"`

LoginTimeout time.Duration `json:"loginTimeout"` // default: 5m
Expand All @@ -46,6 +54,10 @@ func (c Config) ConnectionString() (dsn string, err error) {
Warehouse: c.Warehouse,
Schema: c.Schema,
Role: c.Role,
Region: c.Region,
Protocol: c.Protocol,
Host: c.Host,
Port: c.Port,
Application: c.Application,
LoginTimeout: c.LoginTimeout,
Params: make(map[string]*string),
Expand All @@ -58,6 +70,9 @@ func (c Config) ConnectionString() (dsn string, err error) {
return "", fmt.Errorf("parsing private key: %w", err)
}
sc.PrivateKey = privateKey
} else if c.UseOAuth {
sc.Authenticator = gosnowflake.AuthTypeOAuth
sc.Token = c.OAuthToken
}

if c.KeepSessionAlive {
Expand Down

0 comments on commit 9bc7188

Please sign in to comment.