diff --git a/doc/reference.md b/doc/reference.md index f6a2856a9..1abc877e7 100644 --- a/doc/reference.md +++ b/doc/reference.md @@ -227,6 +227,11 @@ OSS Cassandra this should never be used. true Sets whether to record connector specific metrics on write + + input.ratelimiterprovider + main.scala.com.datastax.spark.connector.writer.LeakyBucketRateLimiterProvider + Determines which rate limiter provider to use in reads + input.reads_per_sec 2147483647 @@ -314,6 +319,11 @@ finer control see the CassandraOption class true Sets whether to record connector specific metrics on write + + output.ratelimiterprovider + main.scala.com.datastax.spark.connector.writer.LeakyBucketRateLimiterProvider + Determines which rate limiter provider to use in writes + output.throughput_mb_per_sec 2.147483647E9 diff --git a/spark-cassandra-connector/src/main/java/com/datastax/spark/connector/japi/RDDAndDStreamCommonJavaFunctions.java b/spark-cassandra-connector/src/main/java/com/datastax/spark/connector/japi/RDDAndDStreamCommonJavaFunctions.java index 045263039..f7e07abd6 100644 --- a/spark-cassandra-connector/src/main/java/com/datastax/spark/connector/japi/RDDAndDStreamCommonJavaFunctions.java +++ b/spark-cassandra-connector/src/main/java/com/datastax/spark/connector/japi/RDDAndDStreamCommonJavaFunctions.java @@ -178,7 +178,7 @@ public WriterBuilder withBatchSize(BatchSize batchSize) { new WriteConf(batchSize, writeConf.batchGroupingBufferSize(), writeConf.batchGroupingKey(), writeConf.consistencyLevel(), writeConf.ifNotExists(), writeConf.ignoreNulls(), writeConf.parallelismLevel(), writeConf.throughputMiBPS(), writeConf.ttl(), writeConf.timestamp(), - writeConf.taskMetricsEnabled())); + writeConf.taskMetricsEnabled(), writeConf.rateLimiterProvider())); else return this; } @@ -197,7 +197,7 @@ public WriterBuilder withBatchGroupingBufferSize(int batchGroupingBufferSize) { new WriteConf(writeConf.batchSize(), batchGroupingBufferSize, writeConf.batchGroupingKey(), writeConf.consistencyLevel(), writeConf.ifNotExists(), writeConf.ignoreNulls(), writeConf.parallelismLevel(), writeConf.throughputMiBPS(), writeConf.ttl(), writeConf.timestamp(), - writeConf.taskMetricsEnabled())); + writeConf.taskMetricsEnabled(), writeConf.rateLimiterProvider())); else return this; } @@ -216,7 +216,7 @@ public WriterBuilder withBatchGroupingKey(BatchGroupingKey batchGroupingKey) { new WriteConf(writeConf.batchSize(), writeConf.batchGroupingBufferSize(), batchGroupingKey, writeConf.consistencyLevel(), writeConf.ifNotExists(), writeConf.ignoreNulls(), writeConf.parallelismLevel(), writeConf.throughputMiBPS(), writeConf.ttl(), writeConf.timestamp(), - writeConf.taskMetricsEnabled())); + writeConf.taskMetricsEnabled(), writeConf.rateLimiterProvider())); else return this; } @@ -234,7 +234,8 @@ public WriterBuilder withConsistencyLevel(ConsistencyLevel consistencyLevel) { return withWriteConf( new WriteConf(writeConf.batchSize(), writeConf.batchGroupingBufferSize(), writeConf.batchGroupingKey(), consistencyLevel, writeConf.ifNotExists(), writeConf.ignoreNulls(), writeConf.parallelismLevel(), - writeConf.throughputMiBPS(), writeConf.ttl(), writeConf.timestamp(), writeConf.taskMetricsEnabled())); + writeConf.throughputMiBPS(), writeConf.ttl(), writeConf.timestamp(), writeConf.taskMetricsEnabled(), + writeConf.rateLimiterProvider())); else return this; } @@ -252,7 +253,8 @@ public WriterBuilder withParallelismLevel(int parallelismLevel) { return withWriteConf( new WriteConf(writeConf.batchSize(), writeConf.batchGroupingBufferSize(), writeConf.batchGroupingKey(), writeConf.consistencyLevel(), writeConf.ifNotExists(), writeConf.ignoreNulls(), parallelismLevel, - writeConf.throughputMiBPS(), writeConf.ttl(), writeConf.timestamp(), writeConf.taskMetricsEnabled())); + writeConf.throughputMiBPS(), writeConf.ttl(), writeConf.timestamp(), writeConf.taskMetricsEnabled(), + writeConf.rateLimiterProvider())); else return this; } @@ -271,7 +273,7 @@ public WriterBuilder withThroughputMBPS(int throughputMBPS) { new WriteConf(writeConf.batchSize(), writeConf.batchGroupingBufferSize(), writeConf.batchGroupingKey(), writeConf.consistencyLevel(), writeConf.ifNotExists(), writeConf.ignoreNulls(), writeConf.parallelismLevel(), throughputMBPS, writeConf.ttl(), writeConf.timestamp(), - writeConf.taskMetricsEnabled())); + writeConf.taskMetricsEnabled(), writeConf.rateLimiterProvider())); else return this; } @@ -290,7 +292,7 @@ public WriterBuilder withTaskMetricsEnabled(boolean taskMetricsEnabled) { new WriteConf(writeConf.batchSize(), writeConf.batchGroupingBufferSize(), writeConf.batchGroupingKey(), writeConf.consistencyLevel(), writeConf.ifNotExists(), writeConf.ignoreNulls(), writeConf.parallelismLevel(), writeConf.throughputMiBPS(), writeConf.ttl(), - writeConf.timestamp(), taskMetricsEnabled)); + writeConf.timestamp(), taskMetricsEnabled, writeConf.rateLimiterProvider())); else return this; } @@ -307,7 +309,7 @@ public WriterBuilder withIfNotExists(boolean ifNotExists) { return withWriteConf( new WriteConf(writeConf.batchSize(), writeConf.batchGroupingBufferSize(), writeConf.batchGroupingKey(), writeConf.consistencyLevel(), ifNotExists, writeConf.ignoreNulls(), writeConf.parallelismLevel(), - writeConf.throughputMiBPS(), writeConf.ttl(), writeConf.timestamp(), writeConf.taskMetricsEnabled())); + writeConf.throughputMiBPS(), writeConf.ttl(), writeConf.timestamp(), writeConf.taskMetricsEnabled(), writeConf.rateLimiterProvider())); else return this; } @@ -324,7 +326,7 @@ public WriterBuilder withIgnoreNulls(boolean ignoreNulls) { return withWriteConf( new WriteConf(writeConf.batchSize(), writeConf.batchGroupingBufferSize(), writeConf.batchGroupingKey(), writeConf.consistencyLevel(), writeConf.ifNotExists(), ignoreNulls, writeConf.parallelismLevel(), - writeConf.throughputMiBPS(), writeConf.ttl(), writeConf.timestamp(), writeConf.taskMetricsEnabled())); + writeConf.throughputMiBPS(), writeConf.ttl(), writeConf.timestamp(), writeConf.taskMetricsEnabled(), writeConf.rateLimiterProvider())); else return this; } @@ -343,7 +345,8 @@ private WriterBuilder withTimestamp(TimestampOption timestamp) { writeConf.throughputMiBPS(), writeConf.ttl(), timestamp, - writeConf.taskMetricsEnabled())); + writeConf.taskMetricsEnabled(), + writeConf.rateLimiterProvider())); } @@ -424,7 +427,8 @@ private WriterBuilder withTTL(TTLOption ttl) { writeConf.throughputMiBPS(), ttl, writeConf.timestamp(), - writeConf.taskMetricsEnabled())); + writeConf.taskMetricsEnabled(), + writeConf.rateLimiterProvider())); } /** diff --git a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/RDDFunctions.scala b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/RDDFunctions.scala index 3cc6646aa..2f7819db1 100644 --- a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/RDDFunctions.scala +++ b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/RDDFunctions.scala @@ -31,7 +31,6 @@ class RDDFunctions[T](rdd: RDD[T]) extends WritableToCassandra[T] with Serializa implicit connector: CassandraConnector = CassandraConnector(sparkContext), rwf: RowWriterFactory[T]): Unit = { - val writer = TableWriter(connector, keyspaceName, tableName, columns, writeConf) rdd.sparkContext.runJob(rdd, writer.write _) } diff --git a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/CassandraJoinRDD.scala b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/CassandraJoinRDD.scala index 31d1c8630..f25aa4bfe 100644 --- a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/CassandraJoinRDD.scala +++ b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/CassandraJoinRDD.scala @@ -5,6 +5,7 @@ import com.datastax.spark.connector._ import com.datastax.spark.connector.cql._ import com.datastax.spark.connector.rdd.reader._ import com.datastax.spark.connector.writer._ +import com.datastax.spark.connector.util.RateLimiterUtil import com.google.common.util.concurrent.{FutureCallback, Futures, SettableFuture} import org.apache.spark.rdd.RDD @@ -118,10 +119,11 @@ class CassandraJoinRDD[L, R] private[connector]( rowMetadata: CassandraRowMetadata, leftIterator: Iterator[L] ): Iterator[(L, R)] = { - val rateLimiter = new RateLimiter( - readConf.readsPerSec, readConf.readsPerSec + val rateLimiter = RateLimiterUtil.getRateLimiter( + readConf.rateLimiterProvider, + readConf.readsPerSec, + readConf.readsPerSec ) - val queryExecutor = QueryExecutor(session, readConf.parallelismLevel, None, None) def pairWithRight(left: L): SettableFuture[Iterator[(L, R)]] = { diff --git a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/CassandraLeftJoinRDD.scala b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/CassandraLeftJoinRDD.scala index 1a76288ea..db7cdb7a3 100644 --- a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/CassandraLeftJoinRDD.scala +++ b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/CassandraLeftJoinRDD.scala @@ -4,6 +4,7 @@ import com.datastax.driver.core.{ResultSet, Session} import com.datastax.spark.connector._ import com.datastax.spark.connector.cql._ import com.datastax.spark.connector.rdd.reader._ +import com.datastax.spark.connector.util.RateLimiterUtil import com.datastax.spark.connector.writer._ import com.google.common.util.concurrent.{FutureCallback, Futures, SettableFuture} import org.apache.spark.rdd.RDD @@ -141,8 +142,10 @@ class CassandraLeftJoinRDD[L, R] private[connector]( rowMetadata: CassandraRowMetadata, leftIterator: Iterator[L] ): Iterator[(L, Option[R])] = { - val rateLimiter = new RateLimiter( - readConf.readsPerSec, readConf.readsPerSec + val rateLimiter = RateLimiterUtil.getRateLimiter( + readConf.rateLimiterProvider, + readConf.readsPerSec, + readConf.readsPerSec ) val queryExecutor = QueryExecutor(session, readConf.parallelismLevel,None, None) diff --git a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/CassandraRDD.scala b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/CassandraRDD.scala index 8ec614c88..0740163fe 100644 --- a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/CassandraRDD.scala +++ b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/CassandraRDD.scala @@ -1,6 +1,5 @@ package com.datastax.spark.connector.rdd -import com.datastax.driver.core.HostDistance import com.datastax.spark.connector.cql._ import com.datastax.spark.connector.rdd.ClusteringOrder.{Ascending, Descending} import com.datastax.spark.connector.rdd.reader._ @@ -27,7 +26,6 @@ abstract class CassandraRDD[R : ClassTag]( ConfigCheck.checkConfig(sc.getConf) - protected[connector] def keyspaceName: String protected[connector] def tableName: String diff --git a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/ReadConf.scala b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/ReadConf.scala index 6c65a28d6..333543aaa 100644 --- a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/ReadConf.scala +++ b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/ReadConf.scala @@ -1,7 +1,8 @@ package com.datastax.spark.connector.rdd import com.datastax.driver.core.ConsistencyLevel -import com.datastax.spark.connector.util.{ConfigParameter, ConfigCheck, Logging} +import com.datastax.spark.connector.util.{ConfigCheck, ConfigParameter, Logging} +import main.scala.com.datastax.spark.connector.writer.LeakyBucketRateLimiterProvider import org.apache.spark.SparkConf /** Read settings for RDD @@ -15,7 +16,9 @@ import org.apache.spark.SparkConf * @param taskMetricsEnabled whether or not enable task metrics updates (requires Spark 1.2+) * @param readsPerSec maximum read throughput allowed per single core in requests/s while * joining an RDD with C* table (joinWithCassandraTable operation) - * also used by enterprise integrations*/ + * also used by enterprise integrations + * @param rateLimiterProvider fully qualified name to a custom rate limiter provider + */ case class ReadConf( splitCount: Option[Int] = None, splitSizeInMB: Int = ReadConf.SplitSizeInMBParam.default, @@ -23,7 +26,8 @@ case class ReadConf( consistencyLevel: ConsistencyLevel = ReadConf.ConsistencyLevelParam.default, taskMetricsEnabled: Boolean = ReadConf.TaskMetricParam.default, parallelismLevel: Int = ReadConf.ParallelismLevelParam.default, - readsPerSec: Int = ReadConf.ReadsPerSecParam.default + readsPerSec: Int = ReadConf.ReadsPerSecParam.default, + rateLimiterProvider: String = ReadConf.RateLimiterProviderParam.default ) @@ -93,6 +97,13 @@ object ReadConf extends Logging { """Sets max requests per core per second for joinWithCassandraTable and some Enterprise integrations""" ) + val RateLimiterProviderParam = ConfigParameter[String] ( + name = "spark.cassandra.input.ratelimiterprovider", + section = ReferenceSection, + default = new LeakyBucketRateLimiterProvider().getClass.getName, + description = """Determines which rate limiter provider to use in reads""" + ) + // Whitelist for allowed Read environment variables val Properties = Set( SplitCountParam, @@ -102,7 +113,8 @@ object ReadConf extends Logging { SplitSizeInMBParam, TaskMetricParam, ThroughputJoinQueryPerSecParam, - ParallelismLevelParam + ParallelismLevelParam, + RateLimiterProviderParam ) def fromSparkConf(conf: SparkConf): ReadConf = { @@ -136,7 +148,8 @@ object ReadConf extends Logging { readsPerSec = conf.getInt(ReadsPerSecParam.name, throughtputJoinQueryPerSec.getOrElse(ReadsPerSecParam.default)), parallelismLevel = conf.getInt(ParallelismLevelParam.name, ParallelismLevelParam.default), - splitCount = conf.getOption(SplitCountParam.name).map(_.toInt) + splitCount = conf.getOption(SplitCountParam.name).map(_.toInt), + rateLimiterProvider = conf.get(RateLimiterProviderParam.name, RateLimiterProviderParam.default) ) } diff --git a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/util/RateLimiterUtil.scala b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/util/RateLimiterUtil.scala new file mode 100644 index 000000000..03e49f5b5 --- /dev/null +++ b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/util/RateLimiterUtil.scala @@ -0,0 +1,39 @@ +package com.datastax.spark.connector.util + +import com.datastax.spark.connector.writer.{BaseRateLimiter, RateLimiterProvider} + +/** + * Exports a method to retrieve a custom rate limiter based on dynamic configuration. + */ +object RateLimiterUtil extends Logging { + var provider:RateLimiterProvider = _ + + /** + * Instantiates a rate limiter provider based on its fully qualified classname and should that not be possible, + * fallbacks to the leaky bucket rate limiter provider in this project. + * + * @param className fully qualified classname of the rate limiter provider to instantiate + * @param args optional sequence of arguments passed on to the provider + * @return an instantiated rate limiter + */ + def getRateLimiter(className: String, args: Any*): BaseRateLimiter = { + try { + provider = Class.forName(className).newInstance.asInstanceOf[RateLimiterProvider] + } catch { + case e:ClassNotFoundException => { + logError("Could not find custom rate limiter provider. Error: " + e) + throw e + } + case e:InstantiationException => { + logError("Could not instantiate custom rate limiter provider. Error: " + e) + throw e + } + case e:Throwable => { + logError("Error: " + e) + throw e + } + } + + provider.getRateLimiterWithConf(args:_*) + } +} \ No newline at end of file diff --git a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/writer/BaseRateLimiter.scala b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/writer/BaseRateLimiter.scala new file mode 100644 index 000000000..855d9a365 --- /dev/null +++ b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/writer/BaseRateLimiter.scala @@ -0,0 +1,16 @@ +package com.datastax.spark.connector.writer + +/** + * Represents a rate limiter. + */ +trait BaseRateLimiter { + + /** + * Processes a single packet and it is up to the implementing class to determine whether + * or not the thread should sleep. + * + * @param packetSize the size of the packet currently being processed + */ + def maybeSleep(packetSize: Long): Unit + +} diff --git a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/writer/RateLimiter.scala b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/writer/LeakyBucketRateLimiter.scala similarity index 92% rename from spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/writer/RateLimiter.scala rename to spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/writer/LeakyBucketRateLimiter.scala index 7d9067b1d..d1052afb9 100644 --- a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/writer/RateLimiter.scala +++ b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/writer/LeakyBucketRateLimiter.scala @@ -22,11 +22,11 @@ import scala.annotation.tailrec * @param sleep a function to call to slow down the calling thread; * must use the same time units as `time` */ -class RateLimiter( - rate: Long, - bucketSize: Long, - time: () => Long = System.currentTimeMillis, - sleep: Long => Any = Thread.sleep) { +class LeakyBucketRateLimiter( + rate: Long, + bucketSize: Long, + time: () => Long = System.currentTimeMillis, + sleep: Long => Any = Thread.sleep) extends BaseRateLimiter { require(rate > 0, "A positive rate is required") require(bucketSize > 0, "A positive bucket size is required") @@ -62,5 +62,4 @@ class RateLimiter( if (delay > 0L) sleep(delay) } - } diff --git a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/writer/LeakyBucketRateLimiterProvider.scala b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/writer/LeakyBucketRateLimiterProvider.scala new file mode 100644 index 000000000..1975c78cb --- /dev/null +++ b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/writer/LeakyBucketRateLimiterProvider.scala @@ -0,0 +1,38 @@ +package main.scala.com.datastax.spark.connector.writer + +import com.datastax.spark.connector.util.Logging +import com.datastax.spark.connector.writer.{BaseRateLimiter, LeakyBucketRateLimiter, RateLimiterProvider} + +/** + * Instantiates a leaky bucket rate limiter based on the supplied configuration. + */ +class LeakyBucketRateLimiterProvider extends RateLimiterProvider with Logging { + {} + + override def getRateLimiterWithConf(args: Any*): BaseRateLimiter = { + val rate = args(0).asInstanceOf[Number].longValue + val bucketSize = args(1).asInstanceOf[Number].longValue + + /** + * If optional arguments are present and cannot be casted correctly, + * omit them and instantiate rate limiter with only rate and bucketSize + */ + try { + if (args.size > 2) { + val time = args(2).asInstanceOf[() => Long] + if (args.size > 3) { + val sleep = args(3).asInstanceOf[Long => Any] + new LeakyBucketRateLimiter(rate, bucketSize, time, sleep) + } + new LeakyBucketRateLimiter(rate, bucketSize, time) + } + } catch { + case _: Exception => { + logError("Invalid optional arguments when instantiating leaky bucket rate limiter") + new LeakyBucketRateLimiter(rate, bucketSize) + } + } + + new LeakyBucketRateLimiter(rate, bucketSize) + } +} diff --git a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/writer/RateLimiterProvider.scala b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/writer/RateLimiterProvider.scala new file mode 100644 index 000000000..1a4563dde --- /dev/null +++ b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/writer/RateLimiterProvider.scala @@ -0,0 +1,14 @@ +package com.datastax.spark.connector.writer + +/** + * Represents a provider that creates and returns a rate limiter with possible configuration. + */ +trait RateLimiterProvider { + /** + * Given a set of arguments, instantiates and returns a rate limiter. + * + * @param args sequence of arguments that can customize the returned rate limiter + * @return the created rate limiter + */ + def getRateLimiterWithConf(args: Any*): BaseRateLimiter +} \ No newline at end of file diff --git a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/writer/TableWriter.scala b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/writer/TableWriter.scala index 097bd4102..5ace698ce 100644 --- a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/writer/TableWriter.scala +++ b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/writer/TableWriter.scala @@ -8,7 +8,7 @@ import com.datastax.spark.connector._ import com.datastax.spark.connector.cql._ import com.datastax.spark.connector.types.{CollectionColumnType, ListType, MapType} import com.datastax.spark.connector.util.Quote._ -import com.datastax.spark.connector.util.{CountingIterator, Logging} +import com.datastax.spark.connector.util.{CountingIterator, Logging, RateLimiterUtil} import org.apache.spark.TaskContext import org.apache.spark.metrics.OutputMetricsUpdater @@ -193,8 +193,9 @@ class TableWriter[T] private ( /** * Write data with Cql INSERT statement */ - def insert(taskContext: TaskContext, data: Iterator[T]):Unit = + def insert(taskContext: TaskContext, data: Iterator[T]):Unit = { writeInternal(queryTemplateUsingInsert, taskContext, data) + } /** * Cql DELETE statement @@ -226,8 +227,11 @@ class TableWriter[T] private ( val batchKeyGenerator = batchRoutingKey(session, routingKeyGenerator) _ val batchBuilder = new GroupingBatchBuilder(boundStmtBuilder, batchStmtBuilder, batchKeyGenerator, writeConf.batchSize, writeConf.batchGroupingBufferSize, rowIterator) - val rateLimiter = new RateLimiter((writeConf.throughputMiBPS * 1024 * 1024).toLong, 1024 * 1024) - + val rateLimiter = RateLimiterUtil.getRateLimiter( + writeConf.rateLimiterProvider, + (writeConf.throughputMiBPS * 1024 * 1024).toLong, + 1024 * 1024 + ) logDebug(s"Writing data partition to $keyspaceName.$tableName in batches of ${writeConf.batchSize}.") for (stmtToWrite <- batchBuilder) { diff --git a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/writer/WriteConf.scala b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/writer/WriteConf.scala index 2d39ba8e0..b33ae40f1 100644 --- a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/writer/WriteConf.scala +++ b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/writer/WriteConf.scala @@ -2,9 +2,11 @@ package com.datastax.spark.connector.writer import com.datastax.driver.core.{ConsistencyLevel, DataType} import com.datastax.spark.connector.cql.{ColumnDef, RegularColumn} +import com.datastax.spark.connector.rdd.ReadConf import com.datastax.spark.connector.types.ColumnType -import com.datastax.spark.connector.util.{ConfigParameter, ConfigCheck} +import com.datastax.spark.connector.util.{ConfigCheck, ConfigParameter} import com.datastax.spark.connector.{BatchSize, BytesInBatch, RowsInBatch} +import main.scala.com.datastax.spark.connector.writer.LeakyBucketRateLimiterProvider import org.apache.commons.configuration.ConfigurationException import org.apache.spark.SparkConf @@ -21,6 +23,7 @@ import org.apache.spark.SparkConf * @param ttl the default TTL value which is used when it is defined (in seconds) * @param timestamp the default timestamp value which is used when it is defined (in microseconds) * @param taskMetricsEnabled whether or not enable task metrics updates (requires Spark 1.2+) + * @param rateLimiterProvider fully qualified name to a custom rate limiter provider */ case class WriteConf(batchSize: BatchSize = BatchSize.Automatic, @@ -33,7 +36,8 @@ case class WriteConf(batchSize: BatchSize = BatchSize.Automatic, throughputMiBPS: Double = WriteConf.ThroughputMiBPSParam.default, ttl: TTLOption = TTLOption.defaultValue, timestamp: TimestampOption = TimestampOption.defaultValue, - taskMetricsEnabled: Boolean = WriteConf.TaskMetricsParam.default) { + taskMetricsEnabled: Boolean = WriteConf.TaskMetricsParam.default, + rateLimiterProvider: String = WriteConf.RateLimiterProviderParam.default) { private[writer] val optionPlaceholders: Seq[String] = Seq(ttl, timestamp).collect { case WriteOption(PerRowWriteOptionValue(placeholder)) => placeholder @@ -123,7 +127,7 @@ object WriteConf { default = 5, description = """Maximum number of batches executed in parallel by a | single Spark task""".stripMargin) - + val ThroughputMiBPSParam = ConfigParameter[Double] ( name = "spark.cassandra.output.throughput_mb_per_sec", section = ReferenceSection, @@ -154,6 +158,13 @@ object WriteConf { description = """Sets whether to record connector specific metrics on write""" ) + val RateLimiterProviderParam = ConfigParameter[String]( + name = "spark.cassandra.output.ratelimiterprovider", + section = ReferenceSection, + default = new LeakyBucketRateLimiterProvider().getClass.getName, + description = """Determines which rate limiter provider to use in writes""" + ) + // Whitelist for allowed Write environment variables val Properties: Set[ConfigParameter[_]] = Set( BatchSizeBytesParam, @@ -167,7 +178,8 @@ object WriteConf { ThroughputMiBPSParam, TTLParam, TimestampParam, - TaskMetricsParam + TaskMetricsParam, + RateLimiterProviderParam ) def fromSparkConf(conf: SparkConf): WriteConf = { @@ -215,7 +227,7 @@ object WriteConf { TTLOption.defaultValue else TTLOption.constant(ttlSeconds) - + val timestampMicros = conf.getLong(TimestampParam.name, TimestampParam.default) val timestampOption = @@ -224,6 +236,8 @@ object WriteConf { else TimestampOption.constant(timestampMicros) + val rateLimiterProvider = conf.get(RateLimiterProviderParam.name, RateLimiterProviderParam.default) + WriteConf( batchSize = batchSize, batchGroupingBufferSize = batchBufferSize, @@ -235,7 +249,8 @@ object WriteConf { ttl = ttlOption, timestamp = timestampOption, ignoreNulls = ignoreNulls, - ifNotExists = ifNotExists) + ifNotExists = ifNotExists, + rateLimiterProvider = rateLimiterProvider) } } diff --git a/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/rdd/ReadConfTest.scala b/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/rdd/ReadConfTest.scala index 5959a366f..e581bb696 100644 --- a/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/rdd/ReadConfTest.scala +++ b/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/rdd/ReadConfTest.scala @@ -1,5 +1,6 @@ package com.datastax.spark.connector.rdd +import com.datastax.spark.connector.writer.WriteConf import org.apache.spark.SparkConf import org.scalatest.{FlatSpec, Matchers} @@ -27,4 +28,11 @@ class ReadConfTest extends FlatSpec with Matchers { readConf.readsPerSec should be (expected) } + it should "allow to set custom rate limiter provider" in { + val conf = new SparkConf(false) + .set("spark.cassandra.input.ratelimiterprovider", "custom.ratelimiter.provider") + val readConf = ReadConf.fromSparkConf(conf) + readConf.rateLimiterProvider should be("custom.ratelimiter.provider") + } + } diff --git a/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/util/RateLimiterUtilSpec.scala b/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/util/RateLimiterUtilSpec.scala new file mode 100644 index 000000000..7edf8879d --- /dev/null +++ b/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/util/RateLimiterUtilSpec.scala @@ -0,0 +1,40 @@ +package com.datastax.spark.connector.util + +import java.lang.Thread.sleep + +import org.scalamock.scalatest.MockFactory +import org.scalatest.{FlatSpec, Matchers} +import com.datastax.spark.connector.writer.{BaseRateLimiter, RateLimiterProvider} +import main.scala.com.datastax.spark.connector.writer.LeakyBucketRateLimiterProvider + +class RateLimiterUtilSpec extends FlatSpec with Matchers { + + "RateLimiterUtil" should "return a custom rate limiter provider should that be specified" in { + val mockProvider = new MockProvider() + val rateLimiter = RateLimiterUtil.getRateLimiter(mockProvider.getClass.getName) + rateLimiter.getClass.getName should equal (mockProvider.getRateLimiterWithConf().getClass.getName) + } + + it should "throw an error when custom rate limiter provider cannot be instantiated" in { + a [ClassNotFoundException] should be thrownBy RateLimiterUtil.getRateLimiter("non.existing.class") + an [InstantiationException] should be thrownBy RateLimiterUtil.getRateLimiter(NonInstantiable.getClass.getName) + } + + // mock object that cannot be instantiated + object NonInstantiable {} +} + +// mock provider with public constructor that can be instantiated +class MockProvider extends RateLimiterProvider { + {} + + override def getRateLimiterWithConf(args: Any*): BaseRateLimiter = { + new MockRateLimiter + } +} + +// mock rate limiter that is returned by MockProvider +class MockRateLimiter extends BaseRateLimiter { + override def maybeSleep(packetSize: Long): Unit = {} +} + diff --git a/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/writer/RateLimiterSpec.scala b/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/writer/LeakyBucketRateLimiterSpec.scala similarity index 83% rename from spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/writer/RateLimiterSpec.scala rename to spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/writer/LeakyBucketRateLimiterSpec.scala index 95cfe0763..78017d2ad 100644 --- a/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/writer/RateLimiterSpec.scala +++ b/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/writer/LeakyBucketRateLimiterSpec.scala @@ -5,7 +5,7 @@ import org.scalatest.concurrent.Eventually import org.scalatest.{FlatSpec, Matchers} -class RateLimiterSpec extends FlatSpec with Matchers with MockFactory with Eventually{ +class LeakyBucketRateLimiterSpec extends FlatSpec with Matchers with MockFactory with Eventually{ val TestRates = Seq(1L, 2L, 4L, 6L, 8L, 16L, 32L, WriteConf.ThroughputMiBPSParam.default.toLong) @@ -14,7 +14,7 @@ class RateLimiterSpec extends FlatSpec with Matchers with MockFactory with Event val sleep = mockFunction[Long, Any]("sleep") sleep.expects(*).never() - val limiter = new RateLimiter(Long.MaxValue, 1000, () => now, sleep) + val limiter = new LeakyBucketRateLimiter(Long.MaxValue, 1000, () => now, sleep) for (i <- 1 to 1000000) { now += 1 limiter.maybeSleep(1000) @@ -33,7 +33,7 @@ class RateLimiterSpec extends FlatSpec with Matchers with MockFactory with Event // 10 units per second + 5 units burst allowed val bucketSize = 5 val rate = 10 - val limiter = new RateLimiter(rate, bucketSize, () => now, sleep) + val limiter = new LeakyBucketRateLimiter(rate, bucketSize, () => now, sleep) val iterations = 25 for (i <- 1 to iterations) @@ -53,7 +53,7 @@ class RateLimiterSpec extends FlatSpec with Matchers with MockFactory with Event now += delay } - val limiter = new RateLimiter(rate, rate * 2, () => now, sleep) + val limiter = new LeakyBucketRateLimiter(rate, rate * 2, () => now, sleep) for (leakNum <- 1 to 1000) { assert( limiter.bucketFill.get() >= 0, diff --git a/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/writer/WriteConfTest.scala b/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/writer/WriteConfTest.scala index c4f5887e3..d2b290ccd 100644 --- a/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/writer/WriteConfTest.scala +++ b/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/writer/WriteConfTest.scala @@ -78,5 +78,12 @@ class WriteConfTest extends FlatSpec with Matchers { writeConf.batchGroupingBufferSize should be(30000) } + it should "allow to set custom rate limiter provider" in { + val conf = new SparkConf(false) + .set("spark.cassandra.output.ratelimiterprovider", "custom.ratelimiter.provider") + val writeConf = WriteConf.fromSparkConf(conf) + writeConf.rateLimiterProvider should be("custom.ratelimiter.provider") + } + }