diff --git a/src/main/scala/org/apache/spark/sql/redis/BinaryRedisPersistence.scala b/src/main/scala/org/apache/spark/sql/redis/BinaryRedisPersistence.scala index c9b0a981..2f666693 100644 --- a/src/main/scala/org/apache/spark/sql/redis/BinaryRedisPersistence.scala +++ b/src/main/scala/org/apache/spark/sql/redis/BinaryRedisPersistence.scala @@ -25,7 +25,7 @@ class BinaryRedisPersistence extends RedisPersistence[Array[Byte]] { override def load(pipeline: Pipeline, key: String, requiredColumns: Seq[String]): Unit = pipeline.get(key.getBytes(UTF_8)) - override def encodeRow(keyName: String, value: Row): Array[Byte] = { + override def encodeRow(keyName: String, value: Row, ttlColumn: Option[String] = None): Array[Byte] = { val fields = value.schema.fields.map(_.name) val valuesArray = fields.map(f => value.getAs[Any](f)) SerializationUtils.serialize(valuesArray) diff --git a/src/main/scala/org/apache/spark/sql/redis/HashRedisPersistence.scala b/src/main/scala/org/apache/spark/sql/redis/HashRedisPersistence.scala index a22c4614..6743c791 100644 --- a/src/main/scala/org/apache/spark/sql/redis/HashRedisPersistence.scala +++ b/src/main/scala/org/apache/spark/sql/redis/HashRedisPersistence.scala @@ -27,7 +27,7 @@ class HashRedisPersistence extends RedisPersistence[Any] { pipeline.hmget(key, requiredColumns: _*) } - override def encodeRow(keyName: String, value: Row): Map[String, String] = { + override def encodeRow(keyName: String, value: Row, ttlColumn: Option[String] = None): Map[String, String] = { val fields = value.schema.fields.map(_.name) val kvMap = value.getValuesMap[Any](fields) kvMap @@ -39,6 +39,13 @@ class HashRedisPersistence extends RedisPersistence[Any] { // don't store key values k != keyName } + .filter { case (k, _) => + // don't store TTLs + ttlColumn match { + case Some(ttl) => k != ttl + case None => true + } + } .map { case (k, v) => k -> String.valueOf(v) } diff --git a/src/main/scala/org/apache/spark/sql/redis/RedisPersistence.scala b/src/main/scala/org/apache/spark/sql/redis/RedisPersistence.scala index d69eef66..bf5f1c58 100644 --- a/src/main/scala/org/apache/spark/sql/redis/RedisPersistence.scala +++ b/src/main/scala/org/apache/spark/sql/redis/RedisPersistence.scala @@ -18,9 +18,10 @@ trait RedisPersistence[T] extends Serializable { * * @param keyName field name that should be encoded in special way, e.g. in Redis keys. * @param value row to encode. + * @param ttlColumn field name to be used for setting the ttl and not added as a value * @return encoded row */ - def encodeRow(keyName: String, value: Row): T + def encodeRow(keyName: String, value: Row, ttlColumn: Option[String] = None): T /** * Decode dataframe row stored in Redis. diff --git a/src/main/scala/org/apache/spark/sql/redis/RedisSourceRelation.scala b/src/main/scala/org/apache/spark/sql/redis/RedisSourceRelation.scala index 06487471..b4db6618 100644 --- a/src/main/scala/org/apache/spark/sql/redis/RedisSourceRelation.scala +++ b/src/main/scala/org/apache/spark/sql/redis/RedisSourceRelation.scala @@ -78,6 +78,7 @@ class RedisSourceRelation(override val sqlContext: SQLContext, private val persistence = RedisPersistence(persistenceModel) private val tableNameOpt: Option[String] = parameters.get(SqlOptionTableName) private val ttl = parameters.get(SqlOptionTTL).map(_.toInt).getOrElse(0) + private val ttlColumn: Option[String] = parameters.get(SqlOptionTTLColumn) /** * redis key pattern for rows, based either on the 'keys.pattern' or 'table' parameter @@ -106,6 +107,12 @@ class RedisSourceRelation(override val sqlContext: SQLContext, s"You should only use either one.") } + // check if both ttl column and ttl are set + if (ttlColumn.isDefined && ttl > 0) { + throw new IllegalArgumentException(s"Both options '$SqlOptionTTL' and '$SqlOptionTTLColumn' are set. " + + s"You should only use either one.") + } + override def schema: StructType = { if (currentSchema == null) { currentSchema = userSpecifiedSchema.getOrElse { @@ -142,8 +149,9 @@ class RedisSourceRelation(override val sqlContext: SQLContext, val conn = node.connect() foreachWithPipeline(conn, keys) { (pipeline, key) => val row = rowsWithKey(key) - val encodedRow = persistence.encodeRow(keyName, row) - persistence.save(pipeline, key, encodedRow, ttl) + val encodedRow = persistence.encodeRow(keyName, row, ttlColumn) + val recordTTL = if (ttlColumn.isEmpty) ttl else row.getAs[Int](ttlColumn.get) + persistence.save(pipeline, key, encodedRow, recordTTL) } conn.close() } diff --git a/src/main/scala/org/apache/spark/sql/redis/redis.scala b/src/main/scala/org/apache/spark/sql/redis/redis.scala index 82b1c1d7..8499fd5d 100644 --- a/src/main/scala/org/apache/spark/sql/redis/redis.scala +++ b/src/main/scala/org/apache/spark/sql/redis/redis.scala @@ -21,6 +21,7 @@ package object redis { val SqlOptionInferSchema = "infer.schema" val SqlOptionKeyColumn = "key.column" val SqlOptionTTL = "ttl" + val SqlOptionTTLColumn = "ttl.column" val SqlOptionMaxPipelineSize = "max.pipeline.size" val SqlOptionScanCount = "scan.count"