From 606462d8994aad3ae70b57f2bc37dda01e72d114 Mon Sep 17 00:00:00 2001 From: Aris Tzoumas Date: Thu, 14 Mar 2024 18:23:26 +0200 Subject: [PATCH] feat: redshift sdk driver --- .github/workflows/test.yaml | 1 + go.mod | 49 ++- go.sum | 123 +++--- .../internal/bigquery/driver/driver_test.go | 6 +- sqlconnect/internal/redshift/config.go | 50 ++- sqlconnect/internal/redshift/db.go | 64 ++- sqlconnect/internal/redshift/driver/client.go | 25 ++ .../internal/redshift/driver/connection.go | 378 ++++++++++++++++++ .../redshift/driver/connection_test.go | 47 +++ .../internal/redshift/driver/connector.go | 30 ++ sqlconnect/internal/redshift/driver/driver.go | 32 ++ .../internal/redshift/driver/driver_test.go | 255 ++++++++++++ sqlconnect/internal/redshift/driver/dsn.go | 213 ++++++++++ .../internal/redshift/driver/dsn_test.go | 93 +++++ sqlconnect/internal/redshift/driver/errors.go | 12 + sqlconnect/internal/redshift/driver/logger.go | 35 ++ sqlconnect/internal/redshift/driver/result.go | 55 +++ sqlconnect/internal/redshift/driver/rows.go | 120 ++++++ .../internal/redshift/driver/statement.go | 35 ++ sqlconnect/internal/redshift/driver/tx.go | 16 + sqlconnect/internal/redshift/driver/utils.go | 18 + .../internal/redshift/integration_test.go | 20 +- sqlconnect/internal/redshift/mappings.go | 4 + .../testdata/column-mapping-test-rows.json | 4 +- .../testdata/column-mapping-test-seed.sql | 2 +- .../legacy-column-mapping-test-rows.json | 4 +- 26 files changed, 1587 insertions(+), 104 deletions(-) create mode 100644 sqlconnect/internal/redshift/driver/client.go create mode 100644 sqlconnect/internal/redshift/driver/connection.go create mode 100644 sqlconnect/internal/redshift/driver/connection_test.go create mode 100644 sqlconnect/internal/redshift/driver/connector.go create mode 100644 sqlconnect/internal/redshift/driver/driver.go create mode 100644 sqlconnect/internal/redshift/driver/driver_test.go create mode 100644 sqlconnect/internal/redshift/driver/dsn.go create mode 100644 sqlconnect/internal/redshift/driver/dsn_test.go create mode 100644 sqlconnect/internal/redshift/driver/errors.go create mode 100644 sqlconnect/internal/redshift/driver/logger.go create mode 100644 sqlconnect/internal/redshift/driver/result.go create mode 100644 sqlconnect/internal/redshift/driver/rows.go create mode 100644 sqlconnect/internal/redshift/driver/statement.go create mode 100644 sqlconnect/internal/redshift/driver/tx.go create mode 100644 sqlconnect/internal/redshift/driver/utils.go diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 6236071..5ada137 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -50,6 +50,7 @@ jobs: make test exclude="${{ matrix.exclude }}" package=${{ matrix.package }} env: REDSHIFT_TEST_ENVIRONMENT_CREDENTIALS: ${{ secrets.REDSHIFT_TEST_ENVIRONMENT_CREDENTIALS }} + REDSHIFT_SDK_TEST_ENVIRONMENT_CREDENTIALS: ${{ secrets.REDSHIFT_SDK_TEST_ENVIRONMENT_CREDENTIALS }} SNOWFLAKE_TEST_ENVIRONMENT_CREDENTIALS: ${{ secrets.SNOWFLAKE_TEST_ENVIRONMENT_CREDENTIALS }} BIGQUERY_TEST_ENVIRONMENT_CREDENTIALS: ${{ secrets.BIGQUERY_TEST_ENVIRONMENT_CREDENTIALS }} DATABRICKS_TEST_ENVIRONMENT_CREDENTIALS: ${{ secrets.DATABRICKS_TEST_ENVIRONMENT_CREDENTIALS }} diff --git a/go.mod b/go.mod index 77d4a91..406865e 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,10 @@ go 1.22.0 require ( cloud.google.com/go v0.112.0 cloud.google.com/go/bigquery v1.59.1 + github.com/aws/aws-sdk-go-v2 v1.25.3 + github.com/aws/aws-sdk-go-v2/config v1.25.3 + github.com/aws/aws-sdk-go-v2/credentials v1.16.2 + github.com/aws/aws-sdk-go-v2/service/redshiftdata v1.25.2 github.com/databricks/databricks-sql-go v1.5.3 github.com/dlclark/regexp2 v1.11.0 github.com/go-sql-driver/mysql v1.7.1 @@ -15,6 +19,7 @@ require ( github.com/sirupsen/logrus v1.9.3 github.com/snowflakedb/gosnowflake v1.7.2 github.com/stretchr/testify v1.9.0 + github.com/tidwall/gjson v1.17.1 github.com/tidwall/sjson v1.2.5 github.com/trinodb/trino-go-client v0.313.0 google.golang.org/api v0.169.0 @@ -26,30 +31,33 @@ require ( cloud.google.com/go/iam v1.1.6 // indirect github.com/99designs/go-keychain v0.0.0-20191008050251-8e49817e8af4 // indirect github.com/99designs/keyring v1.2.2 // indirect - github.com/Azure/azure-sdk-for-go/sdk/azcore v1.4.0 // indirect - github.com/Azure/azure-sdk-for-go/sdk/internal v1.1.2 // indirect - github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.0.0 // indirect + github.com/Azure/azure-sdk-for-go/sdk/azcore v1.6.0 // indirect + github.com/Azure/azure-sdk-for-go/sdk/internal v1.3.0 // indirect + github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.1.0 // indirect github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect github.com/JohnCGriffin/overflow v0.0.0-20211019200055-46fa312c352c // indirect github.com/Microsoft/go-winio v0.6.0 // indirect github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5 // indirect - github.com/andybalholm/brotli v1.0.5 // indirect + github.com/andybalholm/brotli v1.0.6 // indirect github.com/apache/arrow/go/v12 v12.0.1 // indirect github.com/apache/arrow/go/v14 v14.0.2 // indirect github.com/apache/thrift v0.17.0 // indirect - github.com/aws/aws-sdk-go-v2 v1.17.7 // indirect - github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.4.10 // indirect - github.com/aws/aws-sdk-go-v2/credentials v1.13.18 // indirect - github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.59 // indirect - github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.31 // indirect - github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.25 // indirect - github.com/aws/aws-sdk-go-v2/internal/v4a v1.0.23 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.9.11 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.1.26 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.25 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.14.0 // indirect - github.com/aws/aws-sdk-go-v2/service/s3 v1.31.0 // indirect - github.com/aws/smithy-go v1.13.5 // indirect + github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.5.1 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.14.4 // indirect + github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.14.0 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.3 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.3 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.7.1 // indirect + github.com/aws/aws-sdk-go-v2/internal/v4a v1.2.3 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.10.1 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.2.3 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.10.3 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.16.3 // indirect + github.com/aws/aws-sdk-go-v2/service/s3 v1.43.0 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.17.2 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.20.0 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.25.3 // indirect + github.com/aws/smithy-go v1.20.1 // indirect github.com/cenkalti/backoff/v4 v4.2.1 // indirect github.com/containerd/continuity v0.3.0 // indirect github.com/coreos/go-oidc/v3 v3.5.0 // indirect @@ -82,8 +90,8 @@ require ( github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect github.com/googleapis/gax-go/v2 v2.12.2 // indirect github.com/gsterjov/go-libsecret v0.0.0-20161001094733-a6f4afe4910c // indirect - github.com/hashicorp/go-cleanhttp v0.5.1 // indirect - github.com/hashicorp/go-retryablehttp v0.7.1 // indirect + github.com/hashicorp/go-cleanhttp v0.5.2 // indirect + github.com/hashicorp/go-retryablehttp v0.7.5 // indirect github.com/hashicorp/go-uuid v1.0.3 // indirect github.com/imdario/mergo v0.3.13 // indirect github.com/jcmturner/gofork v1.7.6 // indirect @@ -106,9 +114,8 @@ require ( github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/rs/zerolog v1.28.0 // indirect - github.com/tidwall/gjson v1.17.1 // indirect github.com/tidwall/match v1.1.1 // indirect - github.com/tidwall/pretty v1.2.0 // indirect + github.com/tidwall/pretty v1.2.1 // indirect github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb // indirect github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect github.com/xeipuuv/gojsonschema v1.2.0 // indirect diff --git a/go.sum b/go.sum index de38fec..c2bce46 100644 --- a/go.sum +++ b/go.sum @@ -20,18 +20,20 @@ github.com/99designs/go-keychain v0.0.0-20191008050251-8e49817e8af4 h1:/vQbFIOMb github.com/99designs/go-keychain v0.0.0-20191008050251-8e49817e8af4/go.mod h1:hN7oaIRCjzsZ2dE+yG5k+rsdt3qcwykqK6HVGcKwsw4= github.com/99designs/keyring v1.2.2 h1:pZd3neh/EmUzWONb35LxQfvuY7kiSXAq3HQd97+XBn0= github.com/99designs/keyring v1.2.2/go.mod h1:wes/FrByc8j7lFOAGLGSNEg8f/PaI3cgTBqhFkHUrPk= -github.com/Azure/azure-sdk-for-go/sdk/azcore v1.4.0 h1:rTnT/Jrcm+figWlYz4Ixzt0SJVR2cMC8lvZcimipiEY= -github.com/Azure/azure-sdk-for-go/sdk/azcore v1.4.0/go.mod h1:ON4tFdPTwRcgWEaVDrN3584Ef+b7GgSJaXxe5fW9t4M= -github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.1.0 h1:QkAcEIAKbNL4KoFr4SathZPhDhF4mVwpBMFlYjyAqy8= -github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.1.0/go.mod h1:bhXu1AjYL+wutSL/kpSq6s7733q2Rb0yuot9Zgfqa/0= -github.com/Azure/azure-sdk-for-go/sdk/internal v1.1.2 h1:+5VZ72z0Qan5Bog5C+ZkgSqUbeVUd9wgtHOrIKuc5b8= -github.com/Azure/azure-sdk-for-go/sdk/internal v1.1.2/go.mod h1:eWRD7oawr1Mu1sLCawqVc0CUiF43ia3qQMxLscsKQ9w= -github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.0.0 h1:u/LLAOFgsMv7HmNL4Qufg58y+qElGOt5qv0z1mURkRY= -github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.0.0/go.mod h1:2e8rMJtl2+2j+HXbTBwnyGpm5Nou7KhvSfxOq8JpTag= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.6.0 h1:8kDqDngH+DmVBiCtIjCFTGa7MBnsIOkF9IccInFEbjk= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.6.0/go.mod h1:bjGvMhVMb+EEm3VRNQawDMUyMMjo+S5ewNjflkep/0Q= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.3.0 h1:vcYCAze6p19qBW7MhZybIsqD8sMV8js0NyQM8JDnVtg= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.3.0/go.mod h1:OQeznEEkTZ9OrhHJoDD8ZDq51FHgXjqtP9z6bEwBq9U= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.3.0 h1:sXr+ck84g/ZlZUOZiNELInmMgOsuGwdjjVkEIde0OtY= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.3.0/go.mod h1:okt5dMMTOFjX/aovMlrjvvXoPMBVSPzk9185BT0+eZM= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/storage/armstorage v1.2.0 h1:Ma67P/GGprNwsslzEH6+Kb8nybI8jpDTm4Wmzu2ReK8= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/storage/armstorage v1.2.0/go.mod h1:c+Lifp3EDEamAkPVzMooRNOK6CZjNSdEnf1A7jsI9u4= +github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.1.0 h1:nVocQV40OQne5613EeLayJiRAJuKlBGy+m22qWG+WRg= +github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.1.0/go.mod h1:7QJP7dr2wznCMeqIrhMgWGf7XpAQnVrJqDm9nvV3Cu4= github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 h1:UQHMgLO+TxOElx5B5HZ4hJQsoJ/PvUvKRhJHDQXO8P8= github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E= -github.com/AzureAD/microsoft-authentication-library-for-go v0.5.1 h1:BWe8a+f/t+7KY7zH2mqygeUD0t8hNFXe08p1Pb3/jKE= -github.com/AzureAD/microsoft-authentication-library-for-go v0.5.1/go.mod h1:Vt9sXTKwMyGcOxSmLDMnGPgqsUg7m8pe215qMLrDXw4= +github.com/AzureAD/microsoft-authentication-library-for-go v1.0.0 h1:OBhqkivkhkMqLPymWEppkm7vgPQY2XsHoEkaMQ0AdZY= +github.com/AzureAD/microsoft-authentication-library-for-go v1.0.0/go.mod h1:kgDmCTgBzIEPFElEF+FK0SdjAor06dRq2Go927dnQ6o= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/JohnCGriffin/overflow v0.0.0-20211019200055-46fa312c352c h1:RGWPOewvKIROun94nF7v2cua9qP+thov/7M50KEoeSU= github.com/JohnCGriffin/overflow v0.0.0-20211019200055-46fa312c352c/go.mod h1:X0CRv0ky0k6m906ixxpzmDRLvX58TFUKS2eePweuyxk= @@ -39,52 +41,54 @@ github.com/Microsoft/go-winio v0.6.0 h1:slsWYD/zyx7lCXoZVlvQrj0hPTM1HI4+v1sIda2y github.com/Microsoft/go-winio v0.6.0/go.mod h1:cTAf44im0RAYeL23bpB+fzCyDH2MJiz2BO69KH/soAE= github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5 h1:TngWCqHvy9oXAN6lEVMRuU21PR1EtLVZJmdB18Gu3Rw= github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5/go.mod h1:lmUJ/7eu/Q8D7ML55dXQrVaamCz2vxCfdQBasLZfHKk= -github.com/andybalholm/brotli v1.0.5 h1:8uQZIdzKmjc/iuPu7O2ioW48L81FgatrcpfFmiq/cCs= -github.com/andybalholm/brotli v1.0.5/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= +github.com/andybalholm/brotli v1.0.6 h1:Yf9fFpf49Zrxb9NlQaluyE92/+X7UVHlhMNJN2sxfOI= +github.com/andybalholm/brotli v1.0.6/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= github.com/apache/arrow/go/v12 v12.0.1 h1:JsR2+hzYYjgSUkBSaahpqCetqZMr76djX80fF/DiJbg= github.com/apache/arrow/go/v12 v12.0.1/go.mod h1:weuTY7JvTG/HDPtMQxEUp7pU73vkLWMLpY67QwZ/WWw= github.com/apache/arrow/go/v14 v14.0.2 h1:N8OkaJEOfI3mEZt07BIkvo4sC6XDbL+48MBPWO5IONw= github.com/apache/arrow/go/v14 v14.0.2/go.mod h1:u3fgh3EdgN/YQ8cVQRguVW3R+seMybFg8QBQ5LU+eBY= github.com/apache/thrift v0.17.0 h1:cMd2aj52n+8VoAtvSvLn4kDC3aZ6IAkBuqWQ2IDu7wo= github.com/apache/thrift v0.17.0/go.mod h1:OLxhMRJxomX+1I/KUw03qoV3mMz16BwaKI+d4fPBx7Q= -github.com/aws/aws-sdk-go-v2 v1.17.7 h1:CLSjnhJSTSogvqUGhIC6LqFKATMRexcxLZ0i/Nzk9Eg= -github.com/aws/aws-sdk-go-v2 v1.17.7/go.mod h1:uzbQtefpm44goOPmdKyAlXSNcwlRgF3ePWVW6EtJvvw= -github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.4.10 h1:dK82zF6kkPeCo8J1e+tGx4JdvDIQzj7ygIoLg8WMuGs= -github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.4.10/go.mod h1:VeTZetY5KRJLuD/7fkQXMU6Mw7H5m/KP2J5Iy9osMno= -github.com/aws/aws-sdk-go-v2/config v1.18.19 h1:AqFK6zFNtq4i1EYu+eC7lcKHYnZagMn6SW171la0bGw= -github.com/aws/aws-sdk-go-v2/config v1.18.19/go.mod h1:XvTmGMY8d52ougvakOv1RpiTLPz9dlG/OQHsKU/cMmY= -github.com/aws/aws-sdk-go-v2/credentials v1.13.18 h1:EQMdtHwz0ILTW1hoP+EwuWhwCG1hD6l3+RWFQABET4c= -github.com/aws/aws-sdk-go-v2/credentials v1.13.18/go.mod h1:vnwlwjIe+3XJPBYKu1et30ZPABG3VaXJYr8ryohpIyM= -github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.13.1 h1:gt57MN3liKiyGopcqgNzJb2+d9MJaKT/q1OksHNXVE4= -github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.13.1/go.mod h1:lfUx8puBRdM5lVVMQlwt2v+ofiG/X6Ms+dy0UkG/kXw= -github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.59 h1:E3Y+OfzOK1+rmRo/K2G0ml8Vs+Xqk0kOnf4nS0kUtBc= -github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.59/go.mod h1:1M4PLSBUVfBI0aP+C9XI7SM6kZPCGYyI6izWz0TGprE= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.31 h1:sJLYcS+eZn5EeNINGHSCRAwUJMFVqklwkH36Vbyai7M= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.31/go.mod h1:QT0BqUvX1Bh2ABdTGnjqEjvjzrCfIniM9Sc8zn9Yndo= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.25 h1:1mnRASEKnkqsntcxHaysxwgVoUUp5dkiB+l3llKnqyg= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.25/go.mod h1:zBHOPwhBc3FlQjQJE/D3IfPWiWaQmT06Vq9aNukDo0k= -github.com/aws/aws-sdk-go-v2/internal/ini v1.3.32 h1:p5luUImdIqywn6JpQsW3tq5GNOxKmOnEpybzPx+d1lk= -github.com/aws/aws-sdk-go-v2/internal/ini v1.3.32/go.mod h1:XGhIBZDEgfqmFIugclZ6FU7v75nHhBDtzuB4xB/tEi4= -github.com/aws/aws-sdk-go-v2/internal/v4a v1.0.23 h1:DWYZIsyqagnWL00f8M/SOr9fN063OEQWn9LLTbdYXsk= -github.com/aws/aws-sdk-go-v2/internal/v4a v1.0.23/go.mod h1:uIiFgURZbACBEQJfqTZPb/jxO7R+9LeoHUFudtIdeQI= -github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.9.11 h1:y2+VQzC6Zh2ojtV2LoC0MNwHWc6qXv/j2vrQtlftkdA= -github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.9.11/go.mod h1:iV4q2hsqtNECrfmlXyord9u4zyuFEJX9eLgLpSPzWA8= -github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.1.26 h1:CeuSeq/8FnYpPtnuIeLQEEvDv9zUjneuYi8EghMBdwQ= -github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.1.26/go.mod h1:2UqAAwMUXKeRkAHIlDJqvMVgOWkUi/AUXPk/YIe+Dg4= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.25 h1:5LHn8JQ0qvjD9L9JhMtylnkcw7j05GDZqM9Oin6hpr0= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.25/go.mod h1:/95IA+0lMnzW6XzqYJRpjjsAbKEORVeO0anQqjd2CNU= -github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.14.0 h1:e2ooMhpYGhDnBfSvIyusvAwX7KexuZaHbQY2Dyei7VU= -github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.14.0/go.mod h1:bh2E0CXKZsQN+faiKVqC40vfNMAWheoULBCnEgO9K+8= -github.com/aws/aws-sdk-go-v2/service/s3 v1.31.0 h1:B1G2pSPvbAtQjilPq+Y7jLIzCOwKzuVEl+aBBaNG0AQ= -github.com/aws/aws-sdk-go-v2/service/s3 v1.31.0/go.mod h1:ncltU6n4Nof5uJttDtcNQ537uNuwYqsZZQcpkd2/GUQ= -github.com/aws/aws-sdk-go-v2/service/sso v1.12.6 h1:5V7DWLBd7wTELVz5bPpwzYy/sikk0gsgZfj40X+l5OI= -github.com/aws/aws-sdk-go-v2/service/sso v1.12.6/go.mod h1:Y1VOmit/Fn6Tz1uFAeCO6Q7M2fmfXSCLeL5INVYsLuY= -github.com/aws/aws-sdk-go-v2/service/ssooidc v1.14.6 h1:B8cauxOH1W1v7rd8RdI/MWnoR4Ze0wIHWrb90qczxj4= -github.com/aws/aws-sdk-go-v2/service/ssooidc v1.14.6/go.mod h1:Lh/bc9XUf8CfOY6Jp5aIkQtN+j1mc+nExc+KXj9jx2s= -github.com/aws/aws-sdk-go-v2/service/sts v1.18.7 h1:bWNgNdRko2x6gqa0blfATqAZKZokPIeM1vfmQt2pnvM= -github.com/aws/aws-sdk-go-v2/service/sts v1.18.7/go.mod h1:JuTnSoeePXmMVe9G8NcjjwgOKEfZ4cOjMuT2IBT/2eI= -github.com/aws/smithy-go v1.13.5 h1:hgz0X/DX0dGqTYpGALqXJoRKRj5oQ7150i5FdTePzO8= -github.com/aws/smithy-go v1.13.5/go.mod h1:Tg+OJXh4MB2R/uN61Ko2f6hTZwB/ZYGOtib8J3gBHzA= +github.com/aws/aws-sdk-go-v2 v1.25.3 h1:xYiLpZTQs1mzvz5PaI6uR0Wh57ippuEthxS4iK5v0n0= +github.com/aws/aws-sdk-go-v2 v1.25.3/go.mod h1:35hUlJVYd+M++iLI3ALmVwMOyRYMmRqUXpTtRGW+K9I= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.5.1 h1:ZY3108YtBNq96jNZTICHxN1gSBSbnvIdYwwqnvCV4Mc= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.5.1/go.mod h1:t8PYl/6LzdAqsU4/9tz28V/kU+asFePvpOMkdul0gEQ= +github.com/aws/aws-sdk-go-v2/config v1.25.3 h1:E4m9LbwJOoncDNt3e9MPLbz/saxWcGUlZVBydydD6+8= +github.com/aws/aws-sdk-go-v2/config v1.25.3/go.mod h1:tAByZy03nH5jcq0vZmkcVoo6tRzRHEwSFx3QW4NmDw8= +github.com/aws/aws-sdk-go-v2/credentials v1.16.2 h1:0sdZ5cwfOAipTzZ7eOL0gw4LAhk/RZnTa16cDqIt8tg= +github.com/aws/aws-sdk-go-v2/credentials v1.16.2/go.mod h1:sDdvGhXrSVT5yzBDR7qXz+rhbpiMpUYfF3vJ01QSdrc= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.14.4 h1:9wKDWEjwSnXZre0/O3+ZwbBl1SmlgWYBbrTV10X/H1s= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.14.4/go.mod h1:t4i+yGHMCcUNIX1x7YVYa6bH/Do7civ5I6cG/6PMfyA= +github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.14.0 h1:1KdubQbnw76M0Sr8480q6OXBlymBVqpkK+RuCqJz+nQ= +github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.14.0/go.mod h1:UcgIwJ9KHquYxs6Q5skC9qXjhYMK+JASDYcXQ4X7JZE= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.3 h1:ifbIbHZyGl1alsAhPIYsHOg5MuApgqOvVeI8wIugXfs= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.3/go.mod h1:oQZXg3c6SNeY6OZrDY+xHcF4VGIEoNotX2B4PrDeoJI= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.3 h1:Qvodo9gHG9F3E8SfYOspPeBt0bjSbsevK8WhRAUHcoY= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.3/go.mod h1:vCKrdLXtybdf/uQd/YfVR2r5pcbNuEYKzMQpcxmeSJw= +github.com/aws/aws-sdk-go-v2/internal/ini v1.7.1 h1:uR9lXYjdPX0xY+NhvaJ4dD8rpSRz5VY81ccIIoNG+lw= +github.com/aws/aws-sdk-go-v2/internal/ini v1.7.1/go.mod h1:6fQQgfuGmw8Al/3M2IgIllycxV7ZW7WCdVSqfBeUiCY= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.2.3 h1:lMwCXiWJlrtZot0NJTjbC8G9zl+V3i68gBTBBvDeEXA= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.2.3/go.mod h1:5yzAuE9i2RkVAttBl8yxZgQr5OCq4D5yDnG7j9x2L0U= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.10.1 h1:rpkF4n0CyFcrJUG/rNNohoTmhtWlFTRI4BsZOh9PvLs= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.10.1/go.mod h1:l9ymW25HOqymeU2m1gbUQ3rUIsTwKs8gYHXkqDQUhiI= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.2.3 h1:xbwRyCy7kXrOj89iIKLB6NfE2WCpP9HoKyk8dMDvnIQ= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.2.3/go.mod h1:R+/S1O4TYpcktbVwddeOYg+uwUfLhADP2S/x4QwsCTM= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.10.3 h1:kJOolE8xBAD13xTCgOakByZkyP4D/owNmvEiioeUNAg= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.10.3/go.mod h1:Owv1I59vaghv1Ax8zz8ELY8DN7/Y0rGS+WWAmjgi950= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.16.3 h1:KV0z2RDc7euMtg8aUT1czv5p29zcLlXALNFsd3jkkEc= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.16.3/go.mod h1:KZgs2ny8HsxRIRbDwgvJcHHBZPOzQr/+NtGwnP+w2ec= +github.com/aws/aws-sdk-go-v2/service/redshiftdata v1.25.2 h1:Oxycfwi0LeSLnyjnMXwSTWQ0Oo5uz3bfBZfZofK4L+A= +github.com/aws/aws-sdk-go-v2/service/redshiftdata v1.25.2/go.mod h1:ER2zJ2zmzI19TDja171MU9mr+x3X/W8bFGSs6eCKSpY= +github.com/aws/aws-sdk-go-v2/service/s3 v1.43.0 h1:cwTuq73Tv6jtNJIMgTDKsih5O2YsVrKGpg20H98tbmo= +github.com/aws/aws-sdk-go-v2/service/s3 v1.43.0/go.mod h1:NXRKkiRF+erX2hnybnVU660cYT5/KChRD4iUgJ97cI8= +github.com/aws/aws-sdk-go-v2/service/sso v1.17.2 h1:V47N5eKgVZoRSvx2+RQ0EpAEit/pqOhqeSQFiS4OFEQ= +github.com/aws/aws-sdk-go-v2/service/sso v1.17.2/go.mod h1:/pE21vno3q1h4bbhUOEi+6Zu/aT26UK2WKkDXd+TssQ= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.20.0 h1:/XiEU7VIFcVWRDQLabyrSjBoKIm8UkYgsvWDuFW8Img= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.20.0/go.mod h1:dWqm5G767qwKPuayKfzm4rjzFmVjiBFbOJrpSPnAMDs= +github.com/aws/aws-sdk-go-v2/service/sts v1.25.3 h1:M2w4kiMGJCCM6Ljmmx/l6mmpfa3gPJVpBencfnsgvqs= +github.com/aws/aws-sdk-go-v2/service/sts v1.25.3/go.mod h1:4EqRHDCKP78hq3zOnmFXu5k0j4bXbRFfCh/zQ6KnEfQ= +github.com/aws/smithy-go v1.20.1 h1:4SZlSlMr36UEqC7XOyRVb27XMeZubNcBNN+9IgEPIQw= +github.com/aws/smithy-go v1.20.1/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E= github.com/cenkalti/backoff/v4 v4.2.1 h1:y4OZtCnogmCPw98Zjyt5a6+QwPLGkiQsYW5oUqylYbM= github.com/cenkalti/backoff/v4 v4.2.1/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= @@ -107,8 +111,8 @@ github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1 github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dlclark/regexp2 v1.11.0 h1:G/nrcoOa7ZXlpoa/91N3X7mM3r8eIlMBBJZvsz/mxKI= github.com/dlclark/regexp2 v1.11.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= -github.com/dnaeon/go-vcr v1.1.0 h1:ReYa/UBrRyQdant9B4fNHGoCNKw6qh6P0fsdGmZpR7c= -github.com/dnaeon/go-vcr v1.1.0/go.mod h1:M7tiix8f0r6mKKJ3Yq/kqU1OYf3MnfmBWVbPx/yU9ko= +github.com/dnaeon/go-vcr v1.2.0 h1:zHCHvJYTMh1N7xnV7zf1m1GPBF9Ad0Jk/whtQ1663qI= +github.com/dnaeon/go-vcr v1.2.0/go.mod h1:R4UdLID7HZT3taECzJs4YgbbH6PIGXB6W/sc5OLb6RQ= github.com/dnephin/pflag v1.0.7 h1:oxONGlWxhmUct0YzKTgrpQv9AUA1wtPBn7zuSjJqptk= github.com/dnephin/pflag v1.0.7/go.mod h1:uxE91IoWURlOiTUIA8Mq5ZZkAv3dPUfZNaT80Zm7OQE= github.com/docker/cli v20.10.17+incompatible h1:eO2KS7ZFeov5UJeaDmIs1NFEDRf32PaqRpvoEkKBy5M= @@ -154,8 +158,8 @@ github.com/godbus/dbus v0.0.0-20190726142602-4481cbc300e2/go.mod h1:bBOAhwG1umN6 github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= -github.com/golang-jwt/jwt v3.2.1+incompatible h1:73Z+4BJcrTC+KczS6WvTPvRGOp1WmfEP4Q1lOd9Z/+c= -github.com/golang-jwt/jwt v3.2.1+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= +github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOWzg= +github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE= @@ -205,12 +209,12 @@ github.com/googleapis/gax-go/v2 v2.12.2 h1:mhN09QQW1jEWeMF74zGR81R30z4VJzjZsfkUh github.com/googleapis/gax-go/v2 v2.12.2/go.mod h1:61M8vcyyXR2kqKFxKrfA22jaA8JGF7Dc8App1U3H6jc= github.com/gsterjov/go-libsecret v0.0.0-20161001094733-a6f4afe4910c h1:6rhixN/i8ZofjG1Y75iExal34USq5p+wiN1tpie8IrU= github.com/gsterjov/go-libsecret v0.0.0-20161001094733-a6f4afe4910c/go.mod h1:NMPJylDgVpX0MLRlPy15sqSwOFv/U1GZ2m21JhFfek0= -github.com/hashicorp/go-cleanhttp v0.5.1 h1:dH3aiDG9Jvb5r5+bYHsikaOUIpcM0xvgMXVoDkXMzJM= -github.com/hashicorp/go-cleanhttp v0.5.1/go.mod h1:JpRdi6/HCYpAwUzNwuwqhbovhLtngrth3wmdIIUrZ80= +github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ= +github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48= github.com/hashicorp/go-hclog v0.9.2 h1:CG6TE5H9/JXsFWJCfoIVpKFIkFe6ysEuHirp4DxCsHI= github.com/hashicorp/go-hclog v0.9.2/go.mod h1:5CU+agLiy3J7N7QjHK5d05KxGsuXiQLrjA0H7acj2lQ= -github.com/hashicorp/go-retryablehttp v0.7.1 h1:sUiuQAnLlbvmExtFQs72iFW/HXeUn8Z1aJLQ4LJJbTQ= -github.com/hashicorp/go-retryablehttp v0.7.1/go.mod h1:vAew36LZh98gCBJNLH42IQ1ER/9wtLZZ8meHqQvEYWY= +github.com/hashicorp/go-retryablehttp v0.7.5 h1:bJj+Pj19UZMIweq/iie+1u5YCdGrnxCT9yvm0e+Nd5M= +github.com/hashicorp/go-retryablehttp v0.7.5/go.mod h1:Jy/gPYAdjqffZ/yFGCFV2doI5wjtH1ewM9u8iYVjtX8= github.com/hashicorp/go-uuid v1.0.3 h1:2gKiV6YVmrJ1i2CKKa9obLvRieoRGviZFL26PcT/Co8= github.com/hashicorp/go-uuid v1.0.3/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= github.com/imdario/mergo v0.3.13 h1:lFzP57bqS/wsqKssCGmtLAb8A0wKjLGrve2q3PPVcBk= @@ -312,8 +316,9 @@ github.com/tidwall/gjson v1.17.1 h1:wlYEnwqAHgzmhNUFfw7Xalt2JzQvsMx2Se4PcoFCT/U= github.com/tidwall/gjson v1.17.1/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= -github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= +github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/trinodb/trino-go-client v0.313.0 h1:lp8N9JKTqMuZ9LlAwLjgUtkwDnJc8fjpJmunpZ3afjk= diff --git a/sqlconnect/internal/bigquery/driver/driver_test.go b/sqlconnect/internal/bigquery/driver/driver_test.go index 66b427b..7bd8d9e 100644 --- a/sqlconnect/internal/bigquery/driver/driver_test.go +++ b/sqlconnect/internal/bigquery/driver/driver_test.go @@ -45,6 +45,10 @@ func TestBigqueryDriver(t *testing.T) { }) schema := GenerateTestSchema() + t.Cleanup(func() { + _, err := db.Exec(fmt.Sprintf("DROP SCHEMA IF EXISTS `%s` CASCADE", schema)) + require.NoError(t, err, "it should be able to drop the schema") + }) t.Run("Ping", func(t *testing.T) { require.NoError(t, db.Ping(), "it should be able to ping the database") @@ -224,5 +228,5 @@ type config struct { } func GenerateTestSchema() string { - return strings.ToLower(fmt.Sprintf("tbqdrv_%s_%d", rand.String(12), time.Now().Unix())) + return strings.ToLower(fmt.Sprintf("tsqlcon_%s_%d", rand.String(12), time.Now().Unix())) } diff --git a/sqlconnect/internal/redshift/config.go b/sqlconnect/internal/redshift/config.go index 118660c..78c8ff4 100644 --- a/sqlconnect/internal/redshift/config.go +++ b/sqlconnect/internal/redshift/config.go @@ -1,5 +1,53 @@ package redshift -import "github.com/rudderlabs/sqlconnect-go/sqlconnect/internal/postgres" +import ( + "encoding/json" + "time" + "github.com/tidwall/sjson" + + "github.com/rudderlabs/sqlconnect-go/sqlconnect/internal/postgres" +) + +const SDKConfigType = "sdk" + +// Config is the configuration for a redshift database when using postgres driver type Config = postgres.Config + +// SDKConfig is the configuration for a redshift database when using the AWS SDK +type SDKConfig struct { + ClusterIdentifier string `json:"clusterIdentifier"` + Database string `json:"database"` + User string `json:"user"` + Region string `json:"region"` + WorkgroupName string `json:"workgroupName"` + + SecretsARN string `json:"secretsARN"` + + SharedConfigProfile string `json:"sharedConfigProfile"` + + AccessKeyID string `json:"accessKeyId"` + SecretAccessKey string `json:"secretAccessKey"` + SessionToken string `json:"sessionToken"` + + Timeout time.Duration `json:"timeout"` // default 15m + Polling time.Duration `json:"polling"` // default 10ms + + UseLegacyMappings bool `json:"useLegacyMappings"` +} + +func (c *SDKConfig) MarshalJSON() ([]byte, error) { + bytes, err := json.Marshal(*c) + if err != nil { + return nil, err + } + return sjson.SetBytes(bytes, "type", SDKConfigType) +} + +func (c *SDKConfig) Parse(input json.RawMessage) error { + err := json.Unmarshal(input, c) + if err != nil { + return err + } + return nil +} diff --git a/sqlconnect/internal/redshift/db.go b/sqlconnect/internal/redshift/db.go index 76cc833..f6e3b1d 100644 --- a/sqlconnect/internal/redshift/db.go +++ b/sqlconnect/internal/redshift/db.go @@ -6,10 +6,12 @@ import ( "fmt" _ "github.com/lib/pq" // postgres driver + "github.com/tidwall/gjson" "github.com/rudderlabs/sqlconnect-go/sqlconnect" "github.com/rudderlabs/sqlconnect-go/sqlconnect/internal/base" "github.com/rudderlabs/sqlconnect-go/sqlconnect/internal/postgres" + "github.com/rudderlabs/sqlconnect-go/sqlconnect/internal/redshift/driver" ) const ( @@ -18,13 +20,17 @@ const ( // NewDB creates a new redshift db client func NewDB(credentialsJSON json.RawMessage) (*DB, error) { - var config Config - err := config.Parse(credentialsJSON) - if err != nil { - return nil, err + var ( + db *sql.DB + err error + ) + useLegacyMappings := gjson.GetBytes(credentialsJSON, "useLegacyMappings").Bool() + // Use the SDK if the credentials are for the SDK + if configType := gjson.GetBytes(credentialsJSON, "type").Str; configType == SDKConfigType { + db, err = newSdkDB(credentialsJSON) + } else { + db, err = newPgDB(credentialsJSON) } - - db, err := sql.Open(postgres.DatabaseType, config.ConnectionString()) if err != nil { return nil, err } @@ -32,8 +38,8 @@ func NewDB(credentialsJSON json.RawMessage) (*DB, error) { return &DB{ DB: base.NewDB( db, - base.WithColumnTypeMappings(getColumnTypeMappings(config)), - base.WithJsonRowMapper(getJonRowMapper(config)), + base.WithColumnTypeMappings(getColumnTypeMappings(useLegacyMappings)), + base.WithJsonRowMapper(getJonRowMapper(useLegacyMappings)), base.WithSQLCommandsOverride(func(cmds base.SQLCommands) base.SQLCommands { cmds.ListSchemas = func() (string, string) { return "SELECT schema_name FROM svv_redshift_schemas", "schema_name" @@ -47,6 +53,40 @@ func NewDB(credentialsJSON json.RawMessage) (*DB, error) { }, nil } +func newPgDB(credentialsJSON json.RawMessage) (*sql.DB, error) { + var config Config + err := config.Parse(credentialsJSON) + if err != nil { + return nil, err + } + + return sql.Open(postgres.DatabaseType, config.ConnectionString()) +} + +func newSdkDB(credentialsJSON json.RawMessage) (*sql.DB, error) { + var config SDKConfig + err := config.Parse(credentialsJSON) + if err != nil { + return nil, err + } + cfg := driver.RedshiftConfig{ + ClusterIdentifier: config.ClusterIdentifier, + Database: config.Database, + DbUser: config.User, + WorkgroupName: config.WorkgroupName, + SecretsARN: config.SecretsARN, + Region: config.Region, + AccessKeyID: config.AccessKeyID, + SecretAccessKey: config.SecretAccessKey, + SessionToken: config.SessionToken, + Timeout: config.Timeout, + Polling: config.Polling, + } + connector := driver.NewRedshiftConnector(cfg) + + return sql.OpenDB(connector), nil +} + func init() { sqlconnect.RegisterDBFactory(DatabaseType, func(credentialsJSON json.RawMessage) (sqlconnect.DB, error) { return NewDB(credentialsJSON) @@ -57,15 +97,15 @@ type DB struct { *base.DB } -func getColumnTypeMappings(config postgres.Config) map[string]string { - if config.UseLegacyMappings { +func getColumnTypeMappings(useLegacyMappings bool) map[string]string { + if useLegacyMappings { return legacyColumnTypeMappings } return columnTypeMappings } -func getJonRowMapper(config postgres.Config) func(databaseTypeName string, value any) any { - if config.UseLegacyMappings { +func getJonRowMapper(useLegacyMappings bool) func(databaseTypeName string, value any) any { + if useLegacyMappings { return legacyJsonRowMapper } return jsonRowMapper diff --git a/sqlconnect/internal/redshift/driver/client.go b/sqlconnect/internal/redshift/driver/client.go new file mode 100644 index 0000000..d7d0b1c --- /dev/null +++ b/sqlconnect/internal/redshift/driver/client.go @@ -0,0 +1,25 @@ +package driver + +import ( + "context" + + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/redshiftdata" +) + +type RedshiftClient interface { + ExecuteStatement(ctx context.Context, params *redshiftdata.ExecuteStatementInput, optFns ...func(*redshiftdata.Options)) (*redshiftdata.ExecuteStatementOutput, error) + DescribeStatement(ctx context.Context, params *redshiftdata.DescribeStatementInput, optFns ...func(*redshiftdata.Options)) (*redshiftdata.DescribeStatementOutput, error) + CancelStatement(ctx context.Context, params *redshiftdata.CancelStatementInput, optFns ...func(*redshiftdata.Options)) (*redshiftdata.CancelStatementOutput, error) + BatchExecuteStatement(ctx context.Context, params *redshiftdata.BatchExecuteStatementInput, optFns ...func(*redshiftdata.Options)) (*redshiftdata.BatchExecuteStatementOutput, error) + redshiftdata.GetStatementResultAPIClient +} + +func newRedshiftDataClient(ctx context.Context, cfg *RedshiftConfig, opts ...func(*config.LoadOptions) error) (RedshiftClient, error) { + awsCfg, err := config.LoadDefaultConfig(ctx, opts...) + if err != nil { + return nil, err + } + client := redshiftdata.NewFromConfig(awsCfg, cfg.Opts()...) + return client, nil +} diff --git a/sqlconnect/internal/redshift/driver/connection.go b/sqlconnect/internal/redshift/driver/connection.go new file mode 100644 index 0000000..5ab360a --- /dev/null +++ b/sqlconnect/internal/redshift/driver/connection.go @@ -0,0 +1,378 @@ +package driver + +import ( + "context" + "database/sql" + "database/sql/driver" + "fmt" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/redshiftdata" + "github.com/aws/aws-sdk-go-v2/service/redshiftdata/types" +) + +type redshiftConnection struct { + client RedshiftClient + cfg *RedshiftConfig + aliveCh chan struct{} + isClosed bool + + inTx bool + txOpts driver.TxOptions + sqls []string + delayedResult []*redshiftDelayedResult +} + +func newConnection(client RedshiftClient, cfg *RedshiftConfig) *redshiftConnection { + return &redshiftConnection{ + client: client, + cfg: cfg, + aliveCh: make(chan struct{}), + } +} + +func (c *redshiftConnection) Ping(ctx context.Context) error { + _, err := c.ExecContext(ctx, "select 1", nil) + return err +} + +func (c *redshiftConnection) PrepareContext(_ context.Context, query string) (driver.Stmt, error) { + return &redshiftStatement{ + connection: c, + query: query, + }, nil +} + +func (c *redshiftConnection) Prepare(query string) (driver.Stmt, error) { + return c.PrepareContext(context.Background(), query) +} + +func (c *redshiftConnection) Close() error { + if c.isClosed { + return nil + } + c.isClosed = true + close(c.aliveCh) + return nil +} + +func (c *redshiftConnection) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { + if c.inTx { + return nil, ErrInTx + } + if opts.Isolation != driver.IsolationLevel(sql.LevelDefault) { + return nil, fmt.Errorf("transaction isolation level change: %w", ErrNotSupported) + } + c.inTx = true + c.txOpts = opts + cleanup := func() error { // nolint: unparam + c.inTx = false + c.sqls = nil + c.delayedResult = nil + return nil + } + tx := &redshiftTx{ + onRollback: func() error { + if !c.inTx { + return ErrNotInTx + } + return cleanup() + }, + onCommit: func() error { + if !c.inTx { + return ErrNotInTx + } + if len(c.sqls) == 0 { + return cleanup() + } + if len(c.sqls) != len(c.delayedResult) { + panic(fmt.Sprintf("sqls and delayedResult length is not match: sqls=%d delayedResult=%d", len(c.sqls), len(c.delayedResult))) + } + if len(c.sqls) == 1 { + result, err := c.ExecContext(ctx, c.sqls[0], []driver.NamedValue{}) + if err != nil { + return err + } + if c.delayedResult[0] != nil { + c.delayedResult[0].Result = result + } + return nil + } + input := &redshiftdata.BatchExecuteStatementInput{ + Sqls: append(make([]string, 0, len(c.sqls)), c.sqls...), + } + _, desc, err := c.batchExecuteStatement(ctx, input) + if err != nil { + return err + } + for i := range input.Sqls { + if i >= len(desc.SubStatements) { + return fmt.Errorf("sub statement not found: %d", i) + } + if c.delayedResult[i] != nil { + c.delayedResult[i].Result = newResultWithSubStatementData(desc.SubStatements[i]) + } + } + return cleanup() + }, + } + + return tx, nil +} + +func (c *redshiftConnection) Begin() (driver.Tx, error) { + return c.BeginTx(context.Background(), driver.TxOptions{}) +} + +func (c *redshiftConnection) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + if c.inTx { + return nil, fmt.Errorf("query in transaction: %w", ErrNotSupported) + } + + params := &redshiftdata.ExecuteStatementInput{ + Sql: nullStringIfEmpty(rewriteQuery(query, len(args) > 0)), + Parameters: convertArgsToParameters(args), + } + p, output, err := c.executeStatement(ctx, params) + if err != nil { + return nil, err + } + return newRows(ctx, coalesce(output.Id), p) +} + +func (c *redshiftConnection) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + if c.inTx { + if len(args) > 0 { + return nil, fmt.Errorf("exec with args in transaction: %w", ErrNotSupported) + } + if c.txOpts.ReadOnly { + return nil, fmt.Errorf("exec in read only transaction: %w", ErrNotSupported) + } + c.sqls = append(c.sqls, query) + result := &redshiftDelayedResult{} + c.delayedResult = append(c.delayedResult, result) + debugLogger.Printf("delayedResult[%d] creaed for %q", len(c.delayedResult)-1, query) + return result, nil + } + + params := &redshiftdata.ExecuteStatementInput{ + Sql: nullStringIfEmpty(rewriteQuery(query, len(args) > 0)), + Parameters: convertArgsToParameters(args), + } + _, output, err := c.executeStatement(ctx, params) + if err != nil { + return nil, err + } + return newResult(output), nil +} + +func rewriteQuery(query string, scanParams bool) string { + if !scanParams { + return query + } + runes := make([]rune, 0, len(query)) + stack := make([]rune, 0) + var exclamationCount int + for _, r := range query { + if len(stack) > 0 { + if r == stack[len(stack)-1] { + stack = stack[:len(stack)-1] + runes = append(runes, r) + continue + } + } else { + switch r { + case '?': + exclamationCount++ + runes = append(runes, []rune(fmt.Sprintf(":%d", exclamationCount))...) + continue + case '$': + runes = append(runes, ':') + continue + } + } + switch r { + case '"', '\'': + stack = append(stack, r) + } + runes = append(runes, r) + } + return string(runes) +} + +func convertArgsToParameters(args []driver.NamedValue) []types.SqlParameter { + if len(args) == 0 { + return nil + } + scanParams := make([]types.SqlParameter, 0, len(args)) + for _, arg := range args { + scanParams = append(scanParams, types.SqlParameter{ + Name: aws.String(coalesce(nullStringIfEmpty(arg.Name), aws.String(fmt.Sprintf("%d", arg.Ordinal)))), + Value: aws.String(fmt.Sprintf("%v", arg.Value)), + }) + } + return scanParams +} + +func (c *redshiftConnection) executeStatement(ctx context.Context, params *redshiftdata.ExecuteStatementInput) (*redshiftdata.GetStatementResultPaginator, *redshiftdata.DescribeStatementOutput, error) { + debugLogger.Printf("query: %s", coalesce(params.Sql)) + params.ClusterIdentifier = nullStringIfEmpty(c.cfg.ClusterIdentifier) + params.Database = nullStringIfEmpty(c.cfg.Database) + params.DbUser = nullStringIfEmpty(c.cfg.DbUser) + params.WorkgroupName = nullStringIfEmpty(c.cfg.WorkgroupName) + params.SecretArn = nullStringIfEmpty(c.cfg.SecretsARN) + + executeOutput, err := c.client.ExecuteStatement(ctx, params) + if err != nil { + return nil, nil, fmt.Errorf("execute statement:%w", err) + } + queryStart := time.Now() + debugLogger.Printf("[%s] success execute statement: %s", *executeOutput.Id, coalesce(params.Sql)) + describeOutput, err := c.waitWithCancel(ctx, executeOutput.Id, queryStart) + if err != nil { + return nil, nil, err + } + if describeOutput.Status == types.StatusStringAborted { + return nil, nil, fmt.Errorf("query aborted: %s", *describeOutput.Error) + } + if describeOutput.Status == types.StatusStringFailed { + return nil, nil, fmt.Errorf("query failed: %s", *describeOutput.Error) + } + if describeOutput.Status != types.StatusStringFinished { + return nil, nil, fmt.Errorf("query status is not finished: %s", describeOutput.Status) + } + debugLogger.Printf("[%s] success query: elapsed_time=%s", *executeOutput.Id, time.Since(queryStart)) + if !*describeOutput.HasResultSet { + return nil, describeOutput, nil + } + debugLogger.Printf("[%s] query has result set: result_rows=%d", *executeOutput.Id, describeOutput.ResultRows) + p := redshiftdata.NewGetStatementResultPaginator(c.client, &redshiftdata.GetStatementResultInput{ + Id: executeOutput.Id, + }) + return p, describeOutput, nil +} + +func (c *redshiftConnection) batchExecuteStatement(ctx context.Context, params *redshiftdata.BatchExecuteStatementInput) ([]*redshiftdata.GetStatementResultPaginator, *redshiftdata.DescribeStatementOutput, error) { + params.ClusterIdentifier = nullStringIfEmpty(c.cfg.ClusterIdentifier) + params.Database = nullStringIfEmpty(c.cfg.Database) + params.DbUser = nullStringIfEmpty(c.cfg.DbUser) + params.WorkgroupName = nullStringIfEmpty(c.cfg.WorkgroupName) + params.SecretArn = nullStringIfEmpty(c.cfg.SecretsARN) + + batchExecuteOutput, err := c.client.BatchExecuteStatement(ctx, params) + if err != nil { + return nil, nil, fmt.Errorf("execute statement:%w", err) + } + queryStart := time.Now() + debugLogger.Printf("[%s] success execute statement: %d sqls", *batchExecuteOutput.Id, len(params.Sqls)) + describeOutput, err := c.waitWithCancel(ctx, batchExecuteOutput.Id, queryStart) + if err != nil { + return nil, nil, err + } + if describeOutput.Status == types.StatusStringAborted { + return nil, nil, fmt.Errorf("query aborted: %s", *describeOutput.Error) + } + if describeOutput.Status == types.StatusStringFailed { + return nil, nil, fmt.Errorf("query failed: %s", *describeOutput.Error) + } + if describeOutput.Status != types.StatusStringFinished { + return nil, nil, fmt.Errorf("query status is not finished: %s", describeOutput.Status) + } + debugLogger.Printf("[%s] success query: elapsed_time=%s", *batchExecuteOutput.Id, time.Since(queryStart)) + ps := make([]*redshiftdata.GetStatementResultPaginator, len(params.Sqls)) + for i, st := range describeOutput.SubStatements { + if *st.HasResultSet { + continue + } + ps[i] = redshiftdata.NewGetStatementResultPaginator(c.client, &redshiftdata.GetStatementResultInput{ + Id: st.Id, + }) + } + return ps, describeOutput, nil +} + +func isFinishedStatus(status types.StatusString) bool { + return status == types.StatusStringFinished || status == types.StatusStringFailed || status == types.StatusStringAborted +} + +func (c *redshiftConnection) wait(ctx context.Context, id *string, queryStart time.Time) (*redshiftdata.DescribeStatementOutput, error) { + timeout := c.cfg.Timeout + if timeout == 0 { + timeout = 15 * time.Minute + } + polling := c.cfg.Polling + if polling == 0 { + polling = 10 * time.Millisecond + } + ectx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + debugLogger.Printf("[%s] wating finish query: elapsed_time=%s", *id, time.Since(queryStart)) + describeOutput, err := c.client.DescribeStatement(ctx, &redshiftdata.DescribeStatementInput{ + Id: id, + }) + if err != nil { + return nil, fmt.Errorf("describe statement: %w", err) + } + debugLogger.Printf("[%s] describe statement: status=%s pid=%d query_id=%d", *id, describeOutput.Status, describeOutput.RedshiftPid, describeOutput.RedshiftQueryId) + if isFinishedStatus(describeOutput.Status) { + return describeOutput, nil + } + delay := time.NewTimer(polling) + for { + select { + case <-ectx.Done(): + if !delay.Stop() { + <-delay.C + } + return nil, ectx.Err() + case <-delay.C: + case <-c.aliveCh: + if !delay.Stop() { + <-delay.C + } + return nil, ErrConnClosed + } + debugLogger.Printf("[%s] wating finsih query: elapsed_time=%s", *id, time.Since(queryStart)) + describeOutput, err = c.client.DescribeStatement(ctx, &redshiftdata.DescribeStatementInput{ + Id: id, + }) + if err != nil { + return nil, fmt.Errorf("describe statement:%w", err) + } + if isFinishedStatus(describeOutput.Status) { + return describeOutput, nil + } + delay.Reset(polling) + } +} + +func (c *redshiftConnection) waitWithCancel(ctx context.Context, id *string, queryStart time.Time) (*redshiftdata.DescribeStatementOutput, error) { + desc, err := c.wait(ctx, id, queryStart) + cctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + if desc == nil { + var rErr error + desc, rErr = c.client.DescribeStatement(cctx, &redshiftdata.DescribeStatementInput{ + Id: id, + }) + if rErr != nil { + return nil, err + } + } + if isFinishedStatus(desc.Status) { + return desc, err + } + debugLogger.Printf("[%s] try cancel statement", *id) + output, cErr := c.client.CancelStatement(cctx, &redshiftdata.CancelStatementInput{ + Id: id, + }) + if cErr != nil { + errLogger.Printf("[%s] failed cancel statement: %w", *id, err) + return desc, err + } + if !*output.Status { + debugLogger.Printf("[%s] cancel statement status is false", *id) + } + return desc, err +} diff --git a/sqlconnect/internal/redshift/driver/connection_test.go b/sqlconnect/internal/redshift/driver/connection_test.go new file mode 100644 index 0000000..a77b183 --- /dev/null +++ b/sqlconnect/internal/redshift/driver/connection_test.go @@ -0,0 +1,47 @@ +package driver + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestRewriteQuery(t *testing.T) { + cases := []struct { + casename string + query string + params bool + expected string + }{ + { + casename: "no params", + query: `SELECT * FROM pg_user`, + params: false, + expected: `SELECT * FROM pg_user`, + }, + { + casename: "no change", + query: `SELECT * FROM pg_user WHERE usename = :name`, + params: true, + expected: `SELECT * FROM pg_user WHERE usename = :name`, + }, + { + casename: "? rewrite", + query: `SELECT 'hoge?' FROM pg_user WHERE usename = ? AND usesysid > ?`, + params: true, + expected: `SELECT 'hoge?' FROM pg_user WHERE usename = :1 AND usesysid > :2`, + }, + { + casename: "$ rewrite", + query: `SELECT '3$1$' FROM table WHERE "$column" = $1 AND column1 > $2 AND column2 < $1`, + params: true, + expected: `SELECT '3$1$' FROM table WHERE "$column" = :1 AND column1 > :2 AND column2 < :1`, + }, + } + for _, c := range cases { + t.Run(c.casename, func(t *testing.T) { + actual := rewriteQuery(c.query, c.params) + require.Equal(t, c.expected, actual) + }) + } +} diff --git a/sqlconnect/internal/redshift/driver/connector.go b/sqlconnect/internal/redshift/driver/connector.go new file mode 100644 index 0000000..070a3ae --- /dev/null +++ b/sqlconnect/internal/redshift/driver/connector.go @@ -0,0 +1,30 @@ +package driver + +import ( + "context" + "database/sql/driver" +) + +func NewRedshiftConnector(cfg RedshiftConfig) driver.Connector { + return &redshiftDataConnector{ + d: &redshiftDataDriver{}, + cfg: &cfg, + } +} + +type redshiftDataConnector struct { + d *redshiftDataDriver + cfg *RedshiftConfig +} + +func (c *redshiftDataConnector) Connect(ctx context.Context) (driver.Conn, error) { + client, err := newRedshiftDataClient(ctx, c.cfg, c.cfg.LoadOpts()...) + if err != nil { + return nil, err + } + return newConnection(client, c.cfg), nil +} + +func (c *redshiftDataConnector) Driver() driver.Driver { + return c.d +} diff --git a/sqlconnect/internal/redshift/driver/driver.go b/sqlconnect/internal/redshift/driver/driver.go new file mode 100644 index 0000000..abc3e43 --- /dev/null +++ b/sqlconnect/internal/redshift/driver/driver.go @@ -0,0 +1,32 @@ +package driver + +import ( + "context" + "database/sql" + "database/sql/driver" +) + +func init() { + sql.Register("redshift-data", &redshiftDataDriver{}) +} + +type redshiftDataDriver struct{} + +func (d *redshiftDataDriver) Open(dsn string) (driver.Conn, error) { + connector, err := d.OpenConnector(dsn) + if err != nil { + return nil, err + } + return connector.Connect(context.Background()) +} + +func (d *redshiftDataDriver) OpenConnector(dsn string) (driver.Connector, error) { + cfg, err := ParseDSN(dsn) + if err != nil { + return nil, err + } + return &redshiftDataConnector{ + d: d, + cfg: cfg, + }, nil +} diff --git a/sqlconnect/internal/redshift/driver/driver_test.go b/sqlconnect/internal/redshift/driver/driver_test.go new file mode 100644 index 0000000..f9d1fee --- /dev/null +++ b/sqlconnect/internal/redshift/driver/driver_test.go @@ -0,0 +1,255 @@ +package driver_test + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "os" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/rudderlabs/rudder-go-kit/testhelper/rand" + "github.com/rudderlabs/sqlconnect-go/sqlconnect/internal/redshift/driver" +) + +func TestRedshiftDriver(t *testing.T) { + configJSON, ok := os.LookupEnv("REDSHIFT_SDK_TEST_ENVIRONMENT_CREDENTIALS") + if !ok { + t.Skip("skipping redshift sdk driver integration test due to lack of a test environment") + } + var cfg driver.RedshiftConfig + err := json.Unmarshal([]byte(configJSON), &cfg) + require.NoError(t, err, "it should be able to unmarshal the config") + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + connector := driver.NewRedshiftConnector(cfg) + db := sql.OpenDB(connector) + schema := GenerateTestSchema() + t.Cleanup(func() { + _, err := db.Exec(fmt.Sprintf(`DROP SCHEMA IF EXISTS "%s" CASCADE`, schema)) + require.NoError(t, err, "it should be able to drop the schema") + }) + + t.Run("Ping", func(t *testing.T) { + require.NoError(t, db.Ping(), "it should be able to ping the database") + require.NoError(t, db.PingContext(ctx), "it should be able to ping the database using a context") + }) + + t.Run("Exec", func(t *testing.T) { + _, err := db.Exec(fmt.Sprintf(`CREATE SCHEMA "%s"`, schema)) + require.NoError(t, err, "it should be able to create a schema") + }) + + t.Run("ExecContext", func(t *testing.T) { + _, err := db.ExecContext(ctx, fmt.Sprintf(`CREATE TABLE "%s"."test_table" ("C1" INT4, "C2" VARCHAR)`, schema)) + require.NoError(t, err, "it should be able to create a table") + }) + + t.Run("prepared statement", func(t *testing.T) { + t.Run("QueryRow", func(t *testing.T) { + stmt, err := db.Prepare(fmt.Sprintf(`SELECT COUNT(*) FROM "%s"."test_table"`, schema)) + require.NoError(t, err, "it should be able to prepare a statement") + defer func() { + require.NoError(t, stmt.Close(), "it should be able to close the prepared statement") + }() + + var count int + err = stmt.QueryRow().Scan(&count) + require.NoError(t, err, "it should be able to execute a prepared statement") + }) + + t.Run("Exec", func(t *testing.T) { + stmt, err := db.Prepare(fmt.Sprintf(`INSERT INTO "%s"."test_table" (C1) VALUES (?)`, schema)) + require.NoError(t, err, "it should be able to prepare a statement") + defer func() { + require.NoError(t, stmt.Close(), "it should be able to close the prepared statement") + }() + result, err := stmt.Exec(1) + require.NoError(t, err, "it should be able to execute a prepared statement") + + _, err = result.LastInsertId() + require.Error(t, err) + require.ErrorIs(t, err, driver.ErrNotSupported) + + rowsAffected, err := result.RowsAffected() + require.NoError(t, err, "it should be able to get rows affected") + require.EqualValues(t, 1, rowsAffected, "rows affected should be 1") + }) + + t.Run("Query", func(t *testing.T) { + stmt, err := db.Prepare(fmt.Sprintf(`SELECT C1 FROM "%s"."test_table" WHERE C1 = $1`, schema)) + require.NoError(t, err, "it should be able to prepare a statement") + defer func() { + require.NoError(t, stmt.Close(), "it should be able to close the prepared statement") + }() + rows, err := stmt.Query(1) + require.NoError(t, err, "it should be able to execute a prepared statement") + defer func() { + require.NoError(t, rows.Close(), "it should be able to close the rows") + }() + require.True(t, rows.Next(), "it should be able to get a row") + var c1 int + err = rows.Scan(&c1) + require.NoError(t, err, "it should be able to scan the row") + require.EqualValues(t, 1, c1, "it should be able to get the correct value") + require.False(t, rows.Next(), "it shouldn't have next row") + + require.NoError(t, rows.Err()) + }) + t.Run("Query with named parameters", func(t *testing.T) { + stmt, err := db.PrepareContext(ctx, fmt.Sprintf(`SELECT C1, C2 FROM "%s"."test_table" WHERE C1 = :c1_value`, schema)) + require.NoError(t, err, "it should be able to prepare a statement") + defer func() { + require.NoError(t, stmt.Close(), "it should be able to close the prepared statement") + }() + rows, err := stmt.QueryContext(ctx, sql.Named("c1_value", 1)) + require.NoError(t, err, "it should be able to execute a prepared statement") + defer func() { + require.NoError(t, rows.Close(), "it should be able to close the rows") + }() + + cols, err := rows.Columns() + require.NoError(t, err, "it should be able to get the columns") + require.EqualValues(t, []string{"c1", "c2"}, cols, "it should be able to get the correct columns") + + colTypes, err := rows.ColumnTypes() + require.NoError(t, err, "it should be able to get the column types") + require.Len(t, colTypes, 2, "it should be able to get the correct number of column types") + require.EqualValues(t, "INT4", colTypes[0].DatabaseTypeName(), "it should be able to get the correct column type") + require.EqualValues(t, "VARCHAR", colTypes[1].DatabaseTypeName(), "it should be able to get the correct column type") + + require.True(t, rows.Next(), "it should be able to get a row") + var c1 int + var c2 any + err = rows.Scan(&c1, &c2) + require.NoError(t, err, "it should be able to scan the row") + require.EqualValues(t, 1, c1, "it should be able to get the correct value") + require.Nil(t, c2, "it should be able to get the correct value") + require.False(t, rows.Next(), "it shouldn't have next row") + + require.NoError(t, rows.Err()) + }) + }) + t.Run("query", func(t *testing.T) { + t.Run("QueryRow", func(t *testing.T) { + var count int + err := db.QueryRow(fmt.Sprintf(`SELECT COUNT(*) FROM "%s"."test_table"`, schema)).Scan(&count) + require.NoError(t, err, "it should be able to execute a prepared statement") + require.Equal(t, 1, count, "it should be able to get the correct value") + }) + + t.Run("Exec", func(t *testing.T) { + result, err := db.Exec(fmt.Sprintf(`INSERT INTO "%s"."test_table" (C1) VALUES ($1)`, schema), 2) + require.NoError(t, err, "it should be able to execute a prepared statement") + rowsAffected, err := result.RowsAffected() + require.NoError(t, err, "it should be able to get rows affected") + require.EqualValues(t, 1, rowsAffected, "rows affected should be 1") + }) + + t.Run("Query", func(t *testing.T) { + rows, err := db.Query(fmt.Sprintf(`SELECT C1 FROM "%s"."test_table" WHERE C1 = ?`, schema), 2) + require.NoError(t, err, "it should be able to execute a prepared statement") + defer func() { + require.NoError(t, rows.Close(), "it should be able to close the rows") + }() + require.True(t, rows.Next(), "it should be able to get a row") + var c1 int + err = rows.Scan(&c1) + require.NoError(t, err, "it should be able to scan the row") + require.EqualValues(t, 2, c1, "it should be able to get the correct value") + require.False(t, rows.Next(), "it shouldn't have next row") + + require.NoError(t, rows.Err()) + }) + + t.Run("Query with named parameters", func(t *testing.T) { + rows, err := db.QueryContext(ctx, fmt.Sprintf(`SELECT C1 FROM "%s"."test_table" WHERE C1 = :c1_value`, schema), sql.Named("c1_value", 2)) + require.NoError(t, err, "it should be able to execute a prepared statement") + defer func() { + require.NoError(t, rows.Close(), "it should be able to close the rows") + }() + + cols, err := rows.Columns() + require.NoError(t, err, "it should be able to get the columns") + require.EqualValues(t, []string{"c1"}, cols, "it should be able to get the correct columns") + + colTypes, err := rows.ColumnTypes() + require.NoError(t, err, "it should be able to get the column types") + require.Len(t, colTypes, 1, "it should be able to get the correct number of column types") + require.EqualValues(t, "INT4", colTypes[0].DatabaseTypeName(), "it should be able to get the correct column type") + + require.True(t, rows.Next(), "it should be able to get a row") + var c1 int + err = rows.Scan(&c1) + require.NoError(t, err, "it should be able to scan the row") + require.EqualValues(t, 2, c1, "it should be able to get the correct value") + require.False(t, rows.Next(), "it shouldn't have next row") + + require.NoError(t, rows.Err()) + }) + }) + + t.Run("transaction support", func(t *testing.T) { + t.Run("Begin and Commit", func(t *testing.T) { + var countBefore int + err := db.QueryRowContext(ctx, fmt.Sprintf(`SELECT COUNT(*) FROM "%s"."test_table"`, schema)).Scan(&countBefore) + require.NoError(t, err, "it should be able to execute a prepared statement") + + tx, err := db.Begin() + require.NoError(t, err, "it should be able to begin a transaction") + _, err = tx.ExecContext(ctx, fmt.Sprintf(`INSERT INTO "%s"."test_table" (C1) VALUES (3)`, schema)) + require.NoError(t, err, "it should be able to execute a prepared statement") + _, err = tx.ExecContext(ctx, fmt.Sprintf(`INSERT INTO "%s"."test_table" (C1) VALUES (4)`, schema)) + require.NoError(t, err, "it should be able to execute a prepared statement") + _, err = tx.ExecContext(ctx, fmt.Sprintf(`INSERT INTO "%s"."test_table" (C1) VALUES (?)`, schema), 5) + require.Error(t, err, "it should not be able to execute a prepared statement with parameters in a transaction") + require.ErrorIs(t, err, driver.ErrNotSupported) + + var countDuring int + err = db.QueryRowContext(ctx, fmt.Sprintf(`SELECT COUNT(*) FROM "%s"."test_table"`, schema)).Scan(&countDuring) + require.NoError(t, err, "it should be able to execute a prepared statement") + require.Equal(t, countBefore, countDuring, "it should not be able to see the changes from the transaction") + + err = tx.Commit() + require.NoError(t, err, "it should be able to commit the transaction") + + var countAfter int + err = db.QueryRowContext(ctx, fmt.Sprintf(`SELECT COUNT(*) FROM "%s"."test_table"`, schema)).Scan(&countAfter) + require.NoError(t, err, "it should be able to execute a prepared statement") + require.Equal(t, countBefore+2, countAfter, "it should be able to see the changes from the transaction") + }) + t.Run("BeginTx and Rollback", func(t *testing.T) { + var countBefore int + err := db.QueryRowContext(ctx, fmt.Sprintf(`SELECT COUNT(*) FROM "%s"."test_table"`, schema)).Scan(&countBefore) + require.NoError(t, err, "it should be able to execute a prepared statement") + + tx, err := db.BeginTx(ctx, nil) + require.NoError(t, err, "it should be able to begin a transaction") + _, err = tx.ExecContext(ctx, fmt.Sprintf(`INSERT INTO "%s"."test_table" (C1) VALUES (5)`, schema)) + require.NoError(t, err, "it should be able to execute a prepared statement") + + var countDuring int + err = db.QueryRowContext(ctx, fmt.Sprintf(`SELECT COUNT(*) FROM "%s"."test_table"`, schema)).Scan(&countDuring) + require.NoError(t, err, "it should be able to execute a prepared statement") + require.Equal(t, countBefore, countDuring, "it should not be able to see the changes from the transaction") + + err = tx.Rollback() + require.NoError(t, err, "it should be able to rollback the transaction") + + var countAfter int + err = db.QueryRowContext(ctx, fmt.Sprintf(`SELECT COUNT(*) FROM "%s"."test_table"`, schema)).Scan(&countAfter) + require.NoError(t, err, "it should be able to execute a prepared statement") + require.Equal(t, countBefore, countAfter, "changes from the transaction should be rolled back") + }) + }) +} + +func GenerateTestSchema() string { + return strings.ToLower(fmt.Sprintf("tsqlcon_%s_%d", rand.String(12), time.Now().Unix())) +} diff --git a/sqlconnect/internal/redshift/driver/dsn.go b/sqlconnect/internal/redshift/driver/dsn.go new file mode 100644 index 0000000..faa156b --- /dev/null +++ b/sqlconnect/internal/redshift/driver/dsn.go @@ -0,0 +1,213 @@ +package driver + +import ( + "errors" + "fmt" + "net/url" + "strings" + "time" + + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/redshiftdata" +) + +type RedshiftConfig struct { + ClusterIdentifier string `json:"clusterIdentifier"` + Database string `json:"database"` + DbUser string `json:"user"` + WorkgroupName string `json:"workgroupName"` + SecretsARN string `json:"secretsARN"` + Region string `json:"region"` + SharedConfigProfile string `json:"sharedConfigProfile"` + AccessKeyID string `json:"accessKeyId"` + SecretAccessKey string `json:"secretAccessKey"` + SessionToken string `json:"sessionToken"` + Timeout time.Duration `json:"timeout"` // default 15m + Polling time.Duration `json:"polling"` // default 10ms + + Params url.Values +} + +func (cfg *RedshiftConfig) LoadOpts() []func(*config.LoadOptions) error { + var opts []func(*config.LoadOptions) error + if cfg.SharedConfigProfile != "" { + opts = append(opts, config.WithSharedConfigProfile(cfg.SharedConfigProfile)) + } + if cfg.AccessKeyID != "" && cfg.SecretAccessKey != "" { + opts = append(opts, config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider( + cfg.AccessKeyID, + cfg.SecretAccessKey, + cfg.SessionToken, + ))) + } + return opts +} + +func (cfg *RedshiftConfig) Opts() []func(*redshiftdata.Options) { + var opts []func(*redshiftdata.Options) + if cfg.Region != "" { + opts = append(opts, func(o *redshiftdata.Options) { + o.Region = cfg.Region + }) + } + return opts +} + +func (cfg *RedshiftConfig) String() string { + base := strings.TrimPrefix(cfg.baseString(), "//") + if base == "" { + return "" + } + params := url.Values{} + for key, value := range cfg.Params { + params[key] = append([]string{}, value...) + } + if cfg.Timeout != 0 { + params.Add("timeout", cfg.Timeout.String()) + } else { + params.Del("timeout") + } + if cfg.Polling != 0 { + params.Add("polling", cfg.Polling.String()) + } else { + params.Del("polling") + } + if cfg.Region != "" { + params.Add("region", cfg.Region) + } else { + params.Del("region") + } + if cfg.SharedConfigProfile != "" { + params.Add("sharedConfigProfile", cfg.SharedConfigProfile) + } else { + params.Del("sharedConfigProfile") + } + if cfg.AccessKeyID != "" { + params.Add("accessKeyId", cfg.AccessKeyID) + } else { + params.Del("accessKeyId") + } + if cfg.SecretAccessKey != "" { + params.Add("secretAccessKey", cfg.SecretAccessKey) + } else { + params.Del("secretAccessKey") + } + if cfg.SessionToken != "" { + params.Add("sessionToken", cfg.SessionToken) + } else { + params.Del("sessionToken") + } + encodedParams := params.Encode() + if encodedParams != "" { + return base + "?" + encodedParams + } + return base +} + +func (cfg *RedshiftConfig) setParams(params url.Values) error { + var err error + cfg.Params = params + if params.Has("timeout") { + cfg.Timeout, err = time.ParseDuration(params.Get("timeout")) + if err != nil { + return fmt.Errorf("parse timeout as duration: %w", err) + } + cfg.Params.Del("timeout") + } + if params.Has("polling") { + cfg.Polling, err = time.ParseDuration(params.Get("polling")) + if err != nil { + return fmt.Errorf("parse polling as duration: %w", err) + } + cfg.Params.Del("polling") + } + if params.Has("region") { + cfg.Region = params.Get("region") + cfg.Params.Del("region") + } + if params.Has("sharedConfigProfile") { + cfg.SharedConfigProfile = params.Get("sharedConfigProfile") + cfg.Params.Del("sharedConfigProfile") + } + if params.Has("accessKeyId") { + cfg.AccessKeyID = params.Get("accessKeyId") + cfg.Params.Del("accessKeyId") + } + if params.Has("secretAccessKey") { + cfg.SecretAccessKey = params.Get("secretAccessKey") + cfg.Params.Del("secretAccessKey") + } + if params.Has("sessionToken") { + cfg.SecretAccessKey = params.Get("sessionToken") + cfg.Params.Del("sessionToken") + } + if len(cfg.Params) == 0 { + cfg.Params = nil + } + return nil +} + +func (cfg *RedshiftConfig) baseString() string { + if cfg.SecretsARN != "" { + return cfg.SecretsARN + } + var u url.URL + if cfg.ClusterIdentifier != "" && cfg.DbUser != "" { + u.Host = fmt.Sprintf("cluster(%s)", cfg.ClusterIdentifier) + u.User = url.User(cfg.DbUser) + } + if cfg.WorkgroupName != "" { + u.Host = fmt.Sprintf("workgroup(%s)", cfg.WorkgroupName) + } + if u.Host == "" || cfg.Database == "" { + return "" + } + u.Path = cfg.Database + return u.String() +} + +func ParseDSN(dsn string) (*RedshiftConfig, error) { + if dsn == "" { + return nil, ErrDSNEmpty + } + if strings.HasPrefix(dsn, "arn:") { + parts := strings.Split(dsn, "?") + cfg := &RedshiftConfig{ + SecretsARN: parts[0], + } + if len(parts) >= 2 { + params, err := url.ParseQuery(strings.Join(parts[1:], "?")) + if err != nil { + return nil, fmt.Errorf("dsn is invalid: can not parse query params: %w", err) + } + if err := cfg.setParams(params); err != nil { + return nil, fmt.Errorf("dsn is invalid: set query params: %w", err) + } + } + return cfg, nil + } + u, err := url.Parse("redshift-sdk://" + dsn) + if err != nil { + return nil, fmt.Errorf("dsn is invalid: %w", err) + } + cfg := &RedshiftConfig{ + Database: strings.TrimPrefix(u.Path, "/"), + } + if cfg.Database == "" { + return nil, errors.New("dsn is invalid: missing database") + } + if err := cfg.setParams(u.Query()); err != nil { + return nil, fmt.Errorf("dsn is invalid: set query params: %w", err) + } + if strings.HasPrefix(u.Host, "cluster(") { + cfg.DbUser = u.User.Username() + cfg.ClusterIdentifier = strings.TrimSuffix(strings.TrimPrefix(u.Host, "cluster("), ")") + return cfg, nil + } + if strings.HasPrefix(u.Host, "workgroup(") { + cfg.WorkgroupName = strings.TrimSuffix(strings.TrimPrefix(u.Host, "workgroup("), ")") + return cfg, nil + } + return nil, errors.New("dsn is invalid: workgroup(name)/database or username@cluster(name)/database or secrets_arn") +} diff --git a/sqlconnect/internal/redshift/driver/dsn_test.go b/sqlconnect/internal/redshift/driver/dsn_test.go new file mode 100644 index 0000000..d0bc49e --- /dev/null +++ b/sqlconnect/internal/redshift/driver/dsn_test.go @@ -0,0 +1,93 @@ +package driver + +import ( + "net/url" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestRedshiftDataConfig__String(t *testing.T) { + cases := []struct { + dsn *RedshiftConfig + expected string + }{ + { + dsn: &RedshiftConfig{}, + expected: "", + }, + { + dsn: &RedshiftConfig{ + ClusterIdentifier: "default", + DbUser: "admin", + Database: "dev", + }, + expected: "admin@cluster(default)/dev", + }, + { + dsn: &RedshiftConfig{ + ClusterIdentifier: "default", + DbUser: "admin", + Database: "dev", + Timeout: 15 * time.Minute, + }, + expected: "admin@cluster(default)/dev?timeout=15m0s", + }, + { + dsn: &RedshiftConfig{ + ClusterIdentifier: "default", + DbUser: "admin", + Database: "dev", + Polling: 5 * time.Millisecond, + }, + expected: "admin@cluster(default)/dev?polling=5ms", + }, + { + dsn: &RedshiftConfig{ + ClusterIdentifier: "default", + DbUser: "admin", + Database: "dev", + Params: url.Values{ + "extra": []string{"hoge"}, + }, + }, + expected: "admin@cluster(default)/dev?extra=hoge", + }, + { + dsn: &RedshiftConfig{ + ClusterIdentifier: "default", + DbUser: "admin", + Database: "dev", + Region: "us-east-1", + }, + expected: "admin@cluster(default)/dev?region=us-east-1", + }, + { + dsn: &RedshiftConfig{ + WorkgroupName: "default", + Database: "dev", + }, + expected: "workgroup(default)/dev", + }, + { + dsn: &RedshiftConfig{ + SecretsARN: "arn:aws:secretsmanager:us-east-1:0123456789012:secret:redshift", + Timeout: 30 * time.Second, + }, + expected: "arn:aws:secretsmanager:us-east-1:0123456789012:secret:redshift?timeout=30s", + }, + } + + for _, c := range cases { + t.Run(c.expected, func(t *testing.T) { + actual := c.dsn.String() + require.Equal(t, c.expected, actual) + if c.expected != "" { + cfg, err := ParseDSN(actual) + require.NoError(t, err) + require.EqualValues(t, c.dsn, cfg) + } + }) + } +} diff --git a/sqlconnect/internal/redshift/driver/errors.go b/sqlconnect/internal/redshift/driver/errors.go new file mode 100644 index 0000000..05b8e75 --- /dev/null +++ b/sqlconnect/internal/redshift/driver/errors.go @@ -0,0 +1,12 @@ +package driver + +import "errors" + +var ( + ErrNotSupported = errors.New("not supported") + ErrDSNEmpty = errors.New("dsn is empty") + ErrConnClosed = errors.New("connection closed") + ErrBeforeCommit = errors.New("transaction is not committed") + ErrNotInTx = errors.New("not in transaction") + ErrInTx = errors.New("already in transaction") +) diff --git a/sqlconnect/internal/redshift/driver/logger.go b/sqlconnect/internal/redshift/driver/logger.go new file mode 100644 index 0000000..f691f90 --- /dev/null +++ b/sqlconnect/internal/redshift/driver/logger.go @@ -0,0 +1,35 @@ +package driver + +import ( + "errors" + "io" + "log" + "os" +) + +type Logger interface { + Printf(format string, v ...any) + SetOutput(w io.Writer) + Writer() io.Writer +} + +var ( + errLogger = Logger(log.New(os.Stderr, "[redshift-data][error]", log.Ldate|log.Ltime|log.Lshortfile)) + debugLogger = Logger(log.New(io.Discard, "[redshift-data][debug]", log.Ldate|log.Ltime|log.Lshortfile)) +) + +func SetLogger(l Logger) error { + if l == nil { + return errors.New("logger is nil") + } + errLogger = l + return nil +} + +func SetDebugLogger(l Logger) error { + if l == nil { + return errors.New("logger is nil") + } + debugLogger = l + return nil +} diff --git a/sqlconnect/internal/redshift/driver/result.go b/sqlconnect/internal/redshift/driver/result.go new file mode 100644 index 0000000..3683fae --- /dev/null +++ b/sqlconnect/internal/redshift/driver/result.go @@ -0,0 +1,55 @@ +package driver + +import ( + "database/sql/driver" + "fmt" + + "github.com/aws/aws-sdk-go-v2/service/redshiftdata" + "github.com/aws/aws-sdk-go-v2/service/redshiftdata/types" +) + +type redshiftResult struct { + affectedRows int64 +} + +func newResult(output *redshiftdata.DescribeStatementOutput) *redshiftResult { + debugLogger.Printf("[%s] create result", coalesce(output.Id)) + return &redshiftResult{ + affectedRows: output.ResultRows, + } +} + +func newResultWithSubStatementData(st types.SubStatementData) *redshiftResult { + debugLogger.Printf("[%s] create result", coalesce(st.Id)) + return &redshiftResult{ + affectedRows: st.ResultRows, + } +} + +func (r *redshiftResult) LastInsertId() (int64, error) { + return 0, fmt.Errorf("LastInsertId %w", ErrNotSupported) +} + +func (r *redshiftResult) RowsAffected() (int64, error) { + return r.affectedRows, nil +} + +type redshiftDelayedResult struct { + driver.Result +} + +func (r *redshiftDelayedResult) LastInsertId() (int64, error) { + debugLogger.Printf("delayed result LastInsertId called") + if r.Result != nil { + return r.Result.LastInsertId() + } + return 0, ErrBeforeCommit +} + +func (r *redshiftDelayedResult) RowsAffected() (int64, error) { + debugLogger.Printf("delayed result RowsAffected called") + if r.Result != nil { + return r.Result.RowsAffected() + } + return 0, ErrBeforeCommit +} diff --git a/sqlconnect/internal/redshift/driver/rows.go b/sqlconnect/internal/redshift/driver/rows.go new file mode 100644 index 0000000..94e03cf --- /dev/null +++ b/sqlconnect/internal/redshift/driver/rows.go @@ -0,0 +1,120 @@ +package driver + +import ( + "context" + "database/sql/driver" + "io" + "strings" + "time" + + "github.com/aws/aws-sdk-go-v2/service/redshiftdata" + "github.com/aws/aws-sdk-go-v2/service/redshiftdata/types" +) + +type redshiftRows struct { + ctx context.Context + id string + p *redshiftdata.GetStatementResultPaginator + + page *redshiftdata.GetStatementResultOutput + columns []types.ColumnMetadata + columnNames []string + index int +} + +func newRows(ctx context.Context, id string, p *redshiftdata.GetStatementResultPaginator) (*redshiftRows, error) { + debugLogger.Printf("[%s] create rows", id) + rows := &redshiftRows{ + ctx: ctx, + id: id, + p: p, + } + return rows, rows.getStatementResult() +} + +func (rows *redshiftRows) Close() (err error) { + debugLogger.Printf("[%s] rows close called", rows.id) + return nil +} + +func (rows *redshiftRows) Columns() []string { + return rows.columnNames +} + +func (rows *redshiftRows) ColumnTypeDatabaseTypeName(index int) string { + return strings.ToUpper(*rows.columns[index].TypeName) +} + +func (rows *redshiftRows) getStatementResult() error { + if rows.p == nil { + return io.EOF + } + var err error + rows.page, err = rows.p.NextPage(rows.ctx) + if err != nil { + return err + } + rows.columns = rows.page.ColumnMetadata + rows.columnNames = make([]string, 0, len(rows.columns)) + for _, meta := range rows.columns { + rows.columnNames = append(rows.columnNames, *meta.Name) + } + rows.index = 0 + return nil +} + +func (rows *redshiftRows) Next(dest []driver.Value) error { + debugLogger.Printf("[%s] rows next called", rows.id) + if rows.page == nil || rows.index >= len(rows.page.Records) { + if !rows.p.HasMorePages() { + return io.EOF + } + if err := rows.getStatementResult(); err != nil { + return err + } + rows.index = 0 + if len(rows.page.Records) == 0 { + return io.EOF + } + } + record := rows.page.Records[rows.index] + for i := range dest { + if i < len(record) { + switch field := record[i].(type) { + case *types.FieldMemberIsNull: + dest[i] = nil + case *types.FieldMemberStringValue: + switch { + case strings.EqualFold(*rows.page.ColumnMetadata[i].TypeName, "timestamp"): + t, err := time.Parse("2006-01-02 15:04:05", field.Value) + if err != nil { + errLogger.Printf(`time.Parse("2006-01-02 15:04:05", "%s"): %v`, field.Value, err) + dest[i] = field.Value + } else { + dest[i] = t.UTC().Format(time.RFC3339) + } + case strings.EqualFold(*rows.page.ColumnMetadata[i].TypeName, "timestamptz"): + t, err := time.Parse("2006-01-02 15:04:05-07", field.Value) + if err != nil { + errLogger.Printf(`time.Parse("2006-01-02 15:04:05-07", "%s"): %v`, field.Value, err) + dest[i] = field.Value + } else { + dest[i] = t.UTC().Format(time.RFC3339) + } + default: + dest[i] = field.Value + } + case *types.FieldMemberLongValue: + dest[i] = field.Value + case *types.FieldMemberBooleanValue: + dest[i] = field.Value + case *types.FieldMemberDoubleValue: + dest[i] = field.Value + case *types.FieldMemberBlobValue: + dest[i] = field.Value + } + } + } + rows.index++ + return nil +} diff --git a/sqlconnect/internal/redshift/driver/statement.go b/sqlconnect/internal/redshift/driver/statement.go new file mode 100644 index 0000000..d62b146 --- /dev/null +++ b/sqlconnect/internal/redshift/driver/statement.go @@ -0,0 +1,35 @@ +package driver + +import ( + "context" + "database/sql/driver" +) + +type redshiftStatement struct { + connection *redshiftConnection + query string +} + +func (_ *redshiftStatement) Close() error { + return nil +} + +func (_ *redshiftStatement) NumInput() int { + return -1 +} + +func (s *redshiftStatement) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { + return s.connection.ExecContext(ctx, s.query, args) +} + +func (s *redshiftStatement) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { + return s.connection.QueryContext(ctx, s.query, args) +} + +func (_ *redshiftStatement) Exec(args []driver.Value) (driver.Result, error) { + return nil, driver.ErrSkip +} + +func (_ *redshiftStatement) Query(args []driver.Value) (driver.Rows, error) { + return nil, driver.ErrSkip +} diff --git a/sqlconnect/internal/redshift/driver/tx.go b/sqlconnect/internal/redshift/driver/tx.go new file mode 100644 index 0000000..65a7a3b --- /dev/null +++ b/sqlconnect/internal/redshift/driver/tx.go @@ -0,0 +1,16 @@ +package driver + +type redshiftTx struct { + onCommit func() error + onRollback func() error +} + +func (tx *redshiftTx) Commit() error { + debugLogger.Printf("tx commit called") + return tx.onCommit() +} + +func (tx *redshiftTx) Rollback() error { + debugLogger.Printf("tx rollback called") + return tx.onRollback() +} diff --git a/sqlconnect/internal/redshift/driver/utils.go b/sqlconnect/internal/redshift/driver/utils.go new file mode 100644 index 0000000..3a7a956 --- /dev/null +++ b/sqlconnect/internal/redshift/driver/utils.go @@ -0,0 +1,18 @@ +package driver + +func nullStringIfEmpty(str string) *string { + if str == "" { + return nil + } + return &str +} + +func coalesce[T any](strs ...*T) T { + for _, str := range strs { + if str != nil { + return *str + } + } + var zero T + return zero +} diff --git a/sqlconnect/internal/redshift/integration_test.go b/sqlconnect/internal/redshift/integration_test.go index 251a06d..50ae448 100644 --- a/sqlconnect/internal/redshift/integration_test.go +++ b/sqlconnect/internal/redshift/integration_test.go @@ -10,10 +10,20 @@ import ( ) func TestRedshiftDB(t *testing.T) { - configJSON, ok := os.LookupEnv("REDSHIFT_TEST_ENVIRONMENT_CREDENTIALS") - if !ok { - t.Skip("skipping redshift integration test due to lack of a test environment") - } + t.Run("postgres driver", func(t *testing.T) { + configJSON, ok := os.LookupEnv("REDSHIFT_TEST_ENVIRONMENT_CREDENTIALS") + if !ok { + t.Skip("skipping redshift pg integration test due to lack of a test environment") + } - integrationtest.TestDatabaseScenarios(t, redshift.DatabaseType, []byte(configJSON), strings.ToLower, integrationtest.Options{LegacySupport: true}) + integrationtest.TestDatabaseScenarios(t, redshift.DatabaseType, []byte(configJSON), strings.ToLower, integrationtest.Options{LegacySupport: true}) + }) + + t.Run("sdk driver", func(t *testing.T) { + configJSON, ok := os.LookupEnv("REDSHIFT_SDK_TEST_ENVIRONMENT_CREDENTIALS") + if !ok { + t.Skip("skipping redshift sdk integration test due to lack of a test environment") + } + integrationtest.TestDatabaseScenarios(t, redshift.DatabaseType, []byte(configJSON), strings.ToLower, integrationtest.Options{LegacySupport: true}) + }) } diff --git a/sqlconnect/internal/redshift/mappings.go b/sqlconnect/internal/redshift/mappings.go index fbf6bdb..65dd3e7 100644 --- a/sqlconnect/internal/redshift/mappings.go +++ b/sqlconnect/internal/redshift/mappings.go @@ -44,6 +44,10 @@ func jsonRowMapper(databaseTypeName string, value any) any { if n, err := strconv.ParseFloat(string(v), 64); err == nil { return n } + case string: + if n, err := strconv.ParseFloat(v, 64); err == nil { + return n + } } default: switch v := value.(type) { diff --git a/sqlconnect/internal/redshift/testdata/column-mapping-test-rows.json b/sqlconnect/internal/redshift/testdata/column-mapping-test-rows.json index e4eaa73..8c7dbb4 100644 --- a/sqlconnect/internal/redshift/testdata/column-mapping-test-rows.json +++ b/sqlconnect/internal/redshift/testdata/column-mapping-test-rows.json @@ -8,9 +8,9 @@ "_integer": 1, "_smallint": 1, "_bigint": 1, - "_real": 1.1, + "_real": 1, "_float": 1.1, - "_float4": 1.1, + "_float4": 1, "_float8": 1.1, "_numeric": 1.1, "_double": 1.1, diff --git a/sqlconnect/internal/redshift/testdata/column-mapping-test-seed.sql b/sqlconnect/internal/redshift/testdata/column-mapping-test-seed.sql index 57c29aa..1fc69b5 100644 --- a/sqlconnect/internal/redshift/testdata/column-mapping-test-seed.sql +++ b/sqlconnect/internal/redshift/testdata/column-mapping-test-seed.sql @@ -30,6 +30,6 @@ CREATE TABLE "{{.schema}}"."column_mappings_test" ( INSERT INTO "{{.schema}}"."column_mappings_test" (_order, _int, _int2, _int4, _int8, _integer, _smallint, _bigint, _real, _float, _float4, _float8, _numeric, _double, _text, _varchar, _charvar, _nchar, _bpchar, _character, _timestamptz, _timestampntz, _timestampwtz, _timestamp, _boolean, _bool) VALUES - (1, 1, 1, 1, 1, 1, 1, 1, 1.1, 1.1, 1.1, 1.1, 1.1, 1.1, 'abc', 'abc', 'abc', 'abc', 'abc', 'abc', '2004-10-19 10:23:54+02', '2004-10-19 10:23:54', '2004-10-19 10:23:54+02', '2004-10-19 10:23:54+02', true, true ), + (1, 1, 1, 1, 1, 1, 1, 1, 1, 1.1, 1, 1.1, 1.1, 1.1, 'abc', 'abc', 'abc', 'abc', 'abc', 'abc', '2004-10-19 10:23:54+02', '2004-10-19 10:23:54', '2004-10-19 10:23:54+02', '2004-10-19 10:23:54+02', true, true ), (2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, '', '', '', '', '', '', '2004-10-19 10:23:54+02', '2004-10-19 10:23:54', '2004-10-19 10:23:54+02', '2004-10-19 10:23:54+02', false, false), (3, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL ); \ No newline at end of file diff --git a/sqlconnect/internal/redshift/testdata/legacy-column-mapping-test-rows.json b/sqlconnect/internal/redshift/testdata/legacy-column-mapping-test-rows.json index b1d1428..c65fbb7 100644 --- a/sqlconnect/internal/redshift/testdata/legacy-column-mapping-test-rows.json +++ b/sqlconnect/internal/redshift/testdata/legacy-column-mapping-test-rows.json @@ -8,9 +8,9 @@ "_integer": 1, "_smallint": 1, "_bigint": 1, - "_real": 1.1, + "_real": 1, "_float": 1.1, - "_float4": 1.1, + "_float4": 1, "_float8": 1.1, "_numeric": "1.10", "_double": 1.1,