Skip to content

Commit

Permalink
Add missing Java methods
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-lfallasavendano committed Sep 26, 2024
1 parent 152b98c commit cbc9239
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 4 deletions.
12 changes: 12 additions & 0 deletions src/main/java/com/snowflake/snowpark_java/Row.java
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,18 @@ public Row getObject(int index) {
return (Row) get(index);
}

/**
* Returns the index of the field with the specified name.
*
* @param fieldName the name of the field.
* @return the index of the specified field.
* @throws UnsupportedOperationException if schema information is not available.
* @since 1.15.0
*/
public int fieldIndex(String fieldName) {
return this.scalaRow.fieldIndex(fieldName);
}

/**
* Returns the value at the specified column index and casts it to the desired type {@code T}.
*
Expand Down
12 changes: 12 additions & 0 deletions src/main/java/com/snowflake/snowpark_java/types/StructType.java
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,18 @@ private static com.snowflake.snowpark.types.StructField[] toScalaFieldsArray(
return result;
}

/**
* Return the index of the specified field.
*
* @param fieldName the name of the field.
* @return the index of the field with the specified name.
* @throws IllegalArgumentException if the given field name does not exist in the schema.
* @since 1.15.0
*/
public int fieldIndex(String fieldName) {
return this.scalaStructType.fieldIndex(fieldName);
}

/**
* Retrieves the names of StructField.
*
Expand Down
6 changes: 3 additions & 3 deletions src/main/scala/com/snowflake/snowpark/Row.scala
Original file line number Diff line number Diff line change
Expand Up @@ -373,15 +373,15 @@ class Row protected (values: Array[Any], schema: Option[StructType]) extends Ser
/**
* Returns the index of the field with the specified name.
*
* @param name the name of the field.
* @param fieldName the name of the field.
* @return the index of the specified field.
* @throws UnsupportedOperationException if schema information is not available.
* @since 1.15.0
*/
def fieldIndex(name: String): Int = {
def fieldIndex(fieldName: String): Int = {
var schema = this.schema.getOrElse(
throw new UnsupportedOperationException("Cannot get field index for row without schema"))
schema.fieldIndex(name)
schema.fieldIndex(fieldName)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ case class StructType(fields: Array[StructField] = Array())
* Return the index of the specified field.
*
* @param fieldName the name of the field.
* @returns the index of the field with the specified name.
* @return the index of the field with the specified name.
* @throws IllegalArgumentException if the given field name does not exist in the schema.
* @since 1.15.0
*/
Expand Down
17 changes: 17 additions & 0 deletions src/test/java/com/snowflake/snowpark_test/JavaRowSuite.java
Original file line number Diff line number Diff line change
Expand Up @@ -629,4 +629,21 @@ public void getAsWithFieldName() {
UnsupportedOperationException.class,
() -> rowWithoutSchema.getAs("NonExistingColumn", Integer.class));
}

@Test
public void fieldIndex() {
StructType schema =
StructType.create(
new StructField("EmpName", DataTypes.StringType),
new StructField("NumVal", DataTypes.IntegerType));

Row[] data = {Row.create("abcd", 10), Row.create("efgh", 20)};

DataFrame df = getSession().createDataFrame(data, schema);
Row row = df.collect()[0];

assert (row.fieldIndex("EmpName") == 0);
assert (row.fieldIndex("NumVal") == 1);
assertThrows(IllegalArgumentException.class, () -> row.fieldIndex("NonExistingColumn"));
}
}
8 changes: 8 additions & 0 deletions src/test/scala/com/snowflake/snowpark_test/RowSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,14 @@ class RowSuite extends SNTestBase {
rowWithoutSchema.getAs[Integer]("NonExistingColumn"));
}

test("fieldIndex") {
val schema =
StructType(Seq(StructField("EmpName", StringType), StructField("NumVal", IntegerType)))
assert(schema.fieldIndex("EmpName") == 0)
assert(schema.fieldIndex("NumVal") == 1)
assertThrows[IllegalArgumentException](schema.fieldIndex("NonExistingColumn"))
}

test("hashCode") {
val row1 = Row(1, 2, 3)
val row2 = Row("str", null, 3)
Expand Down

0 comments on commit cbc9239

Please sign in to comment.