diff --git a/src/main/scala/com/redislabs/provider/redis/ConnectionPool.scala b/src/main/scala/com/redislabs/provider/redis/ConnectionPool.scala index 322d8c5f..bb4c5f46 100644 --- a/src/main/scala/com/redislabs/provider/redis/ConnectionPool.scala +++ b/src/main/scala/com/redislabs/provider/redis/ConnectionPool.scala @@ -1,16 +1,17 @@ package com.redislabs.provider.redis -import redis.clients.jedis.{JedisPoolConfig, Jedis, JedisPool} +import redis.clients.jedis.{Jedis, JedisPool, JedisPoolConfig, JedisSentinelPool} import redis.clients.jedis.exceptions.JedisConnectionException - import java.util.concurrent.ConcurrentHashMap +import redis.clients.jedis.util.Pool + import scala.collection.JavaConversions._ object ConnectionPool { - @transient private lazy val pools: ConcurrentHashMap[RedisEndpoint, JedisPool] = - new ConcurrentHashMap[RedisEndpoint, JedisPool]() + @transient private lazy val pools: ConcurrentHashMap[RedisEndpoint, Pool[Jedis]] = + new ConcurrentHashMap[RedisEndpoint, Pool[Jedis]]() def connect(re: RedisEndpoint): Jedis = { val pool = pools.getOrElseUpdate(re, @@ -25,7 +26,12 @@ object ConnectionPool { poolConfig.setTimeBetweenEvictionRunsMillis(30000) poolConfig.setNumTestsPerEvictionRun(-1) - new JedisPool(poolConfig, re.host, re.port, re.timeout, re.auth, re.dbNum, re.ssl) + if (null == re.master || re.master.trim.isEmpty) { + new JedisPool(poolConfig, re.host, re.port, re.timeout, re.auth, re.dbNum, re.ssl) + } else { + val sentinels = re.host.split(",").map(x => x + ":" + re.port).toSet + new JedisSentinelPool(re.master.trim, sentinels, poolConfig, re.auth) + } } ) var sleepTime: Int = 4 diff --git a/src/main/scala/com/redislabs/provider/redis/RedisConfig.scala b/src/main/scala/com/redislabs/provider/redis/RedisConfig.scala index 91e2f05e..dbe7eece 100644 --- a/src/main/scala/com/redislabs/provider/redis/RedisConfig.scala +++ b/src/main/scala/com/redislabs/provider/redis/RedisConfig.scala @@ -24,7 +24,8 @@ case class RedisEndpoint(host: String = Protocol.DEFAULT_HOST, auth: String = null, dbNum: Int = Protocol.DEFAULT_DATABASE, timeout: Int = Protocol.DEFAULT_TIMEOUT, - ssl: Boolean = false) + ssl: Boolean = false, + master: String = null) extends Serializable { /** @@ -39,7 +40,8 @@ case class RedisEndpoint(host: String = Protocol.DEFAULT_HOST, conf.get("spark.redis.auth", null), conf.getInt("spark.redis.db", Protocol.DEFAULT_DATABASE), conf.getInt("spark.redis.timeout", Protocol.DEFAULT_TIMEOUT), - conf.getBoolean("spark.redis.ssl", false) + conf.getBoolean("spark.redis.ssl", false), + conf.get("spark.redis.sentinel.master", null) ) } @@ -254,7 +256,7 @@ class RedisConfig(val initialHost: RedisEndpoint) extends Serializable { //simply re-enter this function witht he master host/port getNonClusterNodes(initialHost = new RedisEndpoint(host, port, - initialHost.auth, initialHost.dbNum, ssl = initialHost.ssl)) + initialHost.auth, initialHost.dbNum, ssl = initialHost.ssl, master = initialHost.master)) } else { //this is a master - take its slaves @@ -270,7 +272,7 @@ class RedisConfig(val initialHost: RedisEndpoint) extends Serializable { val range = nodes.length (0 until range).map(i => RedisNode(RedisEndpoint(nodes(i)._1, nodes(i)._2, initialHost.auth, initialHost.dbNum, - initialHost.timeout, initialHost.ssl), + initialHost.timeout, initialHost.ssl, initialHost.master), 0, 16383, i, range)).toArray } } @@ -300,7 +302,7 @@ class RedisConfig(val initialHost: RedisEndpoint) extends Serializable { val host = SafeEncoder.encode(node.get(0).asInstanceOf[Array[scala.Byte]]) val port = node.get(1).toString.toInt RedisNode(RedisEndpoint(host, port, initialHost.auth, initialHost.dbNum, - initialHost.timeout, initialHost.ssl), + initialHost.timeout, initialHost.ssl, initialHost.master), sPos, ePos, i, diff --git a/src/main/scala/org/apache/spark/sql/redis/RedisSourceRelation.scala b/src/main/scala/org/apache/spark/sql/redis/RedisSourceRelation.scala index f2c84911..d2d44356 100644 --- a/src/main/scala/org/apache/spark/sql/redis/RedisSourceRelation.scala +++ b/src/main/scala/org/apache/spark/sql/redis/RedisSourceRelation.scala @@ -41,7 +41,8 @@ class RedisSourceRelation(override val sqlContext: SQLContext, val dbNum = parameters.get("dbNum").map(_.toInt).getOrElse(Protocol.DEFAULT_DATABASE) val timeout = parameters.get("timeout").map(_.toInt).getOrElse(Protocol.DEFAULT_TIMEOUT) val ssl = parameters.get("ssl").map(_.toBoolean).getOrElse(false) - RedisEndpoint(host, port, auth, dbNum, timeout, ssl) + val master = parameters.getOrElse("sentinel.master", null) + RedisEndpoint(host, port, auth, dbNum, timeout, ssl, master) } ) }