Skip to content

Commit

Permalink
Mldb 1879 stddev variance (#591)
Browse files Browse the repository at this point in the history
* [MLDB-1879] Fixed variance/stddev in sql
* [MDLB-1879] Updated titanic notebook to reflect fix in stddev
* [MDLB-1879] PR comments
  • Loading branch information
mailletf authored Aug 4, 2016
1 parent 038d6eb commit d638e67
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 50 deletions.
69 changes: 39 additions & 30 deletions container_files/demos/Predicting Titanic Survival.ipynb

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,8 @@ The standard SQL aggregation functions operate 'vertically' down columns. MLDB d
- `vertical_count(<row>)` alias of `count()`, operates on columns.
- `vertical_sum(<row>)` alias of `sum()`, operates on columns.
- `vertical_avg(<row>)` alias of `avg()`, operates on columns.
- `vertical_stddev(<row>)` alias of `stddev()`, operates on columns.
- `vertical_variance(<row>)` alias of `variance()`, operates on columns.
- `vertical_min(<row>)` alias of `min()`, operates on columns.
- `vertical_max(<row>)` alias of `max()`, operates on columns.
- `vertical_latest(<row>)` alias of `latest()`, operates on columns.
Expand Down
38 changes: 28 additions & 10 deletions sql/builtin_aggregators.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1007,14 +1007,14 @@ struct LaterAccum {
static RegisterAggregatorT<EarliestLatestAccum<EarlierAccum> > registerEarliest("earliest", "vertical_earliest");
static RegisterAggregatorT<EarliestLatestAccum<LaterAccum> > registerLatest("latest", "vertical_latest");

struct StdAggAccum {
struct VarAccum {
static constexpr int nargs = 1;
int64_t n;
double mean;
double M2;
Date ts;

StdAggAccum() : n(0), mean(0), M2(0), ts(Date::negativeInfinity())
VarAccum() : n(0), mean(0), M2(0), ts(Date::negativeInfinity())
{
}

Expand All @@ -1040,16 +1040,21 @@ struct StdAggAccum {

ts.setMax(val.getEffectiveTimestamp());
}


double variance() const
{
if (n < 2)
return std::nan("");

return M2 / (n - 1);
}

ExpressionValue extract()
{
if (n < 2) {
return ExpressionValue(std::nan(""), ts);
}
return ExpressionValue(M2 / (n - 1), ts);
return ExpressionValue(variance(), ts);
}

void merge(StdAggAccum* src)
void merge(VarAccum* src)
{
double delta = src->mean - mean;
M2 = M2 + src->M2 + delta * delta * n * src->n / (n + src->n);
Expand All @@ -1059,8 +1064,21 @@ struct StdAggAccum {
}
};

static RegisterAggregatorT<StdAggAccum>
registerStdAgg("stddev", "vertical_stddev");
struct StdDevAccum : public VarAccum {

StdDevAccum(): VarAccum()
{
}

ExpressionValue extract()
{
return ExpressionValue(sqrt(variance()), ts);
}
};


static RegisterAggregatorT<VarAccum> registerVarAgg("variance", "vertical_variance");
static RegisterAggregatorT<StdDevAccum> registerStdDevAgg("stddev", "vertical_stddev");


} // namespace Builtins
Expand Down
33 changes: 24 additions & 9 deletions testing/stddev_builtin_fct_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#
import unittest
import random
import math

mldb = mldb_wrapper.wrap(mldb) # noqa

Expand All @@ -13,19 +14,27 @@ class StdDevBuiltinFctTest(MldbUnitTest): # noqa
@classmethod
def setUpClass(cls):
ds = mldb.create_dataset({'id' : 'ds', 'type' : 'sparse.mutable'})
ds.record_row('1', [['a', 1, 0]])
ds.record_row('2', [['a', 2, 0]])
ds.record_row('3', [['a', 3, 0]])
ds.record_row('4', [['a', 10, 0]])
ds.record_row('5', [['b', 10, 0]])
for i in xrange(100):
ds.record_row('a%d-1' % i, [['a', 1, 0]])
ds.record_row('a%d-2' % i, [['a', 2, 0]])
ds.record_row('a%d-3' % i, [['a', 3, 0]])
ds.record_row('a%d-4' % i, [['a', 10, 0]])
ds.record_row('a%d-5' % i, [['a', 10, 0]])
ds.commit()

def test_base(self):
var = 15.791583166332668
res = mldb.query("SELECT variance(a) FROM ds")[1][1]
self.assertAlmostEqual(res, var)

res = mldb.query("SELECT vertical_variance(a) FROM ds")[1][1]
self.assertAlmostEqual(res, var)

res = mldb.query("SELECT stddev(a) FROM ds")[1][1]
self.assertAlmostEqual(res, 16.6666666667)
self.assertAlmostEqual(res, math.sqrt(var))

res = mldb.query("SELECT vertical_stddev(a) FROM ds")[1][1]
self.assertAlmostEqual(res, 16.6666666667)
self.assertAlmostEqual(res, math.sqrt(var))

def test_nan(self):
ds = mldb.create_dataset({'id' : 'null_ds', 'type' : 'tabular'})
Expand All @@ -37,6 +46,12 @@ def test_nan(self):
res = mldb.query("SELECT stddev(c) FROM null_ds")
self.assertEqual(res[1][1], "NaN")

res = mldb.query("SELECT variance(b) FROM null_ds")
self.assertEqual(res[1][1], "NaN")

res = mldb.query("SELECT variance(c) FROM null_ds")
self.assertEqual(res[1][1], "NaN")

@unittest.skip("Run manually if you want numpy comparison test")
def test_random_sequences(self):
"""
Expand All @@ -59,7 +74,7 @@ def test_random_sequences(self):
mldb.log(sequence)

mldb_res = mldb.query("SELECT stddev(a) FROM rand")[1][1]
numpy_res = float(numpy.var(sequence, ddof=1))
numpy_res = float(numpy.std(sequence, ddof=1))
if (numpy_res == 0):
mldb.log("Skipping case where numpy_re == 0")
else:
Expand Down Expand Up @@ -99,7 +114,7 @@ def test_pre_generated_sequence(self):

ds.commit()
res = mldb.query("SELECT stddev(col) FROM pre_gen_seq_input")
self.assertEqual(res[1][1], 62294040173.716774)
self.assertAlmostEqual(res[1][1], 249587.74043152996)

if __name__ == '__main__':
mldb.run_tests()
2 changes: 1 addition & 1 deletion testing/summary_stats_proc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def test_it(self):
"NaN", None, None],

["colA", "number", 0, 2, 10, 4, 1, 1, 1, 10, None, None, None, 2,
27, 1, None]
5.196152422706632, 1, None]
])

def test_dottest_col_names(self):
Expand Down

0 comments on commit d638e67

Please sign in to comment.