Skip to content

Commit

Permalink
Add tests scenarios for structured maps and structured arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-fgonzalezmendez committed Aug 30, 2024
1 parent 9c512f0 commit e91aaa0
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 0 deletions.
61 changes: 61 additions & 0 deletions src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
import java.sql.Date;
import java.sql.Time;
import java.sql.Timestamp;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.TimeZone;
import org.junit.Test;

public class JavaRowSuite extends TestBase {
Expand Down Expand Up @@ -532,4 +534,63 @@ public void getAs() {
assert row.getAs(13, Timestamp.class) == null;
assert row.getAs(14, Variant.class) == null;
}

@Test
public void getAsWithStructuredMap() {
structuredTypeTest(
() -> {
String query =
"SELECT "
+ "{'a':1,'b':2}::MAP(VARCHAR, NUMBER) as map1,"
+ "{'1':'a','2':'b'}::MAP(NUMBER, VARCHAR) as map2,"
+ "{'1':{'a':1,'b':2},'2':{'c':3}}::MAP(NUMBER, MAP(VARCHAR, NUMBER)) as map3";

DataFrame df = getSession().sql(query);
Row row = df.collect()[0];

Map<?, ?> map1 = row.getAs(0, Map.class);
assert (Long) map1.get("a") == 1L;
assert (Long) map1.get("b") == 2L;

Map<?, ?> map2 = row.getAs(1, Map.class);
assert map2.get(1L).equals("a");
assert map2.get(2L).equals("b");

Map<?, ?> map3 = row.getAs(2, Map.class);
assert map3.get(1L).equals(Map.of("a", 1L, "b", 2L));
assert map3.get(2L).equals(Map.of("c", 3L));
},
getSession());
}

@Test
public void getAsWithStructuredArray() {
structuredTypeTest(
() -> {
TimeZone.setDefault(TimeZone.getTimeZone("US/Pacific"));

String query =
"SELECT "
+ "[1,2,3]::ARRAY(NUMBER) AS arr1,"
+ "['a','b']::ARRAY(VARCHAR) AS arr2,"
+ "[parse_json(31000000)::timestamp_ntz]::ARRAY(TIMESTAMP_NTZ) AS arr3,"
+ "[[1,2]]::ARRAY(ARRAY) AS arr4";

DataFrame df = getSession().sql(query);
Row row = df.collect()[0];

ArrayList<?> array1 = row.getAs(0, ArrayList.class);
assert array1.equals(List.of(1L, 2L, 3L));

ArrayList<?> array2 = row.getAs(1, ArrayList.class);
assert array2.equals(List.of("a", "b"));

ArrayList<?> array3 = row.getAs(2, ArrayList.class);
assert array3.equals(List.of(new Timestamp(31000000000L)));

ArrayList<?> array4 = row.getAs(3, ArrayList.class);
assert array4.equals(List.of("[\n 1,\n 2\n]"));
},
getSession());
}
}
56 changes: 56 additions & 0 deletions src/test/scala/com/snowflake/snowpark_test/RowSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import com.snowflake.snowpark.{Row, SNTestBase, SnowparkClientException}
import java.sql.{Date, Time, Timestamp}
import java.time.{Instant, LocalDate}
import java.util
import java.util.TimeZone

class RowSuite extends SNTestBase {

Expand Down Expand Up @@ -343,6 +344,61 @@ class RowSuite extends SNTestBase {
assert(row.getAs[Variant](14) == null)
}

test("getAs with structured map") {
structuredTypeTest {
val query =
"""SELECT
| {'a':1,'b':2}::MAP(VARCHAR, NUMBER) as map1,
| {'1':'a','2':'b'}::MAP(NUMBER, VARCHAR) as map2,
| {'1':{'a':1,'b':2},'2':{'c':3}}::MAP(NUMBER, MAP(VARCHAR, NUMBER)) as map3
|""".stripMargin

val df = session.sql(query)
val row = df.collect()(0)

val map1 = row.getAs[Map[String, Long]](0)
assert(map1("a") == 1L)
assert(map1("b") == 2L)

val map2 = row.getAs[Map[Long, String]](1)
assert(map2(1) == "a")
assert(map2(2) == "b")

val map3 = row.getAs[Map[Long, Map[String, Long]]](2)
assert(map3(1) == Map("a" -> 1, "b" -> 2))
assert(map3(2) == Map("c" -> 3))
}
}

test("getAs with structured array") {
structuredTypeTest {
TimeZone.setDefault(TimeZone.getTimeZone("US/Pacific"))

val query =
"""SELECT
| [1,2,3]::ARRAY(NUMBER) AS arr1,
| ['a','b']::ARRAY(VARCHAR) AS arr2,
| [parse_json(31000000)::timestamp_ntz]::ARRAY(TIMESTAMP_NTZ) AS arr3,
| [[1,2]]::ARRAY(ARRAY) AS arr4
|""".stripMargin

val df = session.sql(query)
val row = df.collect()(0)

val array1 = row.getAs[Array[Object]](0)
assert(array1 sameElements Array(1, 2, 3))

val array2 = row.getAs[Array[Object]](1)
assert(array2 sameElements Array("a", "b"))

val array3 = row.getAs[Array[Object]](2)
assert(array3 sameElements Array(new Timestamp(31000000000L)))

val array4 = row.getAs[Array[Object]](3)
assert(array4 sameElements Array("[\n 1,\n 2\n]"))
}
}

test("hashCode") {
val row1 = Row(1, 2, 3)
val row2 = Row("str", null, 3)
Expand Down

0 comments on commit e91aaa0

Please sign in to comment.