Skip to content

Commit

Permalink
destination-s3: assume roal auth
Browse files Browse the repository at this point in the history
  • Loading branch information
stephane-airbyte committed May 15, 2024
1 parent bc83bee commit 0870cd1
Show file tree
Hide file tree
Showing 25 changed files with 413 additions and 87 deletions.
1 change: 1 addition & 0 deletions airbyte-cdk/java/airbyte-cdk/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ corresponds to that version.

| Version | Date | Pull Request | Subject |
|:--------| :--------- | :--------------------------------------------------------- |:---------------------------------------------------------------------------------------------------------------------------------------------------------------|
| 0.35.3 | 2024-05-15 | [\#38204](https://github.com/airbytehq/airbyte/pull/38204) | add assume-role authentication to s3 |
| 0.35.2 | 2024-05-13 | [\#38104](https://github.com/airbytehq/airbyte/pull/38104) | Handle transient error messages |
| 0.35.0 | 2024-05-13 | [\#38127](https://github.com/airbytehq/airbyte/pull/38127) | Destinations: Populate generation/sync ID on StreamConfig |
| 0.34.4 | 2024-05-10 | [\#37712](https://github.com/airbytehq/airbyte/pull/37712) | make sure the exceptionHandler always terminates |
Expand Down
Original file line number Diff line number Diff line change
@@ -1 +1 @@
version=0.35.2
version=0.35.3
Original file line number Diff line number Diff line change
Expand Up @@ -367,10 +367,14 @@ abstract class DestinationAcceptanceTest {
workspaceRoot.toString(),
localRoot.toString(),
"host",
emptyMap()
getConnectorEnv()
)
}

open fun getConnectorEnv(): Map<String, String> {
return emptyMap()
}

@AfterEach
@Throws(Exception::class)
fun tearDownInternal() {
Expand Down
1 change: 1 addition & 0 deletions airbyte-cdk/java/airbyte-cdk/s3-destinations/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ dependencies {

// Re-export dependencies for gcs-destinations.
api 'com.amazonaws:aws-java-sdk-s3:1.12.647'
api 'com.amazonaws:aws-java-sdk-sts:1.12.647'
api ('com.github.airbytehq:json-avro-converter:1.1.0') { exclude group: 'ch.qos.logback', module: 'logback-classic'}
api 'com.github.alexmojaki:s3-stream-upload:2.2.4'
api 'org.apache.avro:avro:1.11.3'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,14 @@ import org.slf4j.LoggerFactory

abstract class BaseS3Destination
protected constructor(
protected val configFactory: S3DestinationConfigFactory = S3DestinationConfigFactory()
protected val configFactory: S3DestinationConfigFactory = S3DestinationConfigFactory(),
protected val environment: Map<String, String> = System.getenv()
) : BaseConnector(), Destination {
private val nameTransformer: NamingConventionTransformer = S3NameTransformer()

override fun check(config: JsonNode): AirbyteConnectionStatus? {
try {
val destinationConfig = configFactory.getS3DestinationConfig(config, storageProvider())
val destinationConfig = configFactory.getS3DestinationConfig(config, storageProvider(), environment)
val s3Client = destinationConfig.getS3Client()

S3BaseChecks.testIAMUserHasListObjectPermission(s3Client, destinationConfig.bucketName)
Expand Down Expand Up @@ -60,7 +61,7 @@ protected constructor(
catalog: ConfiguredAirbyteCatalog,
outputRecordCollector: Consumer<AirbyteMessage>
): AirbyteMessageConsumer? {
val s3Config = configFactory.getS3DestinationConfig(config, storageProvider())
val s3Config = configFactory.getS3DestinationConfig(config, storageProvider(), environment)
return S3ConsumerFactory()
.create(
outputRecordCollector,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,7 @@ import com.amazonaws.services.s3.AmazonS3
import com.amazonaws.services.s3.AmazonS3ClientBuilder
import com.fasterxml.jackson.databind.JsonNode
import io.airbyte.cdk.integrations.destination.s3.constant.S3Constants
import io.airbyte.cdk.integrations.destination.s3.credential.S3AWSDefaultProfileCredentialConfig
import io.airbyte.cdk.integrations.destination.s3.credential.S3AccessKeyCredentialConfig
import io.airbyte.cdk.integrations.destination.s3.credential.S3CredentialConfig
import io.airbyte.cdk.integrations.destination.s3.credential.S3CredentialType
import io.airbyte.cdk.integrations.destination.s3.credential.*
import java.util.*
import javax.annotation.Nonnull
import org.slf4j.Logger
Expand All @@ -34,6 +31,7 @@ open class S3DestinationConfig {
val formatConfig: UploadFormatConfig?
var fileNamePattern: String? = null
private set
var environment: Map<String, String>

private val lock = Any()
private var s3Client: AmazonS3?
Expand Down Expand Up @@ -70,6 +68,7 @@ open class S3DestinationConfig {
this.s3CredentialConfig = credentialConfig
this.formatConfig = formatConfig
this.s3Client = s3Client
this.environment = System.getenv()
}

constructor(
Expand All @@ -83,7 +82,8 @@ open class S3DestinationConfig {
s3Client: AmazonS3?,
fileNamePattern: String?,
checkIntegrity: Boolean,
uploadThreadsCount: Int
uploadThreadsCount: Int,
environment: Map<String, String> = System.getenv()
) {
this.endpoint = endpoint
this.bucketName = bucketName
Expand All @@ -96,6 +96,7 @@ open class S3DestinationConfig {
this.fileNamePattern = fileNamePattern
this.isCheckIntegrity = checkIntegrity
this.uploadThreadsCount = uploadThreadsCount
this.environment = environment
}

fun resetS3Client(): AmazonS3 {
Expand All @@ -113,36 +114,38 @@ open class S3DestinationConfig {
val credentialsProvider = s3CredentialConfig!!.s3CredentialsProvider
val credentialType = s3CredentialConfig.credentialType

if (S3CredentialType.DEFAULT_PROFILE == credentialType) {
return AmazonS3ClientBuilder.standard()
.withRegion(bucketRegion)
.withCredentials(credentialsProvider) // the SDK defaults to RetryMode.LEGACY
// (https://docs.aws.amazon.com/sdkref/latest/guide/feature-retry-behavior.html)
// this _can_ be configured via environment variable, but it seems more reliable to
// configure it
// programmatically
.withClientConfiguration(ClientConfiguration().withRetryMode(RetryMode.STANDARD))
.build()
}

if (null == endpoint || endpoint.isEmpty()) {
return AmazonS3ClientBuilder.standard()
.withCredentials(credentialsProvider)
.withRegion(bucketRegion)
.build()
val clientBuilder = AmazonS3ClientBuilder.standard().withCredentials(credentialsProvider)
when (credentialType) {
S3CredentialType.DEFAULT_PROFILE,
S3CredentialType.ASSUME_ROLE ->
clientBuilder
.withRegion(bucketRegion)
// the SDK defaults to RetryMode.LEGACY
// (https://docs.aws.amazon.com/sdkref/latest/guide/feature-retry-behavior.html)
// this _can_ be configured via environment variable, but it seems more reliable
// to
// configure it
// programmatically
.withClientConfiguration(
ClientConfiguration().withRetryMode(RetryMode.STANDARD)
)
S3CredentialType.ACCESS_KEY -> {
if (null == endpoint || endpoint.isEmpty()) {
clientBuilder.withRegion(bucketRegion)
} else {
val clientConfiguration = ClientConfiguration().withProtocol(Protocol.HTTPS)
clientConfiguration.signerOverride = "AWSS3V4SignerType"

clientBuilder
.withEndpointConfiguration(
AwsClientBuilder.EndpointConfiguration(endpoint, bucketRegion)
)
.withPathStyleAccessEnabled(true)
.withClientConfiguration(clientConfiguration)
}
}
}

val clientConfiguration = ClientConfiguration().withProtocol(Protocol.HTTPS)
clientConfiguration.signerOverride = "AWSS3V4SignerType"

return AmazonS3ClientBuilder.standard()
.withEndpointConfiguration(
AwsClientBuilder.EndpointConfiguration(endpoint, bucketRegion)
)
.withPathStyleAccessEnabled(true)
.withClientConfiguration(clientConfiguration)
.withCredentials(credentialsProvider)
.build()
return clientBuilder.build()
}

override fun equals(other: Any?): Boolean {
Expand Down Expand Up @@ -188,6 +191,7 @@ open class S3DestinationConfig {
private var checkIntegrity = true

private var uploadThreadsCount = S3StorageOperations.DEFAULT_UPLOAD_THREADS
private var environment: Map<String, String> = System.getenv()

fun withBucketName(bucketName: String): Builder {
this.bucketName = bucketName
Expand Down Expand Up @@ -249,6 +253,11 @@ open class S3DestinationConfig {
return this
}

fun withEnvironment(environment: Map<String, String>): Builder {
this.environment = environment
return this
}

fun get(): S3DestinationConfig {
return S3DestinationConfig(
endpoint,
Expand All @@ -261,7 +270,8 @@ open class S3DestinationConfig {
s3Client,
fileNamePattern,
checkIntegrity,
uploadThreadsCount
uploadThreadsCount,
environment
)
}
}
Expand All @@ -284,14 +294,15 @@ open class S3DestinationConfig {
}

@JvmStatic
fun getS3DestinationConfig(@Nonnull config: JsonNode): S3DestinationConfig {
return getS3DestinationConfig(config, StorageProvider.AWS_S3)
fun getS3DestinationConfig(@Nonnull config: JsonNode, environment: Map<String, String> = System.getenv()): S3DestinationConfig {
return getS3DestinationConfig(config, StorageProvider.AWS_S3, environment)
}

@JvmStatic
fun getS3DestinationConfig(
@Nonnull config: JsonNode,
@Nonnull storageProvider: StorageProvider
@Nonnull storageProvider: StorageProvider = StorageProvider.AWS_S3,
environment: Map<String, String> = System.getenv()
): S3DestinationConfig {
var builder =
create(
Expand Down Expand Up @@ -343,6 +354,11 @@ open class S3DestinationConfig {
getProperty(config, S3Constants.ACCESS_KEY_ID),
getProperty(config, S3Constants.SECRET_ACCESS_KEY)
)
} else if (config.has(S3Constants.ROLE_ARN)) {
S3AssumeRoleCredentialConfig(
getProperty(config, S3Constants.ROLE_ARN)!!,
environment
)
} else {
S3AWSDefaultProfileCredentialConfig()
}
Expand All @@ -357,7 +373,7 @@ open class S3DestinationConfig {
UploadFormatConfigFactory.getUploadFormatConfig(config)
)
}

builder.withEnvironment(environment)
return builder.get()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@ import javax.annotation.Nonnull
open class S3DestinationConfigFactory {
open fun getS3DestinationConfig(
config: JsonNode,
@Nonnull storageProvider: StorageProvider
@Nonnull storageProvider: StorageProvider,
environment: Map<String, String>
): S3DestinationConfig {
return S3DestinationConfig.Companion.getS3DestinationConfig(config, storageProvider)
return S3DestinationConfig.Companion.getS3DestinationConfig(config=config, storageProvider=storageProvider, environment=environment)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class S3Constants {
const val SECRET_ACCESS_KEY: String = "secret_access_key"
const val S_3_BUCKET_NAME: String = "s3_bucket_name"
const val S_3_BUCKET_REGION: String = "s3_bucket_region"
const val ROLE_ARN: String = "role_arn"

// r2 requires account_id
const val ACCOUNT_ID: String = "account_id"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Copyright (c) 2024 Airbyte, Inc., all rights reserved.
*/
package io.airbyte.cdk.integrations.destination.s3.credential

import com.amazonaws.auth.*
import com.amazonaws.regions.Regions
import com.amazonaws.services.securitytoken.AWSSecurityTokenServiceClient

private const val AIRBYTE_STS_SESSION_NAME = "airbyte-sts-session"

/**
* The S3AssumeRoleCredentialConfig implementation of the S3CredentialConfig returns an
* STSAssumeRoleSessionCredentialsProvider. The STSAssumeRoleSessionCredentialsProvider
* automatically refreshes assumed role credentials on a background thread.
* The roleArn comes from the spec and the externalId, which
* is used to protect against confused deputy problems, and also is provided through the
* orchestrator via an environment variable. As of 5/2024, the externalId is set to the workspaceId.
*
* @param roleArn The Amazon Resource Name (ARN) of the role to assume.
*/
class S3AssumeRoleCredentialConfig(private val roleArn: String, environment: Map<String, String>) :
S3CredentialConfig {
private val externalId: String = environment.getValue("AWS_ASSUME_ROLE_EXTERNAL_ID")

override val credentialType: S3CredentialType = S3CredentialType.ASSUME_ROLE

/**
* AWSCredentialsProvider implementation that uses the AWS Security Token Service to assume a
* Role and create temporary, short-lived sessions to use for authentication. This credentials
* provider uses a background thread to refresh credentials. This background thread can be shut
* down via the close() method when the credentials provider is no longer used.
*/
override val s3CredentialsProvider: AWSCredentialsProvider by lazy {
STSAssumeRoleSessionCredentialsProvider.Builder(roleArn, AIRBYTE_STS_SESSION_NAME)
.withExternalId(externalId)
.withStsClient(
AWSSecurityTokenServiceClient.builder()
.withRegion(Regions.DEFAULT_REGION)
.withCredentials(
getCredentialProvider(environment)
)
.build()
)
.build()
}

companion object {
@JvmStatic
fun getCredentialProvider(environment: Map<String, String>): AWSStaticCredentialsProvider {
return AWSStaticCredentialsProvider(BasicAWSCredentials(
environment.getValue("AWS_ACCESS_KEY_ID"),
environment.getValue("AWS_SECRET_ACCESS_KEY")
))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package io.airbyte.cdk.integrations.destination.s3.credential

import com.amazonaws.auth.AWSCredentialsProvider

interface S3CredentialConfig : BlobStorageCredentialConfig<S3CredentialType> {
interface S3CredentialConfig {
val s3CredentialsProvider: AWSCredentialsProvider
val credentialType: S3CredentialType
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@ package io.airbyte.cdk.integrations.destination.s3.credential

enum class S3CredentialType {
ACCESS_KEY,
DEFAULT_PROFILE
DEFAULT_PROFILE,
ASSUME_ROLE
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package io.airbyte.cdk.integrations.destination.s3.credential

import org.junit.jupiter.api.Test

class S3AssumeRoleTest {
@Test
fun testFailsWithNoEnvCredentials() {

}

@Test
fun testPassesWithAllCredentials() {

}

@Test
fun testFailsWithWrongExternalId() {

}


fun testAutomaticRenewal() {

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,11 @@ protected constructor(protected val outputFormat: FileUploadFormat) : Destinatio
.set<JsonNode>("format", formatConfig)
this.configJson = configJson
this.s3DestinationConfig =
S3DestinationConfig.getS3DestinationConfig(configJson, storageProvider())
S3DestinationConfig.getS3DestinationConfig(
configJson,
storageProvider(),
getConnectorEnv()
)
LOGGER.info(
"Test full path: {}/{}",
s3DestinationConfig.bucketName,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ plugins {
airbyteJavaConnector {
cdkVersionRequired = '0.30.8'
features = ['db-destinations', 's3-destinations']
useLocalCdk = false
useLocalCdk = true
}

airbyteJavaConnector.addCdkDependencies()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ data:
connectorSubtype: file
connectorType: destination
definitionId: 4816b78f-1489-44c1-9060-4b19d5fa9362
dockerImageTag: 0.6.1
dockerImageTag: 0.6.2
dockerRepository: airbyte/destination-s3
githubIssueLabel: destination-s3
icon: s3.svg
Expand Down
Loading

0 comments on commit 0870cd1

Please sign in to comment.