Skip to content

Commit

Permalink
Change interface
Browse files Browse the repository at this point in the history
Signed-off-by: Nghia Truong <[email protected]>
  • Loading branch information
ttnghia committed Nov 27, 2024
1 parent f8a458a commit aad7900
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 15 deletions.
14 changes: 5 additions & 9 deletions src/main/cpp/src/AggregationUtilsJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

extern "C" {

JNIEXPORT jlongArray JNICALL Java_com_nvidia_spark_rapids_jni_JSONUtils_createNativeTestHostUDF(
JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_AggregationUtils_createNativeTestHostUDF(
JNIEnv* env, jclass, jint agg_type)
{
try {
Expand All @@ -29,16 +29,12 @@ JNIEXPORT jlongArray JNICALL Java_com_nvidia_spark_rapids_jni_JSONUtils_createNa
switch (agg_type) {
case 0: return spark_rapids_jni::create_test_reduction_host_udf();
case 1: return spark_rapids_jni::create_test_segmented_reduction_host_udf();
case 2: return spark_rapids_jni::create_test_groupby_host_udf();
default:;
default: return spark_rapids_jni::create_test_groupby_host_udf();
}
}();
// The first value is pointer to host_udf instance,
// and the second value is its hash code.
auto out_handles = cudf::jni::native_jlongArray(env, 2);
out_handles[1] = static_cast<jlong>(udf_ptr->do_hash());
out_handles[0] = reinterpret_cast<jlong>(udf_ptr.release());
return out_handles.get_jArray();
CUDF_EXPECTS(udf_ptr != nullptr, "Invalid host udf instance.");

return reinterpret_cast<jlong>(udf_ptr.release());
}
CATCH_STD(env, 0);
}
Expand Down
2 changes: 1 addition & 1 deletion src/main/cpp/src/test_host_udf_agg.cu
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ struct test_udf_simple_type : cudf::host_udf_base {
template <typename T>
static constexpr bool is_valid_output_t()
{
return std::is_same_v<T, int64_t>;
return std::is_same_v<T, double> || std::is_same_v<T, int64_t>;
}

struct reduce_fn {
Expand Down
10 changes: 5 additions & 5 deletions src/main/java/com/nvidia/spark/rapids/jni/AggregationUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ public class AggregationUtils {
NativeDepsLoader.loadNativeDeps();
}

enum AggregationType {
public enum AggregationType {
Reduction(0),
SegmentedReduction(1),
GroupByAggregation(2);
Expand All @@ -35,19 +35,19 @@ enum AggregationType {

/**
* Create a test host UDF for testing purposes.
*
*<p/>
* This will return two values: the first is the pointer to the host UDF, and the second is the
* hash code of the host UDF.
*
*<p/>
* To create a host UDF aggregation, do this:
* ```
* long[] udfAndHash = AggregationUtils.createTestHostUDF();
* new ai.rapids.cudf.HostUDFAggregation(udfAndHash[0], udfAndHash[1]);
* ```
*/
public static long[] createTestHostUDF(AggregationType type) {
public static long createTestHostUDF(AggregationType type) {
return createNativeTestHostUDF(type.nativeId);
}

private static native long[] createNativeTestHostUDF(int type);
private static native long createNativeTestHostUDF(int type);
}

0 comments on commit aad7900

Please sign in to comment.