diff --git a/engine/function/src/templates/Numeric.ftl b/engine/function/src/templates/Numeric.ftl index 30cfc23084e..ccedc0d79d0 100644 --- a/engine/function/src/templates/Numeric.ftl +++ b/engine/function/src/templates/Numeric.ftl @@ -2599,6 +2599,7 @@ public class Numeric { } long vsum = 0; + long nullCount = 0; try ( final ${pt.vectorIterator} vi = values.iterator(); @@ -2610,10 +2611,16 @@ public class Numeric { if (!isNull(c) && !isNull(w)) { vsum += c * (long) w; + } else { + nullCount++; } } } + if (nullCount == values.size()) { + return NULL_LONG; + } + return vsum; } <#else> @@ -2629,6 +2636,7 @@ public class Numeric { } double vsum = 0; + long nullCount = 0; try ( final ${pt.vectorIterator} vi = values.iterator(); @@ -2660,10 +2668,16 @@ public class Numeric { <#else> vsum += c * (double) w; + } else { + nullCount++; } } } + if (nullCount == values.size()) { + return NULL_DOUBLE; + } + return vsum; } @@ -2733,6 +2747,7 @@ public class Numeric { double vsum = 0; double wsum = 0; + long nullCount = 0; try ( final ${pt.vectorIterator} vi = values.iterator(); @@ -2750,10 +2765,16 @@ public class Numeric { if (!isNull(c) && !isNull(w)) { vsum += c * w; wsum += w; + } else { + nullCount++; } } } + if (nullCount == values.size()) { + return NULL_DOUBLE; + } + return vsum / wsum; } diff --git a/engine/function/src/templates/TestNumeric.ftl b/engine/function/src/templates/TestNumeric.ftl index 9bf99f52c0b..1805eb3a396 100644 --- a/engine/function/src/templates/TestNumeric.ftl +++ b/engine/function/src/templates/TestNumeric.ftl @@ -1146,6 +1146,10 @@ public class TestNumeric extends BaseArrayTestCase { assertEquals(NULL_LONG, wsum((${pt.primitive}[])null, new ${pt2.primitive}[]{4,5,6})); assertEquals(NULL_LONG, wsum(new ${pt.primitive}[]{1,2,3}, (${pt2.primitive}[])null)); + assertEquals(NULL_LONG, wsum(new ${pt.primitive}[]{${pt.null},${pt.null},${pt.null}}, new ${pt2.primitive}[]{${pt2.null},${pt2.null},${pt2.null}})); + assertEquals(NULL_LONG, wsum(new ${pt.primitive}[]{1,2,3}, new ${pt2.primitive}[]{${pt2.null},${pt2.null},${pt2.null}})); + assertEquals(NULL_LONG, wsum(new ${pt.primitive}[]{${pt.null},${pt.null},${pt.null}}, new ${pt2.primitive}[]{1,2,3})); + assertEquals(1*4+2*5+3*6, wsum(new ${pt.vectorDirect}(new ${pt.primitive}[]{1,2,3,${pt.null},5}), new ${pt2.primitive}[]{4,5,6,7,${pt2.null}})); assertEquals(NULL_LONG, wsum((${pt.vector}) null, new ${pt2.primitive}[]{4,5,6})); assertEquals(NULL_LONG, wsum(new ${pt.vectorDirect}(new ${pt.primitive}[]{1,2,3}), (${pt2.primitive}[])null)); @@ -1169,6 +1173,10 @@ public class TestNumeric extends BaseArrayTestCase { assertEquals(NULL_DOUBLE, wsum((${pt.primitive}[])null, new ${pt2.primitive}[]{4,5,6})); assertEquals(NULL_DOUBLE, wsum(new ${pt.primitive}[]{1,2,3}, (${pt2.primitive}[])null)); + assertEquals(NULL_DOUBLE, wsum(new ${pt.primitive}[]{${pt.null},${pt.null},${pt.null}}, new ${pt2.primitive}[]{${pt2.null},${pt2.null},${pt2.null}})); + assertEquals(NULL_DOUBLE, wsum(new ${pt.primitive}[]{1,2,3}, new ${pt2.primitive}[]{${pt2.null},${pt2.null},${pt2.null}})); + assertEquals(NULL_DOUBLE, wsum(new ${pt.primitive}[]{${pt.null},${pt.null},${pt.null}}, new ${pt2.primitive}[]{1,2,3})); + assertEquals(1.0*4.0+2.0*5.0+3.0*6.0, wsum(new ${pt.vectorDirect}(new ${pt.primitive}[]{1,2,3,${pt.null},5}), new ${pt2.primitive}[]{4,5,6,7,${pt2.null}})); assertEquals(NULL_DOUBLE, wsum((${pt.vector}) null, new ${pt2.primitive}[]{4,5,6})); assertEquals(NULL_DOUBLE, wsum(new ${pt.vectorDirect}(new ${pt.primitive}[]{1,2,3}), (${pt2.primitive}[])null)); @@ -1215,6 +1223,10 @@ public class TestNumeric extends BaseArrayTestCase { assertEquals(NULL_DOUBLE, wavg((${pt.primitive}[])null, new ${pt2.primitive}[]{4,5,6})); assertEquals(NULL_DOUBLE, wavg(new ${pt.primitive}[]{1,2,3}, (${pt2.primitive}[])null)); + assertEquals(NULL_DOUBLE, wavg(new ${pt.primitive}[]{${pt.null},${pt.null},${pt.null}}, new ${pt2.primitive}[]{${pt2.null},${pt2.null},${pt2.null}})); + assertEquals(NULL_DOUBLE, wavg(new ${pt.primitive}[]{1,2,3}, new ${pt2.primitive}[]{${pt2.null},${pt2.null},${pt2.null}})); + assertEquals(NULL_DOUBLE, wavg(new ${pt.primitive}[]{${pt.null},${pt.null},${pt.null}}, new ${pt2.primitive}[]{1,2,3})); + assertEquals((1.0*4.0+2.0*5.0+3.0*6.0)/(4.0+5.0+6.0), wavg(new ${pt.vectorDirect}(new ${pt.primitive}[]{1,2,3,${pt.null},5}), new ${pt2.primitive}[]{4,5,6,7,${pt2.null}})); assertEquals(NULL_DOUBLE, wavg((${pt.vector}) null, new ${pt2.primitive}[]{4,5,6})); assertEquals(NULL_DOUBLE, wavg(new ${pt.vectorDirect}(new ${pt.primitive}[]{1,2,3}), (${pt2.primitive}[])null)); diff --git a/engine/table/src/main/java/io/deephaven/engine/table/impl/updateby/rollingwavg/BasePrimitiveRollingWAvgOperator.java b/engine/table/src/main/java/io/deephaven/engine/table/impl/updateby/rollingwavg/BasePrimitiveRollingWAvgOperator.java index 67d9c7c8b4d..98bcdd7c2b0 100644 --- a/engine/table/src/main/java/io/deephaven/engine/table/impl/updateby/rollingwavg/BasePrimitiveRollingWAvgOperator.java +++ b/engine/table/src/main/java/io/deephaven/engine/table/impl/updateby/rollingwavg/BasePrimitiveRollingWAvgOperator.java @@ -84,8 +84,7 @@ public void pop(int count) { @Override public void writeToOutputChunk(int outIdx) { if (windowValues.size() == nullCount) { - // Looks weird but returning NaN is consistent with Numeric#wavg and AggWAvg - outputValues.set(outIdx, Double.NaN); + outputValues.set(outIdx, NULL_DOUBLE); } else { final double weightedValSum = windowValues.evaluate(); final double weightSum = windowWeightValues.evaluate();