Skip to content

Commit

Permalink
Merge branch 'feature/gob' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
ostafen committed May 10, 2022
2 parents 6b6eac2 + c271ca9 commit e50ea8c
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 29 deletions.
65 changes: 37 additions & 28 deletions criteria.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package clover

import (
"regexp"
"strings"

"github.com/ostafen/clover/encoding"
)
Expand Down Expand Up @@ -62,13 +63,12 @@ func (f *field) IsNilOrNotExists() *Criteria {
}

func (f *field) Eq(value interface{}) *Criteria {
normalizedValue, err := encoding.Normalize(value)
if err != nil {
return &falseCriteria
}

return &Criteria{
p: func(doc *Document) bool {
normalizedValue, err := encoding.Normalize(getFieldOrValue(doc, value))
if err != nil {
return false
}
if !doc.Has(f.name) {
return false
}
Expand All @@ -78,52 +78,48 @@ func (f *field) Eq(value interface{}) *Criteria {
}

func (f *field) Gt(value interface{}) *Criteria {
normValue, err := encoding.Normalize(value)
if err != nil {
return &falseCriteria
}

return &Criteria{
p: func(doc *Document) bool {
normValue, err := encoding.Normalize(getFieldOrValue(doc, value))
if err != nil {
return false
}
return compareValues(doc.Get(f.name), normValue) > 0
},
}
}

func (f *field) GtEq(value interface{}) *Criteria {
normValue, err := encoding.Normalize(value)
if err != nil {
return &falseCriteria
}

return &Criteria{
p: func(doc *Document) bool {
normValue, err := encoding.Normalize(getFieldOrValue(doc, value))
if err != nil {
return false
}
return compareValues(doc.Get(f.name), normValue) >= 0
},
}
}

func (f *field) Lt(value interface{}) *Criteria {
normValue, err := encoding.Normalize(value)
if err != nil {
return &falseCriteria
}

return &Criteria{
p: func(doc *Document) bool {
normValue, err := encoding.Normalize(getFieldOrValue(doc, value))
if err != nil {
return false
}
return compareValues(doc.Get(f.name), normValue) < 0
},
}
}

func (f *field) LtEq(value interface{}) *Criteria {
normValue, err := encoding.Normalize(value)
if err != nil {
return &falseCriteria
}

return &Criteria{
p: func(doc *Document) bool {
normValue, err := encoding.Normalize(getFieldOrValue(doc, value))
if err != nil {
return false
}
return compareValues(doc.Get(f.name), normValue) <= 0
},
}
Expand All @@ -142,8 +138,9 @@ func (f *field) In(values ...interface{}) *Criteria {
return &Criteria{
p: func(doc *Document) bool {
docValue := doc.Get(f.name)
for _, value := range normValues.([]interface{}) {
if compareValues(value, docValue) == 0 {
for _, v := range values {
normValue, err := encoding.Normalize(getFieldOrValue(doc, v))
if err == nil && compareValues(normValue, docValue) == 0 {
return true
}
}
Expand All @@ -164,7 +161,7 @@ func (f *field) Contains(elems ...interface{}) *Criteria {

for _, elem := range elems {
found := false
normElem, err := encoding.Normalize(elem)
normElem, err := encoding.Normalize(getFieldOrValue(doc, elem))

if err == nil {
for _, val := range slice {
Expand Down Expand Up @@ -241,3 +238,15 @@ func (c *Criteria) Not() *Criteria {
p: negatePredicate(c.p),
}
}

// getFieldOrValue returns dereferenced value if value denotes another document field,
// otherwise returns the value itself directly
func getFieldOrValue(doc *Document, value interface{}) interface{} {
if cmpField, ok := value.(*field); ok {
value = doc.Get(cmpField.name)
} else if fStr, ok := value.(string); ok && strings.HasPrefix(fStr, "$") {
fieldName := strings.TrimLeft(fStr, "$")
value = doc.Get(fieldName)
}
return value
}
2 changes: 1 addition & 1 deletion db.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ func (db *DB) Save(collectionName string, doc *Document) error {
// InsertOne inserts a single document to an existing collection. It returns the id of the inserted document.
func (db *DB) InsertOne(collectionName string, doc *Document) (string, error) {
err := db.Insert(collectionName, doc)
return doc.Get(objectIdField).(string), err
return doc.ObjectId(), err
}

// Open opens a new clover database on the supplied path. If such a folder doesn't exist, it is automatically created.
Expand Down
38 changes: 38 additions & 0 deletions db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -688,6 +688,17 @@ func TestInCriteria(t *testing.T) {
require.Fail(t, "userId is not in the correct range")
}
}

criteria := c.Field("userId").In(c.Field("id"), 6)
docs, err = db.Query("todos").Where(criteria).FindAll()
require.NoError(t, err)

require.Greater(t, len(docs), 0)
for _, doc := range docs {
userId := doc.Get("userId").(int64)
id := doc.Get("id").(uint64)
require.True(t, uint64(userId) == id || userId == 6)
}
})
}

Expand Down Expand Up @@ -1327,3 +1338,30 @@ func TestCompareObjects3(t *testing.T) {
require.Equal(t, docs[0].Get("data.SomeString"), "aStr")
})
}

func TestCompareDocumentFields(t *testing.T) {
runCloverTest(t, airlinesPath, nil, func(t *testing.T, db *c.DB) {
criteria := c.Field("Statistics.Flights.Diverted").Gt(c.Field("Statistics.Flights.Cancelled"))
docs, err := db.Query("airlines").Where(criteria).FindAll()
require.NoError(t, err)

require.Greater(t, len(docs), 0)
for _, doc := range docs {
diverted := doc.Get("Statistics.Flights.Diverted").(float64)
cancelled := doc.Get("Statistics.Flights.Cancelled").(float64)
require.Greater(t, diverted, cancelled)
}

//alternative syntax using $
criteria = c.Field("Statistics.Flights.Diverted").Gt("$Statistics.Flights.Cancelled")
docs, err = db.Query("airlines").Where(criteria).FindAll()
require.NoError(t, err)

require.Greater(t, len(docs), 0)
for _, doc := range docs {
diverted := doc.Get("Statistics.Flights.Diverted").(float64)
cancelled := doc.Get("Statistics.Flights.Cancelled").(float64)
require.Greater(t, diverted, cancelled)
}
})
}

0 comments on commit e50ea8c

Please sign in to comment.