Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SIT-2037 Add support for com.snowflake.snowpark.Row.getAs function #148

Merged
merged 7 commits into from
Sep 10, 2024
Merged
58 changes: 58 additions & 0 deletions src/main/java/com/snowflake/snowpark_java/Row.java
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,64 @@ public Row getObject(int index) {
return (Row) get(index);
}

/**
* Returns the value at the specified column index and casts it to the desired type {@code T}.
*
* <p>Example:
*
* <pre>{@code
* Row row = Row.create(1, "Alice", 95.5);
* row.getAs(0, Integer.class); // Returns 1 as an Int
* row.getAs(1, String.class); // Returns "Alice" as a String
* row.getAs(2, Double.class); // Returns 95.5 as a Double
* }</pre>
*
* @param index the zero-based column index within the row.
* @param clazz the {@code Class} object representing the type {@code T}.
* @param <T> the expected type of the value at the specified column index.
* @return the value at the specified column index cast to type {@code T}.
* @throws ClassCastException if the value at the given index cannot be cast to type {@code T}.
* @throws ArrayIndexOutOfBoundsException if the column index is out of bounds.
* @since 1.15.0
*/
@SuppressWarnings("unchecked")
public <T> T getAs(int index, Class<T> clazz)
throws ClassCastException, ArrayIndexOutOfBoundsException {
if (isNullAt(index)) {
return (T) get(index);
}

if (clazz == Byte.class) {
return (T) (Object) getByte(index);
}

if (clazz == Double.class) {
return (T) (Object) getDouble(index);
}

if (clazz == Float.class) {
return (T) (Object) getFloat(index);
}

if (clazz == Integer.class) {
return (T) (Object) getInt(index);
}

if (clazz == Long.class) {
return (T) (Object) getLong(index);
}

if (clazz == Short.class) {
return (T) (Object) getShort(index);
}

if (clazz == Variant.class) {
return (T) getVariant(index);
}

return (T) get(index);
}

/**
* Generates a string value to represent the content of this row.
*
Expand Down
40 changes: 35 additions & 5 deletions src/main/scala/com/snowflake/snowpark/Row.scala
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,39 @@ class Row protected (values: Array[Any]) extends Serializable {
getAs[Map[T, U]](index)
}

/**
* Returns the value at the specified column index and casts it to the desired type `T`.
*
* Example:
* {{{
* val row = Row(1, "Alice", 95.5)
* row.getAs[Int](0) // Returns 1 as an Int
* row.getAs[String](1) // Returns "Alice" as a String
* row.getAs[Double](2) // Returns 95.5 as a Double
* }}}
*
* @param index the zero-based column index within the row.
* @tparam T the expected type of the value at the specified column index.
* @return the value at the specified column index cast to type `T`.
* @throws ClassCastException if the value at the given index cannot be cast to type `T`.
* @throws ArrayIndexOutOfBoundsException if the column index is out of bounds.
* @group getter
* @since 1.15.0
*/
def getAs[T](index: Int)(implicit classTag: ClassTag[T]): T = {
classTag.runtimeClass match {
case _ if isNullAt(index) => get(index).asInstanceOf[T]
case c if c == classOf[Byte] => getByte(index).asInstanceOf[T]
case c if c == classOf[Double] => getDouble(index).asInstanceOf[T]
case c if c == classOf[Float] => getFloat(index).asInstanceOf[T]
case c if c == classOf[Int] => getInt(index).asInstanceOf[T]
case c if c == classOf[Long] => getLong(index).asInstanceOf[T]
case c if c == classOf[Short] => getShort(index).asInstanceOf[T]
case c if c == classOf[Variant] => getVariant(index).asInstanceOf[T]
case _ => get(index).asInstanceOf[T]
}
}

