From 80f807c8ad4ec8a0a4684dea32e5e89dc7a6ecd1 Mon Sep 17 00:00:00 2001 From: Jille Timmermans Date: Thu, 20 Apr 2023 17:39:25 +0200 Subject: [PATCH] MySQL LOAD DATA INFILE: First version This enables the :copyfrom query annotation for people using go-sql-driver/mysql that transforms it into a LOAD DATA LOCAL INFILE. issue #2179 --- internal/codegen/golang/gen.go | 4 +- internal/codegen/golang/imports.go | 7 ++ internal/codegen/golang/query.go | 38 +++++++++- .../go-sql-driver-mysql/copyfromCopy.tmpl | 60 +++++++++++++++ .../golang/templates/pgx/copyfromCopy.tmpl | 4 +- .../codegen/golang/templates/template.tmpl | 2 + .../testdata/copyfrom/mysql/go/copyfrom.go | 73 +++++++++++++++++++ .../endtoend/testdata/copyfrom/mysql/go/db.go | 31 ++++++++ .../testdata/copyfrom/mysql/go/models.go | 16 ++++ .../testdata/copyfrom/mysql/go/query.sql.go | 25 +++++++ .../testdata/copyfrom/mysql/query.sql | 7 ++ .../testdata/copyfrom/mysql/sqlc.json | 14 ++++ internal/endtoend/testdata/go.mod | 2 + internal/endtoend/testdata/go.sum | 4 + 14 files changed, 281 insertions(+), 6 deletions(-) create mode 100644 internal/codegen/golang/templates/go-sql-driver-mysql/copyfromCopy.tmpl create mode 100644 internal/endtoend/testdata/copyfrom/mysql/go/copyfrom.go create mode 100644 internal/endtoend/testdata/copyfrom/mysql/go/db.go create mode 100644 internal/endtoend/testdata/copyfrom/mysql/go/models.go create mode 100644 internal/endtoend/testdata/copyfrom/mysql/go/query.sql.go create mode 100644 internal/endtoend/testdata/copyfrom/mysql/query.sql create mode 100644 internal/endtoend/testdata/copyfrom/mysql/sqlc.json diff --git a/internal/codegen/golang/gen.go b/internal/codegen/golang/gen.go index dac9d47862..44e5aa825a 100644 --- a/internal/codegen/golang/gen.go +++ b/internal/codegen/golang/gen.go @@ -138,8 +138,8 @@ func generate(req *plugin.CodeGenRequest, enums []Enum, structs []Struct, querie SqlcVersion: req.SqlcVersion, } - if tctx.UsesCopyFrom && !tctx.SQLDriver.IsPGX() { - return nil, errors.New(":copyfrom is only supported by pgx") + if tctx.UsesCopyFrom && !tctx.SQLDriver.IsPGX() && golang.SqlDriver != SQLDriverGoSQLDriverMySQL { + return nil, errors.New(":copyfrom is only supported by pgx and github.com/go-sql-driver/mysql") } if tctx.UsesBatch && !tctx.SQLDriver.IsPGX() { diff --git a/internal/codegen/golang/imports.go b/internal/codegen/golang/imports.go index 65096d7879..7d134ca743 100644 --- a/internal/codegen/golang/imports.go +++ b/internal/codegen/golang/imports.go @@ -414,6 +414,13 @@ func (i *importer) copyfromImports() fileImports { }) std["context"] = struct{}{} + if i.Settings.Go.SqlDriver == SQLDriverGoSQLDriverMySQL { + std["io"] = struct{}{} + std["fmt"] = struct{}{} + std["sync/atomic"] = struct{}{} + std["github.com/go-sql-driver/mysql"] = struct{}{} + std["github.com/hexon/mysqltsv"] = struct{}{} + } return sortedImports(std, pkg) } diff --git a/internal/codegen/golang/query.go b/internal/codegen/golang/query.go index 0e553b64ff..c4f648cf82 100644 --- a/internal/codegen/golang/query.go +++ b/internal/codegen/golang/query.go @@ -131,7 +131,18 @@ func (v QueryValue) Params() string { return "\n" + strings.Join(out, ",\n") } -func (v QueryValue) ColumnNames() string { +func (v QueryValue) ColumnNames() []string { + if v.Struct == nil { + return []string{v.DBName} + } + names := make([]string, len(v.Struct.Fields)) + for i, f := range v.Struct.Fields { + names[i] = f.DBName + } + return names +} + +func (v QueryValue) ColumnNamesAsGoSlice() string { if v.Struct == nil { return fmt.Sprintf("[]string{%q}", v.DBName) } @@ -189,6 +200,19 @@ func (v QueryValue) Scan() string { return "\n" + strings.Join(out, ",\n") } +func (v QueryValue) Fields() []Field { + if v.Struct != nil { + return v.Struct.Fields + } + return []Field{ + { + Name: v.Name, + DBName: v.DBName, + Type: v.Typ, + }, + } +} + // A struct used to generate methods and fields on the Queries struct type Query struct { Cmd string @@ -210,7 +234,7 @@ func (q Query) hasRetType() bool { return scanned && !q.Ret.isEmpty() } -func (q Query) TableIdentifier() string { +func (q Query) TableIdentifierAsGoSlice() string { escapedNames := make([]string, 0, 3) for _, p := range []string{q.Table.Catalog, q.Table.Schema, q.Table.Name} { if p != "" { @@ -219,3 +243,13 @@ func (q Query) TableIdentifier() string { } return "[]string{" + strings.Join(escapedNames, ", ") + "}" } + +func (q Query) TableIdentifierForMySQL() string { + escapedNames := make([]string, 0, 3) + for _, p := range []string{q.Table.Catalog, q.Table.Schema, q.Table.Name} { + if p != "" { + escapedNames = append(escapedNames, fmt.Sprintf("`%s`", p)) + } + } + return strings.Join(escapedNames, ".") +} diff --git a/internal/codegen/golang/templates/go-sql-driver-mysql/copyfromCopy.tmpl b/internal/codegen/golang/templates/go-sql-driver-mysql/copyfromCopy.tmpl new file mode 100644 index 0000000000..7778458d1a --- /dev/null +++ b/internal/codegen/golang/templates/go-sql-driver-mysql/copyfromCopy.tmpl @@ -0,0 +1,60 @@ +{{define "copyfromCodeGoSqlDriver"}} +{{range .GoQueries}} +{{if eq .Cmd ":copyfrom" }} +var readerHandlerSequenceFor{{.MethodName}} uint32 = 1 + +func convertRowsFor{{.MethodName}}(w *io.PipeWriter, {{.Arg.SlicePair}}) { + e := mysqltsv.NewEncoder(w, {{ len .Arg.Fields }}, nil) + for _, row := range {{.Arg.Name}} { +{{- with $arg := .Arg }} +{{- range $arg.Fields}} +{{- if eq .Type "string"}} + e.AppendString({{if eq (len $arg.Fields) 1}}row{{else}}row.{{.Name}}{{end}}) +{{- else if eq .Type "[]byte"}} + e.AppendBytes({{if eq (len $arg.Fields) 1}}row{{else}}row.{{.Name}}{{end}}) +{{- else}} + e.AppendValue({{if eq (len $arg.Fields) 1}}row{{else}}row.{{.Name}}{{end}}) +{{- end}} +{{- end}} +{{- end}} + } + w.CloseWithError(e.Close()) +} + +{{range .Comments}}//{{.}} +{{end -}} +// {{.MethodName}} uses MySQL's LOAD DATA LOCAL INFILE and is not atomic. Errors and duplicate keys are treated as warnings and insertion will continue, even without an error for some cases. +// Use this in a transaction and use SHOW WARNINGS to check for any problems and roll back if you want to. +// This is a MySQL limitation, not sqlc. Check the documentation for more information: https://dev.mysql.com/doc/refman/8.0/en/load-data.html#load-data-error-handling +{{- if $.EmitMethodsWithDBArgument}} +func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.SlicePair}}) (int64, error) { + pr, pw := io.Pipe() + defer pr.Close() + rh := fmt.Sprintf("{{.MethodName}}_%d", atomic.AddUint32(&readerHandlerSequenceFor{{.MethodName}}, 1)) + mysql.RegisterReaderHandler(rh, func() io.Reader { return pr }) + defer mysql.DeregisterReaderHandler(rh) + go convertRowsFor{{.MethodName}}(pw, {{.Arg.Name}}) + result, err := db.ExecContext(ctx, fmt.Sprintf("LOAD DATA LOCAL INFILE '%s' INTO TABLE {{.TableIdentifierForMySQL}} %s ({{range $index, $name := .Arg.ColumnNames}}{{if gt $index 0}}, {{end}}{{$name}}{{end}})", "Reader::" + rh, mysqltsv.Escaping)) + if err != nil { + return 0, err + } + return result.RowsAffected() +{{- else}} +func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.SlicePair}}) (int64, error) { + pr, pw := io.Pipe() + defer pr.Close() + rh := fmt.Sprintf("{{.MethodName}}_%d", atomic.AddUint32(&readerHandlerSequenceFor{{.MethodName}}, 1)) + mysql.RegisterReaderHandler(rh, func() io.Reader { return pr }) + defer mysql.DeregisterReaderHandler(rh) + go convertRowsFor{{.MethodName}}(pw, {{.Arg.Name}}) + result, err := q.db.ExecContext(ctx, fmt.Sprintf("LOAD DATA LOCAL INFILE '%s' INTO TABLE {{.TableIdentifierForMySQL}} %s ({{range $index, $name := .Arg.ColumnNames}}{{if gt $index 0}}, {{end}}{{$name}}{{end}})", "Reader::" + rh, mysqltsv.Escaping)) + if err != nil { + return 0, err + } + return result.RowsAffected() +{{- end}} +} + +{{end}} +{{end}} +{{end}} diff --git a/internal/codegen/golang/templates/pgx/copyfromCopy.tmpl b/internal/codegen/golang/templates/pgx/copyfromCopy.tmpl index 5d1c66f866..c1cfa68d1d 100644 --- a/internal/codegen/golang/templates/pgx/copyfromCopy.tmpl +++ b/internal/codegen/golang/templates/pgx/copyfromCopy.tmpl @@ -39,10 +39,10 @@ func (r iteratorFor{{.MethodName}}) Err() error { {{end -}} {{- if $.EmitMethodsWithDBArgument -}} func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.SlicePair}}) (int64, error) { - return db.CopyFrom(ctx, {{.TableIdentifier}}, {{.Arg.ColumnNames}}, &iteratorFor{{.MethodName}}{rows: {{.Arg.Name}}}) + return db.CopyFrom(ctx, {{.TableIdentifierAsGoSlice}}, {{.Arg.ColumnNamesAsGoSlice}}, &iteratorFor{{.MethodName}}{rows: {{.Arg.Name}}}) {{- else -}} func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.SlicePair}}) (int64, error) { - return q.db.CopyFrom(ctx, {{.TableIdentifier}}, {{.Arg.ColumnNames}}, &iteratorFor{{.MethodName}}{rows: {{.Arg.Name}}}) + return q.db.CopyFrom(ctx, {{.TableIdentifierAsGoSlice}}, {{.Arg.ColumnNamesAsGoSlice}}, &iteratorFor{{.MethodName}}{rows: {{.Arg.Name}}}) {{- end}} } diff --git a/internal/codegen/golang/templates/template.tmpl b/internal/codegen/golang/templates/template.tmpl index 9b730db073..f0391e965d 100644 --- a/internal/codegen/golang/templates/template.tmpl +++ b/internal/codegen/golang/templates/template.tmpl @@ -186,6 +186,8 @@ import ( {{define "copyfromCode"}} {{if .SQLDriver.IsPGX }} {{- template "copyfromCodePgx" .}} +{{else}} + {{- template "copyfromCodeGoSqlDriver" .}} {{end}} {{end}} diff --git a/internal/endtoend/testdata/copyfrom/mysql/go/copyfrom.go b/internal/endtoend/testdata/copyfrom/mysql/go/copyfrom.go new file mode 100644 index 0000000000..da4534ca3e --- /dev/null +++ b/internal/endtoend/testdata/copyfrom/mysql/go/copyfrom.go @@ -0,0 +1,73 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.17.2 +// source: copyfrom.go + +package querytest + +import ( + "context" + "database/sql" + "fmt" + "github.com/go-sql-driver/mysql" + "github.com/hexon/mysqltsv" + "io" + "sync/atomic" +) + +var readerHandlerSequenceForInsertSingleValue uint32 = 1 + +func convertRowsForInsertSingleValue(w *io.PipeWriter, a []sql.NullString) { + e := mysqltsv.NewEncoder(w, 1, nil) + for _, row := range a { + e.AppendValue(row) + } + w.CloseWithError(e.Close()) +} + +// InsertSingleValue uses MySQL's LOAD DATA LOCAL INFILE and is not atomic. Errors and duplicate keys are treated as warnings and insertion will continue, even without an error for some cases. +// Use this in a transaction and use SHOW WARNINGS to check for any problems and roll back if you want to. +// This is a MySQL limitation, not sqlc. Check the documentation for more information: https://dev.mysql.com/doc/refman/8.0/en/load-data.html#load-data-error-handling +func (q *Queries) InsertSingleValue(ctx context.Context, a []sql.NullString) (int64, error) { + pr, pw := io.Pipe() + defer pr.Close() + rh := fmt.Sprintf("InsertSingleValue_%d", atomic.AddUint32(&readerHandlerSequenceForInsertSingleValue, 1)) + mysql.RegisterReaderHandler(rh, func() io.Reader { return pr }) + defer mysql.DeregisterReaderHandler(rh) + go convertRowsForInsertSingleValue(pw, a) + result, err := q.db.ExecContext(ctx, fmt.Sprintf("LOAD DATA LOCAL INFILE '%s' INTO TABLE `foo` %s (a)", "Reader::"+rh, mysqltsv.Escaping)) + if err != nil { + return 0, err + } + return result.RowsAffected() +} + +var readerHandlerSequenceForInsertValues uint32 = 1 + +func convertRowsForInsertValues(w *io.PipeWriter, arg []InsertValuesParams) { + e := mysqltsv.NewEncoder(w, 4, nil) + for _, row := range arg { + e.AppendValue(row.A) + e.AppendValue(row.B) + e.AppendValue(row.C) + e.AppendValue(row.D) + } + w.CloseWithError(e.Close()) +} + +// InsertValues uses MySQL's LOAD DATA LOCAL INFILE and is not atomic. Errors and duplicate keys are treated as warnings and insertion will continue, even without an error for some cases. +// Use this in a transaction and use SHOW WARNINGS to check for any problems and roll back if you want to. +// This is a MySQL limitation, not sqlc. Check the documentation for more information: https://dev.mysql.com/doc/refman/8.0/en/load-data.html#load-data-error-handling +func (q *Queries) InsertValues(ctx context.Context, arg []InsertValuesParams) (int64, error) { + pr, pw := io.Pipe() + defer pr.Close() + rh := fmt.Sprintf("InsertValues_%d", atomic.AddUint32(&readerHandlerSequenceForInsertValues, 1)) + mysql.RegisterReaderHandler(rh, func() io.Reader { return pr }) + defer mysql.DeregisterReaderHandler(rh) + go convertRowsForInsertValues(pw, arg) + result, err := q.db.ExecContext(ctx, fmt.Sprintf("LOAD DATA LOCAL INFILE '%s' INTO TABLE `foo` %s (a, b, c, d)", "Reader::"+rh, mysqltsv.Escaping)) + if err != nil { + return 0, err + } + return result.RowsAffected() +} diff --git a/internal/endtoend/testdata/copyfrom/mysql/go/db.go b/internal/endtoend/testdata/copyfrom/mysql/go/db.go new file mode 100644 index 0000000000..02974bda59 --- /dev/null +++ b/internal/endtoend/testdata/copyfrom/mysql/go/db.go @@ -0,0 +1,31 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.17.2 + +package querytest + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/copyfrom/mysql/go/models.go b/internal/endtoend/testdata/copyfrom/mysql/go/models.go new file mode 100644 index 0000000000..1ef3a1e3f2 --- /dev/null +++ b/internal/endtoend/testdata/copyfrom/mysql/go/models.go @@ -0,0 +1,16 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.17.2 + +package querytest + +import ( + "database/sql" +) + +type Foo struct { + A sql.NullString + B sql.NullInt32 + C sql.NullTime + D sql.NullTime +} diff --git a/internal/endtoend/testdata/copyfrom/mysql/go/query.sql.go b/internal/endtoend/testdata/copyfrom/mysql/go/query.sql.go new file mode 100644 index 0000000000..dd6965fc4f --- /dev/null +++ b/internal/endtoend/testdata/copyfrom/mysql/go/query.sql.go @@ -0,0 +1,25 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.17.2 +// source: query.sql + +package querytest + +import ( + "database/sql" +) + +const insertSingleValue = `-- name: InsertSingleValue :copyfrom +INSERT INTO foo (a) VALUES (?) +` + +const insertValues = `-- name: InsertValues :copyfrom +INSERT INTO foo (a, b, c, d) VALUES (?, ?, ?, ?) +` + +type InsertValuesParams struct { + A sql.NullString + B sql.NullInt32 + C sql.NullTime + D sql.NullTime +} diff --git a/internal/endtoend/testdata/copyfrom/mysql/query.sql b/internal/endtoend/testdata/copyfrom/mysql/query.sql new file mode 100644 index 0000000000..577655c4aa --- /dev/null +++ b/internal/endtoend/testdata/copyfrom/mysql/query.sql @@ -0,0 +1,7 @@ +CREATE TABLE foo (a text, b integer, c DATETIME, d DATE); + +-- name: InsertValues :copyfrom +INSERT INTO foo (a, b, c, d) VALUES (?, ?, ?, ?); + +-- name: InsertSingleValue :copyfrom +INSERT INTO foo (a) VALUES (?); diff --git a/internal/endtoend/testdata/copyfrom/mysql/sqlc.json b/internal/endtoend/testdata/copyfrom/mysql/sqlc.json new file mode 100644 index 0000000000..5ae94271b1 --- /dev/null +++ b/internal/endtoend/testdata/copyfrom/mysql/sqlc.json @@ -0,0 +1,14 @@ +{ + "version": "1", + "packages": [ + { + "path": "go", + "sql_package": "database/sql", + "sql_driver": "github.com/go-sql-driver/mysql", + "engine": "mysql", + "name": "querytest", + "schema": "query.sql", + "queries": "query.sql" + } + ] +} diff --git a/internal/endtoend/testdata/go.mod b/internal/endtoend/testdata/go.mod index a480880caf..911dfd65d4 100644 --- a/internal/endtoend/testdata/go.mod +++ b/internal/endtoend/testdata/go.mod @@ -3,8 +3,10 @@ module github.com/kyleconroy/sqlc/endtoend go 1.18 require ( + github.com/go-sql-driver/mysql v1.7.0 github.com/gofrs/uuid v4.0.0+incompatible github.com/google/uuid v1.3.0 + github.com/hexon/mysqltsv v0.1.0 github.com/jackc/pgconn v1.5.1-0.20200601181101-fa742c524853 github.com/jackc/pgtype v1.6.2 github.com/jackc/pgx/v4 v4.6.1-0.20200606145419-4e5062306904 diff --git a/internal/endtoend/testdata/go.sum b/internal/endtoend/testdata/go.sum index bdaacd373a..6a33fd7e57 100644 --- a/internal/endtoend/testdata/go.sum +++ b/internal/endtoend/testdata/go.sum @@ -9,6 +9,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/friendsofgo/errors v0.9.2 h1:X6NYxef4efCBdwI7BgS820zFaN7Cphrmb+Pljdzjtgk= github.com/friendsofgo/errors v0.9.2/go.mod h1:yCvFW5AkDIL9qn7suHVLiI/gH228n7PC4Pn44IGoTOI= +github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc= +github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/gofrs/uuid v3.2.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= github.com/gofrs/uuid v4.0.0+incompatible h1:1SD/1F5pU8p29ybwgQSwpQk+mwdRrXCYuPhW6m+TnJw= @@ -16,6 +18,8 @@ github.com/gofrs/uuid v4.0.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRx github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/hexon/mysqltsv v0.1.0 h1:48wYQlsPH8ZEkKAVCdsOYzMYAlEoevw8ZWD8rqYPdlg= +github.com/hexon/mysqltsv v0.1.0/go.mod h1:p3vPBkpxebjHWF1bWKYNcXx5pFu+yAG89QZQEKSvVrY= github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= github.com/jackc/chunkreader/v2 v2.0.1 h1:i+RDz65UE+mmpjTfyz0MoVTnzeYxroil2G82ki7MGG8=