Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change TypeDataset#apply syntax to use a function #110

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion dataset/src/main/scala/frameless/TypedColumn.scala
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,7 @@ object TypedColumn {
lgen: LabelledGeneric.Aux[T, H],
selector: Selector.Aux[H, K, V]
): Exists[T, K, V] = new Exists[T, K, V] {}

}

implicit class OrderedTypedColumnSyntax[T, U: CatalystOrdered](col: TypedColumn[T, U]) {
Expand All @@ -319,4 +320,4 @@ object TypedColumn {
def >(other: U): TypedColumn[T, Boolean] = (col.untyped > lit(other)(col.uencoder).untyped).typed
def >=(other: U): TypedColumn[T, Boolean] = (col.untyped >= lit(other)(col.uencoder).untyped).typed
}
}
}
6 changes: 2 additions & 4 deletions dataset/src/main/scala/frameless/TypedDataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -149,11 +149,9 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val
*
* It is statically checked that column with such name exists and has type `A`.
*/
def apply[A](column: Witness.Lt[Symbol])(
implicit
exists: TypedColumn.Exists[T, column.T, A],
def apply[A](selector: T => A)(implicit
encoder: TypedEncoder[A]
): TypedColumn[T, A] = col(column)
): TypedColumn[T, A] = macro frameless.column.ColumnMacros.fromFunction[T, A]

