diff --git a/src/main/scala/io/tarantool/spark/connector/config/TarantoolConfig.scala b/src/main/scala/io/tarantool/spark/connector/config/TarantoolConfig.scala index 1f0203f..8514a5b 100644 --- a/src/main/scala/io/tarantool/spark/connector/config/TarantoolConfig.scala +++ b/src/main/scala/io/tarantool/spark/connector/config/TarantoolConfig.scala @@ -11,7 +11,7 @@ case class Timeouts(connect: Option[Int], read: Option[Int], request: Option[Int object ErrorTypes extends Enumeration { type ErrorType = Value - val NONE, NETWORK = Value + val NONE, NETWORK, CONFLICT, ALL = Value } case class Retries(errorType: ErrorType, retryAttempts: Option[Int], delay: Option[Int]) extends Serializable @@ -110,7 +110,7 @@ object TarantoolConfig { if (strErrorType.isDefined) { val errorType = ErrorTypes.withName(strErrorType.get.toUpperCase) - if (errorType == ErrorTypes.NETWORK) { + if (errorType != ErrorTypes.NONE) { if (retryAttempts.isEmpty) { throw new IllegalArgumentException("Number of retry attempts must be specified") } diff --git a/src/main/scala/io/tarantool/spark/connector/connection/TarantoolConnection.scala b/src/main/scala/io/tarantool/spark/connector/connection/TarantoolConnection.scala index f473c88..99574b9 100644 --- a/src/main/scala/io/tarantool/spark/connector/connection/TarantoolConnection.scala +++ b/src/main/scala/io/tarantool/spark/connector/connection/TarantoolConnection.scala @@ -5,10 +5,12 @@ import io.tarantool.driver.api.{TarantoolClient, TarantoolClientConfig, Tarantoo import io.tarantool.driver.auth.SimpleTarantoolCredentials import io.tarantool.driver.api.TarantoolClientFactory import io.tarantool.driver.api.retry.TarantoolRequestRetryPolicies.AttemptsBoundRetryPolicyFactory +import io.tarantool.driver.api.retry.TarantoolRequestRetryPolicies.retryNetworkErrors +import io.tarantool.driver.exceptions.TarantoolInternalException import io.tarantool.driver.protocol.Packable import io.tarantool.spark.connector.Logging import io.tarantool.spark.connector.config.{ErrorTypes, TarantoolConfig} -import io.tarantool.spark.connector.util.ScalaToJavaHelper.toJavaUnaryOperator +import io.tarantool.spark.connector.util.ScalaToJavaHelper.{toJavaPredicate, toJavaUnaryOperator} import java.io.{Closeable, Serializable} import java.util @@ -24,6 +26,13 @@ object TarantoolConnection { def apply(): TarantoolConnection[TarantoolTuple, TarantoolResult[TarantoolTuple]] = TarantoolConnection(defaultClient) + private def isConflictError(e: Throwable): Boolean = { + e.isInstanceOf[TarantoolInternalException] && + e.getMessage.indexOf("Transaction has been aborted by conflict") > 0 + } + + private def retryConflictErrors(): Predicate[Throwable] = toJavaPredicate(isConflictError) + private def defaultClient( clientConfig: TarantoolConfig ): TarantoolClient[TarantoolTuple, TarantoolResult[TarantoolTuple]] = { @@ -36,9 +45,15 @@ object TarantoolConnection { if (clientConfig.retries.isDefined) { val retries = clientConfig.retries.get - if (retries.errorType == ErrorTypes.NETWORK) { + if (retries.errorType != ErrorTypes.NONE) { + val predicate = retries.errorType match { + case ErrorTypes.NETWORK => retryNetworkErrors + case ErrorTypes.CONFLICT => retryConflictErrors + case _ => retryNetworkErrors.and(retryConflictErrors) + } clientFactory = clientFactory.withRetryingByNumberOfAttempts( retries.retryAttempts.get, + predicate, toJavaUnaryOperator { policyBuilder: AttemptsBoundRetryPolicyFactory.Builder[Predicate[Throwable]] => policyBuilder.withDelay(retries.delay.get) } diff --git a/src/main/scala/io/tarantool/spark/connector/util/ScalaToJavaHelper.scala b/src/main/scala/io/tarantool/spark/connector/util/ScalaToJavaHelper.scala index 14caf2a..ed1dce5 100644 --- a/src/main/scala/io/tarantool/spark/connector/util/ScalaToJavaHelper.scala +++ b/src/main/scala/io/tarantool/spark/connector/util/ScalaToJavaHelper.scala @@ -4,6 +4,7 @@ import java.util.function.{ BiFunction => JBiFunction, Consumer => JConsumer, Function => JFunction, + Predicate => JPredicate, Supplier => JSupplier, UnaryOperator => JUnaryOperator } @@ -33,6 +34,13 @@ object ScalaToJavaHelper { override def apply(t: T1): R = f.apply(t) } + /** + * Converts a Scala {@link Function1} to a Java {@link java.util.function.Predicate} + */ + def toJavaPredicate[T1](f: T1 => Boolean): JPredicate[T1] = new JPredicate[T1] { + override def test(t: T1): Boolean = f.apply(t) + } + /** * Converts a Scala {@link Function1} to a Java {@link java.util.function.UnaryOperator} */