diff --git a/build.sbt b/build.sbt index e876fb71..43ecc29d 100644 --- a/build.sbt +++ b/build.sbt @@ -34,7 +34,9 @@ lazy val commonSettings = ), scalacOptions ++= { CrossVersion.partialVersion(scalaVersion.value) match { - case Some((2, v)) if v <= 12 => + case Some((2, 11)) => + Seq("-Xsource:2.12", "-Ypartial-unification") + case Some((2, 12)) => Seq("-Ypartial-unification") case _ => Nil @@ -86,7 +88,7 @@ lazy val specs2 = (project in file("specs2")) lazy val cats = (project in file("cats")) .dependsOn(core) - .dependsOn(common % "compile-internal, test-internal") + .dependsOn(common % "compile-internal, test-internal, test->test") .dependsOn(macroSub % "compile-internal, test-internal") .settings( name := "mockito-scala-cats", @@ -118,8 +120,11 @@ lazy val common = (project in file("common")) .dependsOn(macroCommon) .settings( commonSettings, - libraryDependencies ++= Dependencies.commonLibraries, - libraryDependencies += Dependencies.scalaReflection(scalaVersion.value), + libraryDependencies ++= Dependencies.commonLibraries ++ Seq( + Dependencies.scalaReflection(scalaVersion.value), + Dependencies.catsLaws % "test", + Dependencies.scalacheck % "test" + ), publish := {}, publishLocal := {}, publishArtifact := false @@ -186,4 +191,4 @@ lazy val root = (project in file(".")) .settings( publish := {}, publishLocal := {} - ) aggregate (core, scalatest, specs2, cats, scalaz) + ) aggregate (common, core, scalatest, specs2, cats, scalaz) diff --git a/cats/src/main/scala/org/mockito/cats/instances/argumentMatcher.scala b/cats/src/main/scala/org/mockito/cats/instances/argumentMatcher.scala index 788a858c..acef000c 100644 --- a/cats/src/main/scala/org/mockito/cats/instances/argumentMatcher.scala +++ b/cats/src/main/scala/org/mockito/cats/instances/argumentMatcher.scala @@ -2,28 +2,16 @@ package org.mockito.cats import cats._ import org.mockito.ArgumentMatcher -import org.mockito.internal.matchers.And - -object AnyArgumentMatcher extends ArgumentMatcher[Any] { - override def matches(a: Any) = true -} - -case class MappedArgumentMatcher[A, B](fa: ArgumentMatcher[A], f: B => A) extends ArgumentMatcher[B] { - override def matches(b: B) = fa.matches(f(b)) -} - -case class ProductArgumentMatcher[A, B](fa: ArgumentMatcher[A], fb: ArgumentMatcher[B]) extends ArgumentMatcher[(A, B)] { - override def matches(ab: (A, B)) = ab match { case (a, b) => fa.matches(a) && fb.matches(b) } -} +import org.mockito.matchers._ trait ArgumentMatcherInstances { implicit val argumentMatcherInstance: ContravariantMonoidal[ArgumentMatcher] with MonoidK[ArgumentMatcher] = new ContravariantMonoidal[ArgumentMatcher] with MonoidK[ArgumentMatcher] { - override def unit = narrow(AnyArgumentMatcher) - override def empty[A] = narrow(AnyArgumentMatcher) - override def contramap[A, B](fa: ArgumentMatcher[A])(f: B => A) = MappedArgumentMatcher(fa, f) - override def product[A, B](fa: ArgumentMatcher[A], fb: ArgumentMatcher[B]) = ProductArgumentMatcher(fa, fb) - override def combineK[A](x: ArgumentMatcher[A], y: ArgumentMatcher[A]) = new And(x, y).asInstanceOf[ArgumentMatcher[A]] + override def unit = narrow(AnyArg) + override def empty[A] = narrow(AnyArg) + override def contramap[A, B](fa: ArgumentMatcher[A])(f: B => A) = Transformed(fa)(f) + override def product[A, B](fa: ArgumentMatcher[A], fb: ArgumentMatcher[B]) = ProductOf(fa, fb) + override def combineK[A](x: ArgumentMatcher[A], y: ArgumentMatcher[A]) = AllOf(x, y) } } diff --git a/cats/src/test/scala/org/mockito/cats/instances/ArgumentMatcherInstancesTest.scala b/cats/src/test/scala/org/mockito/cats/instances/ArgumentMatcherInstancesTest.scala index e9c79edd..c08d2590 100644 --- a/cats/src/test/scala/org/mockito/cats/instances/ArgumentMatcherInstancesTest.scala +++ b/cats/src/test/scala/org/mockito/cats/instances/ArgumentMatcherInstancesTest.scala @@ -8,6 +8,7 @@ import cats.laws.discipline.arbitrary._ import cats.laws.discipline.eq._ import org.mockito.{ ArgumentMatcher, ArgumentMatchers, ArgumentMatchersSugar, IdiomaticMockito } import org.mockito.internal.matchers._ +import org.mockito.matchers.{ EqTo, Generators } import org.scalacheck.Arbitrary import org.scalatest.matchers.should.Matchers import org.scalatest.funsuite.AnyFunSuiteLike @@ -15,12 +16,11 @@ import org.scalatest.prop.Configuration import org.typelevel.discipline.scalatest.FunSuiteDiscipline class ArgumentMatcherInstancesTest extends AnyFunSuiteLike with FunSuiteDiscipline with Configuration with ArgumentMatchersSugar with IdiomaticMockito with Matchers { + import Generators._ + implicit def eqArgumentMatcherExhaustive[A: ExhaustiveCheck]: Eq[ArgumentMatcher[A]] = Eq.instance((f, g) => ExhaustiveCheck[A].allValues.forall(a => f.matches(a) == g.matches(a))) - implicit def arbArgumentMatcher[A](implicit a: Arbitrary[A => Boolean]): Arbitrary[ArgumentMatcher[A]] = - Arbitrary(a.arbitrary.map(p => new ArgumentMatcher[A] { def matches(a: A) = p(a) })) - checkAll("ArgumentMatcher[MiniInt]", ContravariantMonoidalTests[ArgumentMatcher].contravariantMonoidal[MiniInt, MiniInt, MiniInt]) checkAll("ArgumentMatcher[MiniInt]", MonoidKTests[ArgumentMatcher].monoidK[MiniInt]) @@ -76,4 +76,13 @@ class ArgumentMatcherInstancesTest extends AnyFunSuiteLike with FunSuiteDiscipli aMock.returnsOptionString("prefix-middle-suffix") shouldBe Some("mocked!") } + + test("EqTo works with cats syntax") { + val aMock = mock[Foo] + + val matcher = (EqTo("foo"), EqTo(new Integer(42))).tupled + aMock.takesTuple(argThat(matcher)) returns "mocked!" + + aMock.takesTuple("foo", 42) shouldBe "mocked!" + } } diff --git a/common/src/main/scala/org/mockito/matchers/AllOf.scala b/common/src/main/scala/org/mockito/matchers/AllOf.scala new file mode 100644 index 00000000..e9a3c46b --- /dev/null +++ b/common/src/main/scala/org/mockito/matchers/AllOf.scala @@ -0,0 +1,23 @@ +package org.mockito +package matchers + +/** Combine multiple matchers using AND + */ +case class AllOf[A] private (matchers: List[ArgumentMatcher[A]]) extends ArgumentMatcher[A] { + override def matches(a: A) = matchers.forall(_.matches(a)) + + override def toString = + matchers match { + case Nil => "" + case matcher :: Nil => matcher.toString + case _ => matchers.mkString("allOf(", ", ", ")") + } +} + +object AllOf { + def apply[A](matchers: ArgumentMatcher[A]*): ArgumentMatcher[A] = + new AllOf(matchers.flatMap { + case AllOf(ms) => ms + case m => List(m) + }.toList) +} diff --git a/common/src/main/scala/org/mockito/matchers/EqTo.scala b/common/src/main/scala/org/mockito/matchers/EqTo.scala index 66e42755..b752c56b 100644 --- a/common/src/main/scala/org/mockito/matchers/EqTo.scala +++ b/common/src/main/scala/org/mockito/matchers/EqTo.scala @@ -12,3 +12,9 @@ case class EqTo[T: Equality: ValueClassExtractor](value: T)(implicit $pt: Pretti override def toString: String = $pt(value) } + +object EqTo { + // Smart constructor to return ArgumentMatcher[T] rather than a subtype + def apply[T: Equality: ValueClassExtractor](value: T)(implicit $pt: Prettifier): ArgumentMatcher[T] = + new EqTo(value) +} diff --git a/common/src/main/scala/org/mockito/matchers/ProductOf.scala b/common/src/main/scala/org/mockito/matchers/ProductOf.scala new file mode 100644 index 00000000..a34e415e --- /dev/null +++ b/common/src/main/scala/org/mockito/matchers/ProductOf.scala @@ -0,0 +1,14 @@ +package org.mockito +package matchers + +/** The product (2-tuple) of two matchers + */ +case class ProductOf[A, B] private (ma: ArgumentMatcher[A], mb: ArgumentMatcher[B]) extends ArgumentMatcher[(A, B)] { + override def matches(ab: (A, B)) = ab match { case (a, b) => ma.matches(a) && mb.matches(b) } + override def toString = s"productOf($ma, $mb)" +} + +object ProductOf { + def apply[A, B](ma: ArgumentMatcher[A], mb: ArgumentMatcher[B]): ArgumentMatcher[(A, B)] = + new ProductOf(ma, mb) +} diff --git a/common/src/main/scala/org/mockito/matchers/Transformed.scala b/common/src/main/scala/org/mockito/matchers/Transformed.scala new file mode 100644 index 00000000..1b5296ab --- /dev/null +++ b/common/src/main/scala/org/mockito/matchers/Transformed.scala @@ -0,0 +1,16 @@ +package org.mockito +package matchers + +/** Matcher tranformed from one type to another with a function to modify the input + * + * Technically this is 'contramapped' but that seemed like an unnecessarily jargony name. + */ +case class Transformed[A, B] private (ma: ArgumentMatcher[A])(f: B => A) extends ArgumentMatcher[B] { + override def matches(b: B) = ma.matches(f(b)) + override def toString = s"transformed($ma: $f)" +} + +object Transformed { + def apply[A, B](ma: ArgumentMatcher[A])(f: B => A): ArgumentMatcher[B] = + new Transformed(ma)(f) +} diff --git a/common/src/main/scala/org/mockito/matchers/package.scala b/common/src/main/scala/org/mockito/matchers/package.scala new file mode 100644 index 00000000..3b97bc66 --- /dev/null +++ b/common/src/main/scala/org/mockito/matchers/package.scala @@ -0,0 +1,7 @@ +package org.mockito + +package object matchers { + private val AnyArgMatcher: ArgumentMatcher[Any] = AllOf[Any]() + + def AnyArg[A]: ArgumentMatcher[A] = AnyArgMatcher.asInstanceOf[ArgumentMatcher[A]] +} diff --git a/common/src/test/scala/org/mockito/matchers/Generators.scala b/common/src/test/scala/org/mockito/matchers/Generators.scala new file mode 100644 index 00000000..847e9ab3 --- /dev/null +++ b/common/src/test/scala/org/mockito/matchers/Generators.scala @@ -0,0 +1,9 @@ +package org.mockito +package matchers + +import org.scalacheck.Arbitrary + +object Generators { + implicit def arbArgumentMatcher[A](implicit a: Arbitrary[A => Boolean]): Arbitrary[ArgumentMatcher[A]] = + Arbitrary(a.arbitrary.map(p => new ArgumentMatcher[A] { def matches(a: A) = p(a) })) +} diff --git a/common/src/test/scala/org/mockito/matchers/MatcherProps.scala b/common/src/test/scala/org/mockito/matchers/MatcherProps.scala new file mode 100644 index 00000000..e52d618a --- /dev/null +++ b/common/src/test/scala/org/mockito/matchers/MatcherProps.scala @@ -0,0 +1,60 @@ +package org.mockito +package matchers + +import cats.laws.discipline.MiniInt +import cats.laws.discipline.arbitrary._ +import org.mockito.internal.matchers._ +import org.scalacheck._ + +import Arbitrary.arbitrary +import Gen._ +import Prop._ + +class MatcherProps extends Properties("matchers") { + import Generators._ + + property("AllOf") = forAll(chooseNum(0, 8))(length => + forAll(listOfN(length, arbitrary[ArgumentMatcher[MiniInt]]), arbitrary[MiniInt]) { + case (matchers, value) => + val allOf = AllOf(matchers: _*) + val stringRep = allOf.toString + + classify(allOf.matches(value), "matches", "doesn't match") { + (allOf.matches(value) ?= matchers.forall(_.matches(value))) :| "matches all underlying" && + matchers.iff { + case Nil => stringRep ?= "" + case matcher :: Nil => stringRep ?= matcher.toString() + case _ => stringRep ?= s"allOf(${matchers.mkString(", ")})" + } :| "renders to string correctly" + + } + } + ) + + property("ProductOf") = forAll { (ma: ArgumentMatcher[MiniInt], mb: ArgumentMatcher[String], a: MiniInt, b: String) => + val productOf = ProductOf(ma, mb) + val product = (a, b) + + val maMatches = ma.matches(a) + val mbMatches = mb.matches(b) + val productMatches = productOf.matches(product) + + classify(productMatches, "matches", "doesn't match") { + all( + (productMatches ==> maMatches) :| "ma matches if product does", + (productMatches ==> mbMatches) :| "mb matches if product does", + ((maMatches && mbMatches) ==> productMatches) :| "product matches if both ma and mb do", + (productOf.toString ?= s"productOf($ma, $mb)") :| "renders to string correctly" + ) + } + } + + property("Transformed") = forAll { (ma: ArgumentMatcher[String], f: MiniInt => String, value: MiniInt) => + val transformed = Transformed(ma)(f) + val matches = transformed.matches(value) + classify(matches, "matches", "doesn't match") { + (matches ?= ma.matches(f(value))) :| "matches if underlying matches transfomed value" && + (transformed.toString ?= s"transformed($ma: $f)") :| "renders to string correctly" + } + } +} diff --git a/project/Dependencies.scala b/project/Dependencies.scala index e8e5ff73..b8ebd9ce 100644 --- a/project/Dependencies.scala +++ b/project/Dependencies.scala @@ -10,6 +10,8 @@ object Dependencies { "ru.vyarus" % "generics-resolver" % "3.0.2", ) + val scalacheck = "org.scalacheck" %% "scalacheck" % "1.14.3" + val scalatest = "org.scalatest" %% "scalatest" % scalatestVersion val specs2 = Seq(