diff --git a/retry.go b/retry.go index d0f2d2e..cbeca94 100644 --- a/retry.go +++ b/retry.go @@ -32,14 +32,12 @@ import ( // gorm database, and retries the transaction if it is aborted by Spanner. func RunTransaction(ctx context.Context, db *gorm.DB, fc func(tx *gorm.DB) error, opts ...*sql.TxOptions) error { // Disable internal (checksum-based) retries on the Spanner database/SQL connection. - var opt *sql.TxOptions // Note: gorm also only uses the first option, so it is safe to pick just the first element in the slice. - if len(opts) > 0 { - opt = opts[0] + if len(opts) > 0 && opts[0] != nil { + opts[0].Isolation = spannerdriver.WithDisableRetryAborts(opts[0].Isolation) } - opt.Isolation = spannerdriver.WithDisableRetryAborts(opt.Isolation) for { - err := db.Transaction(fc, opt) + err := db.Transaction(fc, opts...) if err == nil { return nil } diff --git a/spanner_test.go b/spanner_test.go index 0321910..81fc1f0 100644 --- a/spanner_test.go +++ b/spanner_test.go @@ -245,7 +245,7 @@ func TestRunTransaction(t *testing.T) { return err } return nil - }, &sql.TxOptions{}); err != nil { + }); err != nil { t.Fatal(err) } // Verify that the insert was only executed once. @@ -282,6 +282,20 @@ func TestRunTransaction(t *testing.T) { } } +func TestRunTransactionWithNilAsOptions(t *testing.T) { + t.Parallel() + + ctx := context.Background() + db, _, teardown := setupTestGormConnection(t) + defer teardown() + + if err := RunTransaction(ctx, db, func(tx *gorm.DB) error { + return nil + }, nil); err != nil { + t.Fatal(err) + } +} + func filter(requests []interface{}, sql string) (ret []*spannerpb.ExecuteSqlRequest) { for _, i := range requests { if req, ok := i.(*spannerpb.ExecuteSqlRequest); ok {