diff --git a/cache/cache.go b/cache/cache.go index 1bb52a8e..d8add17f 100644 --- a/cache/cache.go +++ b/cache/cache.go @@ -1,8 +1,13 @@ package cache import ( + "bytes" + "crypto/sha256" + "encoding/gob" + "encoding/hex" "fmt" "reflect" + "strings" "sync" "log" @@ -13,37 +18,233 @@ import ( ) const ( - updateEvent = "update" - addEvent = "add" - deleteEvent = "delete" - bufferSize = 65536 + updateEvent = "update" + addEvent = "add" + deleteEvent = "delete" + bufferSize = 65536 + columnDelimiter = "," ) +type IndexExistsError struct { + table string + value interface{} + column string + new string + existing string +} + +func (i *IndexExistsError) Error() string { + return fmt.Sprintf("operation would cause rows in the \"%s\" table to have identical values (%v) for index on column \"%s\". First row, with UUID %s, was inserted by this transaction. Second row, with UUID %s, existed in the database before this operation and was not modified", + i.table, + i.value, + i.column, + i.new, + i.existing, + ) +} + +func NewIndexExistsError(table string, value interface{}, column, new, existing string) *IndexExistsError { + return &IndexExistsError{ + table, value, column, new, existing, + } +} + +// map of unique values to uuids +type valueToUUID map[interface{}]string + +// map of column name(s) to a unique values, to UUIDs +type columnToValue map[string]valueToUUID + // RowCache is a collections of Models hashed by UUID type RowCache struct { - cache map[string]model.Model - mutex sync.RWMutex + name string + schema ovsdb.TableSchema + dataType reflect.Type + cache map[string]model.Model + indexes columnToValue + mutex sync.RWMutex } // Row returns one model from the cache by UUID func (r *RowCache) Row(uuid string) model.Model { r.mutex.RLock() defer r.mutex.RUnlock() + return r.row(uuid) +} + +func (r *RowCache) row(uuid string) model.Model { if row, ok := r.cache[uuid]; ok { return row.(model.Model) } return nil } -// Set writes the provided content to the cache -// WARNING: Do not use Set outside of testing -// as it may case cache corruption if, for example, -// you write a model.Model that isn't part of the -// model.DBModel -func (r *RowCache) Set(uuid string, m model.Model) { +// Create writes the provided content to the cache +func (r *RowCache) Create(uuid string, m model.Model) error { + r.mutex.Lock() + defer r.mutex.Unlock() + return r.create(uuid, m) +} + +func (r *RowCache) create(uuid string, m model.Model) error { + if _, ok := r.cache[uuid]; ok { + return fmt.Errorf("row %s already exists", uuid) + } + if reflect.TypeOf(m) != r.dataType { + return fmt.Errorf("expected data of type %s, but got %s", r.dataType.String(), reflect.TypeOf(m).String()) + } + info, err := mapper.NewMapperInfo(&r.schema, m) + if err != nil { + return err + } + newIndexes := newColumnToValue(r.schema.Indexes) + var errs []error + for columnStr := range r.indexes { + columns := strings.Split(columnStr, columnDelimiter) + var val interface{} + var err error + if len(columns) > 1 { + val, err = hashColumnValues(info, columns) + if err != nil { + return err + } + } else { + column := columns[0] + val, err = info.FieldByColumn(column) + if err != nil { + return err + } + } + if existing, ok := r.indexes[columnStr][val]; ok { + errs = append(errs, + NewIndexExistsError(r.name, val, columnStr, uuid, existing)) + } + newIndexes[columnStr][val] = uuid + } + if len(errs) != 0 { + return fmt.Errorf("%v", errs) + } + // write indexes + for k1, v1 := range newIndexes { + for k2, v2 := range v1 { + r.indexes[k1][k2] = v2 + } + } + r.cache[uuid] = m + return nil +} + +// Update updates the content in the cache +func (r *RowCache) Update(uuid string, m model.Model) error { r.mutex.Lock() defer r.mutex.Unlock() + return r.update(uuid, m) +} + +func (r *RowCache) update(uuid string, m model.Model) error { + if _, ok := r.cache[uuid]; !ok { + return fmt.Errorf("row %s does not exist", uuid) + } + oldRow := r.cache[uuid] + oldInfo, err := mapper.NewMapperInfo(&r.schema, oldRow) + if err != nil { + return err + } + newInfo, err := mapper.NewMapperInfo(&r.schema, m) + if err != nil { + return err + } + newIndexes := newColumnToValue(r.schema.Indexes) + oldIndexes := newColumnToValue(r.schema.Indexes) + var errs []error + for columnStr := range r.indexes { + columns := strings.Split(columnStr, columnDelimiter) + var oldVal interface{} + var newVal interface{} + var err error + if len(columns) > 1 { + oldVal, err = hashColumnValues(oldInfo, columns) + if err != nil { + return err + } + newVal, err = hashColumnValues(newInfo, columns) + if err != nil { + return err + } + } else { + column := columns[0] + oldVal, err = oldInfo.FieldByColumn(column) + if err != nil { + return err + } + newVal, err = newInfo.FieldByColumn(column) + if err != nil { + return err + } + } + // if old and new values are the same, don't worry + if oldVal == newVal { + continue + } + // old and new values are NOT the same + + // check that there are no conflicts + if conflict, ok := r.indexes[columnStr][newVal]; ok && conflict != uuid { + errs = append(errs, NewIndexExistsError( + r.name, + newVal, + columnStr, + uuid, + conflict, + )) + } + newIndexes[columnStr][newVal] = uuid + oldIndexes[columnStr][oldVal] = "" + } + if len(errs) > 0 { + return fmt.Errorf("%+v", errs) + } + // write indexes + for k1, v1 := range newIndexes { + for k2, v2 := range v1 { + r.indexes[k1][k2] = v2 + } + } + // delete old indexes + for k1, v1 := range oldIndexes { + for k2 := range v1 { + delete(r.indexes[k1], k2) + } + } r.cache[uuid] = m + return nil +} + +// Delete deletes a row from the cache +func (r *RowCache) Delete(uuid string) error { + r.mutex.Lock() + defer r.mutex.Unlock() + return r.delete(uuid) +} + +func (r *RowCache) delete(uuid string) error { + if _, ok := r.cache[uuid]; !ok { + return fmt.Errorf("row %s does not exist", uuid) + } + oldRow := r.cache[uuid] + oldInfo, err := mapper.NewMapperInfo(&r.schema, oldRow) + if err != nil { + return err + } + for column := range r.indexes { + oldVal, err := oldInfo.FieldByColumn(column) + if err != nil { + return err + } + delete(r.indexes[column], oldVal) + } + delete(r.cache, uuid) + return nil } // Rows returns a list of row UUIDs as strings @@ -64,16 +265,14 @@ func (r *RowCache) Len() int { return len(r.cache) } -// NewRowCache creates a new row cache with the provided data -// if the data is nil, and empty RowCache will be created -func NewRowCache(data map[string]model.Model) *RowCache { - if data == nil { - data = make(map[string]model.Model) - } - return &RowCache{ - cache: data, - mutex: sync.RWMutex{}, +func (r *RowCache) Index(column string) (map[interface{}]string, error) { + r.mutex.RLock() + defer r.mutex.RUnlock() + index, ok := r.indexes[column] + if !ok { + return nil, fmt.Errorf("%s is not an index", column) } + return index, nil } // EventHandler can handle events when the contents of the cache changes @@ -116,20 +315,37 @@ func (e *EventHandlerFuncs) OnDelete(table string, row model.Model) { // and an array of EventHandlers that respond to cache updates type TableCache struct { cache map[string]*RowCache - cacheMutex sync.RWMutex eventProcessor *eventProcessor mapper *mapper.Mapper dbModel *model.DBModel } +// CacheData is the type for data that can be prepoulated in the cache +type CacheData map[string]map[string]model.Model + // NewTableCache creates a new TableCache -func NewTableCache(schema *ovsdb.DatabaseSchema, dbModel *model.DBModel) (*TableCache, error) { +func NewTableCache(schema *ovsdb.DatabaseSchema, dbModel *model.DBModel, data CacheData) (*TableCache, error) { if schema == nil || dbModel == nil { return nil, fmt.Errorf("tablecache without databasemodel cannot be populated") } eventProcessor := newEventProcessor(bufferSize) + cache := make(map[string]*RowCache) + tableTypes := dbModel.Types() + for name, tableSchema := range schema.Tables { + cache[name] = newRowCache(name, tableSchema, tableTypes[name]) + } + for table, rowData := range data { + if _, ok := schema.Tables[table]; !ok { + return nil, fmt.Errorf("table %s is not in schema", table) + } + for uuid, row := range rowData { + if err := cache[table].Create(uuid, row); err != nil { + return nil, err + } + } + } return &TableCache{ - cache: make(map[string]*RowCache), + cache: cache, eventProcessor: eventProcessor, mapper: mapper.NewMapper(schema), dbModel: dbModel, @@ -148,30 +364,14 @@ func (t *TableCache) DBModel() *model.DBModel { // Table returns the a Table from the cache with a given name func (t *TableCache) Table(name string) *RowCache { - t.cacheMutex.RLock() - defer t.cacheMutex.RUnlock() if table, ok := t.cache[name]; ok { return table } return nil } -// Set write the provided RowCache to the provided table name in the cache -// if the provided cache is nil, we'll initialize a new one -// WARNING: Do not use Set outside of testing -func (t *TableCache) Set(name string, rc *RowCache) { - if rc == nil { - rc = NewRowCache(nil) - } - t.cacheMutex.Lock() - defer t.cacheMutex.Unlock() - t.cache[name] = rc -} - // Tables returns a list of table names that are in the cache func (t *TableCache) Tables() []string { - t.cacheMutex.RLock() - defer t.cacheMutex.RUnlock() var result []string for k := range t.cache { result = append(result, k) @@ -204,30 +404,41 @@ func (t *TableCache) Echo([]interface{}) { func (t *TableCache) Disconnected() { } +// lock acquires a lock on all tables in the cache +func (t *TableCache) lock() { + for _, r := range t.cache { + r.mutex.Lock() + } +} + +// unlock releases a lock on all tables in the cache +func (t *TableCache) unlock() { + for _, r := range t.cache { + r.mutex.Unlock() + } +} + // Populate adds data to the cache and places an event on the channel func (t *TableCache) Populate(tableUpdates ovsdb.TableUpdates) { - t.cacheMutex.Lock() - defer t.cacheMutex.Unlock() + t.lock() + defer t.unlock() for table := range t.dbModel.Types() { updates, ok := tableUpdates[table] if !ok { continue } - var tCache *RowCache - if tCache, ok = t.cache[table]; !ok { - t.cache[table] = NewRowCache(nil) - tCache = t.cache[table] - } - tCache.mutex.Lock() + tCache := t.cache[table] for uuid, row := range updates { if row.New != nil { newModel, err := t.CreateModel(table, row.New, uuid) if err != nil { panic(err) } - if existing, ok := tCache.cache[uuid]; ok { + if existing := tCache.row(uuid); existing != nil { if !reflect.DeepEqual(newModel, existing) { - tCache.cache[uuid] = newModel + if err := tCache.update(uuid, newModel); err != nil { + panic(err) + } oldModel, err := t.CreateModel(table, row.Old, uuid) if err != nil { panic(err) @@ -237,7 +448,9 @@ func (t *TableCache) Populate(tableUpdates ovsdb.TableUpdates) { // no diff continue } - tCache.cache[uuid] = newModel + if err := tCache.create(uuid, newModel); err != nil { + panic(err) + } t.eventProcessor.AddEvent(addEvent, table, nil, newModel) continue } else { @@ -245,13 +458,13 @@ func (t *TableCache) Populate(tableUpdates ovsdb.TableUpdates) { if err != nil { panic(err) } - // delete from cache - delete(tCache.cache, uuid) + if err := tCache.delete(uuid); err != nil { + panic(err) + } t.eventProcessor.AddEvent(deleteEvent, table, oldModel, nil) continue } } - tCache.mutex.Unlock() } } @@ -265,7 +478,37 @@ func (t *TableCache) Run(stopCh <-chan struct{}) { t.eventProcessor.Run(stopCh) } -// event encapsualtes a cache event +// newRowCache creates a new row cache with the provided data +// if the data is nil, and empty RowCache will be created +func newRowCache(name string, schema ovsdb.TableSchema, dataType reflect.Type) *RowCache { + r := &RowCache{ + name: name, + schema: schema, + indexes: newColumnToValue(schema.Indexes), + dataType: dataType, + cache: make(map[string]model.Model), + mutex: sync.RWMutex{}, + } + return r +} + +func newColumnToValue(schemaIndexes [][]string) columnToValue { + // RFC 7047 says that Indexes is a [] and "Each is a set of + // columns whose values, taken together within any given row, must be + // unique within the table". We'll store the column names, separated by comma + // as we'll assuume (RFC is not clear), that comma isn't valid in a + var indexes []string + for i := range schemaIndexes { + indexes = append(indexes, strings.Join(schemaIndexes[i], columnDelimiter)) + } + c := make(columnToValue) + for _, index := range indexes { + c[index] = make(valueToUUID) + } + return c +} + +// event encapsulates a cache event type event struct { eventType string table string @@ -373,3 +616,21 @@ func (t *TableCache) CreateModel(tableName string, row *ovsdb.Row, uuid string) return model, nil } + +func hashColumnValues(info *mapper.MapperInfo, columns []string) (string, error) { + var buf bytes.Buffer + enc := gob.NewEncoder(&buf) + for _, column := range columns { + val, err := info.FieldByColumn(column) + if err != nil { + return "", err + } + err = enc.Encode(val) + if err != nil { + return "", err + } + } + h := sha256.New() + val := hex.EncodeToString(h.Sum(buf.Bytes())) + return val, nil +} diff --git a/cache/cache_test.go b/cache/cache_test.go index 91205df7..5ae41fee 100644 --- a/cache/cache_test.go +++ b/cache/cache_test.go @@ -5,14 +5,17 @@ import ( "encoding/json" + "github.com/ovn-org/libovsdb/mapper" "github.com/ovn-org/libovsdb/model" "github.com/ovn-org/libovsdb/ovsdb" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) type testModel struct { UUID string `ovs:"_uuid"` Foo string `ovs:"foo"` + Bar string `ovs:"bar"` } func TestRowCache_Row(t *testing.T) { @@ -79,6 +82,375 @@ func TestRowCache_Rows(t *testing.T) { } } +func TestRowCacheCreate(t *testing.T) { + var schema ovsdb.DatabaseSchema + db, err := model.NewDBModel("Open_vSwitch", map[string]model.Model{"Open_vSwitch": &testModel{}}) + require.Nil(t, err) + err = json.Unmarshal([]byte(` + {"name": "TestDB", + "tables": { + "Open_vSwitch": { + "indexes": [["foo"]], + "columns": { + "foo": { + "type": "string" + }, + "bar": { + "type": "string" + } + } + } + } + } + `), &schema) + require.Nil(t, err) + testData := CacheData{ + "Open_vSwitch": map[string]model.Model{"bar": &testModel{Foo: "bar"}}, + } + tc, err := NewTableCache(&schema, db, testData) + require.Nil(t, err) + + tests := []struct { + name string + uuid string + model *testModel + wantErr bool + }{ + { + "inserts a new row", + "foo", + &testModel{Foo: "foo"}, + false, + }, + { + "error duplicate uuid", + "bar", + &testModel{Foo: "foo"}, + true, + }, + { + "error duplicate index", + "baz", + &testModel{Foo: "bar"}, + true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rc := tc.Table("Open_vSwitch") + require.NotNil(t, rc) + err := rc.Create(tt.uuid, tt.model) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.Nil(t, err) + assert.Equal(t, tt.uuid, rc.indexes["foo"][tt.model.Foo]) + } + }) + } +} + +func TestRowCacheCreateMultiIndex(t *testing.T) { + var schema ovsdb.DatabaseSchema + db, err := model.NewDBModel("Open_vSwitch", map[string]model.Model{"Open_vSwitch": &testModel{}}) + require.Nil(t, err) + err = json.Unmarshal([]byte(` + {"name": "TestDB", + "tables": { + "Open_vSwitch": { + "indexes": [["foo", "bar"]], + "columns": { + "foo": { + "type": "string" + }, + "bar": { + "type": "string" + } + } + } + } + } + `), &schema) + require.Nil(t, err) + tSchema := schema.Table("Open_vSwitch") + testData := CacheData{ + "Open_vSwitch": map[string]model.Model{"bar": &testModel{Foo: "bar", Bar: "bar"}}, + } + tc, err := NewTableCache(&schema, db, testData) + require.Nil(t, err) + tests := []struct { + name string + uuid string + model *testModel + wantErr bool + }{ + { + "inserts a new row", + "foo", + &testModel{Foo: "foo", Bar: "foo"}, + false, + }, + { + "error duplicate uuid", + "bar", + &testModel{Foo: "bar", Bar: "bar"}, + true, + }, + { + "error duplicate index", + "baz", + &testModel{Foo: "foo", Bar: "foo"}, + true, + }, + { + "new row with one duplicate value", + "baz", + &testModel{Foo: "foo", Bar: "bar"}, + false, + }, + { + "new row with other duplicate value", + "quux", + &testModel{Foo: "bar", Bar: "baz"}, + false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rc := tc.Table("Open_vSwitch") + require.NotNil(t, rc) + err := rc.Create(tt.uuid, tt.model) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.Nil(t, err) + mapperInfo, err := mapper.NewMapperInfo(tSchema, tt.model) + require.Nil(t, err) + h, err := hashColumnValues(mapperInfo, []string{"foo", "bar"}) + require.Nil(t, err) + assert.Equal(t, tt.uuid, rc.indexes["foo,bar"][h]) + } + }) + } +} + +func TestRowCacheUpdate(t *testing.T) { + var schema ovsdb.DatabaseSchema + db, err := model.NewDBModel("Open_vSwitch", map[string]model.Model{"Open_vSwitch": &testModel{}}) + require.Nil(t, err) + err = json.Unmarshal([]byte(` + {"name": "TestDB", + "tables": { + "Open_vSwitch": { + "indexes": [["foo"]], + "columns": { + "foo": { + "type": "string" + }, + "bar": { + "type": "string" + } + } + } + } + } + `), &schema) + require.Nil(t, err) + testData := CacheData{ + "Open_vSwitch": map[string]model.Model{ + "bar": &testModel{Foo: "bar"}, + "foobar": &testModel{Foo: "foobar"}, + }, + } + tc, err := NewTableCache(&schema, db, testData) + require.Nil(t, err) + + tests := []struct { + name string + uuid string + model *testModel + wantErr bool + }{ + { + "error if row does not exist", + "foo", + &testModel{Foo: "foo"}, + true, + }, + { + "update", + "bar", + &testModel{Foo: "baz"}, + false, + }, + { + "error new index would cause duplicate", + "baz", + &testModel{Foo: "foobar"}, + true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rc := tc.Table("Open_vSwitch") + require.NotNil(t, rc) + err := rc.Update(tt.uuid, tt.model) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.Nil(t, err) + assert.Equal(t, tt.uuid, rc.indexes["foo"][tt.model.Foo]) + } + }) + } +} + +func TestRowCacheUpdateMultiIndex(t *testing.T) { + var schema ovsdb.DatabaseSchema + db, err := model.NewDBModel("Open_vSwitch", map[string]model.Model{"Open_vSwitch": &testModel{}}) + require.Nil(t, err) + err = json.Unmarshal([]byte(` + {"name": "TestDB", + "tables": { + "Open_vSwitch": { + "indexes": [["foo", "bar"]], + "columns": { + "foo": { + "type": "string" + }, + "bar": { + "type": "string" + } + } + } + } + } + `), &schema) + tSchema := schema.Table("Open_vSwitch") + require.Nil(t, err) + testData := CacheData{ + "Open_vSwitch": map[string]model.Model{ + "bar": &testModel{Foo: "bar", Bar: "bar"}, + "foobar": &testModel{Foo: "foobar", Bar: "foobar"}, + }, + } + tc, err := NewTableCache(&schema, db, testData) + require.Nil(t, err) + + tests := []struct { + name string + uuid string + model *testModel + wantErr bool + }{ + { + "error if row does not exist", + "foo", + &testModel{Foo: "foo", Bar: "foo"}, + true, + }, + { + "update both index cols", + "bar", + &testModel{Foo: "baz", Bar: "baz"}, + false, + }, + { + "update single index col", + "bar", + &testModel{Foo: "baz", Bar: "quux"}, + false, + }, + { + "error new index would cause duplicate", + "baz", + &testModel{Foo: "foobar", Bar: "foobar"}, + true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rc := tc.Table("Open_vSwitch") + require.NotNil(t, rc) + err := rc.Update(tt.uuid, tt.model) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.Nil(t, err) + mapperInfo, err := mapper.NewMapperInfo(tSchema, tt.model) + require.Nil(t, err) + h, err := hashColumnValues(mapperInfo, []string{"foo", "bar"}) + require.Nil(t, err) + assert.Equal(t, tt.uuid, rc.indexes["foo,bar"][h]) + } + }) + } +} + +func TestRowCacheDelete(t *testing.T) { + var schema ovsdb.DatabaseSchema + db, err := model.NewDBModel("Open_vSwitch", map[string]model.Model{"Open_vSwitch": &testModel{}}) + require.Nil(t, err) + err = json.Unmarshal([]byte(` + {"name": "TestDB", + "tables": { + "Open_vSwitch": { + "indexes": [["foo"]], + "columns": { + "foo": { + "type": "string" + }, + "bar": { + "type": "string" + } + } + } + } + } + `), &schema) + require.Nil(t, err) + testData := CacheData{ + "Open_vSwitch": map[string]model.Model{ + "bar": &testModel{Foo: "bar"}, + }, + } + tc, err := NewTableCache(&schema, db, testData) + require.Nil(t, err) + + tests := []struct { + name string + uuid string + model *testModel + wantErr bool + }{ + { + "deletes a row", + "bar", + &testModel{Foo: "bar"}, + false, + }, + { + "error if row does not exist", + "foobar", + &testModel{Foo: "bar"}, + true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rc := tc.Table("Open_vSwitch") + require.NotNil(t, rc) + err := rc.Delete(tt.uuid) + if tt.wantErr { + assert.Error(t, err) + } else { + require.Nil(t, err) + assert.Equal(t, "", rc.indexes["foo"][tt.model.Foo]) + } + }) + } +} + func TestEventHandlerFuncs_OnAdd(t *testing.T) { calls := 0 type fields struct { @@ -212,75 +584,62 @@ func TestEventHandlerFuncs_OnDelete(t *testing.T) { } } -func TestTableCache_Table(t *testing.T) { - type fields struct { - cache map[string]*RowCache - } - type args struct { - name string - } +func TestTableCacheTable(t *testing.T) { tests := []struct { - name string - fields fields - args args - want *RowCache + name string + cache map[string]*RowCache + table string + want *RowCache }{ { "returns nil for an empty table", - fields{ - cache: map[string]*RowCache{"bar": NewRowCache(nil)}, - }, - args{ - "foo", - }, + map[string]*RowCache{"bar": newRowCache("bar", ovsdb.TableSchema{}, nil)}, + "foo", nil, }, { "returns nil for an empty table", - fields{ - cache: map[string]*RowCache{"bar": NewRowCache(nil)}, - }, - args{ - "bar", - }, - NewRowCache(nil), + map[string]*RowCache{"bar": newRowCache("bar", ovsdb.TableSchema{}, nil)}, + "bar", + newRowCache("bar", ovsdb.TableSchema{}, nil), }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { tr := &TableCache{ - cache: tt.fields.cache, + cache: tt.cache, } - got := tr.Table(tt.args.name) + got := tr.Table(tt.table) assert.Equal(t, tt.want, got) }) } } -func TestTableCache_Tables(t *testing.T) { - type fields struct { - cache map[string]*RowCache - } +func TestTableCacheTables(t *testing.T) { tests := []struct { - name string - fields fields - want []string + name string + cache map[string]*RowCache + want []string }{ { "returns a table that exists", - fields{cache: map[string]*RowCache{"test1": NewRowCache(nil), "test2": NewRowCache(nil), "test3": NewRowCache(nil)}}, + map[string]*RowCache{ + "test1": newRowCache("test1", ovsdb.TableSchema{}, nil), + "test2": newRowCache("test2", ovsdb.TableSchema{}, nil), + "test3": newRowCache("test3", ovsdb.TableSchema{}, nil), + }, []string{"test1", "test2", "test3"}, }, { "returns an empty slice if no tables exist", - fields{cache: map[string]*RowCache{}}, + map[string]*RowCache{}, []string{}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { tr := &TableCache{ - cache: tt.fields.cache, + cache: tt.cache, } got := tr.Tables() assert.ElementsMatch(t, tt.want, got) @@ -297,17 +656,21 @@ func TestTableCache_populate(t *testing.T) { {"name": "TestDB", "tables": { "Open_vSwitch": { + "indexes": [["foo"]], "columns": { "foo": { "type": "string" - } + }, + "bar": { + "type": "string" + } } } } } `), &schema) assert.Nil(t, err) - tc, err := NewTableCache(&schema, db) + tc, err := NewTableCache(&schema, db, nil) assert.Nil(t, err) testRow := ovsdb.Row(map[string]interface{}{"_uuid": "test", "foo": "bar"}) diff --git a/client/api_test.go b/client/api_test.go index 544aa773..067f7657 100644 --- a/client/api_test.go +++ b/client/api_test.go @@ -12,7 +12,7 @@ import ( ) func TestAPIListSimple(t *testing.T) { - tcache := apiTestCache(t) + lscacheList := []model.Model{ &testLogicalSwitch{ UUID: aUUID0, @@ -40,9 +40,10 @@ func TestAPIListSimple(t *testing.T) { for i := range lscacheList { lscache[lscacheList[i].(*testLogicalSwitch).UUID] = lscacheList[i] } - tcache.Set("Logical_Switch", cache.NewRowCache(lscache)) - tcache.Set("Logical_Switch_Port", nil) // empty - + testData := cache.CacheData{ + "Logical_Switch": lscache, + } + tcache := apiTestCache(t, testData) test := []struct { name string initialCap int @@ -121,7 +122,6 @@ func TestAPIListSimple(t *testing.T) { } func TestAPIListPredicate(t *testing.T) { - tcache := apiTestCache(t) lscacheList := []model.Model{ &testLogicalSwitch{ UUID: aUUID0, @@ -149,7 +149,10 @@ func TestAPIListPredicate(t *testing.T) { for i := range lscacheList { lscache[lscacheList[i].(*testLogicalSwitch).UUID] = lscacheList[i] } - tcache.Set("Logical_Switch", cache.NewRowCache(lscache)) + testData := cache.CacheData{ + "Logical_Switch": lscache, + } + tcache := apiTestCache(t, testData) test := []struct { name string @@ -214,7 +217,6 @@ func TestAPIListPredicate(t *testing.T) { } func TestAPIListFields(t *testing.T) { - tcache := apiTestCache(t) lspcacheList := []model.Model{ &testLogicalSwitchPort{ UUID: aUUID0, @@ -245,7 +247,10 @@ func TestAPIListFields(t *testing.T) { for i := range lspcacheList { lspcache[lspcacheList[i].(*testLogicalSwitchPort).UUID] = lspcacheList[i] } - tcache.Set("Logical_Switch_Port", cache.NewRowCache(lspcache)) + testData := cache.CacheData{ + "Logical_Switch_Port": lspcache, + } + tcache := apiTestCache(t, testData) testObj := testLogicalSwitchPort{} @@ -339,7 +344,7 @@ func TestConditionFromFunc(t *testing.T) { for _, tt := range test { t.Run(fmt.Sprintf("conditionFromFunc: %s", tt.name), func(t *testing.T) { - cache := apiTestCache(t) + cache := apiTestCache(t, nil) apiIface := newAPI(cache) condition := apiIface.(api).conditionFromFunc(tt.arg) if tt.err { @@ -398,7 +403,7 @@ func TestConditionFromModel(t *testing.T) { for _, tt := range test { t.Run(fmt.Sprintf("conditionFromModel: %s", tt.name), func(t *testing.T) { - cache := apiTestCache(t) + cache := apiTestCache(t, nil) apiIface := newAPI(cache) condition := apiIface.(api).conditionFromModel(false, tt.model, tt.conds...) if tt.err { @@ -416,7 +421,6 @@ func TestConditionFromModel(t *testing.T) { } func TestAPIGet(t *testing.T) { - tcache := apiTestCache(t) lsCacheList := []model.Model{} lspCacheList := []model.Model{ &testLogicalSwitchPort{ @@ -440,8 +444,11 @@ func TestAPIGet(t *testing.T) { for i := range lspCacheList { lspCache[lspCacheList[i].(*testLogicalSwitchPort).UUID] = lspCacheList[i] } - tcache.Set("Logical_Switch", cache.NewRowCache(lsCache)) - tcache.Set("Logical_Switch_Port", cache.NewRowCache(lspCache)) + testData := cache.CacheData{ + "Logical_Switch": lsCache, + "Logical_Switch_Port": lspCache, + } + tcache := apiTestCache(t, testData) test := []struct { name string @@ -496,7 +503,6 @@ func TestAPIGet(t *testing.T) { } func TestAPICreate(t *testing.T) { - tcache := apiTestCache(t) lsCacheList := []model.Model{} lspCacheList := []model.Model{ &testLogicalSwitchPort{ @@ -520,8 +526,11 @@ func TestAPICreate(t *testing.T) { for i := range lspCacheList { lspCache[lspCacheList[i].(*testLogicalSwitchPort).UUID] = lspCacheList[i] } - tcache.Set("Logical_Switch", cache.NewRowCache(lsCache)) - tcache.Set("Logical_Switch_Port", cache.NewRowCache(lspCache)) + testData := cache.CacheData{ + "Logical_Switch": lsCache, + "Logical_Switch_Port": lspCache, + } + tcache := apiTestCache(t, testData) rowFoo := ovsdb.Row(map[string]interface{}{"name": "foo"}) rowBar := ovsdb.Row(map[string]interface{}{"name": "bar"}) @@ -609,7 +618,6 @@ func TestAPICreate(t *testing.T) { } func TestAPIMutate(t *testing.T) { - tcache := apiTestCache(t) lspCache := map[string]model.Model{ aUUID0: &testLogicalSwitchPort{ UUID: aUUID0, @@ -634,7 +642,10 @@ func TestAPIMutate(t *testing.T) { Tag: []int{1}, }, } - tcache.Set("Logical_Switch_Port", cache.NewRowCache(lspCache)) + testData := cache.CacheData{ + "Logical_Switch_Port": lspCache, + } + tcache := apiTestCache(t, testData) testObj := testLogicalSwitchPort{} @@ -776,7 +787,6 @@ func TestAPIMutate(t *testing.T) { } func TestAPIUpdate(t *testing.T) { - tcache := apiTestCache(t) lspCache := map[string]model.Model{ aUUID0: &testLogicalSwitchPort{ UUID: aUUID0, @@ -802,7 +812,10 @@ func TestAPIUpdate(t *testing.T) { Tag: []int{1}, }, } - tcache.Set("Logical_Switch_Port", cache.NewRowCache(lspCache)) + testData := cache.CacheData{ + "Logical_Switch_Port": lspCache, + } + tcache := apiTestCache(t, testData) testObj := testLogicalSwitchPort{} testRow := ovsdb.Row(map[string]interface{}{"type": "somethingElse", "tag": testOvsSet(t, []int{6})}) @@ -1022,7 +1035,6 @@ func TestAPIUpdate(t *testing.T) { } func TestAPIDelete(t *testing.T) { - tcache := apiTestCache(t) lspCache := map[string]model.Model{ aUUID0: &testLogicalSwitchPort{ UUID: aUUID0, @@ -1048,7 +1060,10 @@ func TestAPIDelete(t *testing.T) { Tag: []int{1}, }, } - tcache.Set("Logical_Switch_Port", cache.NewRowCache(lspCache)) + testData := cache.CacheData{ + "Logical_Switch_Port": lspCache, + } + tcache := apiTestCache(t, testData) test := []struct { name string diff --git a/client/api_test_model.go b/client/api_test_model.go index a77f9367..1fa825e1 100644 --- a/client/api_test_model.go +++ b/client/api_test_model.go @@ -153,13 +153,13 @@ func (*testLogicalSwitchPort) Table() string { return "Logical_Switch_Port" } -func apiTestCache(t *testing.T) *cache.TableCache { +func apiTestCache(t *testing.T, data map[string]map[string]model.Model) *cache.TableCache { var schema ovsdb.DatabaseSchema err := json.Unmarshal(apiTestSchema, &schema) assert.Nil(t, err) db, err := model.NewDBModel("OVN_NorthBound", map[string]model.Model{"Logical_Switch": &testLogicalSwitch{}, "Logical_Switch_Port": &testLogicalSwitchPort{}}) assert.Nil(t, err) - cache, err := cache.NewTableCache(&schema, db) + cache, err := cache.NewTableCache(&schema, db, data) assert.Nil(t, err) return cache } diff --git a/client/client.go b/client/client.go index 2131f367..4076e99f 100644 --- a/client/client.go +++ b/client/client.go @@ -135,7 +135,7 @@ func newRPC2Client(conn net.Conn, database *model.DBModel) (*OvsdbClient, error) if err == nil { ovs.Schema = *schema - if cache, err := cache.NewTableCache(schema, database); err == nil { + if cache, err := cache.NewTableCache(schema, database, nil); err == nil { ovs.Cache = cache ovs.Register(ovs.Cache) ovs.api = newAPI(ovs.Cache) diff --git a/client/condition_test.go b/client/condition_test.go index 09ec7b6b..5ca45b72 100644 --- a/client/condition_test.go +++ b/client/condition_test.go @@ -11,7 +11,6 @@ import ( ) func TestEqualityConditional(t *testing.T) { - tcache := apiTestCache(t) lspcacheList := []model.Model{ &testLogicalSwitchPort{ UUID: aUUID0, @@ -42,7 +41,10 @@ func TestEqualityConditional(t *testing.T) { for i := range lspcacheList { lspcache[lspcacheList[i].(*testLogicalSwitchPort).UUID] = lspcacheList[i] } - tcache.Set("Logical_Switch_Port", cache.NewRowCache(lspcache)) + testData := cache.CacheData{ + "Logical_Switch_Port": lspcache, + } + tcache := apiTestCache(t, testData) test := []struct { name string @@ -149,7 +151,6 @@ func TestEqualityConditional(t *testing.T) { } func TestPredicateConditional(t *testing.T) { - tcache := apiTestCache(t) lspcacheList := []model.Model{ &testLogicalSwitchPort{ UUID: aUUID0, @@ -180,7 +181,10 @@ func TestPredicateConditional(t *testing.T) { for i := range lspcacheList { lspcache[lspcacheList[i].(*testLogicalSwitchPort).UUID] = lspcacheList[i] } - tcache.Set("Logical_Switch_Port", cache.NewRowCache(lspcache)) + testData := cache.CacheData{ + "Logical_Switch_Port": lspcache, + } + tcache := apiTestCache(t, testData) test := []struct { name string @@ -256,7 +260,6 @@ func TestPredicateConditional(t *testing.T) { } func TestExplicitConditional(t *testing.T) { - tcache := apiTestCache(t) lspcacheList := []model.Model{ &testLogicalSwitchPort{ UUID: aUUID0, @@ -287,7 +290,10 @@ func TestExplicitConditional(t *testing.T) { for i := range lspcacheList { lspcache[lspcacheList[i].(*testLogicalSwitchPort).UUID] = lspcacheList[i] } - tcache.Set("Logical_Switch_Port", cache.NewRowCache(lspcache)) + testData := cache.CacheData{ + "Logical_Switch_Port": lspcache, + } + tcache := apiTestCache(t, testData) testObj := &testLogicalSwitchPort{} diff --git a/cmd/stress/stress.go b/cmd/stress/stress.go index 0f6173ee..0b31748d 100644 --- a/cmd/stress/stress.go +++ b/cmd/stress/stress.go @@ -4,11 +4,15 @@ import ( "context" "flag" "fmt" + "io" "log" "os" "runtime" "runtime/pprof" + "sync" + "time" + "github.com/google/uuid" "github.com/ovn-org/libovsdb/cache" "github.com/ovn-org/libovsdb/client" "github.com/ovn-org/libovsdb/model" @@ -34,36 +38,91 @@ type ovsType struct { var ( cpuprofile = flag.String("cpuprofile", "", "write cpu profile to this file") memprofile = flag.String("memoryprofile", "", "write memory profile to this file") - nins = flag.Int("ninserts", 100, "insert this number of elements in the database") + nins = flag.Int("inserts", 100, "the number of insertions to make to the database (per client)") + nclients = flag.Int("clients", 1, "the number of clients to use") + parallel = flag.Bool("parallel", false, "run clients in parallel") verbose = flag.Bool("verbose", false, "Be verbose") connection = flag.String("ovsdb", "unix:/var/run/openvswitch/db.sock", "OVSDB connection string") dbModel *model.DBModel - - ready bool - rootUUID string - insertions int - deletions int ) -func run() { +type result struct { + insertions int + deletions int + transactTime []time.Duration + cacheTime []time.Duration +} + +func cleanup(ctx context.Context) { ovs, err := client.Connect(context.Background(), *connection, dbModel, nil) if err != nil { log.Fatal(err) } defer ovs.Disconnect() + + if err := ovs.MonitorAll(""); err != nil { + log.Fatal(err) + } + + var rootUUID string + // Get root UUID + for _, uuid := range ovs.Cache.Table("Open_vSwitch").Rows() { + rootUUID = uuid + log.Printf("rootUUID is %v", rootUUID) + } + + // Remove all existing bridges + var bridges []bridgeType + if err := ovs.List(&bridges); err == nil { + log.Printf("%d existing bridges found", len(bridges)) + for _, bridge := range bridges { + deleteBridge(ctx, ovs, rootUUID, &bridge) + } + } else { + if err != client.ErrNotFound { + log.Fatal(err) + } + } +} + +func run(ctx context.Context, resultsChan chan result, wg *sync.WaitGroup) { + defer wg.Done() + + result := result{} + ready := false + var rootUUID string + + ovs, err := client.Connect(context.Background(), *connection, dbModel, nil) + if err != nil { + log.Fatal(err) + } + defer ovs.Disconnect() + + var bridges []bridgeType + bridgeCh := make(map[string]chan bool) + for i := 0; i < *nins; i++ { + br := newBridge() + bridges = append(bridges, br) + bridgeCh[br.Name] = make(chan bool) + } + ovs.Cache.AddEventHandler( &cache.EventHandlerFuncs{ AddFunc: func(table string, model model.Model) { if ready && table == "Bridge" { - insertions++ - if *verbose { - fmt.Printf(".") + br := model.(*bridgeType) + var ch chan bool + var ok bool + if ch, ok = bridgeCh[br.Name]; !ok { + return } + close(ch) + result.insertions++ } }, DeleteFunc: func(table string, model model.Model) { if table == "Bridge" { - deletions++ + result.deletions++ } }, }, @@ -81,39 +140,36 @@ func run() { } } - // Remove all existing bridges - var bridges []bridgeType - if err := ovs.List(&bridges); err == nil { - for _, bridge := range bridges { - deleteBridge(ovs, &bridge) - } - } else { - if err != client.ErrNotFound { - log.Fatal(err) - } - } - ready = true + cacheWg := sync.WaitGroup{} for i := 0; i < *nins; i++ { - createBridge(ovs, i) + br := bridges[i] + ch := bridgeCh[br.Name] + log.Printf("create bridge: %s", br.Name) + cacheWg.Add(1) + go func(ctx context.Context, ch chan bool) { + defer cacheWg.Done() + <-ch + }(ctx, ch) + createBridge(ctx, ovs, rootUUID, br) } + cacheWg.Wait() + resultsChan <- result } -func transact(ovs *client.OvsdbClient, operations []ovsdb.Operation) (ok bool, uuid string) { +func transact(ctx context.Context, ovs *client.OvsdbClient, operations []ovsdb.Operation) (bool, string) { reply, err := ovs.Transact(operations...) if err != nil { - ok = false - return + return false, "" } if _, err := ovsdb.CheckOperationResults(reply, operations); err != nil { - ok = false - return + return false, "" } - uuid = reply[0].UUID.GoUUID - return + return true, reply[0].UUID.GoUUID } -func deleteBridge(ovs *client.OvsdbClient, bridge *bridgeType) { +func deleteBridge(ctx context.Context, ovs *client.OvsdbClient, rootUUID string, bridge *bridgeType) { + log.Printf("deleting bridge %s", bridge.Name) deleteOp, err := ovs.Where(bridge).Delete() if err != nil { log.Fatal(err) @@ -121,7 +177,6 @@ func deleteBridge(ovs *client.OvsdbClient, bridge *bridgeType) { ovsRow := ovsType{ UUID: rootUUID, } - mutateOp, err := ovs.Where(&ovsRow).Mutate(&ovsRow, model.Mutation{ Field: &ovsRow.Bridges, Mutator: ovsdb.MutateOperationDelete, @@ -130,20 +185,14 @@ func deleteBridge(ovs *client.OvsdbClient, bridge *bridgeType) { if err != nil { log.Fatal(err) } - operations := append(deleteOp, mutateOp...) - ok, _ := transact(ovs, operations) - if ok { - if *verbose { - fmt.Println("Bridge Deletion Successful : ", bridge.UUID) - } - } + _, _ = transact(ctx, ovs, operations) } -func createBridge(ovs *client.OvsdbClient, iter int) { - bridge := bridgeType{ +func newBridge() bridgeType { + return bridgeType{ UUID: "gopher", - Name: fmt.Sprintf("bridge-%d", iter), + Name: fmt.Sprintf("br-%s", uuid.NewString()), OtherConfig: map[string]string{ "foo": "bar", "fake": "config", @@ -153,6 +202,9 @@ func createBridge(ovs *client.OvsdbClient, iter int) { "key2": "val2", }, } +} + +func createBridge(ctx context.Context, ovs *client.OvsdbClient, rootUUID string, bridge bridgeType) { insertOp, err := ovs.Create(&bridge) if err != nil { log.Fatal(err) @@ -168,16 +220,12 @@ func createBridge(ovs *client.OvsdbClient, iter int) { } operations := append(insertOp, mutateOp...) - ok, uuid := transact(ovs, operations) - if ok { - if *verbose { - fmt.Println("Bridge Addition Successful : ", uuid) - } - } + _, _ = transact(ctx, ovs, operations) } func main() { flag.Parse() - var err error + ctx := context.Background() + if *cpuprofile != "" { f, err := os.Create(*cpuprofile) if err != nil { @@ -188,18 +236,52 @@ func main() { } defer pprof.StopCPUProfile() } + if !*verbose { + log.SetOutput(io.Discard) + } + var err error dbModel, err = model.NewDBModel("Open_vSwitch", map[string]model.Model{"Open_vSwitch": &ovsType{}, "Bridge": &bridgeType{}}) if err != nil { log.Fatal(err) } - run() + cleanup(ctx) + + var wg sync.WaitGroup + resultChan := make(chan result) + results := make([]result, *nclients) + go func() { + for result := range resultChan { + results = append(results, result) + } + }() + + for i := 0; i < *nclients; i++ { + wg.Add(1) + go run(ctx, resultChan, &wg) + if !*parallel { + wg.Wait() + } + } + log.Print("waiting for clients to complete") + // wait for all clients + wg.Wait() + // close the result channel to avoid leaking a goroutine + close(resultChan) + + result := result{} + for _, r := range results { + result.insertions += r.insertions + result.deletions += r.deletions + result.transactTime = append(result.transactTime, r.transactTime...) + result.cacheTime = append(result.transactTime, r.cacheTime...) + } fmt.Printf("\n\n\n") fmt.Printf("Summary:\n") - fmt.Printf("\tInsertions: %d\n", insertions) - fmt.Printf("\tDeletions: %d\n", deletions) + fmt.Printf("\tTotal Insertions: %d\n", result.insertions) + fmt.Printf("\tTotal Deletions: %d\n", result.deletions) if *memprofile != "" { f, err := os.Create(*memprofile) diff --git a/go.mod b/go.mod index 302c4110..56a7ffb9 100644 --- a/go.mod +++ b/go.mod @@ -6,5 +6,6 @@ require ( github.com/cenk/hub v1.0.1 // indirect github.com/cenkalti/hub v1.0.1 // indirect github.com/cenkalti/rpc2 v0.0.0-20210220005819-4a29bc83afe1 - github.com/stretchr/testify v1.4.0 + github.com/google/uuid v1.2.0 + github.com/stretchr/testify v1.6.1 ) diff --git a/go.sum b/go.sum index cac2fcf0..e01ad442 100644 --- a/go.sum +++ b/go.sum @@ -6,13 +6,14 @@ github.com/cenkalti/rpc2 v0.0.0-20210220005819-4a29bc83afe1 h1:aT9Ez2drLmrviqTnV github.com/cenkalti/rpc2 v0.0.0-20210220005819-4a29bc83afe1/go.mod h1:v2npkhrXyk5BCnkNIiPdRI23Uq6uWPUQGL2hnRcRr/M= github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/google/uuid v1.2.0 h1:qJYtXnJRWmpe7m/3XlyhrsLrEURqHRM2kxzoxXqyUDs= +github.com/google/uuid v1.2.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/stretchr/objx v0.1.0 h1:4G4v2dO3VZwixGIRoQ5Lfboy6nUhCyYzaqnIAPPhYs4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= -github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= -gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/mapper/mapper.go b/mapper/mapper.go index f75a1170..a68d37ec 100644 --- a/mapper/mapper.go +++ b/mapper/mapper.go @@ -119,8 +119,13 @@ func (m Mapper) NewRow(tableName string, data interface{}, fields ...interface{} return nil, err } - ovsRow := make(map[string]interface{}, len(table.Columns)) - for name, column := range table.Columns { + columns := make(map[string]*ovsdb.ColumnSchema) + for k, v := range table.Columns { + columns[k] = v + } + columns["_uuid"] = &ovsdb.UUIDColumn + ovsRow := make(map[string]interface{}, len(columns)) + for name, column := range columns { nativeElem, err := mapperInfo.FieldByColumn(name) if err != nil { // If provided struct does not have a field to hold this value, skip it @@ -144,7 +149,6 @@ func (m Mapper) NewRow(tableName string, data interface{}, fields ...interface{} continue } } - if len(fields) == 0 && ovsdb.IsDefaultValue(column, nativeElem) { continue } diff --git a/mapper/mapper_test.go b/mapper/mapper_test.go index 7eb6fb05..6afb04d8 100644 --- a/mapper/mapper_test.go +++ b/mapper/mapper_test.go @@ -957,7 +957,7 @@ func TestMapperMutation(t *testing.T) { name: "non-mutable", column: "unmutable", obj: testType{}, - mutator: ovsdb.MutateOperationSubstract, + mutator: ovsdb.MutateOperationSubtract, value: 2, err: true, }, diff --git a/ovsdb/bindings.go b/ovsdb/bindings.go index 1926996a..1d1bbb6c 100644 --- a/ovsdb/bindings.go +++ b/ovsdb/bindings.go @@ -230,14 +230,14 @@ func validateMutationAtomic(atype string, mutator Mutator, value interface{}) er return fmt.Errorf("atomictype %s does not support mutation", atype) case TypeReal: switch mutator { - case MutateOperationAdd, MutateOperationSubstract, MutateOperationMultiply, MutateOperationDivide: + case MutateOperationAdd, MutateOperationSubtract, MutateOperationMultiply, MutateOperationDivide: return nil default: return fmt.Errorf("wrong mutator for real type %s", mutator) } case TypeInteger: switch mutator { - case MutateOperationAdd, MutateOperationSubstract, MutateOperationMultiply, MutateOperationDivide, MutateOperationModulo: + case MutateOperationAdd, MutateOperationSubtract, MutateOperationMultiply, MutateOperationDivide, MutateOperationModulo: return nil default: return fmt.Errorf("wrong mutator for integer type: %s", mutator) @@ -257,6 +257,15 @@ func ValidateMutation(column *ColumnSchema, mutator Mutator, value interface{}) case TypeSet: switch mutator { case MutateOperationInsert, MutateOperationDelete: + // RFC7047 says a may be an with a single + // element. Check if we can store this value in our column + if reflect.TypeOf(value).Kind() != reflect.Slice { + if NativeType(column) != reflect.SliceOf(reflect.TypeOf(value)) { + return NewErrWrongType(fmt.Sprintf("Mutation %s of single value in to column %s", mutator, column), + NativeType(column).String(), reflect.SliceOf(reflect.TypeOf(value)).String()) + } + return nil + } if NativeType(column) != reflect.TypeOf(value) { return NewErrWrongType(fmt.Sprintf("Mutation %s of column %s", mutator, column), NativeType(column).String(), value) @@ -324,7 +333,7 @@ func isDefaultBaseValue(elem interface{}, etype ExtendedType) bool { switch etype { case TypeUUID: - return elem.(string) == "00000000-0000-0000-0000-000000000000" || elem.(string) == "" + return elem.(string) == "00000000-0000-0000-0000-000000000000" || elem.(string) == "" || isNamed(elem.(string)) case TypeMap, TypeSet: return value.IsNil() || value.Len() == 0 case TypeString: diff --git a/ovsdb/bindings_test.go b/ovsdb/bindings_test.go index a0d65adb..af901abb 100644 --- a/ovsdb/bindings_test.go +++ b/ovsdb/bindings_test.go @@ -747,35 +747,35 @@ func TestMutationValidation(t *testing.T) { { name: "string", column: []byte(`{"type":"string"}`), - mutators: []Mutator{MutateOperationAdd, MutateOperationAdd, MutateOperationSubstract, MutateOperationMultiply, MutateOperationDivide, MutateOperationModulo}, + mutators: []Mutator{MutateOperationAdd, MutateOperationAdd, MutateOperationSubtract, MutateOperationMultiply, MutateOperationDivide, MutateOperationModulo}, value: "foo", valid: false, }, { name: "string", column: []byte(`{"type":"uuid"}`), - mutators: []Mutator{MutateOperationAdd, MutateOperationAdd, MutateOperationSubstract, MutateOperationMultiply, MutateOperationDivide, MutateOperationModulo}, + mutators: []Mutator{MutateOperationAdd, MutateOperationAdd, MutateOperationSubtract, MutateOperationMultiply, MutateOperationDivide, MutateOperationModulo}, value: "foo", valid: false, }, { name: "boolean", column: []byte(`{"type":"boolean"}`), - mutators: []Mutator{MutateOperationAdd, MutateOperationAdd, MutateOperationSubstract, MutateOperationMultiply, MutateOperationDivide, MutateOperationModulo}, + mutators: []Mutator{MutateOperationAdd, MutateOperationAdd, MutateOperationSubtract, MutateOperationMultiply, MutateOperationDivide, MutateOperationModulo}, value: true, valid: false, }, { name: "integer", column: []byte(`{"type":"integer"}`), - mutators: []Mutator{MutateOperationAdd, MutateOperationAdd, MutateOperationSubstract, MutateOperationMultiply, MutateOperationDivide, MutateOperationModulo}, + mutators: []Mutator{MutateOperationAdd, MutateOperationAdd, MutateOperationSubtract, MutateOperationMultiply, MutateOperationDivide, MutateOperationModulo}, value: 4, valid: true, }, { name: "unmutable", column: []byte(`{"type":"integer", "mutable": false}`), - mutators: []Mutator{MutateOperationAdd, MutateOperationAdd, MutateOperationSubstract, MutateOperationMultiply, MutateOperationDivide, MutateOperationModulo}, + mutators: []Mutator{MutateOperationAdd, MutateOperationAdd, MutateOperationSubtract, MutateOperationMultiply, MutateOperationDivide, MutateOperationModulo}, value: 4, valid: false, }, @@ -789,14 +789,14 @@ func TestMutationValidation(t *testing.T) { { name: "integer wrong type", column: []byte(`{"type":"integer"}`), - mutators: []Mutator{MutateOperationAdd, MutateOperationAdd, MutateOperationSubstract, MutateOperationMultiply, MutateOperationDivide, MutateOperationModulo}, + mutators: []Mutator{MutateOperationAdd, MutateOperationAdd, MutateOperationSubtract, MutateOperationMultiply, MutateOperationDivide, MutateOperationModulo}, value: "foo", valid: false, }, { name: "real", column: []byte(`{"type":"real"}`), - mutators: []Mutator{MutateOperationAdd, MutateOperationAdd, MutateOperationSubstract, MutateOperationMultiply, MutateOperationDivide}, + mutators: []Mutator{MutateOperationAdd, MutateOperationAdd, MutateOperationSubtract, MutateOperationMultiply, MutateOperationDivide}, value: 4.0, valid: true, }, @@ -816,7 +816,7 @@ func TestMutationValidation(t *testing.T) { "min": 0 } }`), - mutators: []Mutator{MutateOperationAdd, MutateOperationAdd, MutateOperationSubstract, MutateOperationMultiply, MutateOperationDivide, MutateOperationModulo}, + mutators: []Mutator{MutateOperationAdd, MutateOperationAdd, MutateOperationSubtract, MutateOperationMultiply, MutateOperationDivide, MutateOperationModulo}, value: 4, valid: true, }, @@ -829,7 +829,7 @@ func TestMutationValidation(t *testing.T) { "min": 0 } }`), - mutators: []Mutator{MutateOperationAdd, MutateOperationAdd, MutateOperationSubstract, MutateOperationMultiply, MutateOperationDivide}, + mutators: []Mutator{MutateOperationAdd, MutateOperationAdd, MutateOperationSubtract, MutateOperationMultiply, MutateOperationDivide}, value: 4.0, valid: true, }, @@ -842,10 +842,36 @@ func TestMutationValidation(t *testing.T) { "min": 0 } }`), - mutators: []Mutator{MutateOperationAdd, MutateOperationAdd, MutateOperationSubstract, MutateOperationMultiply}, + mutators: []Mutator{MutateOperationAdd, MutateOperationAdd, MutateOperationSubtract, MutateOperationMultiply}, value: "foo", valid: false, }, + { + name: "string set insert single string", + column: []byte(`{ + "type": { + "key": "string", + "max": "unlimited", + "min": 0 + } + }`), + mutators: []Mutator{MutateOperationInsert}, + value: "foo", + valid: true, + }, + { + name: "string set insert single int", + column: []byte(`{ + "type": { + "key": "string", + "max": "unlimited", + "min": 0 + } + }`), + mutators: []Mutator{MutateOperationInsert}, + value: 42, + valid: false, + }, { name: "string set insert/delete", column: []byte(`{ diff --git a/ovsdb/condition.go b/ovsdb/condition.go index 6c7ba3ff..6f3a2fad 100644 --- a/ovsdb/condition.go +++ b/ovsdb/condition.go @@ -3,6 +3,7 @@ package ovsdb import ( "encoding/json" "fmt" + "reflect" ) type ConditionFunction string @@ -25,6 +26,10 @@ type Condition struct { Value interface{} } +func (c Condition) String() string { + return fmt.Sprintf("where column %s %s %v", c.Column, c.Function, c.Value) +} + // NewCondition returns a new condition func NewCondition(column string, function ConditionFunction, value interface{}) Condition { return Condition{ @@ -41,7 +46,7 @@ func (c Condition) MarshalJSON() ([]byte, error) { } // UnmarshalJSON converts a 3 element JSON array to a Condition -func (c Condition) UnmarshalJSON(b []byte) error { +func (c *Condition) UnmarshalJSON(b []byte) error { var v []interface{} err := json.Unmarshal(b, &v) if err != nil { @@ -53,18 +58,152 @@ func (c Condition) UnmarshalJSON(b []byte) error { c.Column = v[0].(string) function := ConditionFunction(v[1].(string)) switch function { + case ConditionEqual, + ConditionNotEqual, + ConditionIncludes, + ConditionExcludes, + ConditionGreaterThan, + ConditionGreaterThanOrEqual, + ConditionLessThan, + ConditionLessThanOrEqual: + c.Function = function + default: + return fmt.Errorf("%s is not a valid function", function) + } + vv, err := interfaceToOVSDBNotationInterface(reflect.ValueOf(v[2])) + if err != nil { + return err + } + c.Value = vv + return nil +} + +// Evaluate will evaluate the condition on the two provided values +// The conditions operately differently depending on the type of +// the provided values. The behavjour is as described in RFC7047 +func (c ConditionFunction) Evaluate(a interface{}, b interface{}) (bool, error) { + x := reflect.ValueOf(a) + y := reflect.ValueOf(b) + if x.Kind() != y.Kind() { + return false, fmt.Errorf("comparison between %s and %s not supported", x.Kind(), y.Kind()) + } + switch c { case ConditionEqual: + return reflect.DeepEqual(a, b), nil case ConditionNotEqual: + return !reflect.DeepEqual(a, b), nil case ConditionIncludes: + switch x.Kind() { + case reflect.Slice: + return sliceContains(x, y), nil + case reflect.Map: + return mapContains(x, y), nil + case reflect.Int, reflect.Float64, reflect.Bool, reflect.String: + return reflect.DeepEqual(a, b), nil + default: + return false, fmt.Errorf("condition not supported on %s", x.Kind()) + } case ConditionExcludes: + switch x.Kind() { + case reflect.Slice: + return !sliceContains(x, y), nil + case reflect.Map: + return !mapContains(x, y), nil + case reflect.Int, reflect.Float64, reflect.Bool, reflect.String: + return !reflect.DeepEqual(a, b), nil + default: + return false, fmt.Errorf("condition not supported on %s", x.Kind()) + } case ConditionGreaterThan: + switch x.Kind() { + case reflect.Int: + return x.Int() > y.Int(), nil + case reflect.Float64: + return x.Float() > y.Float(), nil + case reflect.Bool, reflect.String, reflect.Slice, reflect.Map: + default: + return false, fmt.Errorf("condition not supported on %s", x.Kind()) + } case ConditionGreaterThanOrEqual: + switch x.Kind() { + case reflect.Int: + return x.Int() >= y.Int(), nil + case reflect.Float64: + return x.Float() >= y.Float(), nil + case reflect.Bool, reflect.String, reflect.Slice, reflect.Map: + default: + return false, fmt.Errorf("condition not supported on %s", x.Kind()) + } case ConditionLessThan: + switch x.Kind() { + case reflect.Int: + return x.Int() < y.Int(), nil + case reflect.Float64: + return x.Float() < y.Float(), nil + case reflect.Bool, reflect.String, reflect.Slice, reflect.Map: + default: + return false, fmt.Errorf("condition not supported on %s", x.Kind()) + } case ConditionLessThanOrEqual: - c.Function = function + switch x.Kind() { + case reflect.Int: + return x.Int() <= y.Int(), nil + case reflect.Float64: + return x.Float() <= y.Float(), nil + case reflect.Bool, reflect.String, reflect.Slice, reflect.Map: + default: + return false, fmt.Errorf("condition not supported on %s", x.Kind()) + } default: - return fmt.Errorf("%s is not a valid function", function) + return false, fmt.Errorf("unsuported condition function %s", c) } - c.Value = v[2] - return nil + // we should never get here + return false, fmt.Errorf("unreachable condition") +} + +func sliceContains(x, y reflect.Value) bool { + for i := 0; i < y.Len(); i++ { + found := false + vy := y.Index(i) + for j := 0; j < x.Len(); j++ { + vx := x.Index(j) + if vy.Kind() == reflect.Interface { + if vy.Elem() == vx.Elem() { + found = true + break + } + } else { + if vy.Interface() == vx.Interface() { + found = true + break + } + } + } + if !found { + return false + } + } + return true +} + +func mapContains(x, y reflect.Value) bool { + iter := y.MapRange() + for iter.Next() { + k := iter.Key() + v := iter.Value() + vx := x.MapIndex(k) + if !vx.IsValid() { + return false + } + if v.Kind() != reflect.Interface { + if v.Interface() != vx.Interface() { + return false + } + } else { + if v.Elem() != vx.Elem() { + return false + } + } + } + return true } diff --git a/ovsdb/condition_test.go b/ovsdb/condition_test.go index e1940ec0..4eae9cb8 100644 --- a/ovsdb/condition_test.go +++ b/ovsdb/condition_test.go @@ -1,85 +1,120 @@ package ovsdb import ( + "encoding/json" + "reflect" "testing" "github.com/stretchr/testify/assert" ) -func TestConditionMarshalJSON(t *testing.T) { - type fields struct { - Column string - Function ConditionFunction - Value interface{} - } +func TestConditionMarshalUnmarshalJSON(t *testing.T) { tests := []struct { - name string - fields fields - want string - wantErr bool + name string + condition Condition + want string + wantErr bool }{ { "test <", - fields{"foo", ConditionLessThan, "bar"}, + Condition{"foo", ConditionLessThan, "bar"}, `[ "foo", "<", "bar" ]`, false, }, { "test <=", - fields{"foo", ConditionLessThanOrEqual, "bar"}, + Condition{"foo", ConditionLessThanOrEqual, "bar"}, `[ "foo", "<=", "bar" ]`, false, }, { "test >", - fields{"foo", ConditionGreaterThan, "bar"}, + Condition{"foo", ConditionGreaterThan, "bar"}, `[ "foo", ">", "bar" ]`, false, }, { "test >=", - fields{"foo", ConditionGreaterThanOrEqual, "bar"}, + Condition{"foo", ConditionGreaterThanOrEqual, "bar"}, `[ "foo", ">=", "bar" ]`, false, }, { "test ==", - fields{"foo", ConditionEqual, "bar"}, + Condition{"foo", ConditionEqual, "bar"}, `[ "foo", "==", "bar" ]`, false, }, { "test !=", - fields{"foo", ConditionNotEqual, "bar"}, + Condition{"foo", ConditionNotEqual, "bar"}, `[ "foo", "!=", "bar" ]`, false, }, { "test includes", - fields{"foo", ConditionIncludes, "bar"}, + Condition{"foo", ConditionIncludes, "bar"}, `[ "foo", "includes", "bar" ]`, false, }, { "test excludes", - fields{"foo", ConditionExcludes, "bar"}, + Condition{"foo", ConditionExcludes, "bar"}, `[ "foo", "excludes", "bar" ]`, false, }, + { + "test uuid", + Condition{"foo", ConditionExcludes, UUID{GoUUID: "foo"}}, + `[ "foo", "excludes", ["named-uuid", "foo"] ]`, + false, + }, + { + "test set", + Condition{"foo", ConditionExcludes, OvsSet{GoSet: []interface{}{"foo", "bar", "baz"}}}, + `[ "foo", "excludes", ["set",["foo", "bar", "baz"]] ]`, + false, + }, + { + "test map", + Condition{"foo", ConditionExcludes, OvsMap{GoMap: map[interface{}]interface{}{"foo": "bar", "baz": "quux"}}}, + `[ "foo", "excludes", ["map",[["foo", "bar"], ["baz", "quux"]]]]`, + false, + }, + { + "test uuid set", + Condition{"foo", ConditionExcludes, OvsSet{GoSet: []interface{}{UUID{GoUUID: "foo"}, UUID{GoUUID: "bar"}}}}, + `[ "foo", "excludes", ["set",[["named-uuid", "foo"], ["named-uuid", "bar"]]] ]`, + false, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - c := Condition{ - Column: tt.fields.Column, - Function: tt.fields.Function, - Value: tt.fields.Value, + got, err := json.Marshal(tt.condition) + if err != nil { + t.Fatal(err) } - got, err := c.MarshalJSON() - if (err != nil) != tt.wantErr { - t.Errorf("Condition.MarshalJSON() error = %v, wantErr %v", err, tt.wantErr) - return + // testing JSON equality is flaky for ovsdb notated maps + // it's safe to skip this as we test from json->object later + if tt.name != "test map" { + assert.JSONEq(t, tt.want, string(got)) + } + var c Condition + if err := json.Unmarshal(got, &c); err != nil { + t.Fatal(err) + } + assert.Equal(t, tt.condition.Column, c.Column) + assert.Equal(t, tt.condition.Function, c.Function) + v := reflect.TypeOf(tt.condition.Value) + vv := reflect.ValueOf(c.Value) + if !vv.IsValid() { + t.Fatalf("c.Value is empty: %v", c.Value) + } + assert.Equal(t, v, vv.Type()) + assert.Equal(t, tt.condition.Value, vv.Convert(v).Interface()) + if vv.Kind() == reflect.String { + assert.Equal(t, tt.condition.Value, vv.String()) } - assert.JSONEq(t, tt.want, string(got)) }) } } @@ -131,3 +166,508 @@ func TestCondition_UnmarshalJSON(t *testing.T) { }) } } + +func TestConditionFunctionEvaluate(t *testing.T) { + tests := []struct { + name string + c ConditionFunction + a interface{} + b interface{} + want bool + wantErr bool + }{ + { + "equal string true", + ConditionEqual, + "foo", + "foo", + true, + false, + }, + { + "equal string false", + ConditionEqual, + "foo", + "bar", + false, + false, + }, + { + "equal int true", + ConditionEqual, + 1024, + 1024, + true, + false, + }, + { + "equal int false", + ConditionEqual, + 1024, + 2048, + false, + false, + }, + { + "equal real true", + ConditionEqual, + float64(42.0), + float64(42.0), + true, + false, + }, + { + "equal real false", + ConditionEqual, + float64(42.0), + float64(420.0), + false, + false, + }, + { + "equal map true", + ConditionEqual, + map[string]string{"foo": "bar"}, + map[string]string{"foo": "bar"}, + true, + false, + }, + { + "equal map false", + ConditionEqual, + map[string]string{"foo": "bar"}, + map[string]string{"bar": "baz"}, + false, + false, + }, + { + "equal slice true", + ConditionEqual, + []string{"foo", "bar"}, + []string{"foo", "bar"}, + true, + false, + }, + { + "equal slice false", + ConditionEqual, + []string{"foo", "bar"}, + []string{"foo", "baz"}, + false, + false, + }, + { + "notequal string true", + ConditionNotEqual, + "foo", + "bar", + true, + false, + }, + { + "notequal string false", + ConditionNotEqual, + "foo", + "foo", + false, + false, + }, + { + "notequal int true", + ConditionNotEqual, + 1024, + 2048, + true, + false, + }, + { + "notequal int false", + ConditionNotEqual, + 1024, + 1024, + false, + false, + }, + { + "notequal real true", + ConditionNotEqual, + float64(42.0), + float64(24.0), + true, + false, + }, + { + "notequal real false", + ConditionNotEqual, + float64(42.0), + float64(42.0), + false, + false, + }, + { + "notequal map true", + ConditionNotEqual, + map[string]string{"foo": "bar"}, + map[string]string{"bar": "baz"}, + true, + false, + }, + { + "notequal map false", + ConditionNotEqual, + map[string]string{"foo": "bar"}, + map[string]string{"foo": "bar"}, + false, + false, + }, + { + "notequal slice true", + ConditionNotEqual, + []string{"foo", "bar"}, + []string{"foo", "baz"}, + true, + false, + }, + { + "notequal slice false", + ConditionNotEqual, + []string{"foo", "bar"}, + []string{"foo", "bar"}, + false, + false, + }, + { + "includes string true", + ConditionIncludes, + "foo", + "foo", + true, + false, + }, + { + "includes string false", + ConditionIncludes, + "foo", + "bar", + false, + false, + }, + { + "incldes int true", + ConditionIncludes, + 1024, + 1024, + true, + false, + }, + { + "includes int false", + ConditionIncludes, + 1024, + 2048, + false, + false, + }, + { + "includes real true", + ConditionIncludes, + float64(42.0), + float64(42.0), + true, + false, + }, + { + "includes real false", + ConditionIncludes, + float64(42.0), + float64(420.0), + false, + false, + }, + { + "includes map true", + ConditionIncludes, + map[interface{}]interface{}{1: "bar", "bar": "baz", "baz": "quux"}, + map[interface{}]interface{}{1: "bar"}, + true, + false, + }, + { + "includes map false", + ConditionIncludes, + map[string]string{"foo": "bar", "bar": "baz", "baz": "quux"}, + map[string]string{"quux": "foobar"}, + false, + false, + }, + { + "includes slice true", + ConditionIncludes, + []string{"foo", "bar", "baz", "quux"}, + []string{"foo", "bar"}, + true, + false, + }, + { + "includes slice false", + ConditionIncludes, + []string{"foo", "bar", "baz", "quux"}, + []string{"foobar", "quux"}, + false, + false, + }, + { + "excludes string true", + ConditionExcludes, + "foo", + "bar", + true, + false, + }, + { + "excludes string false", + ConditionExcludes, + "foo", + "foo", + false, + false, + }, + { + "excludes int true", + ConditionExcludes, + 1024, + 2048, + true, + false, + }, + { + "excludes int false", + ConditionExcludes, + 1024, + 1024, + false, + false, + }, + { + "excludes real true", + ConditionExcludes, + float64(42.0), + float64(24.0), + true, + false, + }, + { + "excludes real false", + ConditionExcludes, + float64(42.0), + float64(42.0), + false, + false, + }, + { + "excludes map true", + ConditionExcludes, + map[interface{}]interface{}{1: "bar", "bar": "baz", "baz": "quux"}, + map[interface{}]interface{}{1: "foo"}, + true, + false, + }, + { + "excludes map false", + ConditionExcludes, + map[string]string{"foo": "bar", "bar": "baz", "baz": "quux"}, + map[string]string{"foo": "bar"}, + false, + false, + }, + { + "excludes slice true", + ConditionExcludes, + []string{"foo", "bar", "baz", "quux"}, + []string{"foobar"}, + true, + false, + }, + { + "excludes slice false", + ConditionExcludes, + []string{"foobar", "bar", "baz", "quux"}, + []string{"foobar", "quux"}, + false, + false, + }, + { + "lt unsuported", + ConditionLessThan, + "foo", + "foo", + false, + true, + }, + { + "lteq unsupported", + ConditionLessThanOrEqual, + []string{"foo"}, + []string{"foo"}, + false, + true, + }, + { + "gt unsupported", + ConditionGreaterThan, + map[string]string{"foo": "foo"}, + map[string]string{"foo": "foo"}, + false, + true, + }, + { + "gteq unsupported", + ConditionGreaterThanOrEqual, + true, + true, + false, + true, + }, + { + "lt true", + ConditionLessThan, + 0, + 42, + true, + false, + }, + { + "lteq true", + ConditionLessThanOrEqual, + 42, + 42, + true, + false, + }, + { + "gt true", + ConditionGreaterThan, + float64(420.0), + float64(42.0), + true, + false, + }, + { + "gteq true", + ConditionGreaterThanOrEqual, + float64(420.00), + float64(419.99), + true, + false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.c.Evaluate(tt.a, tt.b) + if (err != nil) != tt.wantErr { + t.Errorf("ConditionFunction.Evaluate() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("ConditionFunction.Evaluate() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestSliceContains(t *testing.T) { + tests := []struct { + name string + a interface{} + b interface{} + want bool + }{ + { + "string slice", + []string{"foo", "bar", "baz"}, + []string{"foo", "bar"}, + true, + }, + { + "int slice", + []int{1, 2, 3}, + []int{2, 3}, + true, + }, + { + "real slice", + []float64{42.0, 42.0, 24.0}, + []float64{42.0, 24.0}, + true, + }, + { + "interface slice", + []interface{}{1, "bar", "baz"}, + []interface{}{1, "bar"}, + true, + }, + { + "no match", + []interface{}{1, "bar", "baz"}, + []interface{}{2, "bar"}, + false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + x := reflect.ValueOf(tt.a) + y := reflect.ValueOf(tt.b) + if got := sliceContains(x, y); got != tt.want { + t.Errorf("compareSlice() = %v, want %v", got, tt.want) + } + }) + } +} +func TestMapContains(t *testing.T) { + tests := []struct { + name string + a interface{} + b interface{} + want bool + }{ + { + "string map", + map[string]string{"foo": "bar", "bar": "baz"}, + map[string]string{"foo": "bar"}, + true, + }, + { + "int keys", + map[int]string{1: "bar", 2: "baz"}, + map[int]string{1: "bar"}, + true, + }, + { + "interface keys", + map[interface{}]interface{}{1: 1024, 2: "baz"}, + map[interface{}]interface{}{2: "baz"}, + true, + }, + { + "no key match", + map[string]string{"foo": "bar", "bar": "baz"}, + map[string]string{"quux": "bar"}, + false, + }, + { + "no value match", + map[string]string{"foo": "bar", "bar": "baz"}, + map[string]string{"foo": "quux"}, + false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + x := reflect.ValueOf(tt.a) + y := reflect.ValueOf(tt.b) + if got := mapContains(x, y); got != tt.want { + t.Errorf("mapContains() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/ovsdb/mutation.go b/ovsdb/mutation.go index 64441b47..cb752f6b 100644 --- a/ovsdb/mutation.go +++ b/ovsdb/mutation.go @@ -3,18 +3,19 @@ package ovsdb import ( "encoding/json" "fmt" + "reflect" ) type Mutator string const ( - MutateOperationDelete Mutator = "delete" - MutateOperationInsert Mutator = "insert" - MutateOperationAdd Mutator = "+=" - MutateOperationSubstract Mutator = "-=" - MutateOperationMultiply Mutator = "*=" - MutateOperationDivide Mutator = "/=" - MutateOperationModulo Mutator = "%=" + MutateOperationDelete Mutator = "delete" + MutateOperationInsert Mutator = "insert" + MutateOperationAdd Mutator = "+=" + MutateOperationSubtract Mutator = "-=" + MutateOperationMultiply Mutator = "*=" + MutateOperationDivide Mutator = "/=" + MutateOperationModulo Mutator = "%=" ) // Mutation is described in RFC 7047: 5.1 @@ -40,7 +41,7 @@ func (m Mutation) MarshalJSON() ([]byte, error) { } // UnmarshalJSON converts a 3 element JSON array to a Mutation -func (m Mutation) UnmarshalJSON(b []byte) error { +func (m *Mutation) UnmarshalJSON(b []byte) error { var v []interface{} err := json.Unmarshal(b, &v) if err != nil { @@ -60,17 +61,21 @@ func (m Mutation) UnmarshalJSON(b []byte) error { } mutator := Mutator(mutatorString) switch mutator { - case MutateOperationDelete: - case MutateOperationInsert: - case MutateOperationAdd: - case MutateOperationSubstract: - case MutateOperationMultiply: - case MutateOperationDivide: - case MutateOperationModulo: + case MutateOperationDelete, + MutateOperationInsert, + MutateOperationAdd, + MutateOperationSubtract, + MutateOperationMultiply, + MutateOperationDivide, + MutateOperationModulo: m.Mutator = mutator default: return fmt.Errorf("%s is not a valid mutator", mutator) } - m.Value = v[2] + vv, err := interfaceToOVSDBNotationInterface(reflect.ValueOf(v[2])) + if err != nil { + return err + } + m.Value = vv return nil } diff --git a/ovsdb/mutation_test.go b/ovsdb/mutation_test.go new file mode 100644 index 00000000..a70c835b --- /dev/null +++ b/ovsdb/mutation_test.go @@ -0,0 +1,114 @@ +package ovsdb + +import ( + "encoding/json" + "reflect" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestMutationMarshalUnmarshalJSON(t *testing.T) { + tests := []struct { + name string + mutation Mutation + want string + wantErr bool + }{ + { + "test delete", + Mutation{"foo", MutateOperationDelete, "bar"}, + `[ "foo", "delete", "bar" ]`, + false, + }, + { + "test insert", + Mutation{"foo", MutateOperationInsert, "bar"}, + `[ "foo", "insert", "bar" ]`, + false, + }, + { + "test add", + Mutation{"foo", MutateOperationAdd, "bar"}, + `[ "foo", "+=", "bar" ]`, + false, + }, + { + "test subtract", + Mutation{"foo", MutateOperationSubtract, "bar"}, + `[ "foo", "-=", "bar" ]`, + false, + }, + { + "test multiply", + Mutation{"foo", MutateOperationMultiply, "bar"}, + `[ "foo", "*=", "bar" ]`, + false, + }, + { + "test divide", + Mutation{"foo", MutateOperationDivide, "bar"}, + `[ "foo", "/=", "bar" ]`, + false, + }, + { + "test modulo", + Mutation{"foo", MutateOperationModulo, "bar"}, + `[ "foo", "%=", "bar" ]`, + false, + }, + { + "test uuid", + Mutation{"foo", MutateOperationInsert, UUID{GoUUID: "foo"}}, + `[ "foo", "insert", ["named-uuid", "foo"] ]`, + false, + }, + { + "test set", + Mutation{"foo", MutateOperationInsert, OvsSet{GoSet: []interface{}{"foo", "bar", "baz"}}}, + `[ "foo", "insert", ["set",["foo", "bar", "baz"]] ]`, + false, + }, + { + "test map", + Mutation{"foo", MutateOperationInsert, OvsMap{GoMap: map[interface{}]interface{}{"foo": "bar", "baz": "quux"}}}, + `[ "foo", "insert", ["map",[["foo", "bar"], ["baz", "quux"]]]]`, + false, + }, + { + "test uuid set", + Mutation{"foo", MutateOperationInsert, OvsSet{GoSet: []interface{}{UUID{GoUUID: "foo"}, UUID{GoUUID: "bar"}}}}, + `[ "foo", "insert", ["set",[["named-uuid", "foo"], ["named-uuid", "bar"]]] ]`, + false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := json.Marshal(tt.mutation) + if err != nil { + t.Fatal(err) + } + // testing JSON equality is flaky for ovsdb notated maps + // it's safe to skip this as we test from json->object later + if tt.name != "test map" { + assert.JSONEq(t, tt.want, string(got)) + } + var c Mutation + if err := json.Unmarshal(got, &c); err != nil { + t.Fatal(err) + } + assert.Equal(t, tt.mutation.Column, c.Column) + assert.Equal(t, tt.mutation.Mutator, c.Mutator) + v := reflect.TypeOf(tt.mutation.Value) + vv := reflect.ValueOf(c.Value) + if !vv.IsValid() { + t.Fatalf("c.Value is empty: %v", c.Value) + } + assert.Equal(t, v, vv.Type()) + assert.Equal(t, tt.mutation.Value, vv.Convert(v).Interface()) + if vv.Kind() == reflect.String { + assert.Equal(t, tt.mutation.Value, vv.String()) + } + }) + } +} diff --git a/ovsdb/notation.go b/ovsdb/notation.go index 90f03752..40ab6a19 100644 --- a/ovsdb/notation.go +++ b/ovsdb/notation.go @@ -2,6 +2,8 @@ package ovsdb import ( "encoding/json" + "fmt" + "reflect" ) const ( @@ -121,3 +123,74 @@ func ovsSliceToGoNotation(val interface{}) (interface{}, error) { } return val, nil } + +// interfaceToSetMapOrUUIDInterface takes a reflect.Value and converts it to +// the correct OVSDB Notation (Set, Map, UUID) using reflection +func interfaceToOVSDBNotationInterface(v reflect.Value) (interface{}, error) { + // if value is a scalar value, it will be an interface that can + // be type asserted back to string, float64, int etc... + if v.Kind() != reflect.Slice && v.Kind() != reflect.Array { + return v.Interface(), nil + } + // if its a set, map or uuid here we need to convert it to the correct type, not []interface{} + s := v.Slice(0, v.Len()) + first := s.Index(0) + // assert that our first element is a string value + if first.Elem().Kind() != reflect.String { + return nil, fmt.Errorf("first element of array/slice is not a string: %v %s", first, first.Kind().String()) + } + switch first.Elem().String() { + case "uuid", "named-uuid": + uuid := s.Index(1).Elem().String() + return UUID{GoUUID: uuid}, nil + case "set": + // second is the second element of the slice + second := s.Index(1).Elem() + // in a set, it must be a slice + if second.Kind() != reflect.Slice && second.Kind() != reflect.Array { + return nil, fmt.Errorf("second element of set is not a slice") + } + ss := second.Slice(0, second.Len()) + + // check first index of second element + // if it's not a slice or array this is a set of scalar values + if ss.Index(0).Elem().Kind() != reflect.Slice && ss.Index(0).Elem().Kind() != reflect.Array { + si := second.Interface() + set := OvsSet{GoSet: si.([]interface{})} + return set, nil + } + innerSet := []interface{}{} + // iterate over the slice and extract the uuid, adding a UUID object to our innerSet + for i := 0; i < ss.Len(); i++ { + uuid := ss.Index(i).Elem().Index(1).Elem().String() + innerSet = append(innerSet, UUID{GoUUID: uuid}) + } + return OvsSet{GoSet: innerSet}, nil + case "map": + ovsMap := OvsMap{GoMap: make(map[interface{}]interface{})} + second := s.Index(1).Elem() + for i := 0; i < second.Len(); i++ { + pair := second.Index(i).Elem().Slice(0, 2) + var key interface{} + // check if key is slice or array, in which case we can infer that it's a UUUID + if pair.Index(0).Elem().Kind() == reflect.Slice || pair.Index(0).Elem().Kind() == reflect.Array { + uuid := pair.Index(0).Elem().Index(1).Elem().String() + key = UUID{GoUUID: uuid} + } else { + key = pair.Index(0).Interface() + } + // check if value is slice or array, in which case we can infer that it's a UUUID + var value interface{} + if pair.Index(1).Elem().Kind() == reflect.Slice || pair.Index(1).Elem().Kind() == reflect.Array { + uuid := pair.Index(1).Elem().Index(1).Elem().String() + value = UUID{GoUUID: uuid} + } else { + value = pair.Index(1).Elem().Interface() + } + ovsMap.GoMap[key] = value + } + return ovsMap, nil + default: + return nil, fmt.Errorf("unsupported notation. expected ,, or . got %v", v) + } +} diff --git a/ovsdb/notation_test.go b/ovsdb/notation_test.go index 59f0ab10..7acb4995 100644 --- a/ovsdb/notation_test.go +++ b/ovsdb/notation_test.go @@ -3,7 +3,10 @@ package ovsdb import ( "encoding/json" "log" + "reflect" "testing" + + "github.com/stretchr/testify/assert" ) func TestOpRowSerialization(t *testing.T) { @@ -172,3 +175,86 @@ func TestNewMutation(t *testing.T) { t.Error("mutation is not correctly formatted") } } + +func TestOperationsMarshalUnmarshalJSON(t *testing.T) { + in := []byte(`{"op":"mutate","table":"Open_vSwitch","mutations":[["bridges","insert",["named-uuid","foo"]]],"where":[["_uuid","==",["named-uuid","ovs"]]]}`) + var op Operation + err := json.Unmarshal(in, &op) + if err != nil { + t.Fatal(err) + } + assert.Equal(t, OperationMutate, op.Op) + assert.Equal(t, "Open_vSwitch", op.Table) + assert.Equal(t, 1, len(op.Mutations)) + assert.Equal(t, Mutation{ + Column: "bridges", + Mutator: OperationInsert, + Value: UUID{GoUUID: "foo"}, + }, op.Mutations[0]) +} + +func TestInterfaceToOVSDBNotationInterface(t *testing.T) { + tests := []struct { + name string + value interface{} + want interface{} + wantErr bool + }{ + { + "scalar value", + "foo", + "foo", + false, + }, + { + "set", + []interface{}{"set", []interface{}{"foo", "bar", "baz"}}, + OvsSet{GoSet: []interface{}{"foo", "bar", "baz"}}, + false, + }, + { + "uuid set", + []interface{}{"set", []interface{}{[]interface{}{"named-uuid", "foo"}, []interface{}{"named-uuid", "bar"}}}, + OvsSet{GoSet: []interface{}{UUID{GoUUID: "foo"}, UUID{GoUUID: "bar"}}}, + false, + }, + { + "map", + []interface{}{"map", []interface{}{[]interface{}{"foo", "bar"}, []interface{}{"baz", "quux"}}}, + OvsMap{GoMap: map[interface{}]interface{}{"foo": "bar", "baz": "quux"}}, + false, + }, + { + "map uuid values", + []interface{}{"map", []interface{}{[]interface{}{"foo", []interface{}{"named-uuid", "bar"}}, []interface{}{"baz", []interface{}{"named-uuid", "quux"}}}}, + OvsMap{GoMap: map[interface{}]interface{}{"foo": UUID{GoUUID: "bar"}, "baz": UUID{GoUUID: "quux"}}}, + false, + }, + { + "map uuid keys", + []interface{}{"map", []interface{}{[]interface{}{[]interface{}{"named-uuid", "bar"}, "foo"}, []interface{}{[]interface{}{"named-uuid", "quux"}, "baz"}}}, + OvsMap{GoMap: map[interface{}]interface{}{UUID{GoUUID: "bar"}: "foo", UUID{GoUUID: "quux"}: "baz"}}, + false, + }, + { + "map uuid keys and values", + []interface{}{"map", []interface{}{[]interface{}{[]interface{}{"named-uuid", "bar"}, "foo"}, []interface{}{[]interface{}{"named-uuid", "quux"}, "baz"}}}, + OvsMap{GoMap: map[interface{}]interface{}{UUID{GoUUID: "bar"}: "foo", UUID{GoUUID: "quux"}: "baz"}}, + false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + v := reflect.ValueOf(tt.value) + got, err := interfaceToOVSDBNotationInterface(v) + if (err != nil) != tt.wantErr { + t.Errorf("interfaceToOVSDBNotationInterface() error = %v, wantErr %v", err, tt.wantErr) + return + } + wantValue := reflect.ValueOf(tt.want) + gotValue := reflect.ValueOf(got) + assert.Equal(t, wantValue.Type(), gotValue.Type()) + assert.Equal(t, wantValue.Interface(), gotValue.Interface()) + }) + } +} diff --git a/ovsdb/uuid.go b/ovsdb/uuid.go index f58e323b..9caba43d 100644 --- a/ovsdb/uuid.go +++ b/ovsdb/uuid.go @@ -6,6 +6,8 @@ import ( "regexp" ) +var validUUID = regexp.MustCompile(`^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$`) + // UUID is a UUID according to RFC7047 type UUID struct { GoUUID string `json:"uuid"` @@ -38,11 +40,13 @@ func (u UUID) validateUUID() error { return fmt.Errorf("uuid exceeds 36 characters") } - var validUUID = regexp.MustCompile(`^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$`) - if !validUUID.MatchString(u.GoUUID) { return fmt.Errorf("uuid does not match regexp") } return nil } + +func isNamed(uuid string) bool { + return len(uuid) > 0 && !validUUID.MatchString(uuid) +} diff --git a/ovsdb/uuid_test.go b/ovsdb/uuid_test.go new file mode 100644 index 00000000..7ea96787 --- /dev/null +++ b/ovsdb/uuid_test.go @@ -0,0 +1,34 @@ +package ovsdb + +import "testing" + +func TestUUIDIsNamed(t *testing.T) { + tests := []struct { + name string + uuid string + want bool + }{ + { + "named", + "foo", + true, + }, + { + "named", + aUUID0, + false, + }, + { + "empty", + "", + false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := isNamed(tt.uuid); got != tt.want { + t.Errorf("UUID.Named() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/server/database.go b/server/database.go new file mode 100644 index 00000000..bb9ad9b7 --- /dev/null +++ b/server/database.go @@ -0,0 +1,579 @@ +package server + +import ( + "errors" + "fmt" + "reflect" + "sync" + + "github.com/google/uuid" + "github.com/ovn-org/libovsdb/cache" + "github.com/ovn-org/libovsdb/mapper" + "github.com/ovn-org/libovsdb/model" + "github.com/ovn-org/libovsdb/ovsdb" +) + +var ( + ErrNotImplemented = errors.New("not implemented") +) + +// Database abstracts database operations from ovsdb +type Database interface { + CreateDatabase(name string, model *ovsdb.DatabaseSchema) error + Exists(name string) bool + Transact(database string, operations []ovsdb.Operation) ([]ovsdb.OperationResult, ovsdb.TableUpdates) + Select(database string, table string, where []ovsdb.Condition, columns []string) ovsdb.OperationResult + Insert(database string, table string, uuidName string, row ovsdb.Row) (ovsdb.OperationResult, ovsdb.TableUpdates) + Update(database, table string, where []ovsdb.Condition, row ovsdb.Row) (ovsdb.OperationResult, ovsdb.TableUpdates) + Mutate(database, table string, where []ovsdb.Condition, mutations []ovsdb.Mutation) (ovsdb.OperationResult, ovsdb.TableUpdates) + Delete(database, table string, where []ovsdb.Condition) (ovsdb.OperationResult, ovsdb.TableUpdates) + Wait(database, table string, timeout int, conditions []ovsdb.Condition, columns []string, until string, rows []ovsdb.Row) ovsdb.OperationResult + Commit(database, table string, durable bool) ovsdb.OperationResult + Abort(database, table string) ovsdb.OperationResult + Comment(database, table string, comment string) ovsdb.OperationResult + Assert(database, table, lock string) ovsdb.OperationResult +} + +type inMemoryDatabase struct { + databases map[string]*cache.TableCache + models map[string]*model.DBModel + mutex sync.RWMutex +} + +func NewInMemoryDatabase(models map[string]*model.DBModel) Database { + return &inMemoryDatabase{ + databases: make(map[string]*cache.TableCache), + models: models, + mutex: sync.RWMutex{}, + } +} + +func (db *inMemoryDatabase) CreateDatabase(name string, schema *ovsdb.DatabaseSchema) error { + db.mutex.Lock() + defer db.mutex.Unlock() + var mo *model.DBModel + var ok bool + if mo, ok = db.models[schema.Name]; !ok { + return fmt.Errorf("no db model provided for schema with name %s", name) + } + database, err := cache.NewTableCache(schema, mo, nil) + if err != nil { + return nil + } + db.databases[name] = database + return nil +} + +func (db *inMemoryDatabase) Exists(name string) bool { + db.mutex.RLock() + defer db.mutex.RUnlock() + _, ok := db.databases[name] + return ok +} + +func (db *inMemoryDatabase) Transact(name string, operations []ovsdb.Operation) ([]ovsdb.OperationResult, ovsdb.TableUpdates) { + db.mutex.Lock() + defer db.mutex.Unlock() + results := []ovsdb.OperationResult{} + updates := make(ovsdb.TableUpdates) + for _, op := range operations { + switch op.Op { + case ovsdb.OperationInsert: + r, tu := db.Insert(name, op.Table, op.UUIDName, op.Row) + results = append(results, r) + if tu != nil { + updates.Merge(tu) + } + case ovsdb.OperationSelect: + r := db.Select(name, op.Table, op.Where, op.Columns) + results = append(results, r) + case ovsdb.OperationUpdate: + r, tu := db.Update(name, op.Table, op.Where, op.Row) + results = append(results, r) + if tu != nil { + updates.Merge(tu) + } + case ovsdb.OperationMutate: + r, tu := db.Mutate(name, op.Table, op.Where, op.Mutations) + results = append(results, r) + if tu != nil { + updates.Merge(tu) + } + case ovsdb.OperationDelete: + r, tu := db.Delete(name, op.Table, op.Where) + results = append(results, r) + if tu != nil { + updates.Merge(tu) + } + case ovsdb.OperationWait: + r := db.Wait(name, op.Table, op.Timeout, op.Where, op.Columns, op.Until, op.Rows) + results = append(results, r) + case ovsdb.OperationCommit: + durable := op.Durable + r := db.Commit(name, op.Table, *durable) + results = append(results, r) + case ovsdb.OperationAbort: + r := db.Abort(name, op.Table) + results = append(results, r) + case ovsdb.OperationComment: + r := db.Comment(name, op.Table, *op.Comment) + results = append(results, r) + case ovsdb.OperationAssert: + r := db.Assert(name, op.Table, *op.Lock) + results = append(results, r) + default: + return nil, updates + } + } + return results, updates +} + +func (db *inMemoryDatabase) Insert(database string, table string, rowUUID string, row ovsdb.Row) (ovsdb.OperationResult, ovsdb.TableUpdates) { + var targetDb *cache.TableCache + var ok bool + if targetDb, ok = db.databases[database]; !ok { + return ovsdb.OperationResult{ + Error: "database does not exist", + }, nil + } + if rowUUID == "" { + rowUUID = uuid.NewString() + } + model, err := targetDb.CreateModel(table, &row, rowUUID) + if err != nil { + return ovsdb.OperationResult{ + Error: err.Error(), + }, nil + } + + // insert in to db + if err := targetDb.Table(table).Create(rowUUID, model); err != nil { + panic(err) + } + + resultRow, err := targetDb.Mapper().NewRow(table, model) + if err != nil { + return ovsdb.OperationResult{ + Error: err.Error(), + }, nil + } + + result := ovsdb.OperationResult{ + UUID: ovsdb.UUID{GoUUID: rowUUID}, + } + return result, ovsdb.TableUpdates{ + table: { + rowUUID: { + New: &resultRow, + Old: nil, + }, + }, + } +} + +func (db *inMemoryDatabase) Select(database string, table string, where []ovsdb.Condition, columns []string) ovsdb.OperationResult { + var targetDb *cache.TableCache + var ok bool + if targetDb, ok = db.databases[database]; !ok { + return ovsdb.OperationResult{ + Error: "database does not exist", + } + } + + var results []ovsdb.Row + rows, err := matchCondition(targetDb, table, where) + if err != nil { + panic(err) + } + for _, row := range rows { + resultRow, err := targetDb.Mapper().NewRow(table, row) + if err != nil { + panic(err) + } + results = append(results, resultRow) + } + return ovsdb.OperationResult{ + Rows: results, + } +} + +func (db *inMemoryDatabase) Update(database, table string, where []ovsdb.Condition, row ovsdb.Row) (ovsdb.OperationResult, ovsdb.TableUpdates) { + var targetDb *cache.TableCache + var ok bool + if targetDb, ok = db.databases[database]; !ok { + return ovsdb.OperationResult{ + Error: "database does not exist", + }, nil + } + + schema := targetDb.Mapper().Schema.Table(table) + tableUpdate := make(ovsdb.TableUpdate) + rows, err := matchCondition(targetDb, table, where) + if err != nil { + return ovsdb.OperationResult{ + Error: err.Error(), + }, nil + } + for _, old := range rows { + info, _ := mapper.NewMapperInfo(schema, old) + uuid, _ := info.FieldByColumn("_uuid") + oldRow, err := targetDb.Mapper().NewRow(table, old) + if err != nil { + panic(err) + } + newRow, err := targetDb.Mapper().NewRow(table, row) + if err != nil { + panic(err) + } + if err = targetDb.Table(table).Update(uuid.(string), row); err != nil { + panic(err) + } + tableUpdate.AddRowUpdate(uuid.(string), &ovsdb.RowUpdate{ + Old: &oldRow, + New: &newRow, + }) + } + // FIXME: We need to filter the returned columns + return ovsdb.OperationResult{ + Count: len(rows), + }, ovsdb.TableUpdates{ + table: tableUpdate, + } +} + +func (db *inMemoryDatabase) Mutate(database, table string, where []ovsdb.Condition, mutations []ovsdb.Mutation) (ovsdb.OperationResult, ovsdb.TableUpdates) { + var targetDb *cache.TableCache + var ok bool + if targetDb, ok = db.databases[database]; !ok { + return ovsdb.OperationResult{ + Error: "database does not exist", + }, nil + } + + schema := targetDb.Mapper().Schema.Table(table) + tableUpdate := make(ovsdb.TableUpdate) + + rows, err := matchCondition(targetDb, table, where) + if err != nil { + panic(err) + } + + for _, old := range rows { + info, err := mapper.NewMapperInfo(schema, old) + if err != nil { + panic(err) + } + uuid, _ := info.FieldByColumn("_uuid") + oldRow, err := targetDb.Mapper().NewRow(table, old) + if err != nil { + panic(err) + } + for _, m := range mutations { + column := schema.Column(m.Column) + nativeValue, err := ovsdb.OvsToNative(column, m.Value) + if err != nil { + panic(err) + } + if err := ovsdb.ValidateMutation(column, m.Mutator, nativeValue); err != nil { + panic(err) + } + info, err := mapper.NewMapperInfo(schema, old) + if err != nil { + panic(err) + } + current, err := info.FieldByColumn(m.Column) + if err != nil { + panic(err) + } + new := mutate(current, m.Mutator, nativeValue) + if err := info.SetField(m.Column, new); err != nil { + panic(err) + } + // the field in old has been set, write back to db + err = targetDb.Table(table).Update(uuid.(string), old) + if err != nil { + panic(err) + } + newRow, err := targetDb.Mapper().NewRow(table, old) + if err != nil { + panic(err) + } + tableUpdate.AddRowUpdate(uuid.(string), &ovsdb.RowUpdate{ + Old: &oldRow, + New: &newRow, + }) + } + } + return ovsdb.OperationResult{ + Count: len(rows), + }, ovsdb.TableUpdates{ + table: tableUpdate, + } +} + +func (db *inMemoryDatabase) Delete(database, table string, where []ovsdb.Condition) (ovsdb.OperationResult, ovsdb.TableUpdates) { + var targetDb *cache.TableCache + var ok bool + if targetDb, ok = db.databases[database]; !ok { + return ovsdb.OperationResult{ + Error: "database does not exist", + }, nil + } + + schema := targetDb.Mapper().Schema.Table(table) + tableUpdate := make(ovsdb.TableUpdate) + rows, err := matchCondition(targetDb, table, where) + if err != nil { + panic(err) + } + for _, row := range rows { + info, _ := mapper.NewMapperInfo(schema, row) + uuid, _ := info.FieldByColumn("_uuid") + oldRow, err := targetDb.Mapper().NewRow(table, row) + if err != nil { + panic(err) + } + if err := targetDb.Table(table).Delete(uuid.(string)); err != nil { + panic(err) + } + tableUpdate.AddRowUpdate(uuid.(string), &ovsdb.RowUpdate{ + Old: &oldRow, + New: nil, + }) + } + return ovsdb.OperationResult{ + Count: len(rows), + }, ovsdb.TableUpdates{ + table: tableUpdate, + } +} + +func (db *inMemoryDatabase) Wait(database, table string, timeout int, conditions []ovsdb.Condition, columns []string, until string, rows []ovsdb.Row) ovsdb.OperationResult { + return ovsdb.OperationResult{Error: ErrNotImplemented.Error()} +} + +func (db *inMemoryDatabase) Commit(database, table string, durable bool) ovsdb.OperationResult { + return ovsdb.OperationResult{Error: ErrNotImplemented.Error()} +} + +func (db *inMemoryDatabase) Abort(database, table string) ovsdb.OperationResult { + return ovsdb.OperationResult{Error: ErrNotImplemented.Error()} +} + +func (db *inMemoryDatabase) Comment(database, table string, comment string) ovsdb.OperationResult { + return ovsdb.OperationResult{Error: ErrNotImplemented.Error()} +} + +func (db *inMemoryDatabase) Assert(database, table, lock string) ovsdb.OperationResult { + return ovsdb.OperationResult{Error: ErrNotImplemented.Error()} +} + +func mutate(current interface{}, mutator ovsdb.Mutator, value interface{}) interface{} { + switch current.(type) { + case bool, string: + return current + } + switch mutator { + case ovsdb.MutateOperationInsert: + switch current.(type) { + case int, float64: + return current + } + vc := reflect.ValueOf(current) + vv := reflect.ValueOf(value) + if vc.Kind() == reflect.Slice && vc.Type() == reflect.SliceOf(vv.Type()) { + v := reflect.Append(vc, vv) + return v.Interface() + } + if vc.Kind() == reflect.Slice && vv.Kind() == reflect.Slice { + v := reflect.AppendSlice(vc, vv) + return v.Interface() + } + case ovsdb.MutateOperationDelete: + switch current.(type) { + case int, float64: + return current + } + vc := reflect.ValueOf(current) + vv := reflect.ValueOf(value) + if vc.Kind() == reflect.Slice && vc.Type() == reflect.SliceOf(vv.Type()) { + v := removeFromSlice(vc, vv) + return v.Interface() + } + if vc.Kind() == reflect.Slice && vv.Kind() == reflect.Slice { + v := vc + for i := 0; i < vv.Len(); i++ { + v = removeFromSlice(v, vv.Index(i)) + } + return v.Interface() + } + case ovsdb.MutateOperationAdd: + if i, ok := current.(int); ok { + v := value.(int) + return i + v + } + if i, ok := current.(float64); ok { + v := value.(float64) + return i + v + } + if is, ok := current.([]int); ok { + v := value.(int) + for i, j := range is { + is[i] = j + v + } + return is + } + if is, ok := current.([]float64); ok { + v := value.(float64) + for i, j := range is { + is[i] = j + v + } + return is + } + case ovsdb.MutateOperationSubtract: + if i, ok := current.(int); ok { + v := value.(int) + return i - v + } + if i, ok := current.(float64); ok { + v := value.(float64) + return i - v + } + if is, ok := current.([]int); ok { + v := value.(int) + for i, j := range is { + is[i] = j - v + } + return is + } + if is, ok := current.([]float64); ok { + v := value.(float64) + for i, j := range is { + is[i] = j - v + } + return is + } + case ovsdb.MutateOperationMultiply: + if i, ok := current.(int); ok { + v := value.(int) + return i * v + } + if i, ok := current.(float64); ok { + v := value.(float64) + return i * v + } + if is, ok := current.([]int); ok { + v := value.(int) + for i, j := range is { + is[i] = j * v + } + return is + } + if is, ok := current.([]float64); ok { + v := value.(float64) + for i, j := range is { + is[i] = j * v + } + return is + } + case ovsdb.MutateOperationDivide: + if i, ok := current.(int); ok { + v := value.(int) + return i / v + } + if i, ok := current.(float64); ok { + v := value.(float64) + return i / v + } + if is, ok := current.([]int); ok { + v := value.(int) + for i, j := range is { + is[i] = j / v + } + return is + } + if is, ok := current.([]float64); ok { + v := value.(float64) + for i, j := range is { + is[i] = j / v + } + return is + } + case ovsdb.MutateOperationModulo: + if i, ok := current.(int); ok { + v := value.(int) + return i % v + } + if is, ok := current.([]int); ok { + v := value.(int) + for i, j := range is { + is[i] = j % v + } + return is + } + } + return current +} + +func removeFromSlice(a, b reflect.Value) reflect.Value { + for i := 0; i < a.Len(); i++ { + if a.Index(i).Interface() == b.Interface() { + v := reflect.AppendSlice(a.Slice(0, i), a.Slice(i+1, a.Len())) + return v + } + } + return a +} + +func matchCondition(targetDb *cache.TableCache, table string, conditions []ovsdb.Condition) ([]model.Model, error) { + var results []model.Model + if len(conditions) == 0 { + uuids := targetDb.Table(table).Rows() + for _, uuid := range uuids { + row := targetDb.Table(table).Row(uuid) + results = append(results, row) + } + return results, nil + } + + for _, condition := range conditions { + if condition.Column == "_uuid" { + ovsdbUUID, ok := condition.Value.(ovsdb.UUID) + if !ok { + panic(fmt.Sprintf("%+v is not an ovsdb uuid", ovsdbUUID)) + } + uuid := ovsdbUUID.GoUUID + for _, k := range targetDb.Table(table).Rows() { + ok, err := condition.Function.Evaluate(k, uuid) + if err != nil { + return nil, err + } + if ok { + row := targetDb.Table(table).Row(k) + results = append(results, row) + } + } + } else { + index, err := targetDb.Table(table).Index(condition.Column) + if err != nil { + return nil, fmt.Errorf("conditions on non-index fields not supported") + } + for k, v := range index { + tSchema := targetDb.Mapper().Schema.Tables[table].Columns[condition.Column] + nativeValue, err := ovsdb.OvsToNative(tSchema, condition.Value) + if err != nil { + return nil, err + } + ok, err := condition.Function.Evaluate(k, nativeValue) + if err != nil { + return nil, err + } + if ok { + row := targetDb.Table(table).Row(v) + results = append(results, row) + } + } + } + } + return results, nil +} diff --git a/server/database_test.go b/server/database_test.go new file mode 100644 index 00000000..08dd36de --- /dev/null +++ b/server/database_test.go @@ -0,0 +1,250 @@ +package server + +import ( + "reflect" + "testing" + + "github.com/google/uuid" + "github.com/ovn-org/libovsdb/mapper" + "github.com/ovn-org/libovsdb/model" + "github.com/ovn-org/libovsdb/ovsdb" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestMutate(t *testing.T) { + tests := []struct { + name string + current interface{} + mutator ovsdb.Mutator + value interface{} + want interface{} + }{ + { + "add int", + 1, + ovsdb.MutateOperationAdd, + 1, + 2, + }, + { + "add float", + 1.0, + ovsdb.MutateOperationAdd, + 1.0, + 2.0, + }, + { + "add float set", + []float64{1.0, 2.0, 3.0}, + ovsdb.MutateOperationAdd, + 1.0, + []float64{2.0, 3.0, 4.0}, + }, + { + "add int set float", + []int{1, 2, 3}, + ovsdb.MutateOperationAdd, + 1, + []int{2, 3, 4}, + }, + { + "subtract int", + 1, + ovsdb.MutateOperationSubtract, + 1, + 0, + }, + { + "subtract float", + 1.0, + ovsdb.MutateOperationSubtract, + 1.0, + 0.0, + }, + { + "subtract float set", + []float64{1.0, 2.0, 3.0}, + ovsdb.MutateOperationSubtract, + 1.0, + []float64{0.0, 1.0, 2.0}, + }, + { + "subtract int set", + []int{1, 2, 3}, + ovsdb.MutateOperationSubtract, + 1, + []int{0, 1, 2}, + }, + { + "multiply int", + 1, + ovsdb.MutateOperationMultiply, + 2, + 2, + }, + { + "multiply float", + 1.0, + ovsdb.MutateOperationMultiply, + 2.0, + 2.0, + }, + { + "multiply float set", + []float64{1.0, 2.0, 3.0}, + ovsdb.MutateOperationMultiply, + 2.0, + []float64{2.0, 4.0, 6.0}, + }, + { + "multiply int set", + []int{1, 2, 3}, + ovsdb.MutateOperationMultiply, + 2, + []int{2, 4, 6}, + }, + { + "divide int", + 10, + ovsdb.MutateOperationDivide, + 2, + 5, + }, + { + "divide float", + 1.0, + ovsdb.MutateOperationDivide, + 2.0, + 0.5, + }, + { + "divide float set", + []float64{1.0, 2.0, 4.0}, + ovsdb.MutateOperationDivide, + 2.0, + []float64{0.5, 1.0, 2.0}, + }, + { + "divide int set", + []int{10, 20, 30}, + ovsdb.MutateOperationDivide, + 5, + []int{2, 4, 6}, + }, + { + "modulo int", + 3, + ovsdb.MutateOperationModulo, + 2, + 1, + }, + { + "modulo int set", + []int{3, 5, 7}, + ovsdb.MutateOperationModulo, + 2, + []int{1, 1, 1}, + }, + { + "insert single string", + []string{"foo", "bar"}, + ovsdb.MutateOperationInsert, + "baz", + []string{"foo", "bar", "baz"}, + }, + { + "insert multiple string", + []string{"foo", "bar"}, + ovsdb.MutateOperationInsert, + []string{"baz", "quux"}, + []string{"foo", "bar", "baz", "quux"}, + }, + { + "delete single string", + []string{"foo", "bar"}, + ovsdb.MutateOperationDelete, + "bar", + []string{"foo"}, + }, + { + "delete multiple string", + []string{"foo", "bar", "baz"}, + ovsdb.MutateOperationDelete, + []string{"bar", "baz"}, + []string{"foo"}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := mutate(tt.current, tt.mutator, tt.value) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("mutate() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestMutateOp(t *testing.T) { + defDB, err := model.NewDBModel("Open_vSwitch", map[string]model.Model{ + "Open_vSwitch": &ovsType{}, + "Bridge": &bridgeType{}}) + require.Nil(t, err) + db := NewInMemoryDatabase(map[string]*model.DBModel{"Open_vSwitch": defDB}) + schema, err := getSchema() + require.Nil(t, err) + + err = db.CreateDatabase("Open_vSwitch", schema) + require.Nil(t, err) + + ovsUUID := uuid.NewString() + bridgeUUID := uuid.NewString() + + m := mapper.NewMapper(schema) + + ovs := ovsType{} + ovsRow, err := m.NewRow("Open_vSwitch", &ovs) + require.Nil(t, err) + + bridge := bridgeType{ + Name: "foo", + } + bridgeRow, err := m.NewRow("Bridge", &bridge) + require.Nil(t, err) + + res, _ := db.Insert("Open_vSwitch", "Open_vSwitch", ovsUUID, ovsRow) + _, err = ovsdb.CheckOperationResults([]ovsdb.OperationResult{res}, []ovsdb.Operation{{Op: "insert"}}) + require.Nil(t, err) + + res, _ = db.Insert("Open_vSwitch", "Bridge", bridgeUUID, bridgeRow) + _, err = ovsdb.CheckOperationResults([]ovsdb.OperationResult{res}, []ovsdb.Operation{{Op: "insert"}}) + require.Nil(t, err) + + gotResult, gotUpdate := db.Mutate( + "Open_vSwitch", + "Open_vSwitch", + []ovsdb.Condition{ + ovsdb.NewCondition("_uuid", ovsdb.ConditionEqual, ovsdb.UUID{GoUUID: ovsUUID}), + }, + []ovsdb.Mutation{ + *ovsdb.NewMutation("bridges", ovsdb.MutateOperationInsert, ovsdb.UUID{GoUUID: bridgeUUID}), + }, + ) + assert.Equal(t, ovsdb.OperationResult{Count: 1}, gotResult) + + bridgeSet, err := ovsdb.NewOvsSet([]ovsdb.UUID{{GoUUID: bridgeUUID}}) + assert.Nil(t, err) + assert.Equal(t, ovsdb.TableUpdates{ + "Open_vSwitch": ovsdb.TableUpdate{ + ovsUUID: &ovsdb.RowUpdate{ + Old: &ovsdb.Row{ + "_uuid": ovsdb.UUID{GoUUID: ovsUUID}, + }, + New: &ovsdb.Row{ + "_uuid": ovsdb.UUID{GoUUID: ovsUUID}, + "bridges": bridgeSet, + }, + }, + }, + }, gotUpdate) +} diff --git a/server/doc.go b/server/doc.go new file mode 100644 index 00000000..a4af0953 --- /dev/null +++ b/server/doc.go @@ -0,0 +1,8 @@ +/* +Package server provides an alpha-quality implementation of an OVSDB Server + +It is designed only to be used for testing the functionality of the client +library such that assertions can be made on the cache that backs the +client's monitor or the server +*/ +package server diff --git a/server/monitor.go b/server/monitor.go new file mode 100644 index 00000000..70cc3c75 --- /dev/null +++ b/server/monitor.go @@ -0,0 +1,123 @@ +package server + +import ( + "fmt" + "log" + "sync" + + "github.com/cenkalti/rpc2" + "github.com/ovn-org/libovsdb/ovsdb" +) + +// connectionMonitors maps a connection to a map or monitors +type connectionMonitors struct { + monitors map[string]*monitor + mu sync.RWMutex +} + +func newConnectionMonitors() *connectionMonitors { + return &connectionMonitors{ + monitors: make(map[string]*monitor), + mu: sync.RWMutex{}, + } +} + +// monitor represents a connection to a client where db changes +// will be reflected +type monitor struct { + id string + request map[string]*ovsdb.MonitorRequest + client *rpc2.Client + updates chan ovsdb.TableUpdates + stopCh chan struct{} +} + +func newMonitor(id string, request map[string]*ovsdb.MonitorRequest, client *rpc2.Client) *monitor { + m := &monitor{ + id: id, + request: request, + client: client, + updates: make(chan ovsdb.TableUpdates), + stopCh: make(chan struct{}, 1), + } + go m.sendUpdates() + return m +} + +func (m *monitor) sendUpdates() { + for { + select { + case update := <-m.updates: + args := []interface{}{m.id, update} + var reply interface{} + err := m.client.Call("update", args, &reply) + if err != nil { + log.Printf("client error handling update rpc: %v", err) + } + case <-m.stopCh: + return + } + } +} + +// Enqueue will enqueue an update if it matches the tables and monitor select arguments +// we take the update by value (not reference) so we can mutate it in place before +// queuing it for dispatch +func (m *monitor) Enqueue(update ovsdb.TableUpdates) { + // remove updates for tables that we aren't watching + if len(m.request) != 0 { + m.filter(update) + } + if len(update) == 0 { + return + } + m.updates <- update +} + +func (m *monitor) filter(update ovsdb.TableUpdates) { + // remove updates for tables that we aren't watching + if len(m.request) != 0 { + for table, u := range update { + if _, ok := m.request[table]; !ok { + fmt.Println("dropping table update") + delete(update, table) + continue + } + for uuid, row := range u { + switch { + case row.Insert() && m.request[table].Select.Insert(): + fallthrough + case row.Modify() && m.request[table].Select.Modify(): + fallthrough + case row.Delete() && m.request[table].Select.Delete(): + if len(m.request[table].Columns) > 0 { + cols := make(map[string]bool) + for _, c := range m.request[table].Columns { + cols[c] = true + } + if row.New != nil { + new := *row.New + for k := range new { + if _, ok := cols[k]; !ok { + delete(new, k) + } + } + update[table][uuid].New = &new + } + if row.Old != nil { + old := *row.Old + for k := range old { + if _, ok := cols[k]; !ok { + delete(old, k) + } + } + update[table][uuid].Old = &old + } + } + default: + delete(u, uuid) + } + } + } + } +} diff --git a/server/monitor_test.go b/server/monitor_test.go new file mode 100644 index 00000000..ffac715a --- /dev/null +++ b/server/monitor_test.go @@ -0,0 +1,85 @@ +package server + +import ( + "testing" + + "github.com/ovn-org/libovsdb/ovsdb" + "github.com/stretchr/testify/assert" +) + +func TestMonitorFilter(t *testing.T) { + monitor := monitor{ + request: map[string]*ovsdb.MonitorRequest{ + "Bridge": { + Columns: []string{"name"}, + Select: ovsdb.NewDefaultMonitorSelect(), + }, + }, + } + bridgeRow := ovsdb.Row{ + "_uuid": "foo", + "name": "bar", + } + bridgeRowWithIDs := ovsdb.Row{ + "_uuid": "foo", + "name": "bar", + "external_ids": map[string]string{"foo": "bar"}, + } + tests := []struct { + name string + update ovsdb.TableUpdates + expected ovsdb.TableUpdates + }{ + { + "not filtered", + ovsdb.TableUpdates{ + "Bridge": ovsdb.TableUpdate{ + "foo": &ovsdb.RowUpdate{ + Old: nil, New: &bridgeRow, + }, + }, + }, + ovsdb.TableUpdates{ + "Bridge": ovsdb.TableUpdate{ + "foo": &ovsdb.RowUpdate{ + Old: nil, New: &bridgeRow, + }, + }, + }, + }, + { + "removed table", + ovsdb.TableUpdates{ + "Open_vSwitch": ovsdb.TableUpdate{ + "foo": &ovsdb.RowUpdate{ + Old: nil, New: &bridgeRow, + }, + }, + }, + ovsdb.TableUpdates{}, + }, + { + "removed column", + ovsdb.TableUpdates{ + "Bridge": ovsdb.TableUpdate{ + "foo": &ovsdb.RowUpdate{ + Old: nil, New: &bridgeRowWithIDs, + }, + }, + }, + ovsdb.TableUpdates{ + "Bridge": ovsdb.TableUpdate{ + "foo": &ovsdb.RowUpdate{ + Old: nil, New: &bridgeRow, + }, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + monitor.filter(tt.update) + assert.Equal(t, tt.expected, tt.update) + }) + } +} diff --git a/server/server.go b/server/server.go new file mode 100644 index 00000000..2599767e --- /dev/null +++ b/server/server.go @@ -0,0 +1,308 @@ +package server + +import ( + "encoding/json" + "fmt" + "net" + "sync" + + "github.com/cenkalti/rpc2" + "github.com/cenkalti/rpc2/jsonrpc" + "github.com/google/uuid" + "github.com/ovn-org/libovsdb/model" + "github.com/ovn-org/libovsdb/ovsdb" +) + +// OvsdbServer is an ovsdb server +type OvsdbServer struct { + srv *rpc2.Server + done chan struct{} + db Database + dbUpdates chan ovsdb.TableUpdates + ready bool + readyMutex sync.RWMutex + models map[string]DatabaseModel + modelsMutex sync.RWMutex + monitors map[*rpc2.Client]*connectionMonitors + monitorMutex sync.RWMutex +} + +type DatabaseModel struct { + Model *model.DBModel + Schema *ovsdb.DatabaseSchema +} + +// NewOvsdbServer returns a new OvsdbServer +func NewOvsdbServer(db Database, models ...DatabaseModel) (*OvsdbServer, error) { + o := &OvsdbServer{ + done: make(chan struct{}, 1), + db: db, + models: make(map[string]DatabaseModel), + modelsMutex: sync.RWMutex{}, + monitors: make(map[*rpc2.Client]*connectionMonitors), + monitorMutex: sync.RWMutex{}, + dbUpdates: make(chan ovsdb.TableUpdates), + } + o.modelsMutex.Lock() + for _, model := range models { + o.models[model.Schema.Name] = model + } + o.modelsMutex.Unlock() + for database, model := range o.models { + if err := o.db.CreateDatabase(database, model.Schema); err != nil { + return nil, err + } + } + o.srv = rpc2.NewServer() + o.srv.Handle("list_dbs", o.ListDatabases) + o.srv.Handle("get_schema", o.GetSchema) + o.srv.Handle("transact", o.Transact) + o.srv.Handle("cancel", o.Cancel) + o.srv.Handle("monitor", o.Monitor) + o.srv.Handle("monitor_cancel", o.MonitorCancel) + o.srv.Handle("steal", o.Steal) + o.srv.Handle("unlock", o.Unlock) + o.srv.Handle("echo", o.Echo) + return o, nil +} + +// Serve starts the OVSDB server on the given path and protocol +func (o *OvsdbServer) Serve(protocol string, path string) error { + lis, err := net.Listen(protocol, path) + if err != nil { + return err + } + go o.dispatch() + o.readyMutex.Lock() + o.ready = true + o.readyMutex.Unlock() + for { + select { + case <-o.done: + return nil + default: + conn, err := lis.Accept() + if err != nil { + return err + } + // TODO: Need to cleanup when connection is closed + go o.srv.ServeCodec(jsonrpc.NewJSONCodec(conn)) + } + } +} + +// Close closes the OvsdbServer +func (o *OvsdbServer) Close() { + close(o.done) +} + +// Ready returns true if a server is ready to handle connections +func (o *OvsdbServer) Ready() bool { + o.readyMutex.RLock() + defer o.readyMutex.RUnlock() + return o.ready +} + +// ListDatabases lists the databases in the current system +func (o *OvsdbServer) ListDatabases(client *rpc2.Client, args []interface{}, reply *[]string) error { + dbs := []string{} + o.modelsMutex.RLock() + for _, db := range o.models { + dbs = append(dbs, db.Schema.Name) + } + o.modelsMutex.RUnlock() + *reply = dbs + return nil +} + +func (o *OvsdbServer) GetSchema(client *rpc2.Client, args []interface{}, reply *ovsdb.DatabaseSchema, +) error { + db, ok := args[0].(string) + if !ok { + return fmt.Errorf("database %v is not a string", args[0]) + } + o.modelsMutex.RLock() + model, ok := o.models[db] + if !ok { + return fmt.Errorf("database %s does not exist", db) + } + o.modelsMutex.RUnlock() + *reply = *model.Schema + return nil +} + +// Transact issues a new database transaction and returns the results +func (o *OvsdbServer) Transact(client *rpc2.Client, args []json.RawMessage, reply *[]ovsdb.OperationResult) error { + if len(args) < 2 { + return fmt.Errorf("not enough args") + } + var db string + err := json.Unmarshal(args[0], &db) + if err != nil { + return fmt.Errorf("database %v is not a string", args[0]) + } + if !o.db.Exists(db) { + return fmt.Errorf("db does not exist") + } + var ops []ovsdb.Operation + namedUUID := make(map[string]ovsdb.UUID) + for i := 1; i < len(args); i++ { + var op ovsdb.Operation + err = json.Unmarshal(args[i], &op) + if err != nil { + return err + } + if op.UUIDName != "" { + newUUID := uuid.NewString() + namedUUID[op.UUIDName] = ovsdb.UUID{GoUUID: newUUID} + op.UUIDName = newUUID + } + for i, condition := range op.Where { + op.Where[i].Value = expandNamedUUID(condition.Value, namedUUID) + } + for i, mutation := range op.Mutations { + op.Mutations[i].Value = expandNamedUUID(mutation.Value, namedUUID) + } + ops = append(ops, op) + } + response, update := o.db.Transact(db, ops) + *reply = response + o.dbUpdates <- update + return nil +} + +// Cancel cancels the last transaction +func (o *OvsdbServer) Cancel(client *rpc2.Client, args []interface{}, reply *[]interface{}) error { + return fmt.Errorf("not implemented") +} + +// Monitor montiors a given database table and provides updates to the client via an RPC callback +func (o *OvsdbServer) Monitor(client *rpc2.Client, args []json.RawMessage, reply *ovsdb.TableUpdates) error { + var db string + if err := json.Unmarshal(args[0], &db); err != nil { + return fmt.Errorf("database %v is not a string", args[0]) + } + if !o.db.Exists(db) { + return fmt.Errorf("db does not exist") + } + var value string + if err := json.Unmarshal(args[1], &value); err != nil { + return fmt.Errorf("values %v is not a string", args[1]) + } + var request map[string]*ovsdb.MonitorRequest + if err := json.Unmarshal(args[2], &request); err != nil { + return err + } + o.monitorMutex.Lock() + defer o.monitorMutex.Unlock() + clientMonitors, ok := o.monitors[client] + if !ok { + o.monitors[client] = newConnectionMonitors() + } else { + if _, ok := clientMonitors.monitors[value]; ok { + return fmt.Errorf("monitor with that value already exists") + } + } + tableUpdates := make(ovsdb.TableUpdates) + for t, request := range request { + rows := o.db.Select(db, t, nil, request.Columns) + for i := range rows.Rows { + tu := make(ovsdb.TableUpdate) + uuid := rows.Rows[i]["_uuid"].(ovsdb.UUID).GoUUID + tu[uuid] = &ovsdb.RowUpdate{ + New: &rows.Rows[i], + } + tableUpdates.AddTableUpdate(t, tu) + } + } + *reply = tableUpdates + o.monitors[client].monitors[value] = newMonitor(value, request, client) + return nil +} + +// MonitorCancel cancels a monitor on a given table +func (o *OvsdbServer) MonitorCancel(client *rpc2.Client, args []interface{}, reply *[]interface{}) error { + return fmt.Errorf("not implemented") +} + +// Lock acquires a lock on a table for a the client +func (o *OvsdbServer) Lock(client *rpc2.Client, args []interface{}, reply *[]interface{}) error { + return fmt.Errorf("not implemented") +} + +// Steal steals a lock for a client +func (o *OvsdbServer) Steal(client *rpc2.Client, args []interface{}, reply *[]interface{}) error { + return fmt.Errorf("not implemented") +} + +// Unlock releases a lock for a client +func (o *OvsdbServer) Unlock(client *rpc2.Client, args []interface{}, reply *[]interface{}) error { + return fmt.Errorf("not implemented") +} + +// Echo tests the liveness of the connection +func (o *OvsdbServer) Echo(client *rpc2.Client, args []interface{}, reply *[]interface{}) error { + echoReply := make([]interface{}, len(args)) + copy(echoReply, args) + *reply = echoReply + return nil +} + +func (o *OvsdbServer) dispatch() { + for { + select { + case update := <-o.dbUpdates: + o.monitorMutex.RLock() + for _, c := range o.monitors { + for _, m := range c.monitors { + m.Enqueue(update) + } + } + o.monitorMutex.RUnlock() + case <-o.done: + o.monitorMutex.RLock() + for _, c := range o.monitors { + for _, m := range c.monitors { + close(m.stopCh) + } + } + o.monitorMutex.RUnlock() + return + } + } +} + +func expandNamedUUID(value interface{}, namedUUID map[string]ovsdb.UUID) interface{} { + if uuid, ok := value.(ovsdb.UUID); ok { + if newUUID, ok := namedUUID[uuid.GoUUID]; ok { + return newUUID + } + } + if set, ok := value.(ovsdb.OvsSet); ok { + for i, s := range set.GoSet { + if _, ok := s.(ovsdb.UUID); !ok { + return value + } + uuid := s.(ovsdb.UUID) + if newUUID, ok := namedUUID[uuid.GoUUID]; ok { + set.GoSet[i] = newUUID + } + } + } + if m, ok := value.(ovsdb.OvsMap); ok { + for k, v := range m.GoMap { + if uuid, ok := v.(ovsdb.UUID); ok { + if newUUID, ok := namedUUID[uuid.GoUUID]; ok { + m.GoMap[k] = newUUID + } + } + if uuid, ok := k.(ovsdb.UUID); ok { + if newUUID, ok := namedUUID[uuid.GoUUID]; ok { + m.GoMap[newUUID] = m.GoMap[k] + delete(m.GoMap, uuid) + } + } + } + } + return value +} diff --git a/server/server_integration_test.go b/server/server_integration_test.go new file mode 100644 index 00000000..2653b287 --- /dev/null +++ b/server/server_integration_test.go @@ -0,0 +1,317 @@ +package server + +import ( + "context" + "fmt" + "math/rand" + "os" + "path/filepath" + "sync" + "testing" + "time" + + "github.com/ovn-org/libovsdb/cache" + "github.com/ovn-org/libovsdb/client" + "github.com/ovn-org/libovsdb/model" + "github.com/ovn-org/libovsdb/ovsdb" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// bridgeType is the simplified ORM model of the Bridge table +type bridgeType struct { + UUID string `ovs:"_uuid"` + Name string `ovs:"name"` + OtherConfig map[string]string `ovs:"other_config"` + ExternalIds map[string]string `ovs:"external_ids"` + Ports []string `ovs:"ports"` + Status map[string]string `ovs:"status"` +} + +// ovsType is the simplified ORM model of the Bridge table +type ovsType struct { + UUID string `ovs:"_uuid"` + Bridges []string `ovs:"bridges"` +} + +func getSchema() (*ovsdb.DatabaseSchema, error) { + wd, err := os.Getwd() + if err != nil { + return nil, err + } + path := filepath.Join(wd, "testdata", "ovslite.json") + f, err := os.Open(path) + if err != nil { + return nil, err + } + defer f.Close() + schema, err := ovsdb.SchemaFromFile(f) + if err != nil { + return nil, err + } + return schema, nil +} + +func TestClientServerEcho(t *testing.T) { + defDB, err := model.NewDBModel("Open_vSwitch", map[string]model.Model{ + "Open_vSwitch": &ovsType{}, + "Bridge": &bridgeType{}}) + require.Nil(t, err) + + schema, err := getSchema() + require.Nil(t, err) + + ovsDB := NewInMemoryDatabase(map[string]*model.DBModel{"Open_vSwitch": defDB}) + + rand.Seed(time.Now().UnixNano()) + tmpfile := fmt.Sprintf("/tmp/ovsdb-%d.sock", rand.Intn(10000)) + defer os.Remove(tmpfile) + server, err := NewOvsdbServer(ovsDB, DatabaseModel{ + Model: defDB, + Schema: schema, + }) + require.Nil(t, err) + + go func(t *testing.T, o *OvsdbServer) { + if err := o.Serve("unix", tmpfile); err != nil { + t.Error(err) + } + }(t, server) + defer server.Close() + require.Eventually(t, func() bool { + return server.Ready() + }, 1*time.Second, 10*time.Millisecond) + + ovs, err := client.Connect(context.Background(), fmt.Sprintf("unix:%s", tmpfile), defDB, nil) + require.Nil(t, err) + + err = ovs.Echo() + assert.Nil(t, err) +} + +func TestClientServerInsert(t *testing.T) { + defDB, err := model.NewDBModel("Open_vSwitch", map[string]model.Model{ + "Open_vSwitch": &ovsType{}, + "Bridge": &bridgeType{}}) + require.Nil(t, err) + + schema, err := getSchema() + require.Nil(t, err) + + ovsDB := NewInMemoryDatabase(map[string]*model.DBModel{"Open_vSwitch": defDB}) + rand.Seed(time.Now().UnixNano()) + tmpfile := fmt.Sprintf("/tmp/ovsdb-%d.sock", rand.Intn(10000)) + defer os.Remove(tmpfile) + server, err := NewOvsdbServer(ovsDB, DatabaseModel{ + Model: defDB, + Schema: schema, + }) + assert.Nil(t, err) + + go func(t *testing.T, o *OvsdbServer) { + if err := o.Serve("unix", tmpfile); err != nil { + t.Error(err) + } + }(t, server) + defer server.Close() + require.Eventually(t, func() bool { + return server.Ready() + }, 1*time.Second, 10*time.Millisecond) + + ovs, err := client.Connect(context.Background(), fmt.Sprintf("unix:%s", tmpfile), defDB, nil) + require.Nil(t, err) + + bridgeRow := &bridgeType{ + Name: "foo", + ExternalIds: map[string]string{"go": "awesome", "docker": "made-for-each-other"}, + } + + ops, err := ovs.Create(bridgeRow) + require.Nil(t, err) + reply, err := ovs.Transact(ops...) + assert.Nil(t, err) + _, err = ovsdb.CheckOperationResults(reply, ops) + assert.Nil(t, err) +} + +func TestClientServerMonitor(t *testing.T) { + defDB, err := model.NewDBModel("Open_vSwitch", map[string]model.Model{ + "Open_vSwitch": &ovsType{}, + "Bridge": &bridgeType{}}) + if err != nil { + t.Fatal(err) + } + + schema, err := getSchema() + if err != nil { + t.Fatal(err) + } + + ovsDB := NewInMemoryDatabase(map[string]*model.DBModel{"Open_vSwitch": defDB}) + rand.Seed(time.Now().UnixNano()) + tmpfile := fmt.Sprintf("/tmp/ovsdb-%d.sock", rand.Intn(10000)) + defer os.Remove(tmpfile) + server, err := NewOvsdbServer(ovsDB, DatabaseModel{ + Model: defDB, + Schema: schema, + }) + assert.Nil(t, err) + + go func(t *testing.T, o *OvsdbServer) { + if err := o.Serve("unix", tmpfile); err != nil { + t.Error(err) + } + }(t, server) + defer server.Close() + require.Eventually(t, func() bool { + return server.Ready() + }, 1*time.Second, 10*time.Millisecond) + + ovs, err := client.Connect(context.Background(), fmt.Sprintf("unix:%s", tmpfile), defDB, nil) + require.Nil(t, err) + + ovsRow := &ovsType{ + UUID: "ovs", + } + bridgeRow := &bridgeType{ + UUID: "foo", + Name: "foo", + ExternalIds: map[string]string{"go": "awesome", "docker": "made-for-each-other"}, + } + + seenMutex := sync.RWMutex{} + seenInsert := false + seenMutation := false + seenInitialOvs := false + ovs.Cache.AddEventHandler(&cache.EventHandlerFuncs{ + AddFunc: func(table string, model model.Model) { + if table == "Bridge" { + br := model.(*bridgeType) + assert.Equal(t, bridgeRow.Name, br.Name) + assert.Equal(t, bridgeRow.ExternalIds, br.ExternalIds) + seenMutex.Lock() + seenInsert = true + seenMutex.Unlock() + } + if table == "Open_vSwitch" { + seenMutex.Lock() + seenInitialOvs = true + seenMutex.Unlock() + } + }, + UpdateFunc: func(table string, old, new model.Model) { + fmt.Println("got an update") + if table == "Open_vSwitch" { + ov := new.(*ovsType) + assert.Equal(t, 1, len(ov.Bridges)) + seenMutex.Lock() + seenMutation = true + seenMutex.Unlock() + } + }, + }) + + var ops []ovsdb.Operation + ovsOps, err := ovs.Create(ovsRow) + require.Nil(t, err) + reply, err := ovs.Transact(ovsOps...) + require.Nil(t, err) + _, err = ovsdb.CheckOperationResults(reply, ovsOps) + require.Nil(t, err) + require.NotEmpty(t, reply[0].UUID.GoUUID) + ovsRow.UUID = reply[0].UUID.GoUUID + + err = ovs.MonitorAll("test") + require.Nil(t, err) + require.Eventually(t, func() bool { + seenMutex.RLock() + defer seenMutex.RUnlock() + return seenInitialOvs + }, 1*time.Second, 10*time.Millisecond) + + bridgeOps, err := ovs.Create(bridgeRow) + require.Nil(t, err) + ops = append(ops, bridgeOps...) + + mutateOps, err := ovs.Where(ovsRow).Mutate(ovsRow, model.Mutation{ + Field: &ovsRow.Bridges, + Mutator: ovsdb.MutateOperationInsert, + Value: []string{"foo"}, + }) + require.Nil(t, err) + ops = append(ops, mutateOps...) + + reply, err = ovs.Transact(ops...) + require.Nil(t, err) + + _, err = ovsdb.CheckOperationResults(reply, ops) + assert.Nil(t, err) + assert.Equal(t, 1, reply[1].Count) + + assert.Eventually(t, func() bool { + seenMutex.RLock() + defer seenMutex.RUnlock() + return seenInsert + }, 1*time.Second, 10*time.Millisecond) + assert.Eventually(t, func() bool { + seenMutex.RLock() + defer seenMutex.RUnlock() + return seenMutation + }, 1*time.Second, 10*time.Millisecond) +} + +func TestClientServerInsertAndDelete(t *testing.T) { + defDB, err := model.NewDBModel("Open_vSwitch", map[string]model.Model{ + "Open_vSwitch": &ovsType{}, + "Bridge": &bridgeType{}}) + require.Nil(t, err) + + schema, err := getSchema() + require.Nil(t, err) + + ovsDB := NewInMemoryDatabase(map[string]*model.DBModel{"Open_vSwitch": defDB}) + rand.Seed(time.Now().UnixNano()) + tmpfile := fmt.Sprintf("/tmp/ovsdb-%d.sock", rand.Intn(10000)) + defer os.Remove(tmpfile) + server, err := NewOvsdbServer(ovsDB, DatabaseModel{ + Model: defDB, + Schema: schema, + }) + assert.Nil(t, err) + + go func(t *testing.T, o *OvsdbServer) { + if err := o.Serve("unix", tmpfile); err != nil { + t.Error(err) + } + }(t, server) + defer server.Close() + require.Eventually(t, func() bool { + return server.Ready() + }, 1*time.Second, 10*time.Millisecond) + + ovs, err := client.Connect(context.Background(), fmt.Sprintf("unix:%s", tmpfile), defDB, nil) + require.Nil(t, err) + + bridgeRow := &bridgeType{ + Name: "foo", + ExternalIds: map[string]string{"go": "awesome", "docker": "made-for-each-other"}, + } + + ops, err := ovs.Create(bridgeRow) + require.Nil(t, err) + reply, err := ovs.Transact(ops...) + require.Nil(t, err) + _, err = ovsdb.CheckOperationResults(reply, ops) + require.Nil(t, err) + + bridgeRow.UUID = reply[0].UUID.GoUUID + + deleteOp, err := ovs.Where(bridgeRow).Delete() + require.Nil(t, err) + + reply, err = ovs.Transact(deleteOp...) + assert.Nil(t, err) + _, err = ovsdb.CheckOperationResults(reply, ops) + assert.Nil(t, err) + assert.Equal(t, 1, reply[0].Count) +} diff --git a/server/server_test.go b/server/server_test.go new file mode 100644 index 00000000..37a90933 --- /dev/null +++ b/server/server_test.go @@ -0,0 +1,144 @@ +package server + +import ( + "encoding/json" + "testing" + + "github.com/google/uuid" + "github.com/ovn-org/libovsdb/model" + "github.com/ovn-org/libovsdb/ovsdb" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestExpandNamedUUID(t *testing.T) { + testUUID := uuid.NewString() + testUUID1 := uuid.NewString() + tests := []struct { + name string + namedUUIDs map[string]ovsdb.UUID + value interface{} + expected interface{} + }{ + { + "uuid", + map[string]ovsdb.UUID{"foo": {GoUUID: testUUID}}, + ovsdb.UUID{GoUUID: "foo"}, + ovsdb.UUID{GoUUID: testUUID}, + }, + { + "set", + map[string]ovsdb.UUID{"foo": {GoUUID: testUUID}}, + ovsdb.OvsSet{GoSet: []interface{}{ovsdb.UUID{GoUUID: "foo"}}}, + ovsdb.OvsSet{GoSet: []interface{}{ovsdb.UUID{GoUUID: testUUID}}}, + }, + { + "set multiple", + map[string]ovsdb.UUID{"foo": {GoUUID: testUUID}, "bar": {GoUUID: testUUID1}}, + ovsdb.OvsSet{GoSet: []interface{}{ovsdb.UUID{GoUUID: "foo"}, ovsdb.UUID{GoUUID: "bar"}, ovsdb.UUID{GoUUID: "baz"}}}, + ovsdb.OvsSet{GoSet: []interface{}{ovsdb.UUID{GoUUID: testUUID}, ovsdb.UUID{GoUUID: testUUID1}, ovsdb.UUID{GoUUID: "baz"}}}, + }, + { + "map key", + map[string]ovsdb.UUID{"foo": {GoUUID: testUUID}}, + ovsdb.OvsMap{GoMap: map[interface{}]interface{}{ovsdb.UUID{GoUUID: "foo"}: "foo"}}, + ovsdb.OvsMap{GoMap: map[interface{}]interface{}{ovsdb.UUID{GoUUID: testUUID}: "foo"}}, + }, + { + "map values", + map[string]ovsdb.UUID{"foo": {GoUUID: testUUID}}, + ovsdb.OvsMap{GoMap: map[interface{}]interface{}{"foo": ovsdb.UUID{GoUUID: "foo"}}}, + ovsdb.OvsMap{GoMap: map[interface{}]interface{}{"foo": ovsdb.UUID{GoUUID: testUUID}}}, + }, + { + "map key and values", + map[string]ovsdb.UUID{"foo": {GoUUID: testUUID}, "bar": {GoUUID: testUUID1}}, + ovsdb.OvsMap{GoMap: map[interface{}]interface{}{ovsdb.UUID{GoUUID: "foo"}: ovsdb.UUID{GoUUID: "bar"}}}, + ovsdb.OvsMap{GoMap: map[interface{}]interface{}{ovsdb.UUID{GoUUID: testUUID}: ovsdb.UUID{GoUUID: testUUID1}}}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := expandNamedUUID(tt.value, tt.namedUUIDs) + assert.Equal(t, tt.expected, got) + }) + } +} + +func TestOvsdbServerMonitor(t *testing.T) { + defDB, err := model.NewDBModel("Open_vSwitch", map[string]model.Model{ + "Open_vSwitch": &ovsType{}, + "Bridge": &bridgeType{}}) + if err != nil { + t.Fatal(err) + } + schema, err := getSchema() + if err != nil { + t.Fatal(err) + } + ovsDB := NewInMemoryDatabase(map[string]*model.DBModel{"Open_vSwitch": defDB}) + o, err := NewOvsdbServer(ovsDB, DatabaseModel{ + Model: defDB, Schema: schema}) + require.Nil(t, err) + requests := make(map[string]ovsdb.MonitorRequest) + for table, tableSchema := range schema.Tables { + var columns []string + for column := range tableSchema.Columns { + columns = append(columns, column) + } + requests[table] = ovsdb.MonitorRequest{ + Columns: columns, + Select: ovsdb.NewDefaultMonitorSelect(), + } + } + + fooUUID := uuid.NewString() + barUUID := uuid.NewString() + bazUUID := uuid.NewString() + quuxUUID := uuid.NewString() + + _, _ = o.db.Insert("Open_vSwitch", "Bridge", fooUUID, ovsdb.Row{"name": "foo"}) + _, _ = o.db.Insert("Open_vSwitch", "Bridge", barUUID, ovsdb.Row{"name": "bar"}) + _, _ = o.db.Insert("Open_vSwitch", "Bridge", bazUUID, ovsdb.Row{"name": "baz"}) + _, _ = o.db.Insert("Open_vSwitch", "Bridge", quuxUUID, ovsdb.Row{"name": "quux"}) + + db, err := json.Marshal("Open_vSwitch") + require.Nil(t, err) + value, err := json.Marshal("foo") + require.Nil(t, err) + rJSON, err := json.Marshal(requests) + require.Nil(t, err) + args := []json.RawMessage{db, value, rJSON} + reply := &ovsdb.TableUpdates{} + err = o.Monitor(nil, args, reply) + require.Nil(t, err) + expected := &ovsdb.TableUpdates{ + "Bridge": { + fooUUID: &ovsdb.RowUpdate{ + New: &ovsdb.Row{ + "_uuid": ovsdb.UUID{GoUUID: fooUUID}, + "name": "foo", + }, + }, + barUUID: &ovsdb.RowUpdate{ + New: &ovsdb.Row{ + "_uuid": ovsdb.UUID{GoUUID: barUUID}, + "name": "bar", + }, + }, + bazUUID: &ovsdb.RowUpdate{ + New: &ovsdb.Row{ + "_uuid": ovsdb.UUID{GoUUID: bazUUID}, + "name": "baz", + }, + }, + quuxUUID: &ovsdb.RowUpdate{ + New: &ovsdb.Row{ + "_uuid": ovsdb.UUID{GoUUID: quuxUUID}, + "name": "quux", + }, + }, + }, + } + assert.Equal(t, expected, reply) +} diff --git a/server/testdata/ovslite.json b/server/testdata/ovslite.json new file mode 100644 index 00000000..2c38f55c --- /dev/null +++ b/server/testdata/ovslite.json @@ -0,0 +1,70 @@ +{ + "name": "Open_vSwitch", + "version": "0.0.1", + "tables": { + "Open_vSwitch": { + "columns": { + "bridges": { + "type": { + "key": { + "type": "uuid", + "refTable": "Bridge" + }, + "min": 0, + "max": "unlimited" + } + } + }, + "isRoot": true, + "maxRows": 1 + }, + "Bridge": { + "columns": { + "name": { + "type": "string", + "mutable": false + }, + "ports": { + "type": { + "key": { + "type": "uuid", + "refTable": "Port" + }, + "min": 0, + "max": "unlimited" + } + }, + "status": { + "type": { + "key": "string", + "value": "string", + "min": 0, + "max": "unlimited" + }, + "ephemeral": true + }, + "other_config": { + "type": { + "key": "string", + "value": "string", + "min": 0, + "max": "unlimited" + } + }, + "external_ids": { + "type": { + "key": "string", + "value": "string", + "min": 0, + "max": "unlimited" + } + } + }, + "indexes": [ + [ + "name" + ] + ] + } + } +} \ No newline at end of file