From d26eb08a2105e9ba051589c7535dce896227c818 Mon Sep 17 00:00:00 2001 From: Ganesh Mahadevan Date: Wed, 7 Aug 2024 12:15:08 -0500 Subject: [PATCH 1/5] =?UTF-8?q?SNOW-802269=20Add=20ordering=20and=20size?= =?UTF-8?q?=20function=20for=20scala=20and=20java=20modules=E2=80=A6=20(#1?= =?UTF-8?q?37)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit SNOW-802269 Add ordering and size function for scala and java modules (#133) * add java and scala size and ordering functions * add scala unit test for ordering and size function * update comments and add example * add java test cases * fix comments --------- Co-authored-by: sfc-gh-mrojas --- .../snowflake/snowpark_java/Functions.java | 77 +++++++++++++++++++ .../com/snowflake/snowpark/functions.scala | 67 ++++++++++++++++ .../snowpark_test/JavaFunctionSuite.java | 26 +++++++ .../snowpark_test/FunctionSuite.scala | 29 +++++++ 4 files changed, 199 insertions(+) diff --git a/src/main/java/com/snowflake/snowpark_java/Functions.java b/src/main/java/com/snowflake/snowpark_java/Functions.java index 56d8d08b..dbadd87b 100644 --- a/src/main/java/com/snowflake/snowpark_java/Functions.java +++ b/src/main/java/com/snowflake/snowpark_java/Functions.java @@ -2,6 +2,7 @@ import static com.snowflake.snowpark.internal.OpenTelemetry.javaUDF; +import com.snowflake.snowpark.functions; import com.snowflake.snowpark.internal.JavaUtils; import com.snowflake.snowpark_java.types.DataType; import com.snowflake.snowpark_java.udf.*; @@ -3880,6 +3881,82 @@ public static Column listagg(Column col) { return new Column(com.snowflake.snowpark.functions.listagg(col.toScalaColumn())); } + /** + * Returns a Column expression with values sorted in descending order. + * + *

Example: order column values in descending + * + *

{@code
+   * DataFrame df = getSession().sql("select * from values(1),(2),(3) as t(a)");
+   * df.sort(Functions.desc("a")).show();
+   * -------
+   * |"A"  |
+   * -------
+   * |3    |
+   * |2    |
+   * |1    |
+   * -------
+   * }
+ * + * @since 1.14.0 + * @param name The input column name + * @return Column object ordered in descending manner. + */ + public static Column desc(String name) { + return new Column(functions.desc(name)); + } + + /** + * Returns a Column expression with values sorted in ascending order. + * + *

Example: order column values in ascending + * + *

{@code
+   * DataFrame df = getSession().sql("select * from values(3),(1),(2) as t(a)");
+   * df.sort(Functions.asc("a")).show();
+   * -------
+   * |"A"  |
+   * -------
+   * |1    |
+   * |2    |
+   * |3    |
+   * -------
+   * }
+ * + * @since 1.14.0 + * @param name The input column name + * @return Column object ordered in ascending manner. + */ + public static Column asc(String name) { + return new Column(functions.asc(name)); + } + + /** + * Returns the size of the input ARRAY. + * + *

If the specified column contains a VARIANT value that contains an ARRAY, the size of the + * ARRAY is returned; otherwise, NULL is returned if the value is not an ARRAY. + * + *

Example: calculate size of the array in a column + * + *

{@code
+   * DataFrame df = getSession().sql("select array_construct(a,b,c) as arr from values(1,2,3) as T(a,b,c)");
+   * df.select(Functions.size(Functions.col("arr"))).show();
+   * -------------------------
+   * |"ARRAY_SIZE(""ARR"")"  |
+   * -------------------------
+   * |3                      |
+   * -------------------------
+   * }
+ * + * @since 1.14.0 + * @param col The input column name + * @return size of the input ARRAY. + */ + public static Column size(Column col) { + return array_size(col); + } + /** * Calls a user-defined function (UDF) by name. * diff --git a/src/main/scala/com/snowflake/snowpark/functions.scala b/src/main/scala/com/snowflake/snowpark/functions.scala index a7fd9ff0..1cd3eff0 100644 --- a/src/main/scala/com/snowflake/snowpark/functions.scala +++ b/src/main/scala/com/snowflake/snowpark/functions.scala @@ -3140,6 +3140,73 @@ object functions { */ def listagg(col: Column): Column = listagg(col, "", isDistinct = false) + /** + * Returns a Column expression with values sorted in descending order. + * Example: + * {{{ + * val df = session.createDataFrame(Seq(1, 2, 3)).toDF("id") + * df.sort(desc("id")).show() + * + * -------- + * |"ID" | + * -------- + * |3 | + * |2 | + * |1 | + * -------- + * }}} + * + * @since 1.14.0 + * @param colName Column name. + * @return Column object ordered in a descending manner. + */ + def desc(colName: String): Column = col(colName).desc + + /** + * Returns a Column expression with values sorted in ascending order. + * Example: + * {{{ + * val df = session.createDataFrame(Seq(3, 2, 1)).toDF("id") + * df.sort(asc("id")).show() + * + * -------- + * |"ID" | + * -------- + * |1 | + * |2 | + * |3 | + * -------- + * }}} + * @since 1.14.0 + * @param colName Column name. + * @return Column object ordered in an ascending manner. + */ + def asc(colName: String): Column = col(colName).asc + + /** + * Returns the size of the input ARRAY. + * + * If the specified column contains a VARIANT value that contains an ARRAY, the size of the ARRAY + * is returned; otherwise, NULL is returned if the value is not an ARRAY. + * + * Example: + * {{{ + * val df = session.createDataFrame(Seq(Array(1, 2, 3))).toDF("id") + * df.select(size(col("id"))).show() + * + * ------------------------ + * |"ARRAY_SIZE(""ID"")" | + * ------------------------ + * |3 | + * ------------------------ + * }}} + * + * @since 1.14.0 + * @param c Column to get the size. + * @return Size of array column. + */ + def size(c: Column): Column = array_size(c) + /** * Invokes a built-in snowflake function with the specified name and arguments. * Arguments can be of two types diff --git a/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java b/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java index 6ee298d3..2b3b4fc9 100644 --- a/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java +++ b/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java @@ -2764,4 +2764,30 @@ public void any_value() { assert result.length == 1; assert result[0].getInt(0) == 1 || result[0].getInt(0) == 2 || result[0].getInt(0) == 3; } + + @Test + public void test_asc() { + DataFrame df = getSession().sql("select * from values(3),(1),(2) as t(a)"); + Row[] expected = {Row.create(1), Row.create(2), Row.create(3)}; + + checkAnswer(df.sort(Functions.asc("a")), expected, false); + } + + @Test + public void test_desc() { + DataFrame df = getSession().sql("select * from values(2),(1),(3) as t(a)"); + Row[] expected = {Row.create(3), Row.create(2), Row.create(1)}; + + checkAnswer(df.sort(Functions.desc("a")), expected, false); + } + + @Test + public void test_size() { + DataFrame df = getSession() + .sql( + "select array_construct(a,b,c) as arr from values(1,2,3) as T(a,b,c)"); + Row[] expected = {Row.create(3)}; + + checkAnswer(df.select(Functions.size(Functions.col("arr"))), expected, false); + } } diff --git a/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala b/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala index e473de12..770e7c7d 100644 --- a/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala @@ -2178,6 +2178,35 @@ trait FunctionSuite extends TestData { sort = false) } + test("desc column order") { + val input = Seq(1, 2, 3).toDF("data") + val expected = Seq(3, 2, 1).toDF("data") + + val inputStr = Seq("a", "b", "c").toDF("dataStr") + val expectedStr = Seq("c", "b", "a").toDF("dataStr") + + checkAnswer(input.sort(desc("data")), expected, sort = false) + checkAnswer(inputStr.sort(desc("dataStr")), expectedStr, sort = false) + } + + test("asc column order") { + val input = Seq(3, 2, 1).toDF("data") + val expected = Seq(1, 2, 3).toDF("data") + + val inputStr = Seq("c", "b", "a").toDF("dataStr") + val expectedStr = Seq("a", "b", "c").toDF("dataStr") + + checkAnswer(input.sort(asc("data")), expected, sort = false) + checkAnswer(inputStr.sort(asc("dataStr")), expectedStr, sort = false) + } + + test("column array size") { + + val input = Seq(Array(1, 2, 3)).toDF("size") + val expected = Seq((3)).toDF("size") + checkAnswer(input.select(size(col("size"))), expected, sort = false) + } + } class EagerFunctionSuite extends FunctionSuite with EagerSession From f7647e427315ae5829348b5daaa201db52195e16 Mon Sep 17 00:00:00 2001 From: Ganesh Mahadevan Date: Thu, 8 Aug 2024 12:09:57 -0500 Subject: [PATCH 2/5] SNOW-802269 - Add missing scala and java functions (#139) * Merge changes from fork to feature branch (#138) * add java and scala size and ordering functions * add scala unit test for ordering and size function * update comments and add example * add java test cases * fix comments * add expr function for java and scala * add formatting functions scala * remove format_string func --------- Co-authored-by: sfc-gh-mrojas * add java function and test case * fix test case * fix test file import * fix test file import * fix docs --------- Co-authored-by: sfc-gh-mrojas --- .../snowflake/snowpark_java/Functions.java | 95 +++++++++++++++ .../com/snowflake/snowpark/functions.scala | 112 +++++++++++++++++- .../snowpark_test/JavaFunctionSuite.java | 31 +++++ .../snowpark_test/FunctionSuite.scala | 34 ++++++ 4 files changed, 266 insertions(+), 6 deletions(-) diff --git a/src/main/java/com/snowflake/snowpark_java/Functions.java b/src/main/java/com/snowflake/snowpark_java/Functions.java index dbadd87b..8daaf9fc 100644 --- a/src/main/java/com/snowflake/snowpark_java/Functions.java +++ b/src/main/java/com/snowflake/snowpark_java/Functions.java @@ -3957,6 +3957,101 @@ public static Column size(Column col) { return array_size(col); } + /** + * Creates a Column expression from row SQL text. + * + *

Note that the function does not interpret or check the SQL text. + * + *

{@code
+   * DataFrame df = getSession().sql("select a from values(1), (2), (3) as T(a)");
+   * df.filter(Functions.expr("a > 2")).show();
+   * -------
+   * |"A"  |
+   * -------
+   * |3    |
+   * -------
+   * }
+ * + * @since 1.14.0 + * @param s The SQL text + * @return column expression from input statement. + */ + public static Column expr(String s) { + return sqlExpr(s); + } + + /** + * Returns an ARRAY constructed from zero, one, or more inputs. + * + *
{@code
+   * DataFrame df = getSession().sql("select * from values(1,2,3) as T(a,b,c)");
+   * df.select(Functions.array(df.col("a"), df.col("b"), df.col("c")).as("array")).show();
+   *-----------
+   * |"ARRAY"  |
+   * -----------
+   * |[        |
+   * |  1,     |
+   * |  2,     |
+   * |  3      |
+   * |]        |
+   * -----------
+   * }
+ * + * @since 1.14.0 + * @param cols The input column names + * @return Column object as array. + */ + public static Column array(Column... cols) { return array_construct(cols); } + + /** + * + * Converts an input expression into the corresponding date in the specified date format. + * + *
{@code
+   * DataFrame df = getSession().sql("select * from values ('2023-10-10'), ('2022-05-15') as T(a)");
+   * df.select(Functions.date_format(df.col("a"), "YYYY/MM/DD").as("formatted_date")).show();
+   * --------------------
+   * |"FORMATTED_DATE"  |
+   * --------------------
+   * |2023/10/10        |
+   * |2022/05/15        |
+   * --------------------
+   * }
+ * + * @since 1.14.0 + * @param col The input date column name + * @param s string format + * @return formatted column object. + */ + public static Column date_format(Column col, String s) { + return new Column(functions.date_format(col.toScalaColumn(), s)); + } + + /** + * Returns the last value of the column in a group. + * + *
{@code
+   * DataFrame df = getSession().sql("select * from values (5, 'a', 10), (5, 'b', 20),\n" +
+   *             "    (3, 'd', 15), (3, 'e', 40) as T(grade,name,score)");
+   * df.select(Functions.last(df.col("name")).over(Window.partitionBy(df.col("grade")).orderBy(df.col("score").desc()))).show();
+   * ----------------
+   * |"LAST_VALUE"  |
+   * ----------------
+   * |a             |
+   * |a             |
+   * |d             |
+   * |d             |
+   * ----------------
+   * }
+ * + * @since 1.14.0 + * @param col The input column to get last value + * @return column object from last function. + */ + public static Column last(Column col) { + return new Column(functions.last(col.toScalaColumn())); + } + /** * Calls a user-defined function (UDF) by name. * diff --git a/src/main/scala/com/snowflake/snowpark/functions.scala b/src/main/scala/com/snowflake/snowpark/functions.scala index 1cd3eff0..662a00c4 100644 --- a/src/main/scala/com/snowflake/snowpark/functions.scala +++ b/src/main/scala/com/snowflake/snowpark/functions.scala @@ -2,12 +2,8 @@ package com.snowflake.snowpark import com.snowflake.snowpark.internal.analyzer._ import com.snowflake.snowpark.internal.ScalaFunctions._ -import com.snowflake.snowpark.internal.{ - ErrorMessage, - OpenTelemetry, - UDXRegistrationHandler, - Utils -} +import com.snowflake.snowpark.internal.{ErrorMessage, OpenTelemetry, UDXRegistrationHandler, Utils} +import com.snowflake.snowpark.types.TimestampType import scala.reflect.runtime.universe.TypeTag import scala.util.Random @@ -3207,6 +3203,110 @@ object functions { */ def size(c: Column): Column = array_size(c) + /** + * Creates a [[Column]] expression from raw SQL text. + * + * Note that the function does not interpret or check the SQL text. + * + * Example: + * {{{ + * val df = session.createDataFrame(Seq(Array(1, 2, 3))).toDF("id") + * df.filter(expr("id > 2")).show() + * + * -------- + * |"ID" | + * -------- + * |3 | + * -------- + * }}} + * + * @since 1.14.0 + * @param s SQL Expression as text. + * @return Converted SQL Expression. + */ + def expr(s: String): Column = sqlExpr(s) + + /** + * Returns an ARRAY constructed from zero, one, or more inputs. + * + * Example: + * {{{ + * val df = session.createDataFrame(Seq((1, 2, 3), (4, 5, 6))).toDF("id") + * df.select(array(col("a"), col("b")).as("id")).show() + * + * -------- + * |"ID" | + * -------- + * |[ | + * | 1, | + * | 2 | + * |] | + * |[ | + * | 4, | + * | 5 | + * |] | + * -------- + * }}} + * + * @since 1.14.0 + * @param c Columns to build the array. + * @return The array. + */ + def array(c: Column*): Column = array_construct(c: _*) + + /** + * Converts an input expression into the corresponding date in the specified date format. + * Example: + * {{{ + * val df = Seq("2023-10-10", "2022-05-15", null.asInstanceOf[String]).toDF("date") + * df.select(date_format(col("date"), "YYYY/MM/DD").as("formatted_date")).show() + * + * -------------------- + * |"FORMATTED_DATE" | + * -------------------- + * |2023/10/10 | + * |2022/05/15 | + * |NULL | + * -------------------- + * + * }}} + * + * @since 1.14.0 + * @param c Column to format to date. + * @param s Date format. + * @return Column object. + */ + def date_format(c: Column, s: String): Column = + builtin("to_varchar")(c.cast(TimestampType), s.replace("mm", "mi")) + + /** + * Returns the last value of the column in a group. + * Example + * {{{ + * val df = session.createDataFrame(Seq((5, "a", 10), + * (5, "b", 20), + * (3, "d", 15), + * (3, "e", 40))).toDF("grade", "name", "score") + * val window = Window.partitionBy(col("grade")).orderBy(col("score").desc) + * df.select(last(col("name")).over(window)).show() + * + * --------------------- + * |"LAST_SCORE_NAME" | + * --------------------- + * |a | + * |a | + * |d | + * |d | + * --------------------- + * }}} + * + * @since 1.14.0 + * @param c Column to obtain last value. + * @return Column object. + */ + def last(c: Column): Column = + builtin("LAST_VALUE")(c) + /** * Invokes a built-in snowflake function with the specified name and arguments. * Arguments can be of two types diff --git a/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java b/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java index 2b3b4fc9..edddec2e 100644 --- a/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java +++ b/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java @@ -2790,4 +2790,35 @@ public void test_size() { checkAnswer(df.select(Functions.size(Functions.col("arr"))), expected, false); } + + @Test + public void test_expr() { + DataFrame df = getSession().sql("select * from values(1), (2), (3) as T(a)"); + Row[] expected = {Row.create(3)}; + checkAnswer(df.filter(Functions.expr("a > 2")), expected, false); + } + + @Test + public void test_array() { + DataFrame df = getSession().sql("select * from values(1,2,3) as T(a,b,c)"); + Row[] expected = {Row.create("[\n 1,\n 2,\n 3\n]")}; + checkAnswer(df.select(Functions.array(df.col("a"), df.col("b"), df.col("c"))), expected, false); + } + + @Test + public void date_format() { + DataFrame df = getSession().sql("select * from values ('2023-10-10'), ('2022-05-15') as T(a)"); + Row[] expected = {Row.create("2023/10/10"), Row.create("2022/05/15")}; + + checkAnswer(df.select(Functions.date_format(df.col("a"), "YYYY/MM/DD")), expected, false); + } + + @Test + public void last() { + DataFrame df = getSession().sql("select * from values (5, 'a', 10), (5, 'b', 20),\n" + + " (3, 'd', 15), (3, 'e', 40) as T(grade,name,score)"); + + Row[] expected = {Row.create("a"), Row.create("a"), Row.create("d"), Row.create("d")}; + checkAnswer(df.select(Functions.last(df.col("name")).over(Window.partitionBy(df.col("grade")).orderBy(df.col("score").desc()))), expected, false); + } } diff --git a/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala b/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala index 770e7c7d..806a6ff8 100644 --- a/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala @@ -2206,6 +2206,40 @@ trait FunctionSuite extends TestData { val expected = Seq((3)).toDF("size") checkAnswer(input.select(size(col("size"))), expected, sort = false) } + + test("expr function") { + + val input = Seq(1, 2, 3).toDF("id") + val expected = Seq((3)).toDF("id") + checkAnswer(input.filter(expr("id > 2")), expected, sort = false) + } + + test("array function") { + + val input = Seq((1, 2, 3), (4, 5, 6)).toDF("a", "b", "c") + val expected = Seq(Array(1, 2), Array(4, 5)).toDF("id") + checkAnswer(input.select(array(col("a"), col("b")).as("id")), expected, sort = false) + } + + test("date format function") { + + val input = Seq("2023-10-10", "2022-05-15").toDF("date") + val expected = Seq("2023/10/10", "2022/05/15").toDF("formatted_date") + + checkAnswer(input.select(date_format(col("date"), "YYYY/MM/DD").as("formatted_date")), + expected, sort = false) + } + + test("last function") { + + val input = Seq((5, "a", 10), (5, "b", 20), + (3, "d", 15), (3, "e", 40)).toDF("grade", "name", "score") + val window = Window.partitionBy(col("grade")).orderBy(col("score").desc) + val expected = Seq("a", "a", "d", "d").toDF("last_score_name") + + checkAnswer(input.select(last(col("name")).over(window).as("last_score_name")), + expected, sort = false) + } } From 14770b71bf8f8fd822adbdde3bdeb8d293d57ce5 Mon Sep 17 00:00:00 2001 From: Bing Li <63471091+sfc-gh-bli@users.noreply.github.com> Date: Thu, 8 Aug 2024 17:20:29 -0700 Subject: [PATCH 3/5] SNOW-1619170 Create New Github Action to Check the Code Format (#140) * add format checker * chmod * create new github action * reformat --- .github/workflows/code-format-check.yml | 20 +++++++++++++++++++ scripts/format_checker.sh | 11 ++++++++++ .../snowflake/snowpark_java/Functions.java | 7 ++++--- .../com/snowflake/snowpark/functions.scala | 11 +++++++--- .../snowpark_test/JavaFunctionSuite.java | 19 ++++++++++++------ .../snowpark_test/FunctionSuite.scala | 18 ++++++++++------- 6 files changed, 67 insertions(+), 19 deletions(-) create mode 100644 .github/workflows/code-format-check.yml create mode 100755 scripts/format_checker.sh diff --git a/.github/workflows/code-format-check.yml b/.github/workflows/code-format-check.yml new file mode 100644 index 00000000..ac84376d --- /dev/null +++ b/.github/workflows/code-format-check.yml @@ -0,0 +1,20 @@ +name: Code Format Check +on: + push: + branches: [ main ] + pull_request: + branches: '**' + +jobs: + build: + runs-on: ubuntu-latest + steps: + - name: Checkout Code + uses: actions/checkout@v2 + - name: Install Java + uses: actions/setup-java@v1 + with: + java-version: 1.8 + - name: Check Format + run: scripts/format_checker.sh + \ No newline at end of file diff --git a/scripts/format_checker.sh b/scripts/format_checker.sh new file mode 100755 index 00000000..3b4f9af4 --- /dev/null +++ b/scripts/format_checker.sh @@ -0,0 +1,11 @@ +#!/bin/bash -ex + +mvn clean compile + +if [ -z "$(git status --porcelain)" ]; then + echo "Code Format Check: Passed!" +else + echo "Code Format Check: Failed!" + echo "Run 'mvn clean compile' to reformat" + exit 1 +fi diff --git a/src/main/java/com/snowflake/snowpark_java/Functions.java b/src/main/java/com/snowflake/snowpark_java/Functions.java index 8daaf9fc..65f50020 100644 --- a/src/main/java/com/snowflake/snowpark_java/Functions.java +++ b/src/main/java/com/snowflake/snowpark_java/Functions.java @@ -3986,7 +3986,7 @@ public static Column expr(String s) { *
{@code
    * DataFrame df = getSession().sql("select * from values(1,2,3) as T(a,b,c)");
    * df.select(Functions.array(df.col("a"), df.col("b"), df.col("c")).as("array")).show();
-   *-----------
+   * -----------
    * |"ARRAY"  |
    * -----------
    * |[        |
@@ -4001,10 +4001,11 @@ public static Column expr(String s) {
    * @param cols The input column names
    * @return Column object as array.
    */
-  public static Column array(Column... cols) { return array_construct(cols); }
+  public static Column array(Column... cols) {
+    return array_construct(cols);
+  }
 
   /**
-   *
    * Converts an input expression into the corresponding date in the specified date format.
    *
    * 
{@code
diff --git a/src/main/scala/com/snowflake/snowpark/functions.scala b/src/main/scala/com/snowflake/snowpark/functions.scala
index 662a00c4..5c6f599f 100644
--- a/src/main/scala/com/snowflake/snowpark/functions.scala
+++ b/src/main/scala/com/snowflake/snowpark/functions.scala
@@ -2,7 +2,12 @@ package com.snowflake.snowpark
 
 import com.snowflake.snowpark.internal.analyzer._
 import com.snowflake.snowpark.internal.ScalaFunctions._
-import com.snowflake.snowpark.internal.{ErrorMessage, OpenTelemetry, UDXRegistrationHandler, Utils}
+import com.snowflake.snowpark.internal.{
+  ErrorMessage,
+  OpenTelemetry,
+  UDXRegistrationHandler,
+  Utils
+}
 import com.snowflake.snowpark.types.TimestampType
 
 import scala.reflect.runtime.universe.TypeTag
@@ -3151,7 +3156,7 @@ object functions {
    * |1     |
    * --------
    * }}}
- *
+   *
    * @since 1.14.0
    * @param colName Column name.
    * @return Column object ordered in a descending manner.
@@ -3299,7 +3304,7 @@ object functions {
    * |d                  |
    * ---------------------
    * }}}
- *
+   *
    * @since 1.14.0
    * @param c Column to obtain last value.
    * @return Column object.
diff --git a/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java b/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java
index edddec2e..624ea481 100644
--- a/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java
+++ b/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java
@@ -2783,9 +2783,8 @@ public void test_desc() {
 
   @Test
   public void test_size() {
-    DataFrame df = getSession()
-            .sql(
-                    "select array_construct(a,b,c) as arr from values(1,2,3) as T(a,b,c)");
+    DataFrame df =
+        getSession().sql("select array_construct(a,b,c) as arr from values(1,2,3) as T(a,b,c)");
     Row[] expected = {Row.create(3)};
 
     checkAnswer(df.select(Functions.size(Functions.col("arr"))), expected, false);
@@ -2815,10 +2814,18 @@ public void date_format() {
 
   @Test
   public void last() {
-    DataFrame df = getSession().sql("select * from values (5, 'a', 10), (5, 'b', 20),\n" +
-            "    (3, 'd', 15), (3, 'e', 40) as T(grade,name,score)");
+    DataFrame df =
+        getSession()
+            .sql(
+                "select * from values (5, 'a', 10), (5, 'b', 20),\n"
+                    + "    (3, 'd', 15), (3, 'e', 40) as T(grade,name,score)");
 
     Row[] expected = {Row.create("a"), Row.create("a"), Row.create("d"), Row.create("d")};
-    checkAnswer(df.select(Functions.last(df.col("name")).over(Window.partitionBy(df.col("grade")).orderBy(df.col("score").desc()))), expected, false);
+    checkAnswer(
+        df.select(
+            Functions.last(df.col("name"))
+                .over(Window.partitionBy(df.col("grade")).orderBy(df.col("score").desc()))),
+        expected,
+        false);
   }
 }
diff --git a/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala b/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala
index 806a6ff8..8a89d87b 100644
--- a/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala
+++ b/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala
@@ -2206,7 +2206,7 @@ trait FunctionSuite extends TestData {
     val expected = Seq((3)).toDF("size")
     checkAnswer(input.select(size(col("size"))), expected, sort = false)
   }
-  
+
   test("expr function") {
 
     val input = Seq(1, 2, 3).toDF("id")
@@ -2226,19 +2226,23 @@ trait FunctionSuite extends TestData {
     val input = Seq("2023-10-10", "2022-05-15").toDF("date")
     val expected = Seq("2023/10/10", "2022/05/15").toDF("formatted_date")
 
-    checkAnswer(input.select(date_format(col("date"), "YYYY/MM/DD").as("formatted_date")),
-                expected, sort = false)
+    checkAnswer(
+      input.select(date_format(col("date"), "YYYY/MM/DD").as("formatted_date")),
+      expected,
+      sort = false)
   }
 
   test("last function") {
 
-    val input = Seq((5, "a", 10), (5, "b", 20),
-                    (3, "d", 15), (3, "e", 40)).toDF("grade", "name", "score")
+    val input =
+      Seq((5, "a", 10), (5, "b", 20), (3, "d", 15), (3, "e", 40)).toDF("grade", "name", "score")
     val window = Window.partitionBy(col("grade")).orderBy(col("score").desc)
     val expected = Seq("a", "a", "d", "d").toDF("last_score_name")
 
-    checkAnswer(input.select(last(col("name")).over(window).as("last_score_name")),
-      expected, sort = false)
+    checkAnswer(
+      input.select(last(col("name")).over(window).as("last_score_name")),
+      expected,
+      sort = false)
   }
 
 }

From 833ef6d2526026788fa19820a5d715c25da7327b Mon Sep 17 00:00:00 2001
From: Ganesh Mahadevan 
Date: Thu, 15 Aug 2024 14:33:15 -0500
Subject: [PATCH 4/5] Snow 802269 add log10, log1p, base64 and unbase64
 functions (#143)

* add log functions

* add scala log function and test case

* add java log function and test case

* update docs

* fix format

* add base64 and unbase64 scala and java function
---
 .../snowflake/snowpark_java/Functions.java    | 127 ++++++++++++++++++
 .../com/snowflake/snowpark/functions.scala    | 121 +++++++++++++++++
 .../snowpark_test/JavaFunctionSuite.java      |  46 +++++++
 .../snowpark_test/FunctionSuite.scala         |  37 ++++-
 4 files changed, 330 insertions(+), 1 deletion(-)

diff --git a/src/main/java/com/snowflake/snowpark_java/Functions.java b/src/main/java/com/snowflake/snowpark_java/Functions.java
index 65f50020..1d18a91c 100644
--- a/src/main/java/com/snowflake/snowpark_java/Functions.java
+++ b/src/main/java/com/snowflake/snowpark_java/Functions.java
@@ -4053,6 +4053,133 @@ public static Column last(Column col) {
     return new Column(functions.last(col.toScalaColumn()));
   }
 
+  /**
+   * Computes the logarithm of the given value in base 10.
+   *
+   * 
{@code
+   * DataFrame df = getSession().sql("select * from values (100) as T(a)");
+   * df.select(Functions.log10(df.col("a")).as("log10")).show();
+   * -----------
+   * |"LOG10"  |
+   * -----------
+   * |2.0      |
+   * -----------
+   * }
+ * + * @since 1.14.0 + * @param col The input column to get logarithm value + * @return column object from logarithm function. + */ + public static Column log10(Column col) { + return new Column(functions.log10(col.toScalaColumn())); + } + + /** + * Computes the logarithm of the given value in base 10. + * + *
{@code
+   * DataFrame df = getSession().sql("select * from values (100) as T(a)");
+   * df.select(Functions.log10("a").as("log10")).show();
+   * -----------
+   * |"LOG10"  |
+   * -----------
+   * |2.0      |
+   * -----------
+   * }
+ * + * @since 1.14.0 + * @param s The input columnName in string to get logarithm value + * @return column object from logarithm function. + */ + public static Column log10(String s) { + return new Column(functions.log10(s)); + } + + /** + * Computes the logarithm of the given value in base 10. + * + *
{@code
+   * DataFrame df = getSession().sql("select * from values (0.1) as T(a)");
+   * df.select(Functions.log1p(df.col("a")).as("log1p")).show();
+   * -----------------------
+   * |"LOG1P"              |
+   * -----------------------
+   * |0.09531017980432493  |
+   * -----------------------
+   * }
+ * + * @since 1.14.0 + * @param col The input column to get logarithm value + * @return column object from logarithm function. + */ + public static Column log1p(Column col) { + return new Column(functions.log1p(col.toScalaColumn())); + } + + /** + * Computes the logarithm of the given value in base 10. + * + *
{@code
+   * DataFrame df = getSession().sql("select * from values (0.1) as T(a)");
+   * df.select(Functions.log1p("a").as("log1p")).show();
+   * -----------------------
+   * |"LOG1P"              |
+   * -----------------------
+   * |0.09531017980432493  |
+   * -----------------------
+   * }
+ * + * @since 1.14.0 + * @param s The input columnName in string to get logarithm value + * @return column object from logarithm function. + */ + public static Column log1p(String s) { + return new Column(functions.log1p(s)); + } + + /** + * Computes the BASE64 encoding of a column and returns it as a string column. This is the reverse + * of unbase64. + * + *
{@code
+   * DataFrame df = getSession().sql("select * from values ('test') as T(a)");
+   * df.select(Functions.base64(Functions.col("a")).as("base64")).show();
+   * ------------
+   * |"BASE64"  |
+   * ------------
+   * |dGVzdA==  |
+   * ------------
+   * }
+ * + * @since 1.14.0 + * @param c ColumnName to apply base64 operation + * @return base64 encoded value of the given input column. + */ + public static Column base64(Column c) { + return new Column(functions.base64(c.toScalaColumn())); + } + + /** + * Decodes a BASE64 encoded string column and returns it as a column. + * + *
{@code
+   * DataFrame df = getSession().sql("select * from values ('dGVzdA==') as T(a)");
+   * df.select(Functions.unbase64(Functions.col("a")).as("unbase64")).show();
+   * --------------
+   * |"UNBASE64"  |
+   * --------------
+   * |test        |
+   * --------------
+   * }
+ * + * @since 1.14.0 + * @param c ColumnName to apply unbase64 operation + * @return the decoded value of the given encoded value. + */ + public static Column unbase64(Column c) { + return new Column(functions.unbase64(c.toScalaColumn())); + } + /** * Calls a user-defined function (UDF) by name. * diff --git a/src/main/scala/com/snowflake/snowpark/functions.scala b/src/main/scala/com/snowflake/snowpark/functions.scala index 5c6f599f..160c3112 100644 --- a/src/main/scala/com/snowflake/snowpark/functions.scala +++ b/src/main/scala/com/snowflake/snowpark/functions.scala @@ -3312,6 +3312,127 @@ object functions { def last(c: Column): Column = builtin("LAST_VALUE")(c) + /** + * Computes the logarithm of the given value in base 10. + * Example + * {{{ + * val df = session.createDataFrame(Seq(100)).toDF("a") + * df.select(log10(col("a"))).show() + * + * ----------- + * |"LOG10" | + * ----------- + * |2.0 | + * ----------- + * }}} + * + * @since 1.14.0 + * @param c Column to apply logarithm operation + * @return log10 of the given column + */ + def log10(c: Column): Column = builtin("LOG")(10, c) + + /** + * Computes the logarithm of the given column in base 10. + * Example + * {{{ + * val df = session.createDataFrame(Seq(100)).toDF("a") + * df.select(log10("a"))).show() + * ----------- + * |"LOG10" | + * ----------- + * |2.0 | + * ----------- + * + * }}} + * + * @since 1.14.0 + * @param columnName ColumnName in String to apply logarithm operation + * @return log10 of the given column + */ + def log10(columnName: String): Column = builtin("LOG")(10, col(columnName)) + + /** + * Computes the natural logarithm of the given value plus one. + *Example + * {{{ + * val df = session.createDataFrame(Seq(0.1)).toDF("a") + * df.select(log1p(col("a")).as("log1p")).show() + * ----------------------- + * |"LOG1P" | + * ----------------------- + * |0.09531017980432493 | + * ----------------------- + * + * }}} + * + * @since 1.14.0 + * @param c Column to apply logarithm operation + * @return the natural logarithm of the given value plus one. + */ + def log1p(c: Column): Column = callBuiltin("ln", lit(1) + c) + + /** + * Computes the natural logarithm of the given value plus one. + *Example + * {{{ + * val df = session.createDataFrame(Seq(0.1)).toDF("a") + * df.select(log1p("a").as("log1p")).show() + * ----------------------- + * |"LOG1P" | + * ----------------------- + * |0.09531017980432493 | + * ----------------------- + * + * }}} + * + * @since 1.14.0 + * @param columnName ColumnName in String to apply logarithm operation + * @return the natural logarithm of the given value plus one. + */ + def log1p(columnName: String): Column = callBuiltin("ln", lit(1) + col(columnName)) + + /** + * Computes the BASE64 encoding of a column and returns it as a string column. + * This is the reverse of unbase64. + *Example + * {{{ + * val df = session.createDataFrame(Seq("test")).toDF("a") + * df.select(base64(col("a")).as("base64")).show() + * ------------ + * |"BASE64" | + * ------------ + * |dGVzdA== | + * ------------ + * + * }}} + * + * @since 1.14.0 + * @param columnName ColumnName to apply base64 operation + * @return base64 encoded value of the given input column. + */ + def base64(col: Column): Column = callBuiltin("BASE64_ENCODE", col) + + /** + * Decodes a BASE64 encoded string column and returns it as a column. + *Example + * {{{ + * val df = session.createDataFrame(Seq("dGVzdA==")).toDF("a") + * df.select(unbase64(col("a")).as("unbase64")).show() + * -------------- + * |"UNBASE64" | + * -------------- + * |test | + * -------------- + * + * }}} + * + * @since 1.14.0 + * @param columnName ColumnName to apply unbase64 operation + * @return the decoded value of the given encoded value. + */ + def unbase64(col: Column): Column = callBuiltin("BASE64_DECODE_STRING", col) + /** * Invokes a built-in snowflake function with the specified name and arguments. * Arguments can be of two types diff --git a/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java b/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java index 624ea481..05e38211 100644 --- a/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java +++ b/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java @@ -2828,4 +2828,50 @@ public void last() { expected, false); } + + @Test + public void log10_col() { + DataFrame df = getSession().sql("select * from values (100) as T(a)"); + Row[] expected = {Row.create(2.0)}; + + checkAnswer(df.select(Functions.log10(df.col("a"))), expected, false); + } + + @Test + public void log10_str() { + DataFrame df = getSession().sql("select * from values (100) as T(a)"); + Row[] expected = {Row.create(2.0)}; + + checkAnswer(df.select(Functions.log10("a")), expected, false); + } + + @Test + public void log1p_col() { + DataFrame df = getSession().sql("select * from values (0.1) as T(a)"); + Row[] expected = {Row.create(0.09531017980432493)}; + + checkAnswer(df.select(Functions.log1p(df.col("a"))), expected, false); + } + + @Test + public void log1p_str() { + DataFrame df = getSession().sql("select * from values (0.1) as T(a)"); + Row[] expected = {Row.create(0.09531017980432493)}; + + checkAnswer(df.select(Functions.log1p("a")), expected, false); + } + + @Test + public void base64() { + DataFrame df = getSession().sql("select * from values ('test') as T(a)"); + Row[] expected = {Row.create("dGVzdA==")}; + checkAnswer(df.select(Functions.base64(Functions.col("a"))), expected, false); + } + + @Test + public void unbase64() { + DataFrame df = getSession().sql("select * from values ('dGVzdA==') as T(a)"); + Row[] expected = {Row.create("test")}; + checkAnswer(df.select(Functions.unbase64(Functions.col("a"))), expected, false); + } } diff --git a/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala b/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala index 8a89d87b..3db8fd02 100644 --- a/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala @@ -2233,7 +2233,6 @@ trait FunctionSuite extends TestData { } test("last function") { - val input = Seq((5, "a", 10), (5, "b", 20), (3, "d", 15), (3, "e", 40)).toDF("grade", "name", "score") val window = Window.partitionBy(col("grade")).orderBy(col("score").desc) @@ -2245,6 +2244,42 @@ trait FunctionSuite extends TestData { sort = false) } + test("log10 Column function") { + val input = session.createDataFrame(Seq(100)).toDF("a") + val expected = Seq(2.0).toDF("log10") + checkAnswer(input.select(log10(col("a")).as("log10")), expected, sort = false) + } + + test("log10 String function") { + val input = session.createDataFrame(Seq("100")).toDF("a") + val expected = Seq(2.0).toDF("log10") + checkAnswer(input.select(log10("a").as("log10")), expected, sort = false) + } + + test("log1p Column function") { + val input = session.createDataFrame(Seq(0.1)).toDF("a") + val expected = Seq(0.09531017980432493).toDF("log1p") + checkAnswer(input.select(log1p(col("a")).as("log10")), expected, sort = false) + } + + test("log1p String function") { + val input = session.createDataFrame(Seq(0.1)).toDF("a") + val expected = Seq(0.09531017980432493).toDF("log1p") + checkAnswer(input.select(log1p("a").as("log1p")), expected, sort = false) + } + + test("base64 function") { + val input = session.createDataFrame(Seq("test")).toDF("a") + val expected = Seq("dGVzdA==").toDF("base64") + checkAnswer(input.select(base64(col("a")).as("base64")), expected, sort = false) + } + + test("unbase64 function") { + val input = session.createDataFrame(Seq("dGVzdA==")).toDF("a") + val expected = Seq("test").toDF("unbase64") + checkAnswer(input.select(unbase64(col("a")).as("unbase64")), expected, sort = false) + } + } class EagerFunctionSuite extends FunctionSuite with EagerSession From a1babb3ed7ce3bbcab16717e68f3a9b7ccf07c7b Mon Sep 17 00:00:00 2001 From: Shyamala Jayabalan Date: Tue, 20 Aug 2024 12:47:36 -0400 Subject: [PATCH 5/5] SNOW-802269 Add regextract signum subindex collectlist functions (#142) * Sfc gh sjayabalan sma regextract signum subindex collectlist (#141) * Added regexp_extract,signum,substring_index,collect_list 1) Added regexp_extract,signum,substring_index,collect_list to functions.scala . 2) Added test cases for the same * Added examples and updated the description * Fixed format * formatted the comments * Added java functions and unit test cases for java * Added sign function * Modified the alignment * Added examples * adjusted comments * Update Functions.java --------- Co-authored-by: sfc-gh-mrojas * Reformatted * Modified version * added comment * modified description * modified description * Modified comment section and changed regexp in substring_index * Modified test cases --------- Co-authored-by: sfc-gh-mrojas --- .../snowflake/snowpark_java/Functions.java | 122 ++++++++++- .../com/snowflake/snowpark/functions.scala | 194 +++++++++++++++++- .../snowpark_test/JavaFunctionSuite.java | 53 +++++ .../snowpark_test/FunctionSuite.scala | 44 ++++ 4 files changed, 411 insertions(+), 2 deletions(-) diff --git a/src/main/java/com/snowflake/snowpark_java/Functions.java b/src/main/java/com/snowflake/snowpark_java/Functions.java index 1d18a91c..ead78cb4 100644 --- a/src/main/java/com/snowflake/snowpark_java/Functions.java +++ b/src/main/java/com/snowflake/snowpark_java/Functions.java @@ -3882,7 +3882,127 @@ public static Column listagg(Column col) { } /** - * Returns a Column expression with values sorted in descending order. + * Signature - snowflake.snowpark.functions.regexp_extract (value: Union[Column, str], regexp: + * Union[Column, str], idx: int) Column Extract a specific group matched by a regex, from the + * specified string column. If the regex did not match, or the specified group did not match, an + * empty string is returned. Example: + * + *
{@code
+   * from snowflake.snowpark.functions import regexp_extract
+   * df = session.createDataFrame([["id_20_30", 10], ["id_40_50", 30]], ["id", "age"])
+   * df.select(regexp_extract("id", r"(\d+)", 1).alias("RES")).show()
+   *    ---------
+   *     |"RES"  |
+   *     ---------
+   *     |20     |
+   *     |40     |
+   *     ---------
+   * }
+ * + * @since 1.14.0 + * @param col Column. + * @param exp String + * @param position Integer. + * @param Occurences Integer. + * @param grpIdx Integer. + * @return Column object. + */ + public static Column regexp_extract( + Column col, String exp, Integer position, Integer Occurences, Integer grpIdx) { + return new Column( + com.snowflake.snowpark.functions.regexp_extract( + col.toScalaColumn(), exp, position, Occurences, grpIdx)); + } + + /** + * Returns the sign of its argument: + * + *

- -1 if the argument is negative. - 1 if it is positive. - 0 if it is 0. + * + *

Args: col: The column to evaluate its sign Example:: * + * + *

{@code df =
+   * session.create_dataframe([(-2, 2, 0)], ["a", "b", "c"]) >>>
+   * df.select(sign("a").alias("a_sign"), sign("b").alias("b_sign"),
+   * sign("c").alias("c_sign")).show()
+   *   ----------------------------------
+   *     |"A_SIGN"  |"B_SIGN"  |"C_SIGN"  |
+   *     ----------------------------------
+   *     |-1        |1         |0         |
+   *     ----------------------------------
+   * }
+ * + * @since 1.14.0 + * @param col Column to calculate the sign. + * @return Column object. + */ + public static Column signum(Column col) { + return new Column(com.snowflake.snowpark.functions.signum(col.toScalaColumn())); + } + + /** + * Returns the sign of its argument: + * + *

- -1 if the argument is negative. - 1 if it is positive. - 0 if it is 0. + * + *

Args: col: The column to evaluate its sign Example:: + * + *

{@code df =
+   * session.create_dataframe([(-2, 2, 0)], ["a", "b", "c"]) >>>
+   * df.select(sign("a").alias("a_sign"), sign("b").alias("b_sign"),
+   * sign("c").alias("c_sign")).show()
+   *   ----------------------------------
+   *     |"A_SIGN"  |"B_SIGN"  |"C_SIGN"  |
+   *     ----------------------------------
+   *     |-1        |1         |0         |
+   *     ----------------------------------
+   * }
+ * + * @since 1.14.0 + * @param col Column to calculate the sign. + * @return Column object. + */ + public static Column sign(Column col) { + return new Column(com.snowflake.snowpark.functions.sign(col.toScalaColumn())); + } + + /** + * Returns the substring from string str before count occurrences of the delimiter delim. If count + * is positive, everything the left of the final delimiter (counting from left) is returned. If + * count is negative, every to the right of the final delimiter (counting from the right) is + * returned. substring_index performs a case-sensitive match when searching for delim. + * + * @param col String. + * @param delim String + * @param count Integer. + * @return Column object. + * @since 1.14.0 + */ + public static Column substring_index(String col, String delim, Integer count) { + return new Column(com.snowflake.snowpark.functions.substring_index(col, delim, count)); + } + + /** + * Returns the input values, pivoted into an ARRAY. If the input is empty, an empty ARRAY is + * returned. + * + *

Example:: + * + *

{@code
+   * df = session.create_dataframe([[1], [2], [3], [1]], schema=["a"])
+   * df.select(array_agg("a", True).alias("result")).show()
+   * "RESULT" [ 1, 2, 3 ]
+   * }
+ * + * @since 1.14.0 + * @param c Column to be collect. + * @return The array. + */ + public static Column collect_list(Column c) { + return new Column(com.snowflake.snowpark.functions.collect_list(c.toScalaColumn())); + } + + /* Returns a Column expression with values sorted in descending order. * *

Example: order column values in descending * diff --git a/src/main/scala/com/snowflake/snowpark/functions.scala b/src/main/scala/com/snowflake/snowpark/functions.scala index 160c3112..48bbadc6 100644 --- a/src/main/scala/com/snowflake/snowpark/functions.scala +++ b/src/main/scala/com/snowflake/snowpark/functions.scala @@ -3142,7 +3142,199 @@ object functions { def listagg(col: Column): Column = listagg(col, "", isDistinct = false) /** - * Returns a Column expression with values sorted in descending order. + + * Signature - snowflake.snowpark.functions.regexp_extract + * (value: Union[Column, str], regexp: Union[Column, str], idx: int) + * Column + * Extract a specific group matched by a regex, from the specified string + * column. If the regex did not match, or the specified group did not match, + * an empty string is returned. + * Example: + * from snowflake.snowpark.functions import regexp_extract + * df = session.createDataFrame([["id_20_30", 10], ["id_40_50", 30]], + * ["id", "age"]) + * df.select(regexp_extract("id", r"(\d+)", 1).alias("RES")).show() + * + * + * --------- + * |"RES" | + * --------- + * |20 | + * |40 | + * --------- + * + * Note: non-greedy tokens such as are not supported + * @since 1.14.0 + * @return Column object. + */ + def regexp_extract( + colName: Column, + exp: String, + position: Int, + Occurences: Int, + grpIdx: Int): Column = { + when(colName.is_null, lit(null)) + .otherwise( + coalesce( + builtin("REGEXP_SUBSTR")( + colName, + lit(exp), + lit(position), + lit(Occurences), + lit("ce"), + lit(grpIdx)), + lit(""))) + } + + /** + * Returns the sign of its argument as mentioned : + * + * - -1 if the argument is negative. + * - 1 if it is positive. + * - 0 if it is 0. + * + * Args: + * col: The column to evaluate its sign + * + * Example:: + * >>> df = session.create_dataframe([(-2, 2, 0)], ["a", "b", "c"]) + * >>> df.select(sign("a").alias("a_sign"), sign("b").alias("b_sign"), + * sign("c").alias("c_sign")).show() + * ---------------------------------- + * |"A_SIGN" |"B_SIGN" |"C_SIGN" | + * ---------------------------------- + * |-1 |1 |0 | + * ---------------------------------- + * + * @since 1.14.0 + * @param e Column to calculate the sign. + * @return Column object. + */ + def sign(colName: Column): Column = { + builtin("SIGN")(colName) + } + + /** + * Returns the sign of its argument: + * + * - -1 if the argument is negative. + * - 1 if it is positive. + * - 0 if it is 0. + * + * Args: + * col: The column to evaluate its sign + * + * Example:: + * >>> df = session.create_dataframe([(-2, 2, 0)], ["a", "b", "c"]) + * >>> df.select(sign("a").alias("a_sign"), sign("b").alias("b_sign"), + * sign("c").alias("c_sign")).show() + * ---------------------------------- + * |"A_SIGN" |"B_SIGN" |"C_SIGN" | + * ---------------------------------- + * |-1 |1 |0 | + * ---------------------------------- + * + * @since 1.14.0 + * @param e Column to calculate the sign. + * @return Column object. + */ + def signum(colName: Column): Column = { + builtin("SIGN")(colName) + } + + /** + * Returns the sign of the given column. Returns either 1 for positive, + * 0 for 0 or + * NaN, -1 for negative and null for null. + * NOTE: if string values are provided snowflake will attempts to cast. + * If it casts correctly, returns the calculation, + * if not an error will be thrown + * @since 1.14.0 + * @param columnName Name of the column to calculate the sign. + * @return Column object. + */ + def signum(columnName: String): Column = { + signum(col(columnName)) + } + + /** + * Returns the substring from string str before count occurrences + * of the delimiter delim. If count is positive, + * everything the left of the final delimiter (counting from left) + * is returned. If count is negative, every to the right of the + * final delimiter (counting from the right) is returned. + * substring_index performs a case-sensitive match when searching for delim. + * @since 1.14.0 + */ + def substring_index(str: String, delim: String, count: Int): Column = { + when( + lit(count) < lit(0), + callBuiltin( + "substring", + lit(str), + callBuiltin( + "regexp_instr", + sqlExpr(s"reverse('${str}')"), + lit(delim), + 1, + abs(lit(count)), + lit(0)))) + .otherwise( + callBuiltin( + "substring", + lit(str), + 1, + callBuiltin("regexp_instr", lit(str), lit(delim), 1, lit(count), 1))) + } + + /** + * + * Returns the input values, pivoted into an ARRAY. If the input is empty, an empty + * ARRAY is returned. + * + * Example:: + * >>> df = session.create_dataframe([[1], [2], [3], [1]], schema=["a"]) + * >>> df.select(array_agg("a", True).alias("result")).show() + * ------------ + * |"RESULT" | + * ------------ + * |[ | + * | 1, | + * | 2, | + * | 3 | + * |] | + * ------------ + * + * @since 1.14.0 + * @param c Column to be collect. + * @return The array. + */ + def collect_list(c: Column): Column = array_agg(c) + + /** + * + * Returns the input values, pivoted into an ARRAY. If the input is empty, an empty + * ARRAY is returned. + * + * Example:: + * >>> df = session.create_dataframe([[1], [2], [3], [1]], schema=["a"]) + * >>> df.select(array_agg("a", True).alias("result")).show() + * ------------ + * |"RESULT" | + * ------------ + * |[ | + * | 1, | + * | 2, | + * | 3 | + * |] | + * ------------ + * @since 1.14.0 + * @param s Column name to be collected. + * @return The array. + */ + def collect_list(s: String): Column = array_agg(col(s)) + + /* Returns a Column expression with values sorted in descending order. * Example: * {{{ * val df = session.createDataFrame(Seq(1, 2, 3)).toDF("id") diff --git a/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java b/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java index 05e38211..00cdbd2b 100644 --- a/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java +++ b/src/test/java/com/snowflake/snowpark_test/JavaFunctionSuite.java @@ -2766,6 +2766,59 @@ public void any_value() { } @Test + public void regexp_extract() { + DataFrame df = getSession().sql("select * from values('A MAN A PLAN A CANAL') as T(a)"); + Row[] expected = {Row.create("MAN")}; + checkAnswer( + df.select(Functions.regexp_extract(df.col("a"), "A\\W+(\\w+)", 1, 1, 1)), expected, false); + Row[] expected2 = {Row.create("PLAN")}; + checkAnswer( + df.select(Functions.regexp_extract(df.col("a"), "A\\W+(\\w+)", 1, 2, 1)), expected2, false); + Row[] expected3 = {Row.create("CANAL")}; + checkAnswer( + df.select(Functions.regexp_extract(df.col("a"), "A\\W+(\\w+)", 1, 3, 1)), expected3, false); + } + + @Test + public void signum() { + DataFrame df = getSession().sql("select * from values(1) as T(a)"); + checkAnswer(df.select(Functions.signum(df.col("a"))), new Row[] {Row.create(1)}, false); + DataFrame df1 = getSession().sql("select * from values(-2) as T(a)"); + checkAnswer(df1.select(Functions.signum(df1.col("a"))), new Row[] {Row.create(-1)}, false); + DataFrame df2 = getSession().sql("select * from values(0) as T(a)"); + checkAnswer(df2.select(Functions.signum(df2.col("a"))), new Row[] {Row.create(0)}, false); + } + + @Test + public void sign() { + DataFrame df = getSession().sql("select * from values(1) as T(a)"); + checkAnswer(df.select(Functions.signum(df.col("a"))), new Row[] {Row.create(1)}, false); + DataFrame df1 = getSession().sql("select * from values(-2) as T(a)"); + checkAnswer(df1.select(Functions.signum(df1.col("a"))), new Row[] {Row.create(-1)}, false); + DataFrame df2 = getSession().sql("select * from values(0) as T(a)"); + checkAnswer(df2.select(Functions.signum(df2.col("a"))), new Row[] {Row.create(0)}, false); + } + + @Test + public void collect_list() { + DataFrame df = getSession().sql("select * from values(1), (2), (3) as T(a)"); + df.select(Functions.collect_list(df.col("a"))).show(); + } + + @Test + public void substring_index() { + DataFrame df = + getSession() + .sql( + "select * from values ('It was the best of times,it was the worst of times') as T(a)"); + checkAnswer( + df.select( + Functions.substring_index( + "It was the best of times,it was the worst of times", "was", 1)), + new Row[] {Row.create("It was ")}, + false); + } + public void test_asc() { DataFrame df = getSession().sql("select * from values(3),(1),(2) as t(a)"); Row[] expected = {Row.create(1), Row.create(2), Row.create(3)}; diff --git a/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala b/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala index 3db8fd02..9658006e 100644 --- a/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala @@ -1090,6 +1090,7 @@ trait FunctionSuite extends TestData { .collect()(0) .getTimestamp(0) .toString == "2020-10-28 13:35:47.001234567") + } test("timestamp_ltz_from_parts") { @@ -2177,6 +2178,49 @@ trait FunctionSuite extends TestData { expected, sort = false) } + test("regexp_extract") { + val data = Seq("A MAN A PLAN A CANAL").toDF("a") + var expected = Seq(Row("MAN")) + checkAnswer( + data.select(regexp_extract(col("a"), "A\\W+(\\w+)", 1, 1, 1)), + expected, + sort = false) + expected = Seq(Row("PLAN")) + checkAnswer( + data.select(regexp_extract(col("a"), "A\\W+(\\w+)", 1, 2, 1)), + expected, + sort = false) + expected = Seq(Row("CANAL")) + checkAnswer( + data.select(regexp_extract(col("a"), "A\\W+(\\w+)", 1, 3, 1)), + expected, + sort = false) + + } + test("signum") { + val df = Seq(1).toDF("a") + checkAnswer(df.select(sign(col("a"))), Seq(Row(1)), sort = false) + val df1 = Seq(-2).toDF("a") + checkAnswer(df1.select(sign(col("a"))), Seq(Row(-1)), sort = false) + val df2 = Seq(0).toDF("a") + checkAnswer(df2.select(sign(col("a"))), Seq(Row(0)), sort = false) + } + test("sign") { + val df = Seq(1).toDF("a") + checkAnswer(df.select(sign(col("a"))), Seq(Row(1)), sort = false) + val df1 = Seq(-2).toDF("a") + checkAnswer(df1.select(sign(col("a"))), Seq(Row(-1)), sort = false) + val df2 = Seq(0).toDF("a") + checkAnswer(df2.select(sign(col("a"))), Seq(Row(0)), sort = false) + } + + test("substring_index") { + val df = Seq("It was the best of times, it was the worst of times").toDF("a") + checkAnswer( + df.select(substring_index("It was the best of times, it was the worst of times", "was", 1)), + Seq(Row("It was ")), + sort = false) + } test("desc column order") { val input = Seq(1, 2, 3).toDF("data")