Skip to content

Commit

Permalink
MLDB-1792 enh aggregator args count err messages
Browse files Browse the repository at this point in the history
  • Loading branch information
Mich authored Jul 12, 2016
1 parent b06a16f commit 80ba348
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 89 deletions.
65 changes: 35 additions & 30 deletions sql/builtin_aggregators.cc
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
/** builtin_aggregators.cc
Jeremy Barnes, 14 June 2015
Copyright (c) 2015 Datacratic Inc. All rights reserved.
This file is part of MLDB. Copyright 2015 Datacratic. All rights reserved.
Builtin aggregators for SQL.
*/

#include "sql_expression.h"
#include "builtin_functions.h"
#include "mldb/http/http_exception.h"
#include "mldb/jml/stats/distribution.h"
#include "mldb/jml/utils/csv.h"
Expand All @@ -22,7 +21,8 @@ namespace Datacratic {
namespace MLDB {
namespace Builtins {

typedef BoundAggregator (&BuiltinAggregator) (const std::vector<BoundSqlExpression> &);
typedef BoundAggregator (&BuiltinAggregator) (const std::vector<BoundSqlExpression> &,
const string & name);

struct RegisterAggregator {
template<typename... Names>
Expand All @@ -39,12 +39,12 @@ struct RegisterAggregator {
void doRegister(BuiltinAggregator aggregator, std::string name,
Names&&... names)
{
auto fn = [&] (const Utf8String & str,
auto fn = [&aggregator, name] (const Utf8String & str,
const std::vector<BoundSqlExpression> & args,
SqlBindingScope & context)
-> BoundAggregator
{
return std::move(aggregator(args));
return std::move(aggregator(args, name));
};
handles.push_back(registerAggregator(Utf8String(name), fn));
doRegister(aggregator, std::forward<Names>(names)...);
Expand All @@ -59,20 +59,21 @@ struct AggregatorT {
an aggregator. It meets the interface used by the normal aggregator
registration functionality.
*/
static BoundAggregator entry(const std::vector<BoundSqlExpression> & args)
static BoundAggregator entry(const std::vector<BoundSqlExpression> & args,
const string & name)
{
// These take the number of arguments given in the State class
ExcAssertEqual(args.size(), State::nargs);
checkArgsSize(args.size(), State::nargs, name);
ExcAssert(args[0].info);

if (args[0].info->isRow()) {
return enterRow(args);
return enterRow(args, name);
}
else if (args[0].info->isScalar()) {
return enterScalar(args);
return enterScalar(args, name);
}
else {
return enterAmbiguous(args);
return enterAmbiguous(args, name);
}
}

Expand Down Expand Up @@ -107,7 +108,8 @@ struct AggregatorT {
/** Entry point for when we are called with the first argument as a scalar.
This does a normal SQL aggregation.
*/
static BoundAggregator enterScalar(const std::vector<BoundSqlExpression> & args)
static BoundAggregator enterScalar(const std::vector<BoundSqlExpression> & args,
const string & name)
{
return { scalarInit, scalarProcess, scalarExtract, scalarMerge, State::info(args) };
}
Expand All @@ -126,7 +128,7 @@ struct AggregatorT {

void process(const ExpressionValue * args, size_t nargs)
{
ExcAssertEqual(nargs, 1);
checkArgsSize(nargs, 1);
const ExpressionValue & val = args[0];

// This must be a row...
Expand Down Expand Up @@ -200,7 +202,7 @@ struct AggregatorT {

void process(const ExpressionValue * args, size_t nargs)
{
ExcAssertEqual(nargs, 1);
checkArgsSize(nargs, 1);

if (fallback.get()) {
fallback->process(args, nargs);
Expand Down Expand Up @@ -321,14 +323,14 @@ struct AggregatorT {
row. This does an aggregation per column in the row.
*/
static BoundAggregator
enterRow(const std::vector<BoundSqlExpression> & args)
enterRow(const std::vector<BoundSqlExpression> & args, const string & name)
{
// Analyzes the input arguments for a row, and figures out:
// a) what kind of output will be produced
// b) what is the best way to implement the query
// First output: information about the row
// Second output: is it dense (in other words, all rows are the same)?
ExcAssertEqual(args.size(), 1);
checkArgsSize(args.size(), 1, name);
ExcAssert(args[0].info);

// Create a value info object for the output. It has the same
Expand Down Expand Up @@ -403,7 +405,7 @@ struct AggregatorT {

if (!state->isDetermined) {
state->isDetermined = true;
ExcAssertEqual(nargs, 1);
checkArgsSize(nargs, 1);
state->isRow = args[0].isRow();
}

Expand Down Expand Up @@ -445,7 +447,8 @@ struct AggregatorT {
/** Entry point where we don't know whether the argument is a row or a scalar
will be determined on the first row aggregated
*/
static BoundAggregator enterAmbiguous(const std::vector<BoundSqlExpression> & args)
static BoundAggregator enterAmbiguous(const std::vector<BoundSqlExpression> & args,
const string & name)
{
return { ambiguousStateInit, ambiguousProcess, ambiguousExtract, ambiguousMerge, std::make_shared<AnyValueInfo>() };
}
Expand Down Expand Up @@ -476,7 +479,7 @@ struct AverageAccum {

void process(const ExpressionValue * args, size_t nargs)
{
ExcAssertEqual(nargs, 1);
checkArgsSize(nargs, 1);
const ExpressionValue & val = args[0];
if (val.empty())
return;
Expand Down Expand Up @@ -520,7 +523,7 @@ struct ValueAccum {

void process(const ExpressionValue * args, size_t nargs)
{
ExcAssertEqual(nargs, 1);
checkArgsSize(nargs, 1);
const ExpressionValue & val = args[0];
if (val.empty())
return;
Expand Down Expand Up @@ -561,7 +564,7 @@ struct StringAggAccum {

void process(const ExpressionValue * args, size_t nargs)
{
ExcAssertEqual(nargs, 2);
checkArgsSize(nargs, 2);
const ExpressionValue & val = args[0];
const ExpressionValue & separator = args[1];

Expand Down Expand Up @@ -630,7 +633,7 @@ struct MinMaxAccum {
void process(const ExpressionValue * args,
size_t nargs)
{
ExcAssertEqual(nargs, 1);
checkArgsSize(nargs, 1);
const ExpressionValue & val = args[0];
//cerr << "processing " << jsonEncode(val) << endl;
if (val.empty())
Expand Down Expand Up @@ -694,7 +697,7 @@ struct CountAccum {
void process (const ExpressionValue * args,
size_t nargs)
{
ExcAssertEqual(nargs, 1);
checkArgsSize(nargs, 1);
const ExpressionValue & val = args[0];
if (val.empty())
return;
Expand Down Expand Up @@ -736,7 +739,7 @@ struct DistinctAccum {
void process (const ExpressionValue * args,
size_t nargs)
{
ExcAssertEqual(nargs, 1);
checkArgsSize(nargs, 1);
const ExpressionValue & val = args[0];
if (val.empty())
return;
Expand Down Expand Up @@ -772,18 +775,19 @@ struct LikelihoodRatioAccum {
std::unordered_map<ColumnName, std::array<uint64_t, 2> > counts;
};

BoundAggregator lr(const std::vector<BoundSqlExpression> & args)
BoundAggregator lr(const std::vector<BoundSqlExpression> & args,
const string & name)
{
auto init = [] () -> std::shared_ptr<void>
{
return std::make_shared<LikelihoodRatioAccum>();
};

auto process = [] (const ExpressionValue * args,
auto process = [name] (const ExpressionValue * args,
size_t nargs,
void * data)
{
ExcAssertEqual(nargs, 2);
checkArgsSize(nargs, 2, name);
const ExpressionValue & val = args[0];
bool conv = args[1].isTrue();
LikelihoodRatioAccum & accum = *(LikelihoodRatioAccum *)data;
Expand Down Expand Up @@ -859,20 +863,21 @@ struct PivotAccum {
StructValue vals;
};

BoundAggregator pivot(const std::vector<BoundSqlExpression> & args)
BoundAggregator pivot(const std::vector<BoundSqlExpression> & args,
const string & name)
{
auto init = [] () -> std::shared_ptr<void>
{
return std::make_shared<PivotAccum>();
};

auto process = [] (const ExpressionValue * args,
auto process = [name] (const ExpressionValue * args,
size_t nargs,
void * data)
{
PivotAccum & accum = *(PivotAccum *)data;

ExcAssertEqual(nargs, 2);
checkArgsSize(nargs, 2, name);
const ExpressionValue & col = args[0];
const ExpressionValue & val = args[1];

Expand Down Expand Up @@ -920,7 +925,7 @@ struct EarliestLatestAccum {
void process(const ExpressionValue * args,
size_t nargs)
{
ExcAssertEqual(nargs, 1);
checkArgsSize(nargs, 1);
const ExpressionValue & val = args[0];
//cerr << "processing " << jsonEncode(val) << endl;
if (val.empty())
Expand Down
Loading

0 comments on commit 80ba348

Please sign in to comment.