From 154c65a2e3524793e8d97d2959656925ef5be85f Mon Sep 17 00:00:00 2001 From: Chris Twiner Date: Tue, 6 Jun 2023 19:03:04 +0200 Subject: [PATCH] #385 - chained ops impl --- .../scala/frameless/FramelessSyntax.scala | 5 + .../scala/frameless/ops/ChainedJoinOps.scala | 159 ++++++++++++++++++ .../src/test/scala/frameless/JoinTests.scala | 147 +++++++++++++--- 3 files changed, 283 insertions(+), 28 deletions(-) create mode 100644 dataset/src/main/scala/frameless/ops/ChainedJoinOps.scala diff --git a/dataset/src/main/scala/frameless/FramelessSyntax.scala b/dataset/src/main/scala/frameless/FramelessSyntax.scala index 5ba294921..1d4c04762 100644 --- a/dataset/src/main/scala/frameless/FramelessSyntax.scala +++ b/dataset/src/main/scala/frameless/FramelessSyntax.scala @@ -1,5 +1,6 @@ package frameless +import frameless.ops.ChainedJoinOps import org.apache.spark.sql.{Column, DataFrame, Dataset} trait FramelessSyntax { @@ -15,4 +16,8 @@ trait FramelessSyntax { implicit class DataframeSyntax(self: DataFrame){ def unsafeTyped[T: TypedEncoder]: TypedDataset[T] = TypedDataset.createUnsafe(self) } + + implicit class ChainedJoinSyntax[T](ds: TypedDataset[T]) { + def join[U](other: TypedDataset[U]): ChainedJoinOps[T, U] = new ChainedJoinOps[T, U](ds, other) + } } diff --git a/dataset/src/main/scala/frameless/ops/ChainedJoinOps.scala b/dataset/src/main/scala/frameless/ops/ChainedJoinOps.scala new file mode 100644 index 000000000..f2386046c --- /dev/null +++ b/dataset/src/main/scala/frameless/ops/ChainedJoinOps.scala @@ -0,0 +1,159 @@ +package frameless.ops + +import frameless.{TypedColumn, TypedDataset, TypedEncoder} + +/** + * Collection of forwarding functions that optionally provide a reference to the incoming dataset for chaining of joins + * @param ds the dataset on which .join(other) was called + * @param other the dataset to which ds is joined + * @tparam T the type of ds + * @tparam U the type of other + */ +case class ChainedJoinOps[T, U](ds: TypedDataset[T], other: TypedDataset[U]) { + /** Computes the right outer join of `this` `Dataset` with the `other` `Dataset`, + * returning a `Tuple2` for each pair where condition evaluates to true. + * + * This version passes in the current dataset in the chain to the conditionF allowing you access to this TypedDataset's columns + */ + def right(conditionF: TypedDataset[T] => TypedColumn[T with U, Boolean])(implicit e: TypedEncoder[(Option[T], U)]): TypedDataset[(Option[T], U)] = + ds.joinRight(other)(conditionF(ds)) + + /** Computes the right outer join of `this` `Dataset` with the `other` `Dataset`, + * returning a `Tuple2` for each pair where condition evaluates to true. + * + * This version passes in the current and joined datasets in the chain to the conditionF allowing you access to this TypedDataset's columns and the joins + */ + def right(conditionF: (TypedDataset[T], TypedDataset[U]) => TypedColumn[T with U, Boolean])(implicit e: TypedEncoder[(Option[T], U)]): TypedDataset[(Option[T], U)] = + ds.joinRight(other)(conditionF(ds, other)) + + /** Computes the right outer join of `this` `Dataset` with the `other` `Dataset`, + * returning a `Tuple2` for each pair where condition evaluates to true. + */ + def right(condition: TypedColumn[T with U, Boolean])(implicit e: TypedEncoder[(Option[T], U)]): TypedDataset[(Option[T], U)] = + ds.joinRight(other)(condition) + + /** Computes the cartesian project of `this` `Dataset` with the `other` `Dataset` */ + def cross() // here for completeness + (implicit e: TypedEncoder[(T, U)]): TypedDataset[(T, U)] = + ds.joinCross(other) + + /** Computes the full outer join of `this` `Dataset` with the `other` `Dataset`, + * returning a `Tuple2` for each pair where condition evaluates to true. + */ + def full(condition: TypedColumn[T with U, Boolean]) + (implicit e: TypedEncoder[(Option[T], Option[U])]): TypedDataset[(Option[T], Option[U])] = + ds.joinFull(other)(condition) + + /** Computes the full outer join of `this` `Dataset` with the `other` `Dataset`, + * returning a `Tuple2` for each pair where condition evaluates to true. + * + * This version passes in the current dataset in the chain to the conditionF allowing you access to this TypedDataset's columns + */ + def full(conditionF: TypedDataset[T] => TypedColumn[T with U, Boolean]) + (implicit e: TypedEncoder[(Option[T], Option[U])]): TypedDataset[(Option[T], Option[U])] = + ds.joinFull(other)(conditionF(ds)) + + /** Computes the full outer join of `this` `Dataset` with the `other` `Dataset`, + * returning a `Tuple2` for each pair where condition evaluates to true. + * + * This version passes in the current and joined datasets in the chain to the conditionF allowing you access to this TypedDataset's columns and the joins + */ + def full(conditionF: (TypedDataset[T], TypedDataset[U]) => TypedColumn[T with U, Boolean]) + (implicit e: TypedEncoder[(Option[T], Option[U])]): TypedDataset[(Option[T], Option[U])] = + ds.joinFull(other)(conditionF(ds, other)) + + /** Computes the inner join of `this` `Dataset` with the `other` `Dataset`, + * returning a `Tuple2` for each pair where condition evaluates to true. + */ + def inner(condition: TypedColumn[T with U, Boolean]) + (implicit e: TypedEncoder[(T, U)]): TypedDataset[(T, U)] = + ds.joinInner(other)(condition) + + /** Computes the inner join of `this` `Dataset` with the `other` `Dataset`, + * returning a `Tuple2` for each pair where condition evaluates to true. + * + * This version passes in the current dataset in the chain to the conditionF allowing you access to this TypedDataset's columns + */ + def inner(conditionF: TypedDataset[T] => TypedColumn[T with U, Boolean]) + (implicit e: TypedEncoder[(T, U)]): TypedDataset[(T, U)] = + ds.joinInner(other)(conditionF(ds)) + + /** Computes the inner join of `this` `Dataset` with the `other` `Dataset`, + * returning a `Tuple2` for each pair where condition evaluates to true. + * + * This version passes in the current and joined datasets in the chain to the conditionF allowing you access to this TypedDataset's columns and the joins + */ + def inner(conditionF: (TypedDataset[T], TypedDataset[U]) => TypedColumn[T with U, Boolean]) + (implicit e: TypedEncoder[(T, U)]): TypedDataset[(T, U)] = + ds.joinInner(other)(conditionF(ds, other)) + + /** Computes the left outer join of `this` `Dataset` with the `other` `Dataset`, + * returning a `Tuple2` for each pair where condition evaluates to true. + */ + def left(condition: TypedColumn[T with U, Boolean]) + (implicit e: TypedEncoder[(T, Option[U])]): TypedDataset[(T, Option[U])] = + ds.joinLeft(other)(condition) + + /** Computes the left outer join of `this` `Dataset` with the `other` `Dataset`, + * returning a `Tuple2` for each pair where condition evaluates to true. + * + * This version passes in the current dataset in the chain to the conditionF allowing you access to this TypedDataset's columns + */ + def left(conditionF: TypedDataset[T] => TypedColumn[T with U, Boolean]) + (implicit e: TypedEncoder[(T, Option[U])]): TypedDataset[(T, Option[U])] = + ds.joinLeft(other)(conditionF(ds)) + + /** Computes the left outer join of `this` `Dataset` with the `other` `Dataset`, + * returning a `Tuple2` for each pair where condition evaluates to true. + * + * This version passes in the current and joined datasets in the chain to the conditionF allowing you access to this TypedDataset's columns and the joins + */ + def left(conditionF: (TypedDataset[T], TypedDataset[U]) => TypedColumn[T with U, Boolean]) + (implicit e: TypedEncoder[(T, Option[U])]): TypedDataset[(T, Option[U])] = + ds.joinLeft(other)(conditionF(ds,other)) + + /** Computes the left semi join of `this` `Dataset` with the `other` `Dataset`, + * returning a `Tuple2` for each pair where condition evaluates to true. + */ + def leftSemi(condition: TypedColumn[T with U, Boolean]): TypedDataset[T] = + ds.joinLeftSemi(other)(condition) + + /** Computes the left semi join of `this` `Dataset` with the `other` `Dataset`, + * returning a `Tuple2` for each pair where condition evaluates to true. + * + * This version passes in the current dataset in the chain to the conditionF allowing you access to this TypedDataset's columns + */ + def leftSemi(conditionF: TypedDataset[T] => TypedColumn[T with U, Boolean]): TypedDataset[T] = + ds.joinLeftSemi(other)(conditionF(ds)) + + /** Computes the left semi join of `this` `Dataset` with the `other` `Dataset`, + * returning a `Tuple2` for each pair where condition evaluates to true. + * + * This version passes in the current and joined datasets in the chain to the conditionF allowing you access to this TypedDataset's columns and the joins + */ + def leftSemi(conditionF: (TypedDataset[T], TypedDataset[U]) => TypedColumn[T with U, Boolean]): TypedDataset[T] = + ds.joinLeftSemi(other)(conditionF(ds, other)) + + /** Computes the left anti join of `this` `Dataset` with the `other` `Dataset`, + * returning a `Tuple2` for each pair where condition evaluates to true. + */ + def leftAnti(condition: TypedColumn[T with U, Boolean]): TypedDataset[T] = + ds.joinLeftAnti(other)(condition) + + /** Computes the left anti join of `this` `Dataset` with the `other` `Dataset`, + * returning a `Tuple2` for each pair where condition evaluates to true. + * + * This version passes in the current dataset in the chain to the conditionF allowing you access to this TypedDataset's columns + */ + def leftAnti(conditionF: TypedDataset[T] => TypedColumn[T with U, Boolean]): TypedDataset[T] = + ds.joinLeftAnti(other)(conditionF(ds)) + + /** Computes the left anti join of `this` `Dataset` with the `other` `Dataset`, + * returning a `Tuple2` for each pair where condition evaluates to true. + * + * This version passes in the current and joined datasets in the chain to the conditionF allowing you access to this TypedDataset's columns and the joins + */ + def leftAnti(conditionF: (TypedDataset[T], TypedDataset[U]) => TypedColumn[T with U, Boolean]): TypedDataset[T] = + ds.joinLeftAnti(other)(conditionF(ds, other)) + +} diff --git a/dataset/src/test/scala/frameless/JoinTests.scala b/dataset/src/test/scala/frameless/JoinTests.scala index b34911c4f..0eda70031 100644 --- a/dataset/src/test/scala/frameless/JoinTests.scala +++ b/dataset/src/test/scala/frameless/JoinTests.scala @@ -10,11 +10,10 @@ class JoinTests extends TypedDatasetSuite { A : TypedEncoder : Ordering, B : TypedEncoder : Ordering, C : TypedEncoder : Ordering - ](left: List[X2[A, B]], right: List[X2[A, C]]): Prop = { + ](f: (TypedDataset[X2[A,B]], TypedDataset[X2[A,C]]) => TypedDataset[(X2[A,B], X2[A,C])])(left: List[X2[A, B]], right: List[X2[A, C]]): Prop = { val leftDs = TypedDataset.create(left) val rightDs = TypedDataset.create(right) - val joinedDs = leftDs - .joinCross(rightDs) + val joinedDs = f(leftDs, rightDs) val joinedData = joinedDs.collect().run().toVector.sorted @@ -32,7 +31,12 @@ class JoinTests extends TypedDatasetSuite { (joined.sorted ?= joinedData) && equalSchemas } - check(forAll(prop[Int, Long, String] _)) + check(forAll(prop[Int, Long, String]((leftDs, rightDs) => leftDs + .joinCross(rightDs)) _)) + + import frameless.syntax.ChainedJoinSyntax + check(forAll(prop[Int, Long, String]((leftDs, rightDs) => leftDs + .join(rightDs).cross) _)) } test("ab.joinFull(ac)(ab.a == ac.a)") { @@ -40,11 +44,10 @@ class JoinTests extends TypedDatasetSuite { A : TypedEncoder : Ordering, B : TypedEncoder : Ordering, C : TypedEncoder : Ordering - ](left: List[X2[A, B]], right: List[X2[A, C]]): Prop = { + ](f: (TypedDataset[X2[A,B]], TypedDataset[X2[A,C]]) => TypedDataset[(Option[X2[A,B]], Option[X2[A,C]])])(left: List[X2[A, B]], right: List[X2[A, C]]): Prop = { val leftDs = TypedDataset.create(left) val rightDs = TypedDataset.create(right) - val joinedDs = leftDs - .joinFull(rightDs)(leftDs.col('a) === rightDs.col('a)) + val joinedDs = f(leftDs, rightDs) val joinedData = joinedDs.collect().run().toVector.sorted @@ -72,7 +75,16 @@ class JoinTests extends TypedDatasetSuite { (joined.sorted ?= joinedData) && equalSchemas } - check(forAll(prop[Int, Long, String] _)) + check(forAll(prop[Int, Long, String]((leftDs, rightDs) => leftDs + .joinFull(rightDs)(leftDs.col('a) === rightDs.col('a))) _)) + + import frameless.syntax.ChainedJoinSyntax + check(forAll(prop[Int, Long, String]((leftDs, rightDs) => leftDs + .join(rightDs).full(leftDs.col('a) === rightDs.col('a))) _)) + check(forAll(prop[Int, Long, String]((leftDs, rightDs) => leftDs + .join(rightDs).full(_.col('a) === rightDs.col('a))) _)) + check(forAll(prop[Int, Long, String]((leftDs, rightDs) => leftDs + .join(rightDs).full(_.col('a) === _.col('a))) _)) } test("ab.joinInner(ac)(ab.a == ac.a)") { @@ -80,11 +92,10 @@ class JoinTests extends TypedDatasetSuite { A : TypedEncoder : Ordering, B : TypedEncoder : Ordering, C : TypedEncoder : Ordering - ](left: List[X2[A, B]], right: List[X2[A, C]]): Prop = { + ](f: (TypedDataset[X2[A,B]], TypedDataset[X2[A,C]]) => TypedDataset[(X2[A,B], X2[A,C])])(left: List[X2[A, B]], right: List[X2[A, C]]): Prop = { val leftDs = TypedDataset.create(left) val rightDs = TypedDataset.create(right) - val joinedDs = leftDs - .joinInner(rightDs)(leftDs.col('a) === rightDs.col('a)) + val joinedDs = f(leftDs, rightDs) val joinedData = joinedDs.collect().run().toVector.sorted @@ -102,7 +113,16 @@ class JoinTests extends TypedDatasetSuite { (joined.sorted ?= joinedData) && equalSchemas } - check(forAll(prop[Int, Long, String] _)) + check(forAll(prop[Int, Long, String]((leftDs, rightDs) => leftDs + .joinInner(rightDs)(leftDs.col('a) === rightDs.col('a))) _)) + + import frameless.syntax.ChainedJoinSyntax + check(forAll(prop[Int, Long, String]((leftDs, rightDs) => leftDs + .join(rightDs).inner(leftDs.col('a) === rightDs.col('a))) _)) + check(forAll(prop[Int, Long, String]((leftDs, rightDs) => leftDs + .join(rightDs).inner(_.col('a) === rightDs.col('a))) _)) + check(forAll(prop[Int, Long, String]((leftDs, rightDs) => leftDs + .join(rightDs).inner(_.col('a) === _.col('a))) _)) } test("ab.joinLeft(ac)(ab.a == ac.a)") { @@ -110,11 +130,10 @@ class JoinTests extends TypedDatasetSuite { A : TypedEncoder : Ordering, B : TypedEncoder : Ordering, C : TypedEncoder : Ordering - ](left: List[X2[A, B]], right: List[X2[A, C]]): Prop = { + ](f: (TypedDataset[X2[A,B]], TypedDataset[X2[A,C]]) => TypedDataset[(X2[A,B], Option[X2[A,C]])])(left: List[X2[A, B]], right: List[X2[A, C]]): Prop = { val leftDs = TypedDataset.create(left) val rightDs = TypedDataset.create(right) - val joinedDs = leftDs - .joinLeft(rightDs)(leftDs.col('a) === rightDs.col('a)) + val joinedDs = f(leftDs,rightDs) val joinedData = joinedDs.collect().run().toVector.sorted @@ -137,7 +156,16 @@ class JoinTests extends TypedDatasetSuite { (joined.sorted ?= joinedData) && (joinedData.map(_._1).toSet ?= left.toSet) && equalSchemas } - check(forAll(prop[Int, Long, String] _)) + check(forAll(prop[Int, Long, String]((leftDs, rightDs) => leftDs + .joinLeft(rightDs)(leftDs.col('a) === rightDs.col('a))) _)) + + import frameless.syntax.ChainedJoinSyntax + check(forAll(prop[Int, Long, String]((leftDs, rightDs) => leftDs + .join(rightDs).left(leftDs.col('a) === rightDs.col('a))) _)) + check(forAll(prop[Int, Long, String]((leftDs, rightDs) => leftDs + .join(rightDs).left(_.col('a) === rightDs.col('a))) _)) + check(forAll(prop[Int, Long, String]((leftDs, rightDs) => leftDs + .join(rightDs).left(_.col('a) === _.col('a))) _)) } test("ab.joinLeftAnti(ac)(ab.a == ac.a)") { @@ -145,12 +173,11 @@ class JoinTests extends TypedDatasetSuite { A : TypedEncoder : Ordering, B : TypedEncoder : Ordering, C : TypedEncoder : Ordering - ](left: List[X2[A, B]], right: List[X2[A, C]]): Prop = { + ](f: (TypedDataset[X2[A,B]], TypedDataset[X2[A,C]]) => TypedDataset[X2[A,B]])(left: List[X2[A, B]], right: List[X2[A, C]]): Prop = { val leftDs = TypedDataset.create(left) val rightDs = TypedDataset.create(right) val rightKeys = right.map(_.a).toSet - val joinedDs = leftDs - .joinLeftAnti(rightDs)(leftDs.col('a) === rightDs.col('a)) + val joinedDs = f(leftDs, rightDs) val joinedData = joinedDs.collect().run().toVector.sorted @@ -165,7 +192,16 @@ class JoinTests extends TypedDatasetSuite { (joined.sorted ?= joinedData) && equalSchemas } - check(forAll(prop[Int, Long, String] _)) + check(forAll(prop[Int, Long, String]((leftDs,rightDs) => leftDs + .joinLeftAnti(rightDs)(leftDs.col('a) === rightDs.col('a))) _)) + + import frameless.syntax.ChainedJoinSyntax + check(forAll(prop[Int, Long, String]((leftDs, rightDs) => leftDs + .join(rightDs).leftAnti(leftDs.col('a) === rightDs.col('a))) _)) + check(forAll(prop[Int, Long, String]((leftDs, rightDs) => leftDs + .join(rightDs).leftAnti(_.col('a) === rightDs.col('a))) _)) + check(forAll(prop[Int, Long, String]((leftDs, rightDs) => leftDs + .join(rightDs).leftAnti(_.col('a) === _.col('a))) _)) } test("ab.joinLeftSemi(ac)(ab.a == ac.a)") { @@ -173,12 +209,11 @@ class JoinTests extends TypedDatasetSuite { A : TypedEncoder : Ordering, B : TypedEncoder : Ordering, C : TypedEncoder : Ordering - ](left: List[X2[A, B]], right: List[X2[A, C]]): Prop = { + ](f: (TypedDataset[X2[A,B]], TypedDataset[X2[A,C]]) => TypedDataset[X2[A,B]])(left: List[X2[A, B]], right: List[X2[A, C]]): Prop = { val leftDs = TypedDataset.create(left) val rightDs = TypedDataset.create(right) val rightKeys = right.map(_.a).toSet - val joinedDs = leftDs - .joinLeftSemi(rightDs)(leftDs.col('a) === rightDs.col('a)) + val joinedDs = f(leftDs, rightDs) val joinedData = joinedDs.collect().run().toVector.sorted @@ -193,7 +228,16 @@ class JoinTests extends TypedDatasetSuite { (joined.sorted ?= joinedData) && equalSchemas } - check(forAll(prop[Int, Long, String] _)) + check(forAll(prop[Int, Long, String]((leftDs, rightDs) => leftDs + .joinLeftSemi(rightDs)(leftDs.col('a) === rightDs.col('a))) _)) + + import frameless.syntax.ChainedJoinSyntax + check(forAll(prop[Int, Long, String]((leftDs, rightDs) => leftDs + .join(rightDs).leftSemi(leftDs.col('a) === rightDs.col('a))) _)) + check(forAll(prop[Int, Long, String]((leftDs, rightDs) => leftDs + .join(rightDs).leftSemi(_.col('a) === rightDs.col('a))) _)) + check(forAll(prop[Int, Long, String]((leftDs, rightDs) => leftDs + .join(rightDs).leftSemi(_.col('a) === _.col('a))) _)) } test("ab.joinRight(ac)(ab.a == ac.a)") { @@ -201,11 +245,10 @@ class JoinTests extends TypedDatasetSuite { A : TypedEncoder : Ordering, B : TypedEncoder : Ordering, C : TypedEncoder : Ordering - ](left: List[X2[A, B]], right: List[X2[A, C]]): Prop = { + ](f: (TypedDataset[X2[A,B]], TypedDataset[X2[A,C]]) => TypedDataset[(Option[X2[A,B]], X2[A,C])])(left: List[X2[A, B]], right: List[X2[A, C]]): Prop = { val leftDs = TypedDataset.create(left) val rightDs = TypedDataset.create(right) - val joinedDs = leftDs - .joinRight(rightDs)(leftDs.col('a) === rightDs.col('a)) + val joinedDs = f(leftDs, rightDs) val joinedData = joinedDs.collect().run().toVector.sorted @@ -228,6 +271,54 @@ class JoinTests extends TypedDatasetSuite { (joined.sorted ?= joinedData) && (joinedData.map(_._2).toSet ?= right.toSet) && equalSchemas } - check(forAll(prop[Int, Long, String] _)) + check(forAll(prop[Int, Long, String]((leftDs, rightDs) => + leftDs.joinRight(rightDs)(leftDs.col('a) === rightDs.col('a))) _)) + + import frameless.syntax.ChainedJoinSyntax + check(forAll(prop[Int, Long, String]((leftDs, rightDs) => + leftDs.join(rightDs).right(leftDs.col('a) === rightDs.col('a))) _)) + check(forAll(prop[Int, Long, String]((leftDs, rightDs) => + leftDs.join(rightDs).right(_.col('a) === rightDs.col('a))) _)) + check(forAll(prop[Int, Long, String]((leftDs, rightDs) => + leftDs.join(rightDs).right(_.col('a) === _.col('a))) _)) + } + + test("chained") { + def prop[ + A: TypedEncoder : Ordering, + B: TypedEncoder : Ordering, + C: TypedEncoder : Ordering, + D: TypedEncoder : Ordering + ](left: Seq[X2[A, B]], mid: Seq[X2[A, C]], right: Seq[X2[C, D]] ): Unit = { + val leftDs = TypedDataset.create(left) + val midDs = TypedDataset.create(mid) + val rightDs = TypedDataset.create(right) + /* orig + val joinedDs1 = leftDs + .joinRight(midDs)(leftDs.col('a) === midDs.col('a)) + val joinedDs = joinedDs1 + .joinRight(rightDs)(joinedDs1.col('_2).field('b) === rightDs.col('a)) + */ + import frameless.syntax.ChainedJoinSyntax + + // join right + val joinedDs = leftDs + .join(midDs).right(leftDs.col('a) === midDs.col('a)) + .join(rightDs).right(t => t.col('_2).field('b) === rightDs.col('a)) + + //joinedDs.show().run() + + val joinedData = joinedDs.collect().run().toVector.sorted + assert(joinedData == Seq( + (Some((Some(X2(1,1L)), X2(1,5L))), X2(5L, "5s")), + (Some((Some(X2(2,2L)), X2(2,6L))), X2(6L, "6s")), + (Some((Some(X2(3,3L)), X2(3,7L))), X2(7L, "7s")) + )) + + () + } + + prop[Int, Long, Long, String](Seq(X2(1,1L), X2(2,2L), X2(3,3L)), Seq(X2(1,5L), X2(2,6L), X2(3,7L)), + Seq(X2(5L, "5s"), X2(6L, "6s"), X2(7L, "7s"))) } }