diff --git a/core/src/main/scala/frameless/CatalystOrdered.scala b/core/src/main/scala/frameless/CatalystOrdered.scala index 4943e09f..b2223225 100644 --- a/core/src/main/scala/frameless/CatalystOrdered.scala +++ b/core/src/main/scala/frameless/CatalystOrdered.scala @@ -22,6 +22,7 @@ object CatalystOrdered { implicit val framelessBigDecimalOrdered : CatalystOrdered[BigDecimal] = of[BigDecimal] implicit val framelessSQLDateOrdered : CatalystOrdered[SQLDate] = of[SQLDate] implicit val framelessSQLTimestampOrdered: CatalystOrdered[SQLTimestamp] = of[SQLTimestamp] + implicit val framelessDateOrdered : CatalystOrdered[java.sql.Date]= of[java.sql.Date] implicit val framelessStringOrdered : CatalystOrdered[String] = of[String] implicit def injectionOrdered[A, B] diff --git a/dataset/src/main/scala/frameless/TypedEncoder.scala b/dataset/src/main/scala/frameless/TypedEncoder.scala index 346a03d1..2e71050b 100644 --- a/dataset/src/main/scala/frameless/TypedEncoder.scala +++ b/dataset/src/main/scala/frameless/TypedEncoder.scala @@ -5,7 +5,7 @@ import org.apache.spark.sql.FramelessInternals.UserDefinedType import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects._ -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import shapeless._ @@ -191,6 +191,32 @@ object TypedEncoder { ) } + implicit val javaSqlDate: TypedEncoder[java.sql.Date] = new TypedEncoder[java.sql.Date] { + def nullable: Boolean = false + + def jvmRepr: DataType = ScalaReflection.dataTypeFor[java.sql.Date] + + def catalystRepr: DataType = DateType + + def toCatalyst(path: Expression): Expression = + StaticInvoke( + DateTimeUtils.getClass, + ObjectType(classOf[java.sql.Date]), + "fromJavaDate", + path :: Nil, + propagateNull = true + ) + + def fromCatalyst(path: Expression): Expression = + StaticInvoke( + DateTimeUtils.getClass, + ObjectType(classOf[java.sql.Date]), + "toJavaDate", + path :: Nil, + propagateNull = true + ) + } + implicit val sqlTimestamp: TypedEncoder[SQLTimestamp] = new TypedEncoder[SQLTimestamp] { def nullable: Boolean = false diff --git a/dataset/src/main/scala/frameless/functions/NonAggregateFunctions.scala b/dataset/src/main/scala/frameless/functions/NonAggregateFunctions.scala index 6ab1f2c2..553a2e1b 100644 --- a/dataset/src/main/scala/frameless/functions/NonAggregateFunctions.scala +++ b/dataset/src/main/scala/frameless/functions/NonAggregateFunctions.scala @@ -421,10 +421,95 @@ trait NonAggregateFunctions { /** Non-Aggregate function: Extracts the year as an integer from a given date/timestamp/string. * - * Differs from `Column#year` by wrapping it's result into an `Option`. + * Differs from `Column#year` by wrapping it's result into an `Option` in case column + * cannot be parsed into valid date. * * apache/spark */ - def year[T](str: AbstractTypedColumn[T, String]): str.ThisType[T, Option[Int]] = - str.typed(sparkFunctions.year(str.untyped)) -} + def year[T](date: AbstractTypedColumn[T, String]): date.ThisType[T, Option[Int]] = + date.typed(sparkFunctions.year(date.untyped)) + + /** Non-Aggregate function: Extracts the day of the year as an integer from a given date/timestamp/string. + * + * Differs from `Column#dayofyear` by wrapping it's result into an `Option` in case the column + * cannot be parsed into valid date. + * + * apache/spark + */ + def dayofyear[T](date: AbstractTypedColumn[T, String]): date.ThisType[T, Option[Int]] = + date.typed(sparkFunctions.dayofyear(date.untyped)) + + /** Non-Aggregate function: Extracts the week number as an integer from a given date/timestamp/string. + * + * Differs from `Column#weekofyear` by wrapping it's result into an `Option` in case the column + * cannot be parsed into valid date. + * + * apache/spark + */ + def weekofyear[T](date: AbstractTypedColumn[T, String]): date.ThisType[T, Option[Int]] = + date.typed(sparkFunctions.weekofyear(date.untyped)) + + /** Non-Aggregate function: Extracts the month as an integer from a given date/timestamp/string. + * + * Differs from `Column#month` by wrapping it's result into an `Option` in case the column + * cannot be parsed into valid date. + * + * apache/spark + */ + def month[T](date: AbstractTypedColumn[T, String]): date.ThisType[T, Option[Int]] = + date.typed(sparkFunctions.month(date.untyped)) + + /** Non-Aggregate function: Extracts the day of the month as an integer from a given date/timestamp/string. + * + * Differs from `Column#dayofmonth` by wrapping it's result into an `Option` in case the column + * cannot be parsed into valid date. + * + * apache/spark + */ + def dayofmonth[T](date: AbstractTypedColumn[T, String]): date.ThisType[T, Option[Int]] = + date.typed(sparkFunctions.dayofmonth(date.untyped)) + + /** Non-Aggregate function: Extracts the minutes as an integer from a given date/timestamp/string. + * + * Differs from `Column#minute` by wrapping it's result into an `Option` in case the column + * cannot be parsed into valid date. + * + * apache/spark + */ + def minute[T](date: AbstractTypedColumn[T, String]): date.ThisType[T, Option[Int]] = + date.typed(sparkFunctions.minute(date.untyped)) + + /** Non-Aggregate function: Extracts the seconds as an integer from a given date/timestamp/string. + * + * Differs from `Column#second` by wrapping it's result into an `Option` in case the column + * cannot be parsed into valid date. + * + * apache/spark + */ + def second[T](date: AbstractTypedColumn[T, String]): date.ThisType[T, Option[Int]] = + date.typed(sparkFunctions.second(date.untyped)) + + /** + * Non-Aggregate function: Given a date column, returns the first date which is later than the value + * of the date column that is on the specified day of the week. + * + * For example, `next_day('2015-07-27', "Sunday")` returns 2015-08-02 because that is the first + * Sunday after 2015-07-27. + * + * Day of the week parameter is case insensitive, and accepts: + * "Su", "Sun", "Sunday", + * "Mo", "Mon", "Monday", + * "Tu", "Tue", "Tuesday", + * "We", "Wed", "Wednesday", + * "Th", "Thu", "Thursday", + * "Fr", "Fri", "Friday", + * "Sa", "Sat", "Saturday". + * + * Differs from `Column#next_day` by wrapping it's result into an `Option` in case the column + * cannot be parsed into valid date. + * + * apache/spark + */ + def next_day[T](date: AbstractTypedColumn[T, String], dayOfWeek: String): date.ThisType[T, Option[java.sql.Date]] = + date.typed(sparkFunctions.next_day(date.untyped, dayOfWeek)) +} \ No newline at end of file diff --git a/dataset/src/test/scala/frameless/functions/NonAggregateFunctionsTests.scala b/dataset/src/test/scala/frameless/functions/NonAggregateFunctionsTests.scala index fca4908d..d6160370 100644 --- a/dataset/src/test/scala/frameless/functions/NonAggregateFunctionsTests.scala +++ b/dataset/src/test/scala/frameless/functions/NonAggregateFunctionsTests.scala @@ -16,7 +16,7 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { FileUtils.deleteDirectory(new File(testTempFiles)) super.afterAll() } - + def propBitShift[A: TypedEncoder : Encoder, B: TypedEncoder : Encoder](typedDS: TypedDataset[X1[A]]) (typedCol: TypedColumn[X1[A], B], sparkFunc: (Column,Int) => Column, numBits: Int): Prop = { val spark = session @@ -388,8 +388,8 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { check( forAll( - Gen.listOfN(listLength, Gen.choose(0,100)), - Gen.oneOf(true,false) + Gen.listOfN(listLength, Gen.choose(0, 100)), + Gen.oneOf(true, false) ) (prop[List]) ) @@ -1216,7 +1216,7 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { test("Empty vararg tests") { import frameless.functions.aggregate._ - def prop[A : TypedEncoder, B: TypedEncoder](data: Vector[X2[A, B]]) = { + def prop[A: TypedEncoder, B: TypedEncoder](data: Vector[X2[A, B]]) = { val ds = TypedDataset.create(data) val frameless = ds.select(ds('a), concat(), ds('b), concatWs(":")).collect().run().toVector val framelessAggr = ds.agg(first(ds('a)), concat(), concatWs("x"), litAggr(2)).collect().run().toVector @@ -1230,34 +1230,160 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { //check(forAll(prop[Option[Vector[Boolean]], Long] _)) } + def propYear(typedDS: TypedDataset[X1[String]]) + (typedCol: TypedColumn[X1[String], Option[Int]], sparkFunc: Column => Column): Prop = { + val spark = session + import spark.implicits._ + + val nullHandler: Row => Option[Int] = _.get(0) match { + case i: Int => Some(i) + case _ => None + } + + val sparkResult = typedDS.dataset + .select(sparkFunc($"a")) + .map(nullHandler) + .collect() + .toList + + val typed = typedDS + .select(typedCol) + .collect() + .run() + .toList + + typed ?= sparkResult + } + test("year") { val spark = session import spark.implicits._ - val nullHandler: Row => Option[Int] = _.get(0) match { - case i: Int => Some(i) - case _ => None + def prop(data: List[X1[String]])(implicit E: Encoder[Option[Int]]): Prop = { + val ds = TypedDataset.create(data) + propYear(ds)(year(ds[String]('a)), sparkFunctions.year) + } + + check(forAll(dateTimeStringGen)(data => prop(data.map(X1.apply)))) + check(forAll(prop _)) + } + + test("dayofyear") { + val spark = session + import spark.implicits._ + + def prop(data: List[X1[String]])(implicit E: Encoder[Option[Int]]): Prop = { + val ds = TypedDataset.create(data) + propYear(ds)(dayofyear(ds[String]('a)), sparkFunctions.dayofyear) } + check(forAll(dateTimeStringGen)(data => prop(data.map(X1.apply)))) + check(forAll(prop _)) + } + + test("weekofyear") { + val spark = session + import spark.implicits._ + def prop(data: List[X1[String]])(implicit E: Encoder[Option[Int]]): Prop = { - val ds = TypedDataset.create(data) + val ds = TypedDataset.create(data) + propYear(ds)(weekofyear(ds[String]('a)), sparkFunctions.weekofyear) + } - val sparkResult = ds.toDF() - .select(sparkFunctions.year($"a")) - .map(nullHandler) - .collect() - .toList + check(forAll(dateTimeStringGen)(data => prop(data.map(X1.apply)))) + check(forAll(prop _)) + } - val typed = ds - .select(year(ds[String]('a))) - .collect() - .run() - .toList + test("month") { + val spark = session + import spark.implicits._ - typed ?= sparkResult - } + def prop(data: List[X1[String]])(implicit E: Encoder[Option[Int]]): Prop = { + val ds = TypedDataset.create(data) + propYear(ds)(month(ds[String]('a)), sparkFunctions.month) + } + + check(forAll(dateTimeStringGen)(data => prop(data.map(X1.apply)))) + check(forAll(prop _)) + } + + test("dayofmonth") { + val spark = session + import spark.implicits._ + + def prop(data: List[X1[String]])(implicit E: Encoder[Option[Int]]): Prop = { + val ds = TypedDataset.create(data) + propYear(ds)(dayofmonth(ds[String]('a)), sparkFunctions.dayofmonth) + } + + check(forAll(dateTimeStringGen)(data => prop(data.map(X1.apply)))) + check(forAll(prop _)) + } + + test("minute") { + val spark = session + import spark.implicits._ + + def prop(data: List[X1[String]])(implicit E: Encoder[Option[Int]]): Prop = { + val ds = TypedDataset.create(data) + propYear(ds)(minute(ds[String]('a)), sparkFunctions.minute) + } + + check(forAll(dateTimeStringGen)(data => prop(data.map(X1.apply)))) + check(forAll(prop _)) + } + + test("second") { + val spark = session + import spark.implicits._ + + def prop(data: List[X1[String]])(implicit E: Encoder[Option[Int]]): Prop = { + val ds = TypedDataset.create(data) + propYear(ds)(second(ds[String]('a)), sparkFunctions.second) + } check(forAll(dateTimeStringGen)(data => prop(data.map(X1.apply)))) check(forAll(prop _)) } -} + + test("next_day") { + val spark = session + import spark.implicits._ + + val weekDays = List( + "SU", "SUN", "SUNDAY", + "MO", "MON", "MONDAY", + "TU", "TUE", "TUESDAY", + "WE", "WED", "WEDNESDAY", + "TH", "THU", "THURSDAY", + "FR", "FRI", "FRIDAY", + "SA", "SAT", "SATURDAY" + ) + + val nullHandler: Row => Option[java.sql.Date] = _.get(0) match { + case d: java.sql.Date => Some(d) + case _ => None + } + + def prop(data: List[X1[String]], dayOfWeek: String)(implicit E: TypedEncoder[Option[java.sql.Date]]): Prop = { + val typedDS = TypedDataset.create(data) + + val sparkResult = typedDS.toDF() + .select(sparkFunctions.next_day($"a", dayOfWeek)) + .map(nullHandler) + .collect() + .toList + + val typed = typedDS + .select(next_day(typedDS[String]('a), dayOfWeek)) + .collect() + .run() + .toList + + typed ?= sparkResult + } + + check(forAll(dateTimeStringGen, Gen.oneOf(weekDays))((data, dayOfWeek) => prop(data.map(X1.apply), dayOfWeek))) + check(forAll(prop _)) + } +} \ No newline at end of file