diff --git a/src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java b/src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java index c57ab9a2..0c046630 100644 --- a/src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java +++ b/src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java @@ -9,10 +9,12 @@ import java.sql.Date; import java.sql.Time; import java.sql.Timestamp; +import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.TimeZone; import org.junit.Test; public class JavaRowSuite extends TestBase { @@ -532,4 +534,63 @@ public void getAs() { assert row.getAs(13, Timestamp.class) == null; assert row.getAs(14, Variant.class) == null; } + + @Test + public void getAsWithStructuredMap() { + structuredTypeTest( + () -> { + String query = + "SELECT " + + "{'a':1,'b':2}::MAP(VARCHAR, NUMBER) as map1," + + "{'1':'a','2':'b'}::MAP(NUMBER, VARCHAR) as map2," + + "{'1':{'a':1,'b':2},'2':{'c':3}}::MAP(NUMBER, MAP(VARCHAR, NUMBER)) as map3"; + + DataFrame df = getSession().sql(query); + Row row = df.collect()[0]; + + Map map1 = row.getAs(0, Map.class); + assert (Long) map1.get("a") == 1L; + assert (Long) map1.get("b") == 2L; + + Map map2 = row.getAs(1, Map.class); + assert map2.get(1L).equals("a"); + assert map2.get(2L).equals("b"); + + Map map3 = row.getAs(2, Map.class); + assert map3.get(1L).equals(Map.of("a", 1L, "b", 2L)); + assert map3.get(2L).equals(Map.of("c", 3L)); + }, + getSession()); + } + + @Test + public void getAsWithStructuredArray() { + structuredTypeTest( + () -> { + TimeZone.setDefault(TimeZone.getTimeZone("US/Pacific")); + + String query = + "SELECT " + + "[1,2,3]::ARRAY(NUMBER) AS arr1," + + "['a','b']::ARRAY(VARCHAR) AS arr2," + + "[parse_json(31000000)::timestamp_ntz]::ARRAY(TIMESTAMP_NTZ) AS arr3," + + "[[1,2]]::ARRAY(ARRAY) AS arr4"; + + DataFrame df = getSession().sql(query); + Row row = df.collect()[0]; + + ArrayList array1 = row.getAs(0, ArrayList.class); + assert array1.equals(List.of(1L, 2L, 3L)); + + ArrayList array2 = row.getAs(1, ArrayList.class); + assert array2.equals(List.of("a", "b")); + + ArrayList array3 = row.getAs(2, ArrayList.class); + assert array3.equals(List.of(new Timestamp(31000000000L))); + + ArrayList array4 = row.getAs(3, ArrayList.class); + assert array4.equals(List.of("[\n 1,\n 2\n]")); + }, + getSession()); + } } diff --git a/src/test/scala/com/snowflake/snowpark_test/RowSuite.scala b/src/test/scala/com/snowflake/snowpark_test/RowSuite.scala index 5491f6f1..5ba449dd 100644 --- a/src/test/scala/com/snowflake/snowpark_test/RowSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/RowSuite.scala @@ -6,6 +6,7 @@ import com.snowflake.snowpark.{Row, SNTestBase, SnowparkClientException} import java.sql.{Date, Time, Timestamp} import java.time.{Instant, LocalDate} import java.util +import java.util.TimeZone class RowSuite extends SNTestBase { @@ -343,6 +344,61 @@ class RowSuite extends SNTestBase { assert(row.getAs[Variant](14) == null) } + test("getAs with structured map") { + structuredTypeTest { + val query = + """SELECT + | {'a':1,'b':2}::MAP(VARCHAR, NUMBER) as map1, + | {'1':'a','2':'b'}::MAP(NUMBER, VARCHAR) as map2, + | {'1':{'a':1,'b':2},'2':{'c':3}}::MAP(NUMBER, MAP(VARCHAR, NUMBER)) as map3 + |""".stripMargin + + val df = session.sql(query) + val row = df.collect()(0) + + val map1 = row.getAs[Map[String, Long]](0) + assert(map1("a") == 1L) + assert(map1("b") == 2L) + + val map2 = row.getAs[Map[Long, String]](1) + assert(map2(1) == "a") + assert(map2(2) == "b") + + val map3 = row.getAs[Map[Long, Map[String, Long]]](2) + assert(map3(1) == Map("a" -> 1, "b" -> 2)) + assert(map3(2) == Map("c" -> 3)) + } + } + + test("getAs with structured array") { + structuredTypeTest { + TimeZone.setDefault(TimeZone.getTimeZone("US/Pacific")) + + val query = + """SELECT + | [1,2,3]::ARRAY(NUMBER) AS arr1, + | ['a','b']::ARRAY(VARCHAR) AS arr2, + | [parse_json(31000000)::timestamp_ntz]::ARRAY(TIMESTAMP_NTZ) AS arr3, + | [[1,2]]::ARRAY(ARRAY) AS arr4 + |""".stripMargin + + val df = session.sql(query) + val row = df.collect()(0) + + val array1 = row.getAs[Array[Object]](0) + assert(array1 sameElements Array(1, 2, 3)) + + val array2 = row.getAs[Array[Object]](1) + assert(array2 sameElements Array("a", "b")) + + val array3 = row.getAs[Array[Object]](2) + assert(array3 sameElements Array(new Timestamp(31000000000L))) + + val array4 = row.getAs[Array[Object]](3) + assert(array4 sameElements Array("[\n 1,\n 2\n]")) + } + } + test("hashCode") { val row1 = Row(1, 2, 3) val row2 = Row("str", null, 3)