Skip to content

Commit

Permalink
Add aggregate function count_if for relational table
Browse files Browse the repository at this point in the history
  • Loading branch information
yx-zhang authored Jan 6, 2025
1 parent eb5d1df commit 2af3d02
Show file tree
Hide file tree
Showing 6 changed files with 527 additions and 0 deletions.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedAccumulator;
import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedAvgAccumulator;
import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedCountAccumulator;
import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedCountIfAccumulator;
import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedExtremeAccumulator;
import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedFirstAccumulator;
import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedFirstByAccumulator;
Expand Down Expand Up @@ -155,6 +156,8 @@ private static GroupedAccumulator createBuiltinGroupedAccumulator(
switch (aggregationType) {
case COUNT:
return new GroupedCountAccumulator();
case COUNT_IF:
return new GroupedCountIfAccumulator();
case AVG:
return new GroupedAvgAccumulator(inputDataTypes.get(0));
case SUM:
Expand Down Expand Up @@ -207,6 +210,8 @@ public static TableAccumulator createBuiltinAccumulator(
switch (aggregationType) {
case COUNT:
return new CountAccumulator();
case COUNT_IF:
return new CountIfAccumulator();
case AVG:
return new AvgAccumulator(inputDataTypes.get(0));
case SUM:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation;

import org.apache.tsfile.block.column.Column;
import org.apache.tsfile.block.column.ColumnBuilder;
import org.apache.tsfile.file.metadata.statistics.Statistics;
import org.apache.tsfile.utils.RamUsageEstimator;

public class CountIfAccumulator implements TableAccumulator {
private static final long INSTANCE_SIZE =
RamUsageEstimator.shallowSizeOfInstance(CountIfAccumulator.class);

private long countState = 0;

@Override
public long getEstimatedSize() {
return INSTANCE_SIZE;
}

@Override
public TableAccumulator copy() {
return new CountIfAccumulator();
}

@Override
public void addInput(Column[] arguments) {
int count = arguments[0].getPositionCount();
for (int i = 0; i < count; i++) {
if (!arguments[0].isNull(i) && arguments[0].getBoolean(i)) {
countState++;
}
}
}

@Override
public void addIntermediate(Column argument) {
for (int i = 0; i < argument.getPositionCount(); i++) {
if (argument.isNull(i)) {
continue;
}
countState += argument.getLong(i);
}
}

@Override
public void evaluateIntermediate(ColumnBuilder columnBuilder) {
columnBuilder.writeLong(countState);
}

@Override
public void evaluateFinal(ColumnBuilder columnBuilder) {
columnBuilder.writeLong(countState);
}

@Override
public boolean hasFinalResult() {
return false;
}

@Override
public void removeInput(Column[] arguments) {
for (int i = 0; i < arguments[0].getPositionCount(); i++) {
if (!arguments[0].isNull(i) && arguments[0].getBoolean(i)) {
countState--;
}
}
}

@Override
public boolean removable() {
return true;
}

@Override
public void addStatistics(Statistics[] statistics) {
throw new UnsupportedOperationException("CountIfAccumulator does not support statistics");
}

@Override
public void reset() {
countState = 0;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped;

import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.array.LongBigArray;

import org.apache.tsfile.block.column.Column;
import org.apache.tsfile.block.column.ColumnBuilder;
import org.apache.tsfile.utils.RamUsageEstimator;

public class GroupedCountIfAccumulator implements GroupedAccumulator {
private final LongBigArray countValues = new LongBigArray(0L);

private static final long INSTANCE_SIZE =
RamUsageEstimator.shallowSizeOfInstance(GroupedCountIfAccumulator.class);

@Override
public long getEstimatedSize() {
return INSTANCE_SIZE + countValues.sizeOf();
}

@Override
public void setGroupCount(long groupCount) {
countValues.ensureCapacity(groupCount);
}

@Override
public void addInput(int[] groupIds, Column[] arguments) {
for (int i = 0; i < groupIds.length; i++) {
if (!arguments[0].isNull(i) && arguments[0].getBoolean(i)) {
countValues.increment(groupIds[i]);
}
}
}

@Override
public void addIntermediate(int[] groupIds, Column argument) {
for (int i = 0; i < groupIds.length; i++) {
if (!argument.isNull(i)) {
countValues.add(groupIds[i], argument.getLong(i));
}
}
}

@Override
public void evaluateIntermediate(int groupId, ColumnBuilder columnBuilder) {
columnBuilder.writeLong(countValues.get(groupId));
}

@Override
public void evaluateFinal(int groupId, ColumnBuilder columnBuilder) {
columnBuilder.writeLong(countValues.get(groupId));
}

@Override
public void prepareFinal() {}

@Override
public void reset() {
countValues.reset();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,14 @@ && isIntegerNumber(argumentTypes.get(2)))) {
"Aggregate functions [%s] should only have one argument", functionName));
}
break;
case SqlConstant.COUNT_IF:
if (argumentTypes.size() != 1 || !isBool(argumentTypes.get(0))) {
throw new SemanticException(
String.format(
"Aggregate functions [%s] should only have one boolean expression as argument",
functionName));
}
break;
case SqlConstant.FIRST_AGGREGATION:
case SqlConstant.LAST_AGGREGATION:
if (argumentTypes.size() != 2) {
Expand Down Expand Up @@ -615,6 +623,7 @@ && isIntegerNumber(argumentTypes.get(2)))) {
// get return type
switch (functionName.toLowerCase(Locale.ENGLISH)) {
case SqlConstant.COUNT:
case SqlConstant.COUNT_IF:
return INT64;
case SqlConstant.FIRST_AGGREGATION:
case SqlConstant.LAST_AGGREGATION:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
public enum TableBuiltinAggregationFunction {
SUM("sum"),
COUNT("count"),
COUNT_IF("count_if"),
AVG("avg"),
EXTREME("extreme"),
MAX("max"),
Expand Down Expand Up @@ -81,6 +82,7 @@ public static Type getIntermediateType(String name, List<Type> originalArgumentT
final String functionName = name.toLowerCase();
switch (functionName) {
case "count":
case "count_if":
return INT64;
case "sum":
return DOUBLE;
Expand Down

0 comments on commit 2af3d02

Please sign in to comment.