protected def convertValueToString(value: Any): String =
value match {
case null => "null"
Expand Down Expand Up @@ -400,10 +433,7 @@ class Row protected (values: Array[Any]) extends Serializable {
.map(convertValueToString)
.mkString("Row[", ",", "]")

private def getAs[T](index: Int): T = get(index).asInstanceOf[T]

private def getAnyValAs[T <: AnyVal](index: Int): T =
private def getAnyValAs[T <: AnyVal](index: Int)(implicit classTag: ClassTag[T]): T =
if (isNullAt(index)) throw new NullPointerException(s"Value at index $index is null")
else getAs[T](index)

else getAs[T](index)(classTag)
}
173 changes: 173 additions & 0 deletions src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
package com.snowflake.snowpark_test;

import static org.junit.Assert.assertThrows;

import com.snowflake.snowpark_java.DataFrame;
import com.snowflake.snowpark_java.Row;
import com.snowflake.snowpark_java.types.*;
import java.math.BigDecimal;
import java.sql.Date;
import java.sql.Time;
import java.sql.Timestamp;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.TimeZone;
import org.junit.Test;

public class JavaRowSuite extends TestBase {
Expand Down Expand Up @@ -429,4 +434,172 @@ public void testGetRow() {
},
getSession());
}

@Test
public void getAs() {
long milliseconds = System.currentTimeMillis();

StructType schema =
sfc-gh-mrojas marked this conversation as resolved.
Show resolved Hide resolved
StructType.create(
new StructField("c01", DataTypes.BinaryType),
new StructField("c02", DataTypes.BooleanType),
new StructField("c03", DataTypes.ByteType),
new StructField("c04", DataTypes.DateType),
new StructField("c05", DataTypes.DoubleType),
new StructField("c06", DataTypes.FloatType),
new StructField("c07", DataTypes.GeographyType),
new StructField("c08", DataTypes.GeometryType),
new StructField("c09", DataTypes.IntegerType),
new StructField("c10", DataTypes.LongType),
new StructField("c11", DataTypes.ShortType),
new StructField("c12", DataTypes.StringType),
new StructField("c13", DataTypes.TimeType),
new StructField("c14", DataTypes.TimestampType),
new StructField("c15", DataTypes.VariantType));

Row[] data = {
Row.create(
new byte[] {1, 2},
true,
Byte.MIN_VALUE,
Date.valueOf("2024-01-01"),
Double.MIN_VALUE,
Float.MIN_VALUE,
Geography.fromGeoJSON("POINT(30 10)"),
Geometry.fromGeoJSON("POINT(20 40)"),
Integer.MIN_VALUE,
Long.MIN_VALUE,
Short.MIN_VALUE,
"string",
Time.valueOf("16:23:04"),
new Timestamp(milliseconds),
new Variant(1))
};

DataFrame df = getSession().createDataFrame(data, schema);
Row row = df.collect()[0];

assert Arrays.equals(row.getAs(0, byte[].class), new byte[] {1, 2});
assert row.getAs(1, Boolean.class);
assert row.getAs(2, Byte.class) == Byte.MIN_VALUE;
assert row.getAs(3, Date.class).equals(Date.valueOf("2024-01-01"));
assert row.getAs(4, Double.class) == Double.MIN_VALUE;
assert row.getAs(5, Float.class) == Float.MIN_VALUE;
assert row.getAs(6, Geography.class)
.equals(
Geography.fromGeoJSON(
"{\n \"coordinates\": [\n 30,\n 10\n ],\n \"type\": \"Point\"\n}"));
assert row.getAs(7, Geometry.class)
.equals(
Geometry.fromGeoJSON(
"{\n \"coordinates\": [\n 2.000000000000000e+01,\n 4.000000000000000e+01\n ],\n \"type\": \"Point\"\n}"));
assert row.getAs(8, Integer.class) == Integer.MIN_VALUE;
assert row.getAs(9, Long.class) == Long.MIN_VALUE;
assert row.getAs(10, Short.class) == Short.MIN_VALUE;
assert row.getAs(11, String.class).equals("string");
assert row.getAs(12, Time.class).equals(Time.valueOf("16:23:04"));
assert row.getAs(13, Timestamp.class).equals(new Timestamp(milliseconds));
assert row.getAs(14, Variant.class).equals(new Variant(1));

Row finalRow = row;
assertThrows(
ClassCastException.class,
() -> {
Boolean b = finalRow.getAs(0, Boolean.class);
});
assertThrows(ArrayIndexOutOfBoundsException.class, () -> finalRow.getAs(-1, Boolean.class));

data =
new Row[] {
Row.create(
null, null, null, null, null, null, null, null, null, null, null, null, null, null,
null)
};

df = getSession().createDataFrame(data, schema);
row = df.collect()[0];

assert row.getAs(0, byte[].class) == null;
assert row.getAs(1, Boolean.class) == null;
assert row.getAs(2, Byte.class) == null;
assert row.getAs(3, Date.class) == null;
assert row.getAs(4, Double.class) == null;
assert row.getAs(5, Float.class) == null;
assert row.getAs(6, Geography.class) == null;
assert row.getAs(7, Geometry.class) == null;
assert row.getAs(8, Integer.class) == null;
assert row.getAs(9, Long.class) == null;
assert row.getAs(10, Short.class) == null;
assert row.getAs(11, String.class) == null;
assert row.getAs(12, Time.class) == null;
assert row.getAs(13, Timestamp.class) == null;
assert row.getAs(14, Variant.class) == null;
}

