Skip to content

Commit

Permalink
Widen returned types for UpdateBy floating point operations. (#5371)
Browse files Browse the repository at this point in the history
* Initial commit of changes for widening update_by returned types (Float -> Double)

* Fixed NULL_FLOAT related bugs, adjusted tests to use correct datatypes.

* Added all NULL tests to UpdateBy cumulative operations.
  • Loading branch information
lbooker42 authored May 8, 2024
1 parent f2dc4fe commit 25e0cb1
Show file tree
Hide file tree
Showing 11 changed files with 197 additions and 147 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import io.deephaven.engine.table.impl.updateby.internal.BaseDoubleUpdateByOperator;
import org.jetbrains.annotations.NotNull;

import static io.deephaven.util.QueryConstants.NULL_DOUBLE;
import static io.deephaven.util.QueryConstants.NULL_DOUBLE;

public class DoubleCumProdOperator extends BaseDoubleUpdateByOperator {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,17 @@
import io.deephaven.chunk.attributes.Values;
import io.deephaven.engine.table.impl.MatchPair;
import io.deephaven.engine.table.impl.updateby.UpdateByOperator;
import io.deephaven.engine.table.impl.updateby.internal.BaseFloatUpdateByOperator;
import io.deephaven.engine.table.impl.updateby.internal.BaseDoubleUpdateByOperator;
import org.jetbrains.annotations.NotNull;

import static io.deephaven.util.QueryConstants.NULL_DOUBLE;
import static io.deephaven.util.QueryConstants.NULL_FLOAT;

public class FloatCumProdOperator extends BaseFloatUpdateByOperator {
public class FloatCumProdOperator extends BaseDoubleUpdateByOperator {
// region extra-fields
// endregion extra-fields

protected class Context extends BaseFloatUpdateByOperator.Context {
protected class Context extends BaseDoubleUpdateByOperator.Context {
public FloatChunk<? extends Values> floatValueChunk;

protected Context(final int chunkSize) {
Expand All @@ -37,7 +38,7 @@ public void push(int pos, int count) {
final float val = floatValueChunk.get(pos);

if (val != NULL_FLOAT) {
curVal = curVal == NULL_FLOAT ? val : curVal * val;
curVal = curVal == NULL_DOUBLE ? val : curVal * val;
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

import static io.deephaven.util.QueryConstants.NULL_DOUBLE;
import static io.deephaven.util.QueryConstants.NULL_DOUBLE;

public class DoubleRollingSumOperator extends BaseDoubleUpdateByOperator {
Expand Down Expand Up @@ -60,11 +61,13 @@ public void push(int pos, int count) {
aggSum.ensureRemaining(count);

for (int ii = 0; ii < count; ii++) {
double val = doubleInfluencerValuesChunk.get(pos + ii);
aggSum.addUnsafe(val);
final double val = doubleInfluencerValuesChunk.get(pos + ii);

if (val == NULL_DOUBLE) {
nullCount++;
aggSum.addUnsafe(NULL_DOUBLE);
} else {
aggSum.addUnsafe(val);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,37 +3,38 @@
//
package io.deephaven.engine.table.impl.updateby.rollingsum;

import io.deephaven.base.ringbuffer.AggregatingFloatRingBuffer;
import io.deephaven.base.ringbuffer.AggregatingDoubleRingBuffer;
import io.deephaven.base.verify.Assert;
import io.deephaven.chunk.Chunk;
import io.deephaven.chunk.FloatChunk;
import io.deephaven.chunk.attributes.Values;
import io.deephaven.engine.table.impl.MatchPair;
import io.deephaven.engine.table.impl.updateby.UpdateByOperator;
import io.deephaven.engine.table.impl.updateby.internal.BaseFloatUpdateByOperator;
import io.deephaven.engine.table.impl.updateby.internal.BaseDoubleUpdateByOperator;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

import static io.deephaven.util.QueryConstants.NULL_DOUBLE;
import static io.deephaven.util.QueryConstants.NULL_FLOAT;

public class FloatRollingSumOperator extends BaseFloatUpdateByOperator {
public class FloatRollingSumOperator extends BaseDoubleUpdateByOperator {
private static final int BUFFER_INITIAL_SIZE = 64;

protected class Context extends BaseFloatUpdateByOperator.Context {
protected class Context extends BaseDoubleUpdateByOperator.Context {
protected FloatChunk<? extends Values> floatInfluencerValuesChunk;
protected AggregatingFloatRingBuffer aggSum;
protected AggregatingDoubleRingBuffer aggSum;

protected Context(final int chunkSize) {
super(chunkSize);
aggSum = new AggregatingFloatRingBuffer(BUFFER_INITIAL_SIZE,
aggSum = new AggregatingDoubleRingBuffer(BUFFER_INITIAL_SIZE,
0,
Float::sum, // tree function
Double::sum, // tree function
(a, b) -> { // value function
if (a == NULL_FLOAT && b == NULL_FLOAT) {
if (a == NULL_DOUBLE && b == NULL_DOUBLE) {
return 0; // identity val
} else if (a == NULL_FLOAT) {
} else if (a == NULL_DOUBLE) {
return b;
} else if (b == NULL_FLOAT) {
} else if (b == NULL_DOUBLE) {
return a;
}
return a + b;
Expand All @@ -56,11 +57,13 @@ public void push(int pos, int count) {
aggSum.ensureRemaining(count);

for (int ii = 0; ii < count; ii++) {
float val = floatInfluencerValuesChunk.get(pos + ii);
aggSum.addUnsafe(val);
final float val = floatInfluencerValuesChunk.get(pos + ii);

if (val == NULL_FLOAT) {
nullCount++;
aggSum.addUnsafe(NULL_DOUBLE);
} else {
aggSum.addUnsafe(val);
}
}
}
Expand All @@ -70,9 +73,9 @@ public void pop(int count) {
Assert.geq(aggSum.size(), "aggSum.size()", count);

for (int ii = 0; ii < count; ii++) {
float val = aggSum.removeUnsafe();
double val = aggSum.removeUnsafe();

if (val == NULL_FLOAT) {
if (val == NULL_DOUBLE) {
nullCount--;
}
}
Expand All @@ -81,7 +84,7 @@ public void pop(int count) {
@Override
public void writeToOutputChunk(int outIdx) {
if (aggSum.size() == nullCount) {
outputValues.set(outIdx, NULL_FLOAT);
outputValues.set(outIdx, NULL_DOUBLE);
} else {
outputValues.set(outIdx, aggSum.evaluate());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import io.deephaven.engine.table.impl.updateby.internal.BaseDoubleUpdateByOperator;
import org.jetbrains.annotations.NotNull;

import static io.deephaven.util.QueryConstants.NULL_DOUBLE;
import static io.deephaven.util.QueryConstants.NULL_DOUBLE;

public class DoubleCumSumOperator extends BaseDoubleUpdateByOperator {
Expand All @@ -37,12 +38,10 @@ public void push(int pos, int count) {
Assert.eq(count, "push count", 1);

// read the value from the values chunk
final double currentVal = doubleValueChunk.get(pos);
final double val = doubleValueChunk.get(pos);

if (curVal == NULL_DOUBLE) {
curVal = currentVal;
} else if (currentVal != NULL_DOUBLE) {
curVal += currentVal;
if (val != NULL_DOUBLE) {
curVal = curVal == NULL_DOUBLE ? val : curVal + val;
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@
import io.deephaven.chunk.attributes.Values;
import io.deephaven.engine.table.impl.MatchPair;
import io.deephaven.engine.table.impl.updateby.UpdateByOperator;
import io.deephaven.engine.table.impl.updateby.internal.BaseFloatUpdateByOperator;
import io.deephaven.engine.table.impl.updateby.internal.BaseDoubleUpdateByOperator;
import org.jetbrains.annotations.NotNull;

import static io.deephaven.util.QueryConstants.NULL_DOUBLE;
import static io.deephaven.util.QueryConstants.NULL_FLOAT;

public class FloatCumSumOperator extends BaseFloatUpdateByOperator {
public class FloatCumSumOperator extends BaseDoubleUpdateByOperator {

protected class Context extends BaseFloatUpdateByOperator.Context {
protected class Context extends BaseDoubleUpdateByOperator.Context {
public FloatChunk<? extends Values> floatValueChunk;

protected Context(final int chunkSize) {
Expand All @@ -33,12 +34,10 @@ public void push(int pos, int count) {
Assert.eq(count, "push count", 1);

// read the value from the values chunk
final float currentVal = floatValueChunk.get(pos);
final float val = floatValueChunk.get(pos);

if (curVal == NULL_FLOAT) {
curVal = currentVal;
} else if (currentVal != NULL_FLOAT) {
curVal += currentVal;
if (val != NULL_FLOAT) {
curVal = curVal == NULL_DOUBLE ? val : curVal + val;
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,42 @@ static CreateResult createTestTable(int tableSize, boolean includeSym, boolean i
CollectionUtil.ZERO_LENGTH_STRING_ARRAY, new TestDataGenerator[0]);
}

static CreateResult createTestTable(
int tableSize,
boolean includeSym,
boolean includeGroups,
boolean isRefreshing,
int seed,
String[] extraNames,
TestDataGenerator[] extraGenerators) {
return createTestTable(tableSize, includeSym, includeGroups, isRefreshing, seed, extraNames, extraGenerators,
0.1);
}

@SuppressWarnings({"rawtypes"})
static CreateResult createTestTable(int tableSize,
static CreateResult createTestTableAllNull(
int tableSize,
boolean includeSym,
boolean includeGroups,
boolean isRefreshing,
int seed,
String[] extraNames,
TestDataGenerator[] extraGenerators) {

return createTestTable(tableSize, includeSym, includeGroups, isRefreshing, seed, extraNames, extraGenerators,
1.0);
}

@SuppressWarnings({"rawtypes"})
static CreateResult createTestTable(
int tableSize,
boolean includeSym,
boolean includeGroups,
boolean isRefreshing,
int seed,
String[] extraNames,
TestDataGenerator[] extraGenerators,
double nullFraction) {
if (includeGroups && !includeSym) {
throw new IllegalArgumentException();
}
Expand All @@ -68,15 +96,15 @@ static CreateResult createTestTable(int tableSize,

colsList.addAll(Arrays.asList("byteCol", "shortCol", "intCol", "longCol", "floatCol", "doubleCol", "boolCol",
"bigIntCol", "bigDecimalCol"));
generators.addAll(Arrays.asList(new ByteGenerator((byte) -127, (byte) 127, .1),
new ShortGenerator((short) -6000, (short) 65535, .1),
new IntGenerator(10, 100, .1),
new LongGenerator(10, 100, .1),
new FloatGenerator(10.1F, 20.1F, .1),
new DoubleGenerator(10.1, 20.1, .1),
new BooleanGenerator(.5, .1),
new BigIntegerGenerator(new BigInteger("-10"), new BigInteger("10"), .1),
new BigDecimalGenerator(new BigInteger("1"), new BigInteger("2"), 5, .1)));
generators.addAll(Arrays.asList(new ByteGenerator((byte) -127, (byte) 127, nullFraction),
new ShortGenerator((short) -6000, (short) 65535, nullFraction),
new IntGenerator(10, 100, nullFraction),
new LongGenerator(10, 100, nullFraction),
new FloatGenerator(10.1F, 20.1F, nullFraction),
new DoubleGenerator(10.1, 20.1, nullFraction),
new BooleanGenerator(.5, nullFraction),
new BigIntegerGenerator(new BigInteger("-10"), new BigInteger("10"), nullFraction),
new BigDecimalGenerator(new BigInteger("1"), new BigInteger("2"), 5, nullFraction)));

final Random random = new Random(seed);
final ColumnInfo[] columnInfos = initColumnInfos(colsList.toArray(CollectionUtil.ZERO_LENGTH_STRING_ARRAY),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package io.deephaven.engine.table.impl.updateby;

import io.deephaven.api.updateby.UpdateByOperation;
import io.deephaven.datastructures.util.CollectionUtil;
import io.deephaven.engine.context.ExecutionContext;
import io.deephaven.engine.table.PartitionedTable;
import io.deephaven.engine.table.Table;
Expand All @@ -12,6 +13,7 @@
import io.deephaven.engine.testutil.EvalNugget;
import io.deephaven.engine.table.impl.QueryTable;
import io.deephaven.engine.testutil.TstUtils;
import io.deephaven.engine.testutil.generator.TestDataGenerator;
import io.deephaven.function.Numeric;
import io.deephaven.test.types.OutOfBandTest;
import org.jetbrains.annotations.NotNull;
Expand Down Expand Up @@ -52,6 +54,29 @@ public void testStaticZeroKey() {
}
}

@Test
public void testStaticZeroKeyAllNulls() {
final QueryTable t = createTestTableAllNull(100000, false, false, false, 0x31313131,
CollectionUtil.ZERO_LENGTH_STRING_ARRAY, new TestDataGenerator[0]).t;

final Table result = t.updateBy(List.of(
UpdateByOperation.CumMin("byteColMin=byteCol", "shortColMin=shortCol", "intColMin=intCol",
"longColMin=longCol", "floatColMin=floatCol", "doubleColMin=doubleCol",
"bigIntColMin=bigIntCol", "bigDecimalColMin=bigDecimalCol"),
UpdateByOperation.CumMax("byteColMax=byteCol", "shortColMax=shortCol", "intColMax=intCol",
"longColMax=longCol", "floatColMax=floatCol", "doubleColMax=doubleCol",
"bigIntColMax=bigIntCol", "bigDecimalColMax=bigDecimalCol")));
for (String col : t.getDefinition().getColumnNamesArray()) {
if ("boolCol".equals(col)) {
continue;
}
assertWithCumMin(DataAccessHelpers.getColumn(t, col).getDirect(),
DataAccessHelpers.getColumn(result, col + "Min").getDirect());
assertWithCumMax(DataAccessHelpers.getColumn(t, col).getDirect(),
DataAccessHelpers.getColumn(result, col + "Max").getDirect());
}
}

// endregion

// region Bucketed Tests
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package io.deephaven.engine.table.impl.updateby;

import io.deephaven.api.updateby.UpdateByControl;
import io.deephaven.datastructures.util.CollectionUtil;
import io.deephaven.engine.context.ExecutionContext;
import io.deephaven.engine.table.PartitionedTable;
import io.deephaven.engine.table.Table;
Expand All @@ -14,6 +15,7 @@
import io.deephaven.engine.testutil.GenerateTableUpdates;
import io.deephaven.engine.testutil.EvalNugget;
import io.deephaven.engine.testutil.TstUtils;
import io.deephaven.engine.testutil.generator.TestDataGenerator;
import io.deephaven.function.Numeric;
import io.deephaven.test.types.OutOfBandTest;
import org.jetbrains.annotations.NotNull;
Expand Down Expand Up @@ -53,6 +55,21 @@ public void testStaticZeroKey() {
}
}

@Test
public void testStaticZeroKeyAllNulls() {
final QueryTable t = createTestTableAllNull(100000, false, false, false, 0x31313131,
CollectionUtil.ZERO_LENGTH_STRING_ARRAY, new TestDataGenerator[0]).t;
final Table result = t.updateBy(UpdateByOperation.CumProd());
for (String col : t.getDefinition().getColumnNamesArray()) {
if ("boolCol".equals(col)) {
continue;
}
assertWithCumProd(DataAccessHelpers.getColumn(t, col).getDirect(),
DataAccessHelpers.getColumn(result, col).getDirect(),
DataAccessHelpers.getColumn(result, col).getType());
}
}

// endregion

// region Bucketed Tests
Expand Down
Loading

0 comments on commit 25e0cb1

Please sign in to comment.