From bb4fc64d66bc87ced5dc2ad9b0f2e039e511fc5d Mon Sep 17 00:00:00 2001 From: Dan Williams Date: Fri, 14 Oct 2022 16:03:22 -0500 Subject: [PATCH] adsfasdf --- mapper/info.go | 15 +---- mapper/mapper.go | 3 + mapper/mapper_test.go | 61 +++++++++++++----- ovsdb/bindings.go | 18 ++++++ ovsdb/error.go | 2 +- server/server_integration_test.go | 1 + server/testdata/ovslite.json | 13 +++- server/transact.go | 61 ++++++++++-------- server/transact_test.go | 104 +++++++++++++++++++++++------- 9 files changed, 198 insertions(+), 80 deletions(-) diff --git a/mapper/info.go b/mapper/info.go index d7a8ee23..f6ac01c9 100644 --- a/mapper/info.go +++ b/mapper/info.go @@ -71,20 +71,11 @@ func (i *Info) SetField(column string, value interface{}) error { if colSchema == nil { return fmt.Errorf("SetField: column %s schema not found", column) } - - // Validate set length requirements - newVal := reflect.ValueOf(value) - if colSchema.Type == ovsdb.TypeSet || colSchema.Type == ovsdb.TypeEnum { - maxVal := colSchema.TypeObj.Max() - minVal := colSchema.TypeObj.Min() - if maxVal > 1 && newVal.Len() > maxVal { - return fmt.Errorf("SetField: column %s overflow: %d new elements but max is %d", column, newVal.Len(), maxVal) - } else if minVal > 0 && newVal.Len() < minVal { - return fmt.Errorf("SetField: column %s underflow: %d new elements but min is %d", column, newVal.Len(), minVal) - } + if err := ovsdb.ValidateColumnConstraints(colSchema, value); err != nil { + return fmt.Errorf("SetField: column %s failed validation: %v", column, err) } - fieldValue.Set(newVal) + fieldValue.Set(reflect.ValueOf(value)) return nil } diff --git a/mapper/mapper.go b/mapper/mapper.go index 0d1e2938..d4dbcd60 100644 --- a/mapper/mapper.go +++ b/mapper/mapper.go @@ -118,6 +118,9 @@ func (m Mapper) NewRow(data *Info, fields ...interface{}) (ovsdb.Row, error) { if len(fields) == 0 && ovsdb.IsDefaultValue(column, nativeElem) { continue } + if err := ovsdb.ValidateColumnConstraints(column, nativeElem); err != nil { + return nil, fmt.Errorf("column %s assignment failed: %w", column, err) + } ovsElem, err := ovsdb.NativeToOvs(column, nativeElem) if err != nil { return nil, fmt.Errorf("table %s, column %s: failed to generate ovs element. %s", data.Metadata.TableName, name, err.Error()) diff --git a/mapper/mapper_test.go b/mapper/mapper_test.go index 7bd38871..c5ca299c 100644 --- a/mapper/mapper_test.go +++ b/mapper/mapper_test.go @@ -40,6 +40,8 @@ var ( 42.0, } + aFloatSetTooBig = []float64{1.0, 2.0, 3.0, 4.0, 5.0, 6.0} + aMap = map[string]string{ "key1": "value1", "key2": "value2", @@ -115,7 +117,7 @@ var testSchema = []byte(`{ "type": "real" }, "min": 0, - "max": 10 + "max": 5 } }, "aEmptySet": { @@ -216,28 +218,49 @@ func TestMapperGetData(t *testing.T) { NonTagged: "something", } - ovsRow := getOvsTestRow(t) /* Code under test */ var schema ovsdb.DatabaseSchema if err := json.Unmarshal(testSchema, &schema); err != nil { t.Error(err) } - mapper := NewMapper(schema) - test := ormTestType{ - NonTagged: "something", - } - testInfo, err := NewInfo("TestTable", schema.Table("TestTable"), &test) - assert.NoError(t, err) - - err = mapper.GetRowData(&ovsRow, testInfo) - assert.NoError(t, err) - /*End code under test*/ + tests := []struct { + name string + setup func() ovsdb.Row + expectErr bool + }{{ + name: "basic", + setup: func() ovsdb.Row { + return getOvsTestRow(t) + }, + }, { + name: "too big array", + setup: func() ovsdb.Row { + testRow := getOvsTestRow(t) + testRow["aFloatSet"] = test.MakeOvsSet(t, ovsdb.TypeReal, aFloatSetTooBig) + return testRow + }, + expectErr: true, + }} + for _, test := range tests { + t.Run(fmt.Sprintf("GetData: %s", test.name), func(t *testing.T) { + mapper := NewMapper(schema) + tt := ormTestType{ + NonTagged: "something", + } + testInfo, err := NewInfo("TestTable", schema.Table("TestTable"), &tt) + assert.NoError(t, err) - if err != nil { - t.Error(err) + ovsRow := test.setup() + err = mapper.GetRowData(&ovsRow, testInfo) + if test.expectErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, expected, tt) + } + }) } - assert.Equal(t, expected, test) } func TestMapperNewRow(t *testing.T) { @@ -315,6 +338,14 @@ func TestMapperNewRow(t *testing.T) { MyFloatSet: aFloatSet, }, expectedRow: ovsdb.Row(map[string]interface{}{"aFloatSet": test.MakeOvsSet(t, ovsdb.TypeReal, aFloatSet)}), + }, { + name: "aFloatSet too big", + objInput: &struct { + MyFloatSet []float64 `ovsdb:"aFloatSet"` + }{ + MyFloatSet: aFloatSetTooBig, + }, + shoulderr: true, }, { name: "Enum", objInput: &struct { diff --git a/ovsdb/bindings.go b/ovsdb/bindings.go index fb5965b1..ab6cb433 100644 --- a/ovsdb/bindings.go +++ b/ovsdb/bindings.go @@ -399,3 +399,21 @@ func isDefaultBaseValue(elem interface{}, etype ExtendedType) bool { return false } } + +// ValidateColumnConstraints validates the native value against any constraints +// of a given column. +func ValidateColumnConstraints(column *ColumnSchema, nativeValue interface{}) error { + switch column.Type { + case TypeSet, TypeEnum: + // Validate set length requirements + newVal := reflect.ValueOf(nativeValue) + maxVal := column.TypeObj.Max() + minVal := column.TypeObj.Min() + if maxVal > 1 && newVal.Len() > maxVal { + return fmt.Errorf("slice would overflow (%d elements but %d allowed)", newVal.Len(), maxVal) + } else if minVal > 0 && newVal.Len() < minVal { + return fmt.Errorf("slice would underflow (%d elements but %d required)", newVal.Len(), minVal) + } + } + return nil +} diff --git a/ovsdb/error.go b/ovsdb/error.go index f5bb136f..e0b2ec61 100644 --- a/ovsdb/error.go +++ b/ovsdb/error.go @@ -117,7 +117,7 @@ type ConstraintViolation struct { } // Error implements the error interface -func (e *ConstraintViolation) Error() string { +func (e ConstraintViolation) Error() string { msg := constraintViolation if e.details != "" { msg += ": " + e.details diff --git a/server/server_integration_test.go b/server/server_integration_test.go index 536fd0ff..4abecce5 100644 --- a/server/server_integration_test.go +++ b/server/server_integration_test.go @@ -29,6 +29,7 @@ type bridgeType struct { ExternalIds map[string]string `ovsdb:"external_ids"` Ports []string `ovsdb:"ports"` Status map[string]string `ovsdb:"status"` + FloodVLANs []int `ovsdb:"flood_vlans"` } // ovsType is the simplified ORM model of the Bridge table diff --git a/server/testdata/ovslite.json b/server/testdata/ovslite.json index d8c402f1..cd3323a9 100644 --- a/server/testdata/ovslite.json +++ b/server/testdata/ovslite.json @@ -69,6 +69,17 @@ "min": 0, "max": "unlimited" } + }, + "flood_vlans": { + "type": { + "key": { + "type": "integer", + "minInteger": 0, + "maxInteger": 4095 + }, + "min": 0, + "max": 3 + } } }, "indexes": [ @@ -78,4 +89,4 @@ ] } } -} \ No newline at end of file +} diff --git a/server/transact.go b/server/transact.go index 53e89576..0a44282a 100644 --- a/server/transact.go +++ b/server/transact.go @@ -48,7 +48,10 @@ func (o *OvsdbServer) transact(name string, operations []ovsdb.Operation) ([]ovs r := transaction.Select(op.Table, op.Where, op.Columns) results = append(results, r) case ovsdb.OperationUpdate: - r, tu := transaction.Update(name, op.Table, op.Where, op.Row) + r, tu, err := transaction.Update(name, op.Table, op.Where, op.Row) + if err != nil { + panic(err) + } results = append(results, r) if tu != nil { if err := updates.Merge(tu); err != nil { @@ -241,7 +244,17 @@ func (t *Transaction) Select(table string, where []ovsdb.Condition, columns []st } } -func (t *Transaction) Update(database, table string, where []ovsdb.Condition, row ovsdb.Row) (ovsdb.OperationResult, ovsdb.TableUpdates2) { +func opResultError(err error, table, detailFmt string, args ...interface{}) ovsdb.OperationResult { + detail := fmt.Sprintf(detailFmt, args...) + return ovsdb.OperationResult{ + Error: err.Error(), + Details: fmt.Sprintf("table %q: %s", table, detail), + } +} + +// Update updates the cache with the new row and returns the result of the operation +// and the updates, or an error +func (t *Transaction) Update(database, table string, where []ovsdb.Condition, row ovsdb.Row) (ovsdb.OperationResult, ovsdb.TableUpdates2, error) { dbModel := t.Model m := dbModel.Mapper schema := dbModel.Schema.Table(table) @@ -249,9 +262,7 @@ func (t *Transaction) Update(database, table string, where []ovsdb.Condition, ro rows, err := t.rowsFromTransactionCacheAndDatabase(table, where) if err != nil { - return ovsdb.OperationResult{ - Error: err.Error(), - }, nil + return opResultError(err, table, "failed to get rows from database"), nil, nil } for uuid, old := range rows { @@ -259,23 +270,23 @@ func (t *Transaction) Update(database, table string, where []ovsdb.Condition, ro oldRow, err := m.NewRow(oldInfo) if err != nil { - panic(err) + return ovsdb.OperationResult{}, nil, err } new, err := dbModel.NewModel(table) if err != nil { - panic(err) + return ovsdb.OperationResult{}, nil, err } newInfo, err := dbModel.NewModelInfo(new) if err != nil { - panic(err) + return ovsdb.OperationResult{}, nil, err } err = m.GetRowData(&oldRow, newInfo) if err != nil { - panic(err) + return ovsdb.OperationResult{}, nil, err } err = newInfo.SetField("_uuid", uuid) if err != nil { - panic(err) + return ovsdb.OperationResult{}, nil, err } rowDelta := ovsdb.NewRow() @@ -283,26 +294,20 @@ func (t *Transaction) Update(database, table string, where []ovsdb.Condition, ro colSchema := schema.Column(column) if colSchema == nil { e := ovsdb.ConstraintViolation{} - return ovsdb.OperationResult{ - Error: e.Error(), - Details: fmt.Sprintf("%s is not a valid column in the %s table", column, table), - }, nil + return opResultError(e, table, "%q is not a valid column", column), nil, nil } if !colSchema.Mutable() { e := ovsdb.ConstraintViolation{} - return ovsdb.OperationResult{ - Error: e.Error(), - Details: fmt.Sprintf("column %s is of table %s not mutable", column, table), - }, nil + return opResultError(e, table, "column %q is not mutable", column), nil, nil } old, err := newInfo.FieldByColumn(column) if err != nil { - panic(err) + return ovsdb.OperationResult{}, nil, err } native, err := ovsdb.OvsToNative(colSchema, value) if err != nil { - panic(err) + return ovsdb.OperationResult{}, nil, err } if reflect.DeepEqual(old, native) { @@ -316,17 +321,17 @@ func (t *Transaction) Update(database, table string, where []ovsdb.Condition, ro err = newInfo.SetField(column, native) if err != nil { - panic(err) + return ovsdb.OperationResult{}, nil, err } // convert the native to an ovs value // since the value in the RowUpdate hasn't been normalized newValue, err := ovsdb.NativeToOvs(colSchema, native) if err != nil { - panic(err) + return ovsdb.OperationResult{}, nil, err } diff, err := diff(colSchema, oldValue, newValue) if err != nil { - panic(err) + return ovsdb.OperationResult{}, nil, err } if diff != nil { rowDelta[column] = diff @@ -335,7 +340,7 @@ func (t *Transaction) Update(database, table string, where []ovsdb.Condition, ro newRow, err := m.NewRow(newInfo) if err != nil { - panic(err) + return ovsdb.OperationResult{}, nil, err } // check for index conflicts @@ -345,11 +350,11 @@ func (t *Transaction) Update(database, table string, where []ovsdb.Condition, ro return ovsdb.OperationResult{ Error: e.Error(), Details: newIndexExistsDetails(*indexExists), - }, nil + }, nil, nil } return ovsdb.OperationResult{ Error: err.Error(), - }, nil + }, nil, nil } err = tableUpdate.AddRowUpdate(uuid, &ovsdb.RowUpdate2{ @@ -358,7 +363,7 @@ func (t *Transaction) Update(database, table string, where []ovsdb.Condition, ro New: &newRow, }) if err != nil { - panic(err) + return ovsdb.OperationResult{}, nil, err } } // FIXME: We need to filter the returned columns @@ -366,7 +371,7 @@ func (t *Transaction) Update(database, table string, where []ovsdb.Condition, ro Count: len(rows), }, ovsdb.TableUpdates2{ table: tableUpdate, - } + }, nil } func (t *Transaction) Mutate(database, table string, where []ovsdb.Condition, mutations []ovsdb.Mutation) (ovsdb.OperationResult, ovsdb.TableUpdates2) { diff --git a/server/transact_test.go b/server/transact_test.go index 08d1b434..8650fe46 100644 --- a/server/transact_test.go +++ b/server/transact_test.go @@ -575,18 +575,60 @@ func TestOvsdbServerInsert(t *testing.T) { err = ovsDB.Commit("Open_vSwitch", uuid.New(), updates) assert.NoError(t, err) - bridge.UUID = bridgeUUID - br, err := o.db.Get("Open_vSwitch", "Bridge", bridgeUUID) - assert.NoError(t, err) - assert.Equal(t, &bridge, br) - assert.Equal(t, ovsdb.TableUpdates2{ - "Bridge": { - bridgeUUID: &ovsdb.RowUpdate2{ - Insert: &bridgeRow, - New: &bridgeRow, + tests := []struct { + name string + row ovsdb.Row + expected *ovsdb.RowUpdate2 + expectErr bool + }{ + { + "update single field", + ovsdb.Row{"datapath_type": "waldo"}, + &ovsdb.RowUpdate2{ + Modify: &ovsdb.Row{ + "datapath_type": "waldo", + }, }, + false, }, - }, updates) + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + bridge.UUID = bridgeUUID + br, err := o.db.Get("Open_vSwitch", "Bridge", bridgeUUID) + assert.NoError(t, err) + assert.Equal(t, &bridge, br) + assert.Equal(t, ovsdb.TableUpdates2{ + "Bridge": { + bridgeUUID: &ovsdb.RowUpdate2{ + Insert: &bridgeRow, + New: &bridgeRow, + }, + }, + }, updates) + + + res, updates, err := transaction.Update( + "Open_vSwitch", "Bridge", + []ovsdb.Condition{{ + Column: "_uuid", Function: ovsdb.ConditionEqual, Value: ovsdb.UUID{GoUUID: bridgeUUID}, + }}, tt.row) + if tt.expectErr { + require.Error(t, err) + } else { + errs, err := ovsdb.CheckOperationResults([]ovsdb.OperationResult{res}, []ovsdb.Operation{{Op: "update"}}) + require.NoErrorf(t, err, "%+v", errs) + + bridge.UUID = bridgeUUID + row, err := o.db.Get("Open_vSwitch", "Bridge", bridgeUUID) + assert.NoError(t, err) + br := row.(*bridgeType) + assert.NotEqual(t, br, bridgeRow) + assert.Equal(t, tt.expected.Modify, updates["Bridge"][bridgeUUID].Modify) + } + }) + } + } func TestOvsdbServerUpdate(t *testing.T) { @@ -634,10 +676,12 @@ func TestOvsdbServerUpdate(t *testing.T) { halloween := test.MakeOvsSet(t, ovsdb.TypeString, []string{"halloween"}) emptySet := test.MakeOvsSet(t, ovsdb.TypeString, []string{}) + floodVlanSet := test.MakeOvsSet(t, ovsdb.TypeInteger, []int{1, 2, 3, 4, 5, 6, 7}) tests := []struct { - name string - row ovsdb.Row - expected *ovsdb.RowUpdate2 + name string + row ovsdb.Row + expected *ovsdb.RowUpdate2 + expectErr bool }{ { "update single field", @@ -647,6 +691,13 @@ func TestOvsdbServerUpdate(t *testing.T) { "datapath_type": "waldo", }, }, + false, + }, + { + "update single field with too-large array", + ovsdb.Row{"flood_vlans": floodVlanSet}, + nil, + true, }, { "update single optional field, with direct value", @@ -656,6 +707,7 @@ func TestOvsdbServerUpdate(t *testing.T) { "datapath_id": halloween, }, }, + false, }, { "update single optional field, with set", @@ -665,6 +717,7 @@ func TestOvsdbServerUpdate(t *testing.T) { "datapath_id": halloween, }, }, + false, }, { "unset single optional field", @@ -674,24 +727,29 @@ func TestOvsdbServerUpdate(t *testing.T) { "datapath_id": emptySet, }, }, + false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - res, updates := transaction.Update( + res, updates, err := transaction.Update( "Open_vSwitch", "Bridge", []ovsdb.Condition{{ Column: "_uuid", Function: ovsdb.ConditionEqual, Value: ovsdb.UUID{GoUUID: bridgeUUID}, }}, tt.row) - errs, err := ovsdb.CheckOperationResults([]ovsdb.OperationResult{res}, []ovsdb.Operation{{Op: "update"}}) - require.NoErrorf(t, err, "%+v", errs) - - bridge.UUID = bridgeUUID - row, err := o.db.Get("Open_vSwitch", "Bridge", bridgeUUID) - assert.NoError(t, err) - br := row.(*bridgeType) - assert.NotEqual(t, br, bridgeRow) - assert.Equal(t, tt.expected.Modify, updates["Bridge"][bridgeUUID].Modify) + if tt.expectErr { + require.Error(t, err) + } else { + errs, err := ovsdb.CheckOperationResults([]ovsdb.OperationResult{res}, []ovsdb.Operation{{Op: "update"}}) + require.NoErrorf(t, err, "%+v", errs) + + bridge.UUID = bridgeUUID + row, err := o.db.Get("Open_vSwitch", "Bridge", bridgeUUID) + assert.NoError(t, err) + br := row.(*bridgeType) + assert.NotEqual(t, br, bridgeRow) + assert.Equal(t, tt.expected.Modify, updates["Bridge"][bridgeUUID].Modify) + } }) } }