diff --git a/pkg/driver/driver.go b/pkg/driver/driver.go index f529db445..ec215598d 100644 --- a/pkg/driver/driver.go +++ b/pkg/driver/driver.go @@ -32,7 +32,13 @@ func (c RetryConnector) Connect(ctx context.Context) (driver.Conn, error) { err := errors.Wrap(retry.WithBackoff( ctx, func(ctx context.Context) (err error) { - conn, err = c.Connector.Connect(ctx) + if conn, err = c.Connector.Connect(ctx); err == nil && c.driver.initConn != nil { + if err = c.driver.initConn(ctx, conn); err != nil { + _ = conn.Close() + conn = nil + } + } + return }, shouldRetry, @@ -67,7 +73,9 @@ func (c RetryConnector) Driver() driver.Driver { // Driver wraps a driver.Driver that also must implement driver.DriverContext with logging capabilities and provides our RetryConnector. type Driver struct { ctxDriver - Logger *logging.Logger + + Logger *logging.Logger + initConn func(context.Context, driver.Conn) error } // OpenConnector implements the DriverContext interface. @@ -85,7 +93,7 @@ func (d Driver) OpenConnector(name string) (driver.Connector, error) { // Register makes our database Driver available under the name "icingadb-*sql". func Register(logger *logging.Logger) { - sql.Register(MySQL, &Driver{ctxDriver: &mysql.MySQLDriver{}, Logger: logger}) + sql.Register(MySQL, &Driver{ctxDriver: &mysql.MySQLDriver{}, Logger: logger, initConn: setGaleraOpts}) sql.Register(PostgreSQL, &Driver{ctxDriver: PgSQLDriver{}, Logger: logger}) _ = mysql.SetLogger(mysqlLogger(func(v ...interface{}) { logger.Debug(v...) })) sqlx.BindDriver(PostgreSQL, sqlx.DOLLAR) diff --git a/pkg/driver/mysql.go b/pkg/driver/mysql.go new file mode 100644 index 000000000..ac65daebf --- /dev/null +++ b/pkg/driver/mysql.go @@ -0,0 +1,41 @@ +package driver + +import ( + "context" + "database/sql/driver" + "github.com/go-sql-driver/mysql" + "github.com/pkg/errors" +) + +var errUnknownSysVar = &mysql.MySQLError{Number: 1193} + +// setGaleraOpts tries SET SESSION wsrep_sync_wait=4. +// +// This ensures causality checks will take place before executing INSERT or REPLACE, +// ensuring that the statement is executed on a fully synced node. +// https://mariadb.com/kb/en/galera-cluster-system-variables/#wsrep_sync_wait +// +// It prevents running into foreign key errors while inserting into linked tables on different MySQL nodes. +// Error 1193 "Unknown system variable" is ignored to support MySQL single nodes. +func setGaleraOpts(ctx context.Context, conn driver.Conn) error { + const galeraOpts = "SET SESSION wsrep_sync_wait=4" + + stmt, err := conn.(driver.ConnPrepareContext).PrepareContext(ctx, galeraOpts) + if err != nil { + err = errors.Wrap(err, "can't prepare "+galeraOpts) + } else if _, err = stmt.(driver.StmtExecContext).ExecContext(ctx, nil); err != nil { + err = errors.Wrap(err, "can't execute "+galeraOpts) + } + + if err != nil && errors.Is(err, errUnknownSysVar) { + err = nil + } + + if stmt != nil { + if errClose := stmt.Close(); errClose != nil && err == nil { + err = errors.Wrap(errClose, "can't close statement "+galeraOpts) + } + } + + return err +} diff --git a/pkg/driver/mysql_test.go b/pkg/driver/mysql_test.go new file mode 100644 index 000000000..38afc531a --- /dev/null +++ b/pkg/driver/mysql_test.go @@ -0,0 +1,178 @@ +package driver + +import ( + "context" + "database/sql/driver" + "github.com/go-sql-driver/mysql" + "github.com/stretchr/testify/assert" + "io" + "os" + "testing" +) + +func TestSetGaleraOpts(t *testing.T) { + tolerated := &mysql.MySQLError{ + Number: errUnknownSysVar.Number, + SQLState: [5]byte{255, 0, 42, 23, 7}, // Shall not confuse error comparison + Message: "This unusual text shall not confuse error comparison.", + } + + almostTolerated := &mysql.MySQLError{} + *almostTolerated = *tolerated + almostTolerated.Number-- + + notTolerated := io.EOF + ignoredCodeLocation := os.ErrPermission + + subtests := []struct { + name string + input testConn + output error + }{{ + name: "Conn PrepareContext returns error", + input: testConn{prepareError: notTolerated}, + output: notTolerated, + }, { + name: "Conn PrepareContext returns MySQLError", + input: testConn{prepareError: almostTolerated}, + output: almostTolerated, + }, { + name: "Conn PrepareContext returns MySQLError 1193", + input: testConn{prepareError: tolerated}, + output: nil, + }, { + name: "Stmt ExecContext returns error", + input: testConn{preparedStmt: &testStmt{ + execError: notTolerated, + }}, + output: notTolerated, + }, { + name: "Stmt ExecContext and Stmt Close return error", + input: testConn{preparedStmt: &testStmt{ + execError: notTolerated, + closeError: ignoredCodeLocation, + }}, + output: notTolerated, + }, { + name: "Stmt ExecContext returns MySQLError", + input: testConn{preparedStmt: &testStmt{ + execError: almostTolerated, + }}, + output: almostTolerated, + }, { + name: "Stmt ExecContext returns MySQLError and Stmt Close returns error", + input: testConn{preparedStmt: &testStmt{ + execError: almostTolerated, + closeError: ignoredCodeLocation, + }}, + output: almostTolerated, + }, { + name: "Stmt ExecContext returns MySQLError 1193", + input: testConn{preparedStmt: &testStmt{ + execError: tolerated, + }}, + output: nil, + }, { + name: "Stmt ExecContext and Stmt Close return MySQLError 1193", + input: testConn{preparedStmt: &testStmt{ + execError: tolerated, + closeError: tolerated, + }}, + output: tolerated, + }, { + name: "Stmt Close returns MySQLError 1193", + input: testConn{preparedStmt: &testStmt{ + execResult: driver.ResultNoRows, + closeError: tolerated, + }}, + output: tolerated, + }, { + name: "no errors", + input: testConn{preparedStmt: &testStmt{ + execResult: driver.ResultNoRows, + }}, + output: nil, + }} + + for _, st := range subtests { + t.Run(st.name, func(t *testing.T) { + assert.ErrorIs(t, setGaleraOpts(context.Background(), &st.input), st.output) + assert.GreaterOrEqual(t, st.input.prepareCalls, uint8(1)) + + if ts, ok := st.input.preparedStmt.(*testStmt); ok { + assert.GreaterOrEqual(t, ts.execCalls, st.input.prepareCalls) + assert.GreaterOrEqual(t, ts.closeCalls, st.input.prepareCalls) + } + }) + } +} + +type testStmt struct { + execResult driver.Result + execError error + execCalls uint8 + closeError error + closeCalls uint8 +} + +// Close implements the driver.Stmt interface. +func (ts *testStmt) Close() error { + ts.closeCalls++ + return ts.closeError +} + +// NumInput implements the driver.Stmt interface. +func (*testStmt) NumInput() int { + panic("don't call me") +} + +// Exec implements the driver.Stmt interface. +func (*testStmt) Exec([]driver.Value) (driver.Result, error) { + panic("don't call me") +} + +// Query implements the driver.Stmt interface. +func (*testStmt) Query([]driver.Value) (driver.Rows, error) { + panic("don't call me") +} + +// ExecContext implements the driver.StmtExecContext interface. +func (ts *testStmt) ExecContext(context.Context, []driver.NamedValue) (driver.Result, error) { + ts.execCalls++ + return ts.execResult, ts.execError +} + +type testConn struct { + preparedStmt driver.Stmt + prepareError error + prepareCalls uint8 +} + +// Prepare implements the driver.Conn interface. +func (*testConn) Prepare(string) (driver.Stmt, error) { + panic("don't call me") +} + +// Close implements the driver.Conn interface. +func (*testConn) Close() error { + panic("don't call me") +} + +// Begin implements the driver.Conn interface. +func (*testConn) Begin() (driver.Tx, error) { + panic("don't call me") +} + +// PrepareContext implements the driver.ConnPrepareContext interface. +func (tc *testConn) PrepareContext(context.Context, string) (driver.Stmt, error) { + tc.prepareCalls++ + return tc.preparedStmt, tc.prepareError +} + +// Assert interface compliance. +var ( + _ driver.Conn = (*testConn)(nil) + _ driver.ConnPrepareContext = (*testConn)(nil) + _ driver.Stmt = (*testStmt)(nil) + _ driver.StmtExecContext = (*testStmt)(nil) +)