/** Returns `TypedColumn` of type `A` given it's name.
*
Expand Down
72 changes: 72 additions & 0 deletions dataset/src/main/scala/frameless/column/ColumnMacros.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package frameless.column

import frameless.{TypedColumn, TypedEncoder, TypedExpressionEncoder}
import shapeless.CaseClassMacros

import scala.collection.immutable.Queue
import scala.reflect.macros.whitebox

class ColumnMacros(val c: whitebox.Context) extends CaseClassMacros {
import c.universe._

// could be used to reintroduce apply('foo)
// $COVERAGE-OFF$ Currently unused
def fromSymbol[A : WeakTypeTag, B : WeakTypeTag](selector: c.Expr[scala.Symbol])(encoder: c.Expr[TypedEncoder[B]]): Tree = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why would we ever want to use this macro instead of the shapeless one?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason I left this here was in case we wanted to support either ds('a) or ds(_.a) at the same time. We can't do that with overloading, because it will ruin type inference for the function syntax. So if we really wanted to allow both, I thought we could have the macro figure it all out instead.

There are other problems with this, though - I would prefer to just embrace the function syntax because it has better type inference and about 95% smaller bytecode (after implicit expansion is all said and done).

val B = weakTypeOf[B].dealias
val witness = c.typecheck(q"_root_.shapeless.Witness.apply(${selector.tree})")
c.typecheck(q"${c.prefix}.col[$B]($witness)")
}
// $COVERAGE-ON$

def fromFunction[A : WeakTypeTag, B : WeakTypeTag](selector: c.Expr[A => B])(encoder: c.Expr[TypedEncoder[B]]): Tree = {
def fail(tree: Tree) = {
val err =
s"Could not create a column identifier from $tree - try using _.a.b syntax"
c.abort(tree.pos, err)
}

val A = weakTypeOf[A].dealias
val B = weakTypeOf[B].dealias

val selectorStr = selector.tree match {
case NameExtractor(str) => str
case Function(_, body) => fail(body)
// $COVERAGE-OFF$ - cannot be reached as typechecking will fail in this case before macro is even invoked
case other => fail(other)
// $COVERAGE-ON$
}

val typedCol = appliedType(
weakTypeOf[TypedColumn[_, _]].typeConstructor, A, B
)

val TEEObj = reify(TypedExpressionEncoder)

val datasetCol = c.typecheck(
q"${c.prefix}.dataset.col($selectorStr).as[$B]($TEEObj.apply[$B]($encoder))"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why you need an .as here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm.. it's to go from o.a.s.s.Column to o.a.s.s.TypedColumn. But you're right, it looks like you can make a frameless.TypedColumn from an ordinary Column. Can't remember what I thought the advantage would be in doing this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, if it's Spark's .as then it's not a problem! I thought it was one of ours that triggers implicit search & co, but it's not.

)

c.typecheck(q"new $typedCol($datasetCol)")
}

case class NameExtractor(name: TermName) { Self =>
def unapply(tree: Tree): Option[Queue[String]] = {
tree match {
case Ident(`name`) => Some(Queue.empty)
case s @ Select(Self(strs), nested) if s.symbol.isTerm && isCaseAccessorLike(s.symbol.asTerm) =>
Some(strs enqueue nested.toString)
// $COVERAGE-OFF$ - Not sure if this case can ever come up and Encoder will still work
case Apply(Self(strs), List()) => Some(strs)
// $COVERAGE-ON$
case _ => None
}
}
}

object NameExtractor {
def unapply(tree: Tree): Option[String] = tree match {
case Function(List(ValDef(_, name, argTyp, _)), body) => NameExtractor(name).unapply(body).map(_.mkString("."))
case _ => None
}
}
}
36 changes: 36 additions & 0 deletions dataset/src/test/scala/frameless/ColTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,42 @@ class ColTests extends TypedDatasetSuite {
()
}

test("colApply") {
val x4 = TypedDataset.create[X4[Int, String, Long, Boolean]](Nil)
val t4 = TypedDataset.create[(Int, String, Long, Boolean)](Nil)
val x4x4 = TypedDataset.create[X4X4[Int, String, Long, Boolean]](Nil)

x4(_.a)
t4(_._1)
x4[Int](_.a)
t4[Int](_._1)

illTyped("x4[String](_.a)", "type mismatch;\n found : Int\n required: String")

x4(_.b)
t4(_._2)

x4[String](_.b)
t4[String](_._2)

illTyped("x4[Int](_.b)", "type mismatch;\n found : String\n required: Int")

x4x4(_.xa.a)
x4x4[Int](_.xa.a)
x4x4(_.xa.b)
x4x4[String](_.xa.b)

x4x4(_.xb.a)
x4x4[Int](_.xb.a)
x4x4(_.xb.b)
x4x4[String](_.xb.b)

illTyped("x4x4[String](_.xa.a)", "type mismatch;\n found : Int\n required: String")
illTyped("x4x4(item => item.xa.a * 20)", "Could not create a column identifier from item\\.xa\\.a\\.\\*\\(20\\) - try using _\\.a\\.b syntax")

()
}

test("colMany") {
type X2X2 = X2[X2[Int, String], X2[Long, Boolean]]
val x2x2 = TypedDataset.create[X2X2](Nil)
Expand Down
2 changes: 1 addition & 1 deletion dataset/src/test/scala/frameless/ColumnTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,6 @@ class ColumnTests extends TypedDatasetSuite {

test("toString") {
val t = TypedDataset.create((1,2)::Nil)
t('_1).toString ?= t.dataset.col("_1").toString()
t(_._1).toString ?= t.dataset.col("_1").toString()
}
}
8 changes: 4 additions & 4 deletions dataset/src/test/scala/frameless/FilterTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,15 @@ class FilterTests extends TypedDatasetSuite {
test("filter with arithmetic expressions: addition") {
check(forAll { (data: Vector[X1[Int]]) =>
val ds = TypedDataset.create(data)
val res = ds.filter((ds('a) + 1) === (ds('a) + 1)).collect().run().toVector
val res = ds.filter((ds(_.a) + 1) === (ds(_.a) + 1)).collect().run().toVector
res ?= data
})
}

test("filter with values (not columns): addition") {
check(forAll { (data: Vector[X1[Int]], const: Int) =>
val ds = TypedDataset.create(data)
val res = ds.filter(ds('a) > const).collect().run().toVector
val res = ds.filter(ds(_.a) > const).collect().run().toVector
res ?= data.filter(_.a > const)
})
}
Expand All @@ -76,8 +76,8 @@ class FilterTests extends TypedDatasetSuite {
val t = X1(1) :: X1(2) :: X1(3) :: Nil
val tds: TypedDataset[X1[Int]] = TypedDataset.create(t)

assert(tds.filter(tds('a) * 2 === 2).collect().run().toVector === Vector(X1(1)))
assert(tds.filter(tds('a) * 3 === 3).collect().run().toVector === Vector(X1(1)))
assert(tds.filter(tds(_.a) * 2 === 2).collect().run().toVector === Vector(X1(1)))
assert(tds.filter(tds(_.a) * 3 === 3).collect().run().toVector === Vector(X1(1)))
}

test("Option equality/inequality for columns") {
Expand Down
22 changes: 10 additions & 12 deletions dataset/src/test/scala/frameless/SelectTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -276,9 +276,7 @@ class SelectTests extends TypedDatasetSuite {
cb: ClassTag[B]
): Prop = {
val dataset = TypedDataset.create(data)
val AB = dataset.colMany('a, 'b)

val dataset2 = dataset.select(AB).collect().run().toVector
val dataset2 = dataset.select(dataset(_.a.b)).collect().run().toVector
val data2 = data.map { case X2(X2(_, b), _) => b }

dataset2 ?= data2
Expand All @@ -297,7 +295,7 @@ class SelectTests extends TypedDatasetSuite {
): Prop = {
val ds = TypedDataset.create(data)

val dataset2 = ds.select(ds('a) + const).collect().run().toVector
val dataset2 = ds.select(ds(_.a) + const).collect().run().toVector
val data2 = data.map { case X1(a) => num.plus(a, const) }

dataset2 ?= data2
Expand All @@ -319,7 +317,7 @@ class SelectTests extends TypedDatasetSuite {
): Prop = {
val ds = TypedDataset.create(data)

val dataset2 = ds.select(ds('a) * const).collect().run().toVector
val dataset2 = ds.select(ds(_.a) * const).collect().run().toVector
val data2 = data.map { case X1(a) => num.times(a, const) }

dataset2 ?= data2
Expand All @@ -341,7 +339,7 @@ class SelectTests extends TypedDatasetSuite {
): Prop = {
val ds = TypedDataset.create(data)

val dataset2 = ds.select(ds('a) - const).collect().run().toVector
val dataset2 = ds.select(ds(_.a) - const).collect().run().toVector
val data2 = data.map { case X1(a) => num.minus(a, const) }

dataset2 ?= data2
Expand All @@ -363,7 +361,7 @@ class SelectTests extends TypedDatasetSuite {
val ds = TypedDataset.create(data)

if (const != 0) {
val dataset2 = ds.select(ds('a) / const).collect().run().toVector.asInstanceOf[Vector[A]]
val dataset2 = ds.select(ds(_.a) / const).collect().run().toVector.asInstanceOf[Vector[A]]
val data2 = data.map { case X1(a) => frac.div(a, const) }
dataset2 ?= data2
} else 0 ?= 0
Expand All @@ -379,22 +377,22 @@ class SelectTests extends TypedDatasetSuite {
assert(t.select(t.col('_1)).collect().run().toList === List(2))
// Issue #54
val fooT = t.select(t.col('_1)).deserialized.map(x => Tuple1.apply(x)).as[Foo]
assert(fooT.select(fooT('i)).collect().run().toList === List(2))
assert(fooT.select(fooT(_.i)).collect().run().toList === List(2))
}

test("unary - on arithmetic") {
val e = TypedDataset.create[(Int, String, Long)]((1, "a", 2L) :: (2, "b", 4L) :: (2, "b", 1L) :: Nil)
assert(e.select(-e('_1)).collect().run().toVector === Vector(-1, -2, -2))
assert(e.select(-(e('_1) + e('_3))).collect().run().toVector === Vector(-3L, -6L, -3L))
assert(e.select(-e(_._1)).collect().run().toVector === Vector(-1, -2, -2))
assert(e.select(-(e(_._1) + e(_._3))).collect().run().toVector === Vector(-3L, -6L, -3L))
}

test("unary - on strings should not type check") {
val e = TypedDataset.create[(Int, String, Long)]((1, "a", 2L) :: (2, "b", 4L) :: (2, "b", 1L) :: Nil)
illTyped("""e.select( -e('_2) )""")
illTyped("""e.select( -e(_._2) )""")
}

test("select with aggregation operations is not supported") {
val e = TypedDataset.create[(Int, String, Long)]((1, "a", 2L) :: (2, "b", 4L) :: (2, "b", 1L) :: Nil)
illTyped("""e.select(frameless.functions.aggregate.sum(e('_1)))""")
illTyped("""e.select(frameless.functions.aggregate.sum(e(_._1)))""")
}
}
10 changes: 5 additions & 5 deletions dataset/src/test/scala/frameless/WithColumnTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ class WithColumnTest extends TypedDatasetSuite {
test("append five columns") {
def prop[A: TypedEncoder](value: A): Prop = {
val d = TypedDataset.create(X1(value) :: Nil)
val d1 = d.withColumn(d('a))
val d2 = d1.withColumn(d1('_1))
val d3 = d2.withColumn(d2('_2))
val d4 = d3.withColumn(d3('_3))
val d5 = d4.withColumn(d4('_4))
val d1 = d.withColumn(d(_.a))
val d2 = d1.withColumn(d1(_._1))
val d3 = d2.withColumn(d2(_._2))
val d4 = d3.withColumn(d3(_._3))
val d5 = d4.withColumn(d4(_._4))

(value, value, value, value, value, value) ?= d5.collect().run().head
}
Expand Down
2 changes: 2 additions & 0 deletions dataset/src/test/scala/frameless/XN.scala
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,5 @@ object X5 {
implicit def ordering[A: Ordering, B: Ordering, C: Ordering, D: Ordering, E: Ordering]: Ordering[X5[A, B, C, D, E]] =
Ordering.Tuple5[A, B, C, D, E].on(x => (x.a, x.b, x.c, x.d, x.e))
}

case class X4X4[A, B, C, D](xa: X4[A, B, C, D], xb: X4[A, B, C, D])
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ class AggregateFunctionsTests extends TypedDatasetSuite {
check {
forAll(getLowCardinalityKVPairs) { xs: Vector[(Int, Int)] =>
val tds = TypedDataset.create(xs)
val tdsRes: Seq[(Int, Long)] = tds.groupBy(tds('_1)).agg(countDistinct(tds('_2))).collect().run()
val tdsRes: Seq[(Int, Long)] = tds.groupBy(tds(_._1)).agg(countDistinct(tds(_._2))).collect().run()
tdsRes.toMap ?= xs.groupBy(_._1).mapValues(_.map(_._2).distinct.size.toLong).toSeq.toMap
}
}
Expand All @@ -300,7 +300,7 @@ class AggregateFunctionsTests extends TypedDatasetSuite {
forAll(getLowCardinalityKVPairs) { xs: Vector[(Int, Int)] =>
val tds = TypedDataset.create(xs)
val tdsRes: Seq[(Int, Long, Long)] =
tds.groupBy(tds('_1)).agg(countDistinct(tds('_2)), approxCountDistinct(tds('_2))).collect().run()
tds.groupBy(tds(_._1)).agg(countDistinct(tds(_._2)), approxCountDistinct(tds(_._2))).collect().run()
tdsRes.forall { case (_, v1, v2) => approxEqual(v1, v2) }
}
}
Expand All @@ -310,7 +310,7 @@ class AggregateFunctionsTests extends TypedDatasetSuite {
val tds = TypedDataset.create(xs)
val allowedError = 0.1 // 10%
val tdsRes: Seq[(Int, Long, Long)] =
tds.groupBy(tds('_1)).agg(countDistinct(tds('_2)), approxCountDistinct(tds('_2), allowedError)).collect().run()
tds.groupBy(tds(_._1)).agg(countDistinct(tds(_._2)), approxCountDistinct(tds(_._2), allowedError)).collect().run()
tdsRes.forall { case (_, v1, v2) => approxEqual(v1, v2, allowedError) }
}
}
Expand All @@ -319,7 +319,7 @@ class AggregateFunctionsTests extends TypedDatasetSuite {
test("collectList") {
def prop[A: TypedEncoder : Ordering](xs: List[X2[A, A]]): Prop = {
val tds = TypedDataset.create(xs)
val tdsRes: Seq[(A, Vector[A])] = tds.groupBy(tds('a)).agg(collectList(tds('b))).collect().run()
val tdsRes: Seq[(A, Vector[A])] = tds.groupBy(tds(_.a)).agg(collectList(tds(_.b))).collect().run()

tdsRes.toMap.mapValues(_.sorted) ?= xs.groupBy(_.a).mapValues(_.map(_.b).toVector.sorted)
}
Expand All @@ -333,7 +333,7 @@ class AggregateFunctionsTests extends TypedDatasetSuite {
test("collectSet") {
def prop[A: TypedEncoder : Ordering](xs: List[X2[A, A]]): Prop = {
val tds = TypedDataset.create(xs)
val tdsRes: Seq[(A, Vector[A])] = tds.groupBy(tds('a)).agg(collectSet(tds('b))).collect().run()
val tdsRes: Seq[(A, Vector[A])] = tds.groupBy(tds(_.a)).agg(collectSet(tds(_.b))).collect().run()

tdsRes.toMap.mapValues(_.toSet) ?= xs.groupBy(_.a).mapValues(_.map(_.b).toSet)
}
Expand All @@ -347,7 +347,7 @@ class AggregateFunctionsTests extends TypedDatasetSuite {
test("lit") {
def prop[A: TypedEncoder](xs: List[X1[A]], l: A): Prop = {
val tds = TypedDataset.create(xs)
tds.select(tds('a), lit(l)).collect().run() ?= xs.map(x => (x.a, l))
tds.select(tds(_.a), lit(l)).collect().run() ?= xs.map(x => (x.a, l))
}

check(forAll(prop[Long] _))
Expand Down Expand Up @@ -379,7 +379,7 @@ class AggregateFunctionsTests extends TypedDatasetSuite {

val tds = TypedDataset.create(xs)
// Typed implementation of bivar stats function
val tdBivar = tds.groupBy(tds('a)).agg(framelessFun(tds('b), tds('c))).deserialized.map(kv =>
val tdBivar = tds.groupBy(tds(_.a)).agg(framelessFun(tds(_.b), tds(_.c))).deserialized.map(kv =>
(kv._1, kv._2.flatMap(DoubleBehaviourUtils.nanNullHandler))
).collect().run()

Expand Down Expand Up @@ -416,7 +416,7 @@ class AggregateFunctionsTests extends TypedDatasetSuite {

val tds = TypedDataset.create(xs)
//typed implementation of univariate stats function
val tdUnivar = tds.groupBy(tds('a)).agg(framelessFun(tds('b))).deserialized.map(kv =>
val tdUnivar = tds.groupBy(tds(_.a)).agg(framelessFun(tds(_.b))).deserialized.map(kv =>
(kv._1, kv._2.flatMap(DoubleBehaviourUtils.nanNullHandler))
).collect().run()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class UnaryFunctionsTest extends TypedDatasetSuite {
def prop[A: TypedEncoder](xs: List[X1[Vector[A]]]): Prop = {
val tds = TypedDataset.create(xs)

val framelessResults = tds.select(size(tds('a))).collect().run().toVector
val framelessResults = tds.select(size(tds(_.a))).collect().run().toVector
val scalaResults = xs.map(x => x.a.size).toVector

framelessResults ?= scalaResults
Expand All @@ -26,7 +26,7 @@ class UnaryFunctionsTest extends TypedDatasetSuite {
def prop[A: TypedEncoder : Ordering](xs: List[X1[Vector[A]]]): Prop = {
val tds = TypedDataset.create(xs)

val framelessResults = tds.select(sortAscending(tds('a))).collect().run().toVector
val framelessResults = tds.select(sortAscending(tds(_.a))).collect().run().toVector
val scalaResults = xs.map(x => x.a.sorted).toVector

framelessResults ?= scalaResults
Expand All @@ -42,7 +42,7 @@ class UnaryFunctionsTest extends TypedDatasetSuite {
def prop[A: TypedEncoder : Ordering](xs: List[X1[Vector[A]]]): Prop = {
val tds = TypedDataset.create(xs)

val framelessResults = tds.select(sortDescending(tds('a))).collect().run().toVector
val framelessResults = tds.select(sortDescending(tds(_.a))).collect().run().toVector
val scalaResults = xs.map(x => x.a.sorted.reverse).toVector

framelessResults ?= scalaResults
Expand All @@ -58,7 +58,7 @@ class UnaryFunctionsTest extends TypedDatasetSuite {
def prop[A: TypedEncoder](xs: List[X1[Vector[A]]]): Prop = {
val tds = TypedDataset.create(xs)

val framelessResults = tds.select(explode(tds('a))).collect().run().toSet
val framelessResults = tds.select(explode(tds(_.a))).collect().run().toSet
val scalaResults = xs.flatMap(_.a).toSet

framelessResults ?= scalaResults
Expand Down
2 changes: 1 addition & 1 deletion dataset/src/test/scala/frameless/ops/ColumnTypesTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class ColumnTypesTest extends TypedDatasetSuite {
test("test summoning") {
def prop[A: TypedEncoder, B: TypedEncoder, C: TypedEncoder, D: TypedEncoder](data: Vector[X4[A, B, C, D]]): Prop = {
val d: TypedDataset[X4[A, B, C, D]] = TypedDataset.create(data)
val hlist = d('a) :: d('b) :: d('c) :: d('d) :: HNil
val hlist = d(_.a) :: d(_.b) :: d(_.c) :: d(_.d) :: HNil

type TC[N] = TypedColumn[X4[A,B,C,D], N]

Expand Down
Loading