Skip to content

Commit

Permalink
typelevel#385 - safety save
Browse files Browse the repository at this point in the history
  • Loading branch information
chris-twiner committed Jun 5, 2023
1 parent c8ecea8 commit 2763ec6
Showing 1 changed file with 59 additions and 0 deletions.
59 changes: 59 additions & 0 deletions dataset/src/test/scala/frameless/JoinTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -230,4 +230,63 @@ class JoinTests extends TypedDatasetSuite {

check(forAll(prop[Int, Long, String] _))
}

case class Joiner[T, U](ds: TypedDataset[T], other: TypedDataset[U]) {
def right(conditionF: TypedDataset[T] => TypedColumn[T with U, Boolean])(implicit e: TypedEncoder[(Option[T], U)]): TypedDataset[(Option[T], U)] =
ds.joinRight(other)(conditionF(ds))

def right(condition: TypedColumn[T with U, Boolean])(implicit e: TypedEncoder[(Option[T], U)]): TypedDataset[(Option[T], U)] =
ds.joinRight(other)(condition)
}

implicit class Ops[T](ds: TypedDataset[T]) {

def join[U](other: TypedDataset[U]): Joiner[T, U] = new Joiner[T,U](ds, other)

def joinRightChained[U](other: TypedDataset[U])(conditionF: TypedDataset[T] => TypedColumn[T with U, Boolean])(implicit e: TypedEncoder[(Option[T], U)]): TypedDataset[(Option[T], U)] =
ds.joinRight(other)(conditionF(ds))
}

test("chained") {
def prop[
A: TypedEncoder : Ordering,
B: TypedEncoder : Ordering,
C: TypedEncoder : Ordering,
D: TypedEncoder : Ordering
](left: Seq[X2[A, B]], mid: Seq[X2[A, C]], right: Seq[X2[C, D]] ): Unit = {
val leftDs = TypedDataset.create(left)
val midDs = TypedDataset.create(mid)
val rightDs = TypedDataset.create(right)
/* orig
val joinedDs1 = leftDs
.joinRight(midDs)(leftDs.col('a) === midDs.col('a))
val joinedDs = joinedDs1
.joinRight(rightDs)(joinedDs1.col('_2).field('b) === rightDs.col('a))
*/

/* joinRightChained
val joinedDs = leftDs
.joinRight(midDs)(leftDs.col('a) === midDs.col('a))
.joinRightChained(rightDs)(_.col('_2).field('b) === rightDs.col('a))
*/
// join right
val joinedDs = leftDs
.join(midDs).right(leftDs.col('a) === midDs.col('a))
.join(rightDs).right(_.col('_2).field('b) === rightDs.col('a))

//joinedDs.show().run()

val joinedData = joinedDs.collect().run().toVector.sorted
assert(joinedData == Seq(
(Some((Some(X2(1,1L)), X2(1,5L))), X2(5L, "5s")),
(Some((Some(X2(2,2L)), X2(2,6L))), X2(6L, "6s")),
(Some((Some(X2(3,3L)), X2(3,7L))), X2(7L, "7s"))
))

()
}

prop[Int, Long, Long, String](Seq(X2(1,1L), X2(2,2L), X2(3,3L)), Seq(X2(1,5L), X2(2,6L), X2(3,7L)),
Seq(X2(5L, "5s"), X2(6L, "6s"), X2(7L, "7s")))
}
}

0 comments on commit 2763ec6

Please sign in to comment.