Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-bli committed Nov 13, 2023
1 parent 11f0266 commit d925e1b
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 22 deletions.
10 changes: 8 additions & 2 deletions src/main/scala/com/snowflake/snowpark/DataFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@ import com.snowflake.snowpark.internal.{Logging, Utils}
import com.snowflake.snowpark.internal.analyzer._
import com.snowflake.snowpark.types._
import com.github.vertical_blank.sqlformatter.SqlFormatter
import com.snowflake.snowpark.internal.Utils.{TempObjectType, getTableFunctionExpression, randomNameForTempObject}
import com.snowflake.snowpark.internal.Utils.{
TempObjectType,
getTableFunctionExpression,
randomNameForTempObject
}

import javax.xml.bind.DatatypeConverter
import scala.collection.JavaConverters._
Expand Down Expand Up @@ -1909,7 +1913,9 @@ class DataFrame private[snowpark] (

// todo: add test with UDTF
def join(func: Column, partitionBy: Seq[Column], orderBy: Seq[Column]): DataFrame = withPlan {
TableFunctionJoin(this.plan, getTableFunctionExpression(func),
TableFunctionJoin(
this.plan,
getTableFunctionExpression(func),
Some(Window.partitionBy(partitionBy: _*).orderBy(orderBy: _*).getWindowSpecDefinition))
}

Expand Down
12 changes: 10 additions & 2 deletions src/main/scala/com/snowflake/snowpark/Session.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,17 @@ import com.snowflake.snowpark.internal._
import com.snowflake.snowpark.internal.analyzer.{TableFunction => TFunction}
import com.snowflake.snowpark.types._
import com.snowflake.snowpark.functions._
import com.snowflake.snowpark.internal.ErrorMessage.{UDF_CANNOT_ACCEPT_MANY_DF_COLS, UDF_UNEXPECTED_COLUMN_ORDER}
import com.snowflake.snowpark.internal.ErrorMessage.{
UDF_CANNOT_ACCEPT_MANY_DF_COLS,
UDF_UNEXPECTED_COLUMN_ORDER
}
import com.snowflake.snowpark.internal.ParameterUtils.ClosureCleanerMode
import com.snowflake.snowpark.internal.Utils.{TempObjectNamePattern, TempObjectType, getTableFunctionExpression, randomNameForTempObject}
import com.snowflake.snowpark.internal.Utils.{
TempObjectNamePattern,
TempObjectType,
getTableFunctionExpression,
randomNameForTempObject
}
import net.snowflake.client.jdbc.{SnowflakeConnectionV1, SnowflakeDriver, SnowflakeSQLException}

import scala.concurrent.{ExecutionContext, Future}
Expand Down
8 changes: 6 additions & 2 deletions src/main/scala/com/snowflake/snowpark/tableFunctions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,12 @@ object tableFunctions {

def flatten(input: Column): Column = Column(flatten.apply(input))

def flatten(input: Column,
path: String, outer: Boolean, recursive: Boolean, mode: String): Column =
def flatten(
input: Column,
path: String,
outer: Boolean,
recursive: Boolean,
mode: String): Column =
Column(
flatten.apply(
Map(
Expand Down
5 changes: 2 additions & 3 deletions src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -819,8 +819,7 @@ class ErrorMessageSuite extends FunSuite {
val ex = ErrorMessage.MISC_INVALID_TABLE_FUNCTION_INPUT()
assert(ex.telemetryMessage.equals(ErrorMessage.getMessage("0423")))
assert(
ex.message.startsWith(
"Error Code: 0423, Error message: Invalid input argument, " +
"Session.tableFunction only supports table function arguments"))
ex.message.startsWith("Error Code: 0423, Error message: Invalid input argument, " +
"Session.tableFunction only supports table function arguments"))
}
}
89 changes: 76 additions & 13 deletions src/test/scala/com/snowflake/snowpark_test/TableFunctionSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -204,30 +204,78 @@ class TableFunctionSuite extends TestData {
test("Argument in table function: flatten2") {
val df1 = Seq("{\"a\":1, \"b\":[77, 88]}").toDF("col")
checkAnswer(
df1.join(tableFunctions.flatten(input = parse_json(df1("col")),
path = "b", outer = true, recursive = true, mode = "both")).select("value"),
df1
.join(
tableFunctions.flatten(
input = parse_json(df1("col")),
path = "b",
outer = true,
recursive = true,
mode = "both"))
.select("value"),
Seq(Row("77"), Row("88")))

val df2 = Seq("[]").toDF("col")
checkAnswer(df2.join(tableFunctions.flatten(input = parse_json(df1("col")),
path = "", outer = true, recursive = true, mode = "both")).select("value"),
checkAnswer(
df2
.join(
tableFunctions.flatten(
input = parse_json(df1("col")),
path = "",
outer = true,
recursive = true,
mode = "both"))
.select("value"),
Seq(Row(null)))

assert(df1.join(tableFunctions.flatten(input = parse_json(df1("col")),
path = "", outer = true, recursive = true, mode = "both")).count() == 4)
assert(df1.join(tableFunctions.flatten(input = parse_json(df1("col")),
path = "", outer = true, recursive = false, mode = "both")).count() == 2)
assert(df1.join(tableFunctions.flatten(input = parse_json(df1("col")),
path = "", outer = true, recursive = true, mode = "array")).count() == 1)
assert(df1.join(tableFunctions.flatten(input = parse_json(df1("col")),
path = "", outer = true, recursive = true, mode = "object")).count() == 2)
assert(
df1
.join(
tableFunctions.flatten(
input = parse_json(df1("col")),
path = "",
outer = true,
recursive = true,
mode = "both"))
.count() == 4)
assert(
df1
.join(
tableFunctions.flatten(
input = parse_json(df1("col")),
path = "",
outer = true,
recursive = false,
mode = "both"))
.count() == 2)
assert(
df1
.join(
tableFunctions.flatten(
input = parse_json(df1("col")),
path = "",
outer = true,
recursive = true,
mode = "array"))
.count() == 1)
assert(
df1
.join(
tableFunctions.flatten(
input = parse_json(df1("col")),
path = "",
outer = true,
recursive = true,
mode = "object"))
.count() == 2)
}

test("Argument in table function: flatten - session") {
val df = Seq(
(1, Array(1, 2, 3), Map("a" -> "b", "c" -> "d")),
(2, Array(11, 22, 33), Map("a1" -> "b1", "c1" -> "d1"))).toDF("idx", "arr", "map")
checkAnswer(session.tableFunction(tableFunctions.flatten(df("arr"))).select("value"),
checkAnswer(
session.tableFunction(tableFunctions.flatten(df("arr"))).select("value"),
Seq(Row("1"), Row("2"), Row("3"), Row("11"), Row("22"), Row("33")))
// error if it is not a table function
val error1 = intercept[SnowparkClientException] {
Expand All @@ -237,4 +285,19 @@ class TableFunctionSuite extends TestData {
error1.message.contains("Invalid input argument, " +
"Session.tableFunction only supports table function arguments"))
}

test("Argument in table function: flatten - session 2") {
val df1 = Seq("{\"a\":1, \"b\":[77, 88]}").toDF("col")
checkAnswer(
session
.tableFunction(
tableFunctions.flatten(
input = parse_json(df1("col")),
path = "b",
outer = true,
recursive = true,
mode = "both"))
.select("value"),
Seq(Row("77"), Row("88")))
}
}

0 comments on commit d925e1b

Please sign in to comment.