@Test
public void getAsWithStructuredMap() {
structuredTypeTest(
() -> {
String query =
"SELECT "
+ "{'a':1,'b':2}::MAP(VARCHAR, NUMBER) as map1,"
+ "{'1':'a','2':'b'}::MAP(NUMBER, VARCHAR) as map2,"
+ "{'1':{'a':1,'b':2},'2':{'c':3}}::MAP(NUMBER, MAP(VARCHAR, NUMBER)) as map3";

DataFrame df = getSession().sql(query);
Row row = df.collect()[0];

Map<?, ?> map1 = row.getAs(0, Map.class);
assert (Long) map1.get("a") == 1L;
assert (Long) map1.get("b") == 2L;

Map<?, ?> map2 = row.getAs(1, Map.class);
assert map2.get(1L).equals("a");
assert map2.get(2L).equals("b");

Map<?, ?> map3 = row.getAs(2, Map.class);
Map<String, Long> map3ExpectedInnerMap = new HashMap<>();
map3ExpectedInnerMap.put("a", 1L);
map3ExpectedInnerMap.put("b", 2L);
assert map3.get(1L).equals(map3ExpectedInnerMap);
assert map3.get(2L).equals(Collections.singletonMap("c", 3L));
},
getSession());
}

@Test
public void getAsWithStructuredArray() {
structuredTypeTest(
() -> {
TimeZone oldTimeZone = TimeZone.getDefault();
try {
TimeZone.setDefault(TimeZone.getTimeZone("US/Pacific"));

String query =
"SELECT "
+ "[1,2,3]::ARRAY(NUMBER) AS arr1,"
+ "['a','b']::ARRAY(VARCHAR) AS arr2,"
+ "[parse_json(31000000)::timestamp_ntz]::ARRAY(TIMESTAMP_NTZ) AS arr3,"
+ "[[1,2]]::ARRAY(ARRAY) AS arr4";

DataFrame df = getSession().sql(query);
Row row = df.collect()[0];

ArrayList<?> array1 = row.getAs(0, ArrayList.class);
assert array1.equals(Arrays.asList(1L, 2L, 3L));

ArrayList<?> array2 = row.getAs(1, ArrayList.class);
assert array2.equals(Arrays.asList("a", "b"));

ArrayList<?> array3 = row.getAs(2, ArrayList.class);
assert array3.equals(Collections.singletonList(new Timestamp(31000000000L)));

ArrayList<?> array4 = row.getAs(3, ArrayList.class);
assert array4.equals(Collections.singletonList("[\n 1,\n 2\n]"));
} finally {
TimeZone.setDefault(oldTimeZone);
}
},
getSession());
}
}
Loading
Loading