Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: DH-18351: Add CumCountWhere() and RollingCountWhere() features to UpdateBy #6566

Merged
merged 12 commits into from
Jan 21, 2025
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
* {@link UpdateByOperator#initializeRolling(Context, RowSet)} (Context)} for windowed operators</li>
* <li>{@link UpdateByOperator.Context#accumulateCumulative(RowSequence, Chunk[], LongChunk, int)} for cumulative
* operators or
* {@link UpdateByOperator.Context#accumulateRolling(RowSequence, Chunk[], LongChunk, LongChunk, IntChunk, IntChunk, int)}
* {@link UpdateByOperator.Context#accumulateRolling(RowSequence, Chunk[], LongChunk, LongChunk, IntChunk, IntChunk, int, int)}
* for windowed operators</li>
* <li>{@link #finishUpdate(UpdateByOperator.Context)}</li>
* </ol>
Expand Down Expand Up @@ -99,18 +99,48 @@ protected void pop(int count) {
throw new UnsupportedOperationException("pop() must be overriden by rolling operators");
}

public abstract void accumulateCumulative(RowSequence inputKeys,
/**
* For cumulative operators only, this method will be called to pass the input chunk data to the operator and
* produce the output data values.
*
* @param inputKeys the keys for the input data rows (also matches the output keys)
* @param valueChunkArr the input data chunks needed by the operator for internal calculations
* @param tsChunk the timestamp chunk for the input data (if applicable)
* @param len the number of items in the input data chunks
*/
public abstract void accumulateCumulative(
RowSequence inputKeys,
Chunk<? extends Values>[] valueChunkArr,
LongChunk<? extends Values> tsChunk,
int len);

public abstract void accumulateRolling(RowSequence inputKeys,
/**
* For windowed operators only, this method will be called to pass the input chunk data to the operator and
* produce the output data values. It is important to note that the size of the influencer (input) and affected
* (output) chunks are not likely be the same. We pass these sizes explicitly to the operators for the sake of
* the operators (such as {@link io.deephaven.engine.table.impl.updateby.countwhere.CountWhereOperator} with
* zero input columns) where no input chunks are provided but we must still process the exact number of input
* rows.
*
* @param inputKeys the keys for the input data rows (also matches the output keys)
* @param influencerValueChunkArr the input data chunks needed by the operator for internal calculations, these
* values will be pushed and popped into the current window
* @param affectedPosChunk the row positions of the affected rows
* @param influencerPosChunk the row positions of the influencer rows
* @param pushChunk a chunk containing the push instructions for each output row to be calculated
* @param popChunk a chunk containing the pop instructions for each output row to be calculated
* @param affectedCount how many affected (output) rows are being computed
* @param influencerCount how many influencer (input) rows are needed for the computation
*/
public abstract void accumulateRolling(
RowSequence inputKeys,
Chunk<? extends Values>[] influencerValueChunkArr,
LongChunk<OrderedRowKeys> affectedPosChunk,
LongChunk<OrderedRowKeys> influencerPosChunk,
IntChunk<? extends Values> pushChunk,
IntChunk<? extends Values> popChunk,
int len);
int affectedCount,
int influencerCount);
lbooker42 marked this conversation as resolved.
Show resolved Hide resolved

/**
* Write the current value for this row to the output chunk
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,21 @@
import io.deephaven.api.updateby.UpdateByControl;
import io.deephaven.api.updateby.UpdateByOperation;
import io.deephaven.api.updateby.spec.*;
import io.deephaven.base.verify.Require;
import io.deephaven.engine.rowset.RowSetFactory;
import io.deephaven.engine.table.ColumnDefinition;
import io.deephaven.engine.table.ColumnSource;
import io.deephaven.engine.table.Table;
import io.deephaven.engine.table.TableDefinition;
import io.deephaven.engine.table.impl.MatchPair;
import io.deephaven.engine.table.impl.QueryCompilerRequestProcessor;
import io.deephaven.engine.table.impl.QueryTable;
import io.deephaven.engine.table.impl.select.FormulaColumn;
import io.deephaven.engine.table.impl.select.SelectColumn;
import io.deephaven.engine.table.impl.select.WhereFilter;
import io.deephaven.engine.table.impl.sources.NullValueColumnSource;
import io.deephaven.engine.table.impl.sources.ReinterpretUtils;
import io.deephaven.engine.table.impl.updateby.countwhere.CountWhereOperator;
import io.deephaven.engine.table.impl.updateby.delta.*;
import io.deephaven.engine.table.impl.updateby.em.*;
import io.deephaven.engine.table.impl.updateby.emstd.*;
Expand Down Expand Up @@ -45,6 +54,7 @@
import java.time.Instant;
import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;

import static io.deephaven.util.BooleanUtils.NULL_BOOLEAN_AS_BYTE;
Expand Down Expand Up @@ -414,6 +424,12 @@ public Void visit(CumProdSpec cps) {
return null;
}

@Override
public Void visit(CumCountWhereSpec spec) {
ops.add(makeCountWhereOperator(tableDef, spec));
return null;
}

@Override
public Void visit(@NotNull final DeltaSpec spec) {
Arrays.stream(pairs)
Expand Down Expand Up @@ -537,6 +553,12 @@ public Void visit(@NotNull final RollingCountSpec spec) {
return null;
}

@Override
public Void visit(@NotNull final RollingCountWhereSpec spec) {
ops.add(makeCountWhereOperator(tableDef, spec));
return null;
}

@Override
public Void visit(@NotNull final RollingFormulaSpec spec) {
final boolean isTimeBased = spec.revWindowScale().isTimeBased();
Expand Down Expand Up @@ -1240,6 +1262,130 @@ private UpdateByOperator makeRollingCountOperator(@NotNull final MatchPair pair,
}
}

/**
* This is used for Cum/Rolling CountWhere operators
*/
private UpdateByOperator makeCountWhereOperator(
@NotNull final TableDefinition tableDef,
@NotNull final UpdateBySpec spec) {

Require.eqTrue(spec instanceof CumCountWhereSpec || spec instanceof RollingCountWhereSpec,
"spec instanceof CumCountWhereSpec || spec instanceof RollingCountWhereSpec");

final boolean isCumulative = spec instanceof CumCountWhereSpec;

final WhereFilter[] whereFilters = isCumulative
? WhereFilter.fromInternal(((CumCountWhereSpec) spec).filter())
: WhereFilter.fromInternal(((RollingCountWhereSpec) spec).filter());

final List<String> inputColumnNameList = new ArrayList<>();
final Map<String, Integer> inputColumnMap = new HashMap<>();
final List<int[]> filterInputColumnIndicesList = new ArrayList<>();

// Verify all the columns in the where filters are present in the dummy table and valid for use.
for (final WhereFilter whereFilter : whereFilters) {
whereFilter.init(tableDef);
if (whereFilter.isRefreshing()) {
throw new UnsupportedOperationException("CountWhere does not support refreshing filters");
}

// Compute which input sources this filter will use.
final List<String> filterColumnName = whereFilter.getColumns();
final int inputColumnCount = whereFilter.getColumns().size();
final int[] inputColumnIndices = new int[inputColumnCount];
for (int ii = 0; ii < inputColumnCount; ++ii) {
final String inputColumnName = filterColumnName.get(ii);
final int inputColumnIndex = inputColumnMap.computeIfAbsent(inputColumnName, k -> {
inputColumnNameList.add(inputColumnName);
return inputColumnNameList.size() - 1;
});
inputColumnIndices[ii] = inputColumnIndex;
}
filterInputColumnIndicesList.add(inputColumnIndices);
}

// Gather the input column type info and create a dummy table we can use to initialize filters.
final String[] inputColumnNames = inputColumnNameList.toArray(String[]::new);
final ColumnSource<?>[] originalColumnSources = new ColumnSource[inputColumnNames.length];
final ColumnSource<?>[] reinterpretedColumnSources = new ColumnSource[inputColumnNames.length];

final Map<String, ColumnSource<?>> columnSourceMap = new LinkedHashMap<>();
for (int i = 0; i < inputColumnNames.length; i++) {
final String col = inputColumnNames[i];
final ColumnDefinition<?> def = tableDef.getColumn(col);
// Create a representative column source of the correct type for the filter.
final ColumnSource<?> nullSource =
NullValueColumnSource.getInstance(def.getDataType(), def.getComponentType());
// Create a reinterpreted version of the column source.
final ColumnSource<?> maybeReinterpretedSource = ReinterpretUtils.maybeConvertToPrimitive(nullSource);
if (nullSource != maybeReinterpretedSource) {
originalColumnSources[i] = nullSource;
}
columnSourceMap.put(col, maybeReinterpretedSource);
reinterpretedColumnSources[i] = maybeReinterpretedSource;
}
final Table dummyTable = new QueryTable(RowSetFactory.empty().toTracking(), columnSourceMap);

final CountWhereOperator.CountFilter[] countFilters =
CountWhereOperator.CountFilter.createCountFilters(whereFilters, dummyTable,
filterInputColumnIndicesList);

// If any filter is ConditionFilter or ChunkFilter and uses a reinterpreted column, need to produce
// original-typed chunks.
final boolean originalChunksRequired = Arrays.asList(countFilters).stream()
.anyMatch(filter -> (filter.chunkFilter() != null || filter.conditionFilter() != null)
&& IntStream.of(filter.inputColumnIndices())
.anyMatch(i -> originalColumnSources[i] != null));

// If any filter is a standard WhereFilter or we need to produce original-typed chunks, need a chunk source
// table.
final boolean chunkSourceTableRequired = originalChunksRequired ||
Arrays.asList(countFilters).stream().anyMatch(filter -> filter.whereFilter() != null);

// Create a new column pair with the same name for the left and right columns
final String columnName = isCumulative
? ((CumCountWhereSpec) spec).column().name()
: ((RollingCountWhereSpec) spec).column().name();
final MatchPair pair = new MatchPair(columnName, columnName);

// Create and return the operator.
if (isCumulative) {
return new CountWhereOperator(
pair,
countFilters,
inputColumnNames,
originalColumnSources,
reinterpretedColumnSources,
chunkSourceTableRequired,
originalChunksRequired);
} else {
final RollingCountWhereSpec rs = (RollingCountWhereSpec) spec;

final String[] affectingColumns;
if (rs.revWindowScale().timestampCol() == null) {
affectingColumns = inputColumnNames;
} else {
affectingColumns = ArrayUtils.add(inputColumnNames, rs.revWindowScale().timestampCol());
}

final long prevWindowScaleUnits = rs.revWindowScale().getTimeScaleUnits();
final long fwdWindowScaleUnits = rs.fwdWindowScale().getTimeScaleUnits();

return new CountWhereOperator(
pair,
affectingColumns,
rs.revWindowScale().timestampCol(),
prevWindowScaleUnits,
fwdWindowScaleUnits,
countFilters,
inputColumnNames,
originalColumnSources,
reinterpretedColumnSources,
chunkSourceTableRequired,
originalChunksRequired);
}
}

private UpdateByOperator makeRollingStdOperator(@NotNull final MatchPair pair,
@NotNull final TableDefinition tableDef,
@NotNull final RollingStdSpec rs) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,8 @@ void processWindowBucketOperatorSet(final UpdateByWindowBucketContext context,
influencePosChunk,
ctx.pushChunks[affectedChunkOffset],
ctx.popChunks[affectedChunkOffset],
affectedChunkSize);
affectedChunkSize,
influencerCount);
}

affectedChunkOffset++;
Expand Down
Loading