diff --git a/doc/reference.md b/doc/reference.md
index f6a2856a9..1abc877e7 100644
--- a/doc/reference.md
+++ b/doc/reference.md
@@ -227,6 +227,11 @@ OSS Cassandra this should never be used.
true |
Sets whether to record connector specific metrics on write |
+
+ input.ratelimiterprovider |
+ main.scala.com.datastax.spark.connector.writer.LeakyBucketRateLimiterProvider |
+ Determines which rate limiter provider to use in reads |
+
input.reads_per_sec |
2147483647 |
@@ -314,6 +319,11 @@ finer control see the CassandraOption class
true |
Sets whether to record connector specific metrics on write |
+
+ output.ratelimiterprovider |
+ main.scala.com.datastax.spark.connector.writer.LeakyBucketRateLimiterProvider |
+ Determines which rate limiter provider to use in writes |
+
output.throughput_mb_per_sec |
2.147483647E9 |
diff --git a/spark-cassandra-connector/src/main/java/com/datastax/spark/connector/japi/RDDAndDStreamCommonJavaFunctions.java b/spark-cassandra-connector/src/main/java/com/datastax/spark/connector/japi/RDDAndDStreamCommonJavaFunctions.java
index 045263039..f7e07abd6 100644
--- a/spark-cassandra-connector/src/main/java/com/datastax/spark/connector/japi/RDDAndDStreamCommonJavaFunctions.java
+++ b/spark-cassandra-connector/src/main/java/com/datastax/spark/connector/japi/RDDAndDStreamCommonJavaFunctions.java
@@ -178,7 +178,7 @@ public WriterBuilder withBatchSize(BatchSize batchSize) {
new WriteConf(batchSize, writeConf.batchGroupingBufferSize(), writeConf.batchGroupingKey(),
writeConf.consistencyLevel(), writeConf.ifNotExists(), writeConf.ignoreNulls(),
writeConf.parallelismLevel(), writeConf.throughputMiBPS(), writeConf.ttl(), writeConf.timestamp(),
- writeConf.taskMetricsEnabled()));
+ writeConf.taskMetricsEnabled(), writeConf.rateLimiterProvider()));
else
return this;
}
@@ -197,7 +197,7 @@ public WriterBuilder withBatchGroupingBufferSize(int batchGroupingBufferSize) {
new WriteConf(writeConf.batchSize(), batchGroupingBufferSize, writeConf.batchGroupingKey(),
writeConf.consistencyLevel(), writeConf.ifNotExists(), writeConf.ignoreNulls(),
writeConf.parallelismLevel(), writeConf.throughputMiBPS(), writeConf.ttl(), writeConf.timestamp(),
- writeConf.taskMetricsEnabled()));
+ writeConf.taskMetricsEnabled(), writeConf.rateLimiterProvider()));
else
return this;
}
@@ -216,7 +216,7 @@ public WriterBuilder withBatchGroupingKey(BatchGroupingKey batchGroupingKey) {
new WriteConf(writeConf.batchSize(), writeConf.batchGroupingBufferSize(), batchGroupingKey,
writeConf.consistencyLevel(), writeConf.ifNotExists(), writeConf.ignoreNulls(),
writeConf.parallelismLevel(), writeConf.throughputMiBPS(), writeConf.ttl(), writeConf.timestamp(),
- writeConf.taskMetricsEnabled()));
+ writeConf.taskMetricsEnabled(), writeConf.rateLimiterProvider()));
else
return this;
}
@@ -234,7 +234,8 @@ public WriterBuilder withConsistencyLevel(ConsistencyLevel consistencyLevel) {
return withWriteConf(
new WriteConf(writeConf.batchSize(), writeConf.batchGroupingBufferSize(), writeConf.batchGroupingKey(),
consistencyLevel, writeConf.ifNotExists(), writeConf.ignoreNulls(), writeConf.parallelismLevel(),
- writeConf.throughputMiBPS(), writeConf.ttl(), writeConf.timestamp(), writeConf.taskMetricsEnabled()));
+ writeConf.throughputMiBPS(), writeConf.ttl(), writeConf.timestamp(), writeConf.taskMetricsEnabled(),
+ writeConf.rateLimiterProvider()));
else
return this;
}
@@ -252,7 +253,8 @@ public WriterBuilder withParallelismLevel(int parallelismLevel) {
return withWriteConf(
new WriteConf(writeConf.batchSize(), writeConf.batchGroupingBufferSize(), writeConf.batchGroupingKey(),
writeConf.consistencyLevel(), writeConf.ifNotExists(), writeConf.ignoreNulls(), parallelismLevel,
- writeConf.throughputMiBPS(), writeConf.ttl(), writeConf.timestamp(), writeConf.taskMetricsEnabled()));
+ writeConf.throughputMiBPS(), writeConf.ttl(), writeConf.timestamp(), writeConf.taskMetricsEnabled(),
+ writeConf.rateLimiterProvider()));
else
return this;
}
@@ -271,7 +273,7 @@ public WriterBuilder withThroughputMBPS(int throughputMBPS) {
new WriteConf(writeConf.batchSize(), writeConf.batchGroupingBufferSize(), writeConf.batchGroupingKey(),
writeConf.consistencyLevel(), writeConf.ifNotExists(), writeConf.ignoreNulls(),
writeConf.parallelismLevel(), throughputMBPS, writeConf.ttl(), writeConf.timestamp(),
- writeConf.taskMetricsEnabled()));
+ writeConf.taskMetricsEnabled(), writeConf.rateLimiterProvider()));
else
return this;
}
@@ -290,7 +292,7 @@ public WriterBuilder withTaskMetricsEnabled(boolean taskMetricsEnabled) {
new WriteConf(writeConf.batchSize(), writeConf.batchGroupingBufferSize(), writeConf.batchGroupingKey(),
writeConf.consistencyLevel(), writeConf.ifNotExists(), writeConf.ignoreNulls(),
writeConf.parallelismLevel(), writeConf.throughputMiBPS(), writeConf.ttl(),
- writeConf.timestamp(), taskMetricsEnabled));
+ writeConf.timestamp(), taskMetricsEnabled, writeConf.rateLimiterProvider()));
else
return this;
}
@@ -307,7 +309,7 @@ public WriterBuilder withIfNotExists(boolean ifNotExists) {
return withWriteConf(
new WriteConf(writeConf.batchSize(), writeConf.batchGroupingBufferSize(), writeConf.batchGroupingKey(),
writeConf.consistencyLevel(), ifNotExists, writeConf.ignoreNulls(), writeConf.parallelismLevel(),
- writeConf.throughputMiBPS(), writeConf.ttl(), writeConf.timestamp(), writeConf.taskMetricsEnabled()));
+ writeConf.throughputMiBPS(), writeConf.ttl(), writeConf.timestamp(), writeConf.taskMetricsEnabled(), writeConf.rateLimiterProvider()));
else
return this;
}
@@ -324,7 +326,7 @@ public WriterBuilder withIgnoreNulls(boolean ignoreNulls) {
return withWriteConf(
new WriteConf(writeConf.batchSize(), writeConf.batchGroupingBufferSize(), writeConf.batchGroupingKey(),
writeConf.consistencyLevel(), writeConf.ifNotExists(), ignoreNulls, writeConf.parallelismLevel(),
- writeConf.throughputMiBPS(), writeConf.ttl(), writeConf.timestamp(), writeConf.taskMetricsEnabled()));
+ writeConf.throughputMiBPS(), writeConf.ttl(), writeConf.timestamp(), writeConf.taskMetricsEnabled(), writeConf.rateLimiterProvider()));
else
return this;
}
@@ -343,7 +345,8 @@ private WriterBuilder withTimestamp(TimestampOption timestamp) {
writeConf.throughputMiBPS(),
writeConf.ttl(),
timestamp,
- writeConf.taskMetricsEnabled()));
+ writeConf.taskMetricsEnabled(),
+ writeConf.rateLimiterProvider()));
}
@@ -424,7 +427,8 @@ private WriterBuilder withTTL(TTLOption ttl) {
writeConf.throughputMiBPS(),
ttl,
writeConf.timestamp(),
- writeConf.taskMetricsEnabled()));
+ writeConf.taskMetricsEnabled(),
+ writeConf.rateLimiterProvider()));
}
/**
diff --git a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/RDDFunctions.scala b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/RDDFunctions.scala
index 3cc6646aa..2f7819db1 100644
--- a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/RDDFunctions.scala
+++ b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/RDDFunctions.scala
@@ -31,7 +31,6 @@ class RDDFunctions[T](rdd: RDD[T]) extends WritableToCassandra[T] with Serializa
implicit
connector: CassandraConnector = CassandraConnector(sparkContext),
rwf: RowWriterFactory[T]): Unit = {
-
val writer = TableWriter(connector, keyspaceName, tableName, columns, writeConf)
rdd.sparkContext.runJob(rdd, writer.write _)
}
diff --git a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/CassandraJoinRDD.scala b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/CassandraJoinRDD.scala
index 31d1c8630..f25aa4bfe 100644
--- a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/CassandraJoinRDD.scala
+++ b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/CassandraJoinRDD.scala
@@ -5,6 +5,7 @@ import com.datastax.spark.connector._
import com.datastax.spark.connector.cql._
import com.datastax.spark.connector.rdd.reader._
import com.datastax.spark.connector.writer._
+import com.datastax.spark.connector.util.RateLimiterUtil
import com.google.common.util.concurrent.{FutureCallback, Futures, SettableFuture}
import org.apache.spark.rdd.RDD
@@ -118,10 +119,11 @@ class CassandraJoinRDD[L, R] private[connector](
rowMetadata: CassandraRowMetadata,
leftIterator: Iterator[L]
): Iterator[(L, R)] = {
- val rateLimiter = new RateLimiter(
- readConf.readsPerSec, readConf.readsPerSec
+ val rateLimiter = RateLimiterUtil.getRateLimiter(
+ readConf.rateLimiterProvider,
+ readConf.readsPerSec,
+ readConf.readsPerSec
)
-
val queryExecutor = QueryExecutor(session, readConf.parallelismLevel, None, None)
def pairWithRight(left: L): SettableFuture[Iterator[(L, R)]] = {
diff --git a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/CassandraLeftJoinRDD.scala b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/CassandraLeftJoinRDD.scala
index 1a76288ea..db7cdb7a3 100644
--- a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/CassandraLeftJoinRDD.scala
+++ b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/CassandraLeftJoinRDD.scala
@@ -4,6 +4,7 @@ import com.datastax.driver.core.{ResultSet, Session}
import com.datastax.spark.connector._
import com.datastax.spark.connector.cql._
import com.datastax.spark.connector.rdd.reader._
+import com.datastax.spark.connector.util.RateLimiterUtil
import com.datastax.spark.connector.writer._
import com.google.common.util.concurrent.{FutureCallback, Futures, SettableFuture}
import org.apache.spark.rdd.RDD
@@ -141,8 +142,10 @@ class CassandraLeftJoinRDD[L, R] private[connector](
rowMetadata: CassandraRowMetadata,
leftIterator: Iterator[L]
): Iterator[(L, Option[R])] = {
- val rateLimiter = new RateLimiter(
- readConf.readsPerSec, readConf.readsPerSec
+ val rateLimiter = RateLimiterUtil.getRateLimiter(
+ readConf.rateLimiterProvider,
+ readConf.readsPerSec,
+ readConf.readsPerSec
)
val queryExecutor = QueryExecutor(session, readConf.parallelismLevel,None, None)
diff --git a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/CassandraRDD.scala b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/CassandraRDD.scala
index 8ec614c88..0740163fe 100644
--- a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/CassandraRDD.scala
+++ b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/CassandraRDD.scala
@@ -1,6 +1,5 @@
package com.datastax.spark.connector.rdd
-import com.datastax.driver.core.HostDistance
import com.datastax.spark.connector.cql._
import com.datastax.spark.connector.rdd.ClusteringOrder.{Ascending, Descending}
import com.datastax.spark.connector.rdd.reader._
@@ -27,7 +26,6 @@ abstract class CassandraRDD[R : ClassTag](
ConfigCheck.checkConfig(sc.getConf)
-
protected[connector] def keyspaceName: String
protected[connector] def tableName: String
diff --git a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/ReadConf.scala b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/ReadConf.scala
index 6c65a28d6..333543aaa 100644
--- a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/ReadConf.scala
+++ b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/rdd/ReadConf.scala
@@ -1,7 +1,8 @@
package com.datastax.spark.connector.rdd
import com.datastax.driver.core.ConsistencyLevel
-import com.datastax.spark.connector.util.{ConfigParameter, ConfigCheck, Logging}
+import com.datastax.spark.connector.util.{ConfigCheck, ConfigParameter, Logging}
+import main.scala.com.datastax.spark.connector.writer.LeakyBucketRateLimiterProvider
import org.apache.spark.SparkConf
/** Read settings for RDD
@@ -15,7 +16,9 @@ import org.apache.spark.SparkConf
* @param taskMetricsEnabled whether or not enable task metrics updates (requires Spark 1.2+)
* @param readsPerSec maximum read throughput allowed per single core in requests/s while
* joining an RDD with C* table (joinWithCassandraTable operation)
- * also used by enterprise integrations*/
+ * also used by enterprise integrations
+ * @param rateLimiterProvider fully qualified name to a custom rate limiter provider
+ */
case class ReadConf(
splitCount: Option[Int] = None,
splitSizeInMB: Int = ReadConf.SplitSizeInMBParam.default,
@@ -23,7 +26,8 @@ case class ReadConf(
consistencyLevel: ConsistencyLevel = ReadConf.ConsistencyLevelParam.default,
taskMetricsEnabled: Boolean = ReadConf.TaskMetricParam.default,
parallelismLevel: Int = ReadConf.ParallelismLevelParam.default,
- readsPerSec: Int = ReadConf.ReadsPerSecParam.default
+ readsPerSec: Int = ReadConf.ReadsPerSecParam.default,
+ rateLimiterProvider: String = ReadConf.RateLimiterProviderParam.default
)
@@ -93,6 +97,13 @@ object ReadConf extends Logging {
"""Sets max requests per core per second for joinWithCassandraTable and some Enterprise integrations"""
)
+ val RateLimiterProviderParam = ConfigParameter[String] (
+ name = "spark.cassandra.input.ratelimiterprovider",
+ section = ReferenceSection,
+ default = new LeakyBucketRateLimiterProvider().getClass.getName,
+ description = """Determines which rate limiter provider to use in reads"""
+ )
+
// Whitelist for allowed Read environment variables
val Properties = Set(
SplitCountParam,
@@ -102,7 +113,8 @@ object ReadConf extends Logging {
SplitSizeInMBParam,
TaskMetricParam,
ThroughputJoinQueryPerSecParam,
- ParallelismLevelParam
+ ParallelismLevelParam,
+ RateLimiterProviderParam
)
def fromSparkConf(conf: SparkConf): ReadConf = {
@@ -136,7 +148,8 @@ object ReadConf extends Logging {
readsPerSec = conf.getInt(ReadsPerSecParam.name,
throughtputJoinQueryPerSec.getOrElse(ReadsPerSecParam.default)),
parallelismLevel = conf.getInt(ParallelismLevelParam.name, ParallelismLevelParam.default),
- splitCount = conf.getOption(SplitCountParam.name).map(_.toInt)
+ splitCount = conf.getOption(SplitCountParam.name).map(_.toInt),
+ rateLimiterProvider = conf.get(RateLimiterProviderParam.name, RateLimiterProviderParam.default)
)
}
diff --git a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/util/RateLimiterUtil.scala b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/util/RateLimiterUtil.scala
new file mode 100644
index 000000000..03e49f5b5
--- /dev/null
+++ b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/util/RateLimiterUtil.scala
@@ -0,0 +1,39 @@
+package com.datastax.spark.connector.util
+
+import com.datastax.spark.connector.writer.{BaseRateLimiter, RateLimiterProvider}
+
+/**
+ * Exports a method to retrieve a custom rate limiter based on dynamic configuration.
+ */
+object RateLimiterUtil extends Logging {
+ var provider:RateLimiterProvider = _
+
+ /**
+ * Instantiates a rate limiter provider based on its fully qualified classname and should that not be possible,
+ * fallbacks to the leaky bucket rate limiter provider in this project.
+ *
+ * @param className fully qualified classname of the rate limiter provider to instantiate
+ * @param args optional sequence of arguments passed on to the provider
+ * @return an instantiated rate limiter
+ */
+ def getRateLimiter(className: String, args: Any*): BaseRateLimiter = {
+ try {
+ provider = Class.forName(className).newInstance.asInstanceOf[RateLimiterProvider]
+ } catch {
+ case e:ClassNotFoundException => {
+ logError("Could not find custom rate limiter provider. Error: " + e)
+ throw e
+ }
+ case e:InstantiationException => {
+ logError("Could not instantiate custom rate limiter provider. Error: " + e)
+ throw e
+ }
+ case e:Throwable => {
+ logError("Error: " + e)
+ throw e
+ }
+ }
+
+ provider.getRateLimiterWithConf(args:_*)
+ }
+}
\ No newline at end of file
diff --git a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/writer/BaseRateLimiter.scala b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/writer/BaseRateLimiter.scala
new file mode 100644
index 000000000..855d9a365
--- /dev/null
+++ b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/writer/BaseRateLimiter.scala
@@ -0,0 +1,16 @@
+package com.datastax.spark.connector.writer
+
+/**
+ * Represents a rate limiter.
+ */
+trait BaseRateLimiter {
+
+ /**
+ * Processes a single packet and it is up to the implementing class to determine whether
+ * or not the thread should sleep.
+ *
+ * @param packetSize the size of the packet currently being processed
+ */
+ def maybeSleep(packetSize: Long): Unit
+
+}
diff --git a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/writer/RateLimiter.scala b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/writer/LeakyBucketRateLimiter.scala
similarity index 92%
rename from spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/writer/RateLimiter.scala
rename to spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/writer/LeakyBucketRateLimiter.scala
index 7d9067b1d..d1052afb9 100644
--- a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/writer/RateLimiter.scala
+++ b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/writer/LeakyBucketRateLimiter.scala
@@ -22,11 +22,11 @@ import scala.annotation.tailrec
* @param sleep a function to call to slow down the calling thread;
* must use the same time units as `time`
*/
-class RateLimiter(
- rate: Long,
- bucketSize: Long,
- time: () => Long = System.currentTimeMillis,
- sleep: Long => Any = Thread.sleep) {
+class LeakyBucketRateLimiter(
+ rate: Long,
+ bucketSize: Long,
+ time: () => Long = System.currentTimeMillis,
+ sleep: Long => Any = Thread.sleep) extends BaseRateLimiter {
require(rate > 0, "A positive rate is required")
require(bucketSize > 0, "A positive bucket size is required")
@@ -62,5 +62,4 @@ class RateLimiter(
if (delay > 0L)
sleep(delay)
}
-
}
diff --git a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/writer/LeakyBucketRateLimiterProvider.scala b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/writer/LeakyBucketRateLimiterProvider.scala
new file mode 100644
index 000000000..1975c78cb
--- /dev/null
+++ b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/writer/LeakyBucketRateLimiterProvider.scala
@@ -0,0 +1,38 @@
+package main.scala.com.datastax.spark.connector.writer
+
+import com.datastax.spark.connector.util.Logging
+import com.datastax.spark.connector.writer.{BaseRateLimiter, LeakyBucketRateLimiter, RateLimiterProvider}
+
+/**
+ * Instantiates a leaky bucket rate limiter based on the supplied configuration.
+ */
+class LeakyBucketRateLimiterProvider extends RateLimiterProvider with Logging {
+ {}
+
+ override def getRateLimiterWithConf(args: Any*): BaseRateLimiter = {
+ val rate = args(0).asInstanceOf[Number].longValue
+ val bucketSize = args(1).asInstanceOf[Number].longValue
+
+ /**
+ * If optional arguments are present and cannot be casted correctly,
+ * omit them and instantiate rate limiter with only rate and bucketSize
+ */
+ try {
+ if (args.size > 2) {
+ val time = args(2).asInstanceOf[() => Long]
+ if (args.size > 3) {
+ val sleep = args(3).asInstanceOf[Long => Any]
+ new LeakyBucketRateLimiter(rate, bucketSize, time, sleep)
+ }
+ new LeakyBucketRateLimiter(rate, bucketSize, time)
+ }
+ } catch {
+ case _: Exception => {
+ logError("Invalid optional arguments when instantiating leaky bucket rate limiter")
+ new LeakyBucketRateLimiter(rate, bucketSize)
+ }
+ }
+
+ new LeakyBucketRateLimiter(rate, bucketSize)
+ }
+}
diff --git a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/writer/RateLimiterProvider.scala b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/writer/RateLimiterProvider.scala
new file mode 100644
index 000000000..1a4563dde
--- /dev/null
+++ b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/writer/RateLimiterProvider.scala
@@ -0,0 +1,14 @@
+package com.datastax.spark.connector.writer
+
+/**
+ * Represents a provider that creates and returns a rate limiter with possible configuration.
+ */
+trait RateLimiterProvider {
+ /**
+ * Given a set of arguments, instantiates and returns a rate limiter.
+ *
+ * @param args sequence of arguments that can customize the returned rate limiter
+ * @return the created rate limiter
+ */
+ def getRateLimiterWithConf(args: Any*): BaseRateLimiter
+}
\ No newline at end of file
diff --git a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/writer/TableWriter.scala b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/writer/TableWriter.scala
index 097bd4102..5ace698ce 100644
--- a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/writer/TableWriter.scala
+++ b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/writer/TableWriter.scala
@@ -8,7 +8,7 @@ import com.datastax.spark.connector._
import com.datastax.spark.connector.cql._
import com.datastax.spark.connector.types.{CollectionColumnType, ListType, MapType}
import com.datastax.spark.connector.util.Quote._
-import com.datastax.spark.connector.util.{CountingIterator, Logging}
+import com.datastax.spark.connector.util.{CountingIterator, Logging, RateLimiterUtil}
import org.apache.spark.TaskContext
import org.apache.spark.metrics.OutputMetricsUpdater
@@ -193,8 +193,9 @@ class TableWriter[T] private (
/**
* Write data with Cql INSERT statement
*/
- def insert(taskContext: TaskContext, data: Iterator[T]):Unit =
+ def insert(taskContext: TaskContext, data: Iterator[T]):Unit = {
writeInternal(queryTemplateUsingInsert, taskContext, data)
+ }
/**
* Cql DELETE statement
@@ -226,8 +227,11 @@ class TableWriter[T] private (
val batchKeyGenerator = batchRoutingKey(session, routingKeyGenerator) _
val batchBuilder = new GroupingBatchBuilder(boundStmtBuilder, batchStmtBuilder, batchKeyGenerator,
writeConf.batchSize, writeConf.batchGroupingBufferSize, rowIterator)
- val rateLimiter = new RateLimiter((writeConf.throughputMiBPS * 1024 * 1024).toLong, 1024 * 1024)
-
+ val rateLimiter = RateLimiterUtil.getRateLimiter(
+ writeConf.rateLimiterProvider,
+ (writeConf.throughputMiBPS * 1024 * 1024).toLong,
+ 1024 * 1024
+ )
logDebug(s"Writing data partition to $keyspaceName.$tableName in batches of ${writeConf.batchSize}.")
for (stmtToWrite <- batchBuilder) {
diff --git a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/writer/WriteConf.scala b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/writer/WriteConf.scala
index 2d39ba8e0..b33ae40f1 100644
--- a/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/writer/WriteConf.scala
+++ b/spark-cassandra-connector/src/main/scala/com/datastax/spark/connector/writer/WriteConf.scala
@@ -2,9 +2,11 @@ package com.datastax.spark.connector.writer
import com.datastax.driver.core.{ConsistencyLevel, DataType}
import com.datastax.spark.connector.cql.{ColumnDef, RegularColumn}
+import com.datastax.spark.connector.rdd.ReadConf
import com.datastax.spark.connector.types.ColumnType
-import com.datastax.spark.connector.util.{ConfigParameter, ConfigCheck}
+import com.datastax.spark.connector.util.{ConfigCheck, ConfigParameter}
import com.datastax.spark.connector.{BatchSize, BytesInBatch, RowsInBatch}
+import main.scala.com.datastax.spark.connector.writer.LeakyBucketRateLimiterProvider
import org.apache.commons.configuration.ConfigurationException
import org.apache.spark.SparkConf
@@ -21,6 +23,7 @@ import org.apache.spark.SparkConf
* @param ttl the default TTL value which is used when it is defined (in seconds)
* @param timestamp the default timestamp value which is used when it is defined (in microseconds)
* @param taskMetricsEnabled whether or not enable task metrics updates (requires Spark 1.2+)
+ * @param rateLimiterProvider fully qualified name to a custom rate limiter provider
*/
case class WriteConf(batchSize: BatchSize = BatchSize.Automatic,
@@ -33,7 +36,8 @@ case class WriteConf(batchSize: BatchSize = BatchSize.Automatic,
throughputMiBPS: Double = WriteConf.ThroughputMiBPSParam.default,
ttl: TTLOption = TTLOption.defaultValue,
timestamp: TimestampOption = TimestampOption.defaultValue,
- taskMetricsEnabled: Boolean = WriteConf.TaskMetricsParam.default) {
+ taskMetricsEnabled: Boolean = WriteConf.TaskMetricsParam.default,
+ rateLimiterProvider: String = WriteConf.RateLimiterProviderParam.default) {
private[writer] val optionPlaceholders: Seq[String] = Seq(ttl, timestamp).collect {
case WriteOption(PerRowWriteOptionValue(placeholder)) => placeholder
@@ -123,7 +127,7 @@ object WriteConf {
default = 5,
description = """Maximum number of batches executed in parallel by a
| single Spark task""".stripMargin)
-
+
val ThroughputMiBPSParam = ConfigParameter[Double] (
name = "spark.cassandra.output.throughput_mb_per_sec",
section = ReferenceSection,
@@ -154,6 +158,13 @@ object WriteConf {
description = """Sets whether to record connector specific metrics on write"""
)
+ val RateLimiterProviderParam = ConfigParameter[String](
+ name = "spark.cassandra.output.ratelimiterprovider",
+ section = ReferenceSection,
+ default = new LeakyBucketRateLimiterProvider().getClass.getName,
+ description = """Determines which rate limiter provider to use in writes"""
+ )
+
// Whitelist for allowed Write environment variables
val Properties: Set[ConfigParameter[_]] = Set(
BatchSizeBytesParam,
@@ -167,7 +178,8 @@ object WriteConf {
ThroughputMiBPSParam,
TTLParam,
TimestampParam,
- TaskMetricsParam
+ TaskMetricsParam,
+ RateLimiterProviderParam
)
def fromSparkConf(conf: SparkConf): WriteConf = {
@@ -215,7 +227,7 @@ object WriteConf {
TTLOption.defaultValue
else
TTLOption.constant(ttlSeconds)
-
+
val timestampMicros = conf.getLong(TimestampParam.name, TimestampParam.default)
val timestampOption =
@@ -224,6 +236,8 @@ object WriteConf {
else
TimestampOption.constant(timestampMicros)
+ val rateLimiterProvider = conf.get(RateLimiterProviderParam.name, RateLimiterProviderParam.default)
+
WriteConf(
batchSize = batchSize,
batchGroupingBufferSize = batchBufferSize,
@@ -235,7 +249,8 @@ object WriteConf {
ttl = ttlOption,
timestamp = timestampOption,
ignoreNulls = ignoreNulls,
- ifNotExists = ifNotExists)
+ ifNotExists = ifNotExists,
+ rateLimiterProvider = rateLimiterProvider)
}
}
diff --git a/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/rdd/ReadConfTest.scala b/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/rdd/ReadConfTest.scala
index 5959a366f..e581bb696 100644
--- a/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/rdd/ReadConfTest.scala
+++ b/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/rdd/ReadConfTest.scala
@@ -1,5 +1,6 @@
package com.datastax.spark.connector.rdd
+import com.datastax.spark.connector.writer.WriteConf
import org.apache.spark.SparkConf
import org.scalatest.{FlatSpec, Matchers}
@@ -27,4 +28,11 @@ class ReadConfTest extends FlatSpec with Matchers {
readConf.readsPerSec should be (expected)
}
+ it should "allow to set custom rate limiter provider" in {
+ val conf = new SparkConf(false)
+ .set("spark.cassandra.input.ratelimiterprovider", "custom.ratelimiter.provider")
+ val readConf = ReadConf.fromSparkConf(conf)
+ readConf.rateLimiterProvider should be("custom.ratelimiter.provider")
+ }
+
}
diff --git a/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/util/RateLimiterUtilSpec.scala b/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/util/RateLimiterUtilSpec.scala
new file mode 100644
index 000000000..7edf8879d
--- /dev/null
+++ b/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/util/RateLimiterUtilSpec.scala
@@ -0,0 +1,40 @@
+package com.datastax.spark.connector.util
+
+import java.lang.Thread.sleep
+
+import org.scalamock.scalatest.MockFactory
+import org.scalatest.{FlatSpec, Matchers}
+import com.datastax.spark.connector.writer.{BaseRateLimiter, RateLimiterProvider}
+import main.scala.com.datastax.spark.connector.writer.LeakyBucketRateLimiterProvider
+
+class RateLimiterUtilSpec extends FlatSpec with Matchers {
+
+ "RateLimiterUtil" should "return a custom rate limiter provider should that be specified" in {
+ val mockProvider = new MockProvider()
+ val rateLimiter = RateLimiterUtil.getRateLimiter(mockProvider.getClass.getName)
+ rateLimiter.getClass.getName should equal (mockProvider.getRateLimiterWithConf().getClass.getName)
+ }
+
+ it should "throw an error when custom rate limiter provider cannot be instantiated" in {
+ a [ClassNotFoundException] should be thrownBy RateLimiterUtil.getRateLimiter("non.existing.class")
+ an [InstantiationException] should be thrownBy RateLimiterUtil.getRateLimiter(NonInstantiable.getClass.getName)
+ }
+
+ // mock object that cannot be instantiated
+ object NonInstantiable {}
+}
+
+// mock provider with public constructor that can be instantiated
+class MockProvider extends RateLimiterProvider {
+ {}
+
+ override def getRateLimiterWithConf(args: Any*): BaseRateLimiter = {
+ new MockRateLimiter
+ }
+}
+
+// mock rate limiter that is returned by MockProvider
+class MockRateLimiter extends BaseRateLimiter {
+ override def maybeSleep(packetSize: Long): Unit = {}
+}
+
diff --git a/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/writer/RateLimiterSpec.scala b/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/writer/LeakyBucketRateLimiterSpec.scala
similarity index 83%
rename from spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/writer/RateLimiterSpec.scala
rename to spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/writer/LeakyBucketRateLimiterSpec.scala
index 95cfe0763..78017d2ad 100644
--- a/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/writer/RateLimiterSpec.scala
+++ b/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/writer/LeakyBucketRateLimiterSpec.scala
@@ -5,7 +5,7 @@ import org.scalatest.concurrent.Eventually
import org.scalatest.{FlatSpec, Matchers}
-class RateLimiterSpec extends FlatSpec with Matchers with MockFactory with Eventually{
+class LeakyBucketRateLimiterSpec extends FlatSpec with Matchers with MockFactory with Eventually{
val TestRates = Seq(1L, 2L, 4L, 6L, 8L, 16L, 32L, WriteConf.ThroughputMiBPSParam.default.toLong)
@@ -14,7 +14,7 @@ class RateLimiterSpec extends FlatSpec with Matchers with MockFactory with Event
val sleep = mockFunction[Long, Any]("sleep")
sleep.expects(*).never()
- val limiter = new RateLimiter(Long.MaxValue, 1000, () => now, sleep)
+ val limiter = new LeakyBucketRateLimiter(Long.MaxValue, 1000, () => now, sleep)
for (i <- 1 to 1000000) {
now += 1
limiter.maybeSleep(1000)
@@ -33,7 +33,7 @@ class RateLimiterSpec extends FlatSpec with Matchers with MockFactory with Event
// 10 units per second + 5 units burst allowed
val bucketSize = 5
val rate = 10
- val limiter = new RateLimiter(rate, bucketSize, () => now, sleep)
+ val limiter = new LeakyBucketRateLimiter(rate, bucketSize, () => now, sleep)
val iterations = 25
for (i <- 1 to iterations)
@@ -53,7 +53,7 @@ class RateLimiterSpec extends FlatSpec with Matchers with MockFactory with Event
now += delay
}
- val limiter = new RateLimiter(rate, rate * 2, () => now, sleep)
+ val limiter = new LeakyBucketRateLimiter(rate, rate * 2, () => now, sleep)
for (leakNum <- 1 to 1000) {
assert(
limiter.bucketFill.get() >= 0,
diff --git a/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/writer/WriteConfTest.scala b/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/writer/WriteConfTest.scala
index c4f5887e3..d2b290ccd 100644
--- a/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/writer/WriteConfTest.scala
+++ b/spark-cassandra-connector/src/test/scala/com/datastax/spark/connector/writer/WriteConfTest.scala
@@ -78,5 +78,12 @@ class WriteConfTest extends FlatSpec with Matchers {
writeConf.batchGroupingBufferSize should be(30000)
}
+ it should "allow to set custom rate limiter provider" in {
+ val conf = new SparkConf(false)
+ .set("spark.cassandra.output.ratelimiterprovider", "custom.ratelimiter.provider")
+ val writeConf = WriteConf.fromSparkConf(conf)
+ writeConf.rateLimiterProvider should be("custom.ratelimiter.provider")
+ }
+
}