diff --git a/internal/function_aggregate.go b/internal/function_aggregate.go index 8095999..a983741 100644 --- a/internal/function_aggregate.go +++ b/internal/function_aggregate.go @@ -81,32 +81,20 @@ func (f *ARRAY_AGG) Step(v Value, opt *AggregatorOption) error { func (f *ARRAY_AGG) Done() (Value, error) { if f.opt != nil && len(f.opt.OrderBy) != 0 { - for orderBy := 0; orderBy < len(f.opt.OrderBy); orderBy++ { - if f.opt.OrderBy[orderBy].IsAsc { - sort.Slice(f.values, func(i, j int) bool { - if f.values[i].OrderBy[orderBy].Value == nil { - return true - } - - if f.values[j].OrderBy[orderBy].Value == nil { - return false - } - v, _ := f.values[i].OrderBy[orderBy].Value.LT(f.values[j].OrderBy[orderBy].Value) - return v - }) - } else { - sort.Slice(f.values, func(i, j int) bool { - if f.values[i].OrderBy[orderBy].Value == nil { - return true - } - if f.values[j].OrderBy[orderBy].Value == nil { - return false - } - v, _ := f.values[i].OrderBy[orderBy].Value.GT(f.values[j].OrderBy[orderBy].Value) - return v - }) + sort.Slice(f.values, func(i, j int) bool { + for orderBy := 0; orderBy < len(f.opt.OrderBy); orderBy++ { + isAsc := f.opt.OrderBy[orderBy].IsAsc + result, areEqual, _ := shouldComeBefore( + f.values[i].OrderBy[orderBy].Value, + f.values[j].OrderBy[orderBy].Value, + isAsc, + ) + if !areEqual { + return result + } } - } + return false + }) } if f.opt != nil && f.opt.Limit != nil { minLen := int64(len(f.values)) @@ -146,31 +134,20 @@ func (f *ARRAY_CONCAT_AGG) Step(v *ArrayValue, opt *AggregatorOption) error { func (f *ARRAY_CONCAT_AGG) Done() (Value, error) { if f.opt != nil && len(f.opt.OrderBy) != 0 { - for orderBy := 0; orderBy < len(f.opt.OrderBy); orderBy++ { - if f.opt.OrderBy[orderBy].IsAsc { - sort.Slice(f.values, func(i, j int) bool { - if f.values[i].OrderBy[orderBy].Value == nil { - return true - } - if f.values[j].OrderBy[orderBy].Value == nil { - return false - } - v, _ := f.values[i].OrderBy[orderBy].Value.LT(f.values[j].OrderBy[orderBy].Value) - return v - }) - } else { - sort.Slice(f.values, func(i, j int) bool { - if f.values[i].OrderBy[orderBy].Value == nil { - return true - } - if f.values[j].OrderBy[orderBy].Value == nil { - return false - } - v, _ := f.values[i].OrderBy[orderBy].Value.GT(f.values[j].OrderBy[orderBy].Value) - return v - }) + sort.Slice(f.values, func(i, j int) bool { + for orderBy := 0; orderBy < len(f.opt.OrderBy); orderBy++ { + isAsc := f.opt.OrderBy[orderBy].IsAsc + result, areEqual, _ := shouldComeBefore( + f.values[i].OrderBy[orderBy].Value, + f.values[j].OrderBy[orderBy].Value, + isAsc, + ) + if !areEqual { + return result + } } - } + return false + }) } if f.opt != nil && f.opt.Limit != nil { minLen := int64(len(f.values)) @@ -491,19 +468,20 @@ func (f *STRING_AGG) Step(v Value, delim string, opt *AggregatorOption) error { func (f *STRING_AGG) Done() (Value, error) { if f.opt != nil && len(f.opt.OrderBy) != 0 { - for orderBy := 0; orderBy < len(f.opt.OrderBy); orderBy++ { - if f.opt.OrderBy[orderBy].IsAsc { - sort.Slice(f.values, func(i, j int) bool { - v, _ := f.values[i].OrderBy[orderBy].Value.LT(f.values[j].OrderBy[orderBy].Value) - return v - }) - } else { - sort.Slice(f.values, func(i, j int) bool { - v, _ := f.values[i].OrderBy[orderBy].Value.GT(f.values[j].OrderBy[orderBy].Value) - return v - }) + sort.Slice(f.values, func(i, j int) bool { + for orderBy := 0; orderBy < len(f.opt.OrderBy); orderBy++ { + isAsc := f.opt.OrderBy[orderBy].IsAsc + result, areEqual, _ := shouldComeBefore( + f.values[i].OrderBy[orderBy].Value, + f.values[j].OrderBy[orderBy].Value, + isAsc, + ) + if !areEqual { + return result + } } - } + return false + }) } if f.opt != nil && f.opt.Limit != nil { minLen := int64(len(f.values)) diff --git a/internal/function_window_option.go b/internal/function_window_option.go index 19da038..0bb5157 100644 --- a/internal/function_window_option.go +++ b/internal/function_window_option.go @@ -406,31 +406,22 @@ func (s *WindowFuncAggregatedStatus) Done(cb func([]Value, int, int) error) erro sortedValues := make([]*WindowOrderedValue, len(values)) copy(sortedValues, values) if len(sortedValues) != 0 { - for orderBy := 0; orderBy < len(sortedValues[0].OrderBy); orderBy++ { - isAsc := sortedValues[0].OrderBy[orderBy].IsAsc - if isAsc { - sort.Slice(sortedValues, func(i, j int) bool { - if sortedValues[i].OrderBy[orderBy].Value == nil { - return true + orderByObjects := sortedValues[0].OrderBy + if orderByObjects != nil && len(orderByObjects) != 0 { + sort.Slice(sortedValues, func(i, j int) bool { + for orderBy := 0; orderBy < len(orderByObjects); orderBy++ { + isAsc := orderByObjects[orderBy].IsAsc + result, areEqual, _ := shouldComeBefore( + sortedValues[i].OrderBy[orderBy].Value, + sortedValues[j].OrderBy[orderBy].Value, + isAsc, + ) + if !areEqual { + return result } - if sortedValues[j].OrderBy[orderBy].Value == nil { - return false - } - cond, _ := sortedValues[i].OrderBy[orderBy].Value.LT(sortedValues[j].OrderBy[orderBy].Value) - return cond - }) - } else { - sort.Slice(sortedValues, func(i, j int) bool { - if sortedValues[i].OrderBy[orderBy].Value == nil { - return true - } - if sortedValues[j].OrderBy[orderBy].Value == nil { - return false - } - cond, _ := sortedValues[i].OrderBy[orderBy].Value.GT(sortedValues[j].OrderBy[orderBy].Value) - return cond - }) - } + } + return false + }) } } s.SortedValues = sortedValues diff --git a/internal/util.go b/internal/util.go index 378bb49..92f796a 100644 --- a/internal/util.go +++ b/internal/util.go @@ -82,3 +82,42 @@ func modifyTimeZone(t time.Time, loc *time.Location) (time.Time, error) { func timeFromUnixNano(unixNano int64) time.Time { return time.Unix(0, unixNano) } + +// checkOrderBy checks if value2 should come before value1 (first bool value). +// Second bool is to determine if values are equal +func shouldComeBefore(value1, value2 Value, isAscending bool) (bool, bool, error) { + + if value1 == nil && value2 == nil { + return false, true, nil + } + + if value1 == nil { + return true, false, nil + } + if value2 == nil { + return false, false, nil + } + + v, err := value1.EQ(value2) + if err != nil { + return false, false, err + } + + if v { + return false, true, nil + } + + if isAscending { + v, err = value1.LT(value2) + if err != nil { + return false, false, err + } + return v, false, nil + } + + v, err = value1.GT(value2) + if err != nil { + return false, false, err + } + return v, false, nil +}