Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

multi-aggregate and group by functionality #154

Open
wants to merge 9 commits into
base: develop
Choose a base branch
from
90 changes: 90 additions & 0 deletions core/mapper_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package core

import (
"database/sql/driver"
"math/big"
"os"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

var (
schema *Table
conn *Session
tableCache *TableCacheStruct
)

func setup() {
tableCache = NewTableCacheStruct()
schema, _ = LoadTable(tableCache, "./testdata", nil, "cities", nil)
tbuf := make(map[string]*TableBuffer, 0)
tbuf[schema.Name] = &TableBuffer{Table: schema}
conn = &Session{TableBuffers: tbuf}
}

func teardown() {
}

func TestMain(m *testing.M) {
setup()
ret := m.Run()
if ret == 0 {
teardown()
}
os.Exit(ret)
}

func TestMapperFactory(t *testing.T) {

require.NotNil(t, schema)

attr, err1 := schema.GetAttribute("military")
assert.Nil(t, err1)
mapper, err := ResolveMapper(attr)
assert.Nil(t, err)
value, err2 := mapper.MapValue(attr, true, nil, false)
assert.Nil(t, err2)
assert.Equal(t, big.NewInt(1), value)
}

func TestBuiltinMappers(t *testing.T) {

require.NotNil(t, schema)
require.NotNil(t, conn)

data := make(map[string]driver.Value)
data["id"] = driver.Value("1840034016") // id
data["name"] = driver.Value("John") // Name
data["county"] = driver.Value("King") // County
data["latitude"] = driver.Value(123.5) // Latitude
data["longitude"] = driver.Value(-365.5) // Longitude
data["population"] = driver.Value(10000) // Population
data["density"] = driver.Value(100) // Density
data["military"] = driver.Value(true) // Military
data["ranking"] = driver.Value(99) // Ranking

values := make(map[string]uint64)

table, err := LoadTable(tableCache, "./testdata", nil, "cities", nil)
assert.Nil(t, err)
if assert.NotNil(t, table) {
for k, v := range data {
if k == "longitude" { // FIXME: repair error here.
continue
}
a, err := table.GetAttribute(k)
if assert.Nil(t, err) {
value, err := a.MapValue(v, nil, false)
if assert.Nil(t, err) {
values[k] = value.Uint64()
}
}

}
assert.Equal(t, "1840034016", data["id"])
assert.Equal(t, uint64(1), values["military"])
}

}
148 changes: 148 additions & 0 deletions core/projector.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ package core
import (
"database/sql/driver"
"fmt"
"math"
"math/big"
"sort"
"strings"
"sync"
Expand Down Expand Up @@ -52,6 +54,31 @@ type BitmapFieldRow struct {
bm *roaring64.Bitmap
}

// AggregateOp identifier
type AggregateOp int

const (
// Count
COUNT AggregateOp = 1 + iota
// Sum
SUM
// Avg
AVG
// Min
MIN
// Max
MAX
)

// Aggregate
type Aggregate struct {
Table string
Field string
Op AggregateOp
Scale int
GroupIdx int // -1 = Not grouped
}

// NewProjection - Construct a Projection.
func NewProjection(s *Session, foundSets map[string]*roaring64.Bitmap, joinNames, projNames []string,
child, left string, fromTime, toTime int64, joinTypes map[string]bool, negate bool) (*Projector, error) {
Expand Down Expand Up @@ -958,6 +985,127 @@ func (p *Projector) getAggregateResult(table, field string) (result *roaring64.B
return
}

// AggregateAndGroup - Return aggregated results with optional grouping
func (p *Projector) AggregateAndGroup(aggregates []*Aggregate, groups []*Attribute) (rows [][]driver.Value,
err error) {

p.bsiResults, p.bitmapResults, err = p.retrieveBitmapResults(p.foundSets, p.projAttributes, false)
if err != nil {
return
}

rows = make([][]driver.Value, 0)
row := make([]driver.Value, len(aggregates) + len(groups))
if len(groups) == 0 {
err = p.aggregateRow(aggregates, nil, row, 0)
rows = append(rows, row)
return
}
err = p.nestedLoops(0, aggregates, groups, nil, &rows)
return
}

func (p *Projector) nestedLoops(cgrp int, aggs []*Aggregate, groups []*Attribute,
foundSet *roaring64.Bitmap, rows *[][]driver.Value) (err error) {

if cgrp == len(groups) {
return p.aggregateRow(aggs, foundSet, (*rows)[len(*rows) - 1], len(groups))
}

grAttr := groups[cgrp]
r, ok := p.bitmapResults[grAttr.Parent.Name][grAttr.FieldName]
if !ok {
return fmt.Errorf("cant find group result for %s.%s", grAttr.Parent.Name, grAttr.FieldName)
}

for _, br := range r.fieldRows {

var rs *roaring64.Bitmap
if foundSet == nil {
r, ok := p.foundSets[grAttr.Parent.Name]
if !ok {
return fmt.Errorf("cant locate foundSet for %s", grAttr.Parent.Name)
}
foundSet = r
}
rs = roaring64.And(foundSet, br.bm)
if rs.GetCardinality() == 0 {
continue
}

var row []driver.Value
if cgrp == 0 {
row = make([]driver.Value, len(aggs) + len(groups))
*rows = append(*rows, row)
} else {
row = (*rows)[len(*rows) - 1]
if row[cgrp] != nil {
newRow := make([]driver.Value, len(aggs) + len(groups))
copy(newRow, row)
*rows = append(*rows, newRow)
row = newRow
}
}
row[cgrp], err = grAttr.MapValueReverse(br.rowID, p.connection)
if err != nil {
return fmt.Errorf("nestedLoops.MapValueReverse error for field '%s' - %v", grAttr.FieldName, err)
}
err = p.nestedLoops(cgrp + 1, aggs, groups, rs, rows)
}
return
}

// generate an aggregate row. Assumes that row was initialized.
func (p *Projector) aggregateRow(aggs []*Aggregate, foundSet *roaring64.Bitmap,
row []driver.Value, startPos int) error {

// Iterate aggregate operations and generate row(s)
i := startPos - 1
for _, v := range aggs {
i++
if foundSet == nil {
r, ok := p.foundSets[v.Table]
if !ok {
return fmt.Errorf("cant locate foundSet for '%s'", v.Table)
}
foundSet = r
}
if v.Op == COUNT {
row[i] = fmt.Sprintf("%10d", foundSet.GetCardinality())
continue
}

r, errx := p.getAggregateResult(v.Table, v.Field)
if errx != nil {
return errx
}
val := new(big.Float).SetPrec(uint(v.Scale))
switch v.Op {
case SUM:
sum, _ := r.SumBigValues(foundSet)
val.SetInt(sum)
case AVG:
sum, count := r.SumBigValues(foundSet)
if count != 0 {
avg := sum.Div(sum, big.NewInt(int64(count)))
val.SetInt(avg)
}
case MIN:
minmax := r.MinMaxBig(0, roaring64.MIN, foundSet)
val.SetInt(minmax)
case MAX:
minmax := r.MinMaxBig(0, roaring64.MAX, foundSet)
val.SetInt(minmax)
}
if v.Scale > 0 {
val.Quo(val, new(big.Float).SetFloat64(math.Pow10(v.Scale)))
}
row[i] = val.Text('f', v.Scale)
}
return nil
}


// Handle boundary condition where a range of column IDs could span multiple partitions.
func (p *Projector) getPartitionedStrings(attr *Attribute, colIDs []uint64) (map[interface{}]interface{}, error) {

Expand Down
53 changes: 53 additions & 0 deletions core/table_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package core

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestLoadTable(t *testing.T) {
tcs := NewTableCacheStruct()
table, err := LoadTable(tcs, "./testdata", nil, "cities", nil)
assert.Nil(t, err)
if assert.NotNil(t, table) {
assert.NotNil(t, table.BasicTable)
assert.Equal(t, 15, len(table.Attributes))
regionList, err2 := table.GetAttribute("region_list")
assert.Nil(t, err2)
if assert.NotNil(t, regionList) {
assert.NotNil(t, regionList.MapperConfig)
assert.Equal(t, regionList.MapperConfig["delim"], ",")
assert.Equal(t, MapperTypeFromString(regionList.MappingStrategy), StringEnum)
assert.NotNil(t, regionList.mapperInstance)
}

name, err3 := table.GetAttribute("name")
assert.Nil(t, err3)
if assert.NotNil(t, name) {
assert.True(t, name.IsBSI())
}
}
}

func TestLoadTableWithPK(t *testing.T) {
tcs := NewTableCacheStruct()
table, err := LoadTable(tcs, "./testdata", nil, "cityzip", nil)
assert.Nil(t, err)
pki, err2 := table.GetPrimaryKeyInfo()
assert.Nil(t, err2)
assert.NotNil(t, pki)
assert.Equal(t, len(pki), 2)
}

func TestLoadTableWithRelation(t *testing.T) {
tcs := NewTableCacheStruct()
table, err := LoadTable(tcs, "./testdata", nil, "cityzip", nil)
assert.Nil(t, err)
fka, err2 := table.GetAttribute("city_id")
assert.Nil(t, err2)
tab, spec, err3 := fka.GetFKSpec()
assert.Nil(t, err3)
assert.NotNil(t, tab)
assert.NotNil(t, spec)
}
67 changes: 67 additions & 0 deletions source/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ func decorateRow(row []driver.Value, proj *rel.Projection, rowCols map[string]in
}
} else if strings.HasSuffix(v.As, "@rownum") {
newRow[i] = fmt.Sprintf("%d", columnID)
} else {
newRow[i] = cpyRow[i]
continue
}
if v.Col.Expr.NodeType() != "Func" {
continue
Expand Down Expand Up @@ -198,6 +201,33 @@ func outputProjection(outCh exec.MessageChan, sigChan exec.SigChan, proj *core.P
return nil
}


func outputAggregateProjection(outCh exec.MessageChan, sigChan exec.SigChan, rows [][]driver.Value,
colNames, rowCols map[string]int, pro *rel.Projection) error {

if len(rows) == 0 {
return nil
}
for i, r := range rows {
r = decorateRow(r, pro, rowCols, uint64(i))
msg := datasource.NewSqlDriverMessageMap(uint64(i), r, colNames)
select {
case _, closed := <-sigChan:
if closed {
return fmt.Errorf("timed out.")
}
return nil
default:
}
select {
case outCh <- msg:
// continue
}
}
return nil
}


func createFinalProjectionFromMaps(orig *rel.SqlSelect, aliasMap map[string]*rel.SqlSource, allTables []string,
sch *schema.Schema, driverTable string) (*rel.Projection, map[string]int, map[string]int, []string,
[]string, error) {
Expand Down Expand Up @@ -532,3 +562,40 @@ func createRowCols(ret *rel.Projection, tableMap map[string]*schema.Table, alias

return ret, rowCols
}

func outputRank(tableName, fieldName string, conn *core.Session, outCh exec.MessageChan,
sigChan exec.SigChan, results *roaring64.Bitmap, fromTime, toTime time.Time, topn int) error {

c1n := "topn_" + fieldName
c2n := "topn_count"
c3n := "topn_percent"
cn := make(map[string]int, 3)
cn[c1n] = 0
cn[c2n] = 1
cn[c3n] = 2
projFields := []string{fmt.Sprintf("%s.%s", tableName, fieldName)}
foundSet := make(map[string]*roaring64.Bitmap)
foundSet[tableName] = results
proj, err3 := core.NewProjection(conn, foundSet, nil, projFields, "", "",
fromTime.UnixNano(), toTime.UnixNano(), nil, false)
if err3 != nil {
return err3
}
rows, err4 := proj.Rank(tableName, fieldName, topn)
if err4 != nil {
return err4
}
if len(rows) == 0 {
return nil
}
for i, v := range rows {
msg := datasource.NewSqlDriverMessageMap(uint64(i), v, cn)
select {
case <-sigChan:
return nil
case outCh <- msg:
// continue
}
}
return nil
}
Loading
Loading