Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: use native ScanType from driver and enhance RowBuffer to understand more types #18

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 99 additions & 33 deletions dump.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ import (
/*
Data struct to configure dump behavior

Out: Stream to wite to
Connection: Database connection to dump
IgnoreTables: Mark sensitive tables to ignore
MaxAllowedPacket: Sets the largest packet size to use in backups
LockTables: Lock all tables for the duration of the dump
Out: Stream to wite to
Connection: Database connection to dump
IgnoreTables: Mark sensitive tables to ignore
MaxAllowedPacket: Sets the largest packet size to use in backups
LockTables: Lock all tables for the duration of the dump
*/
type Data struct {
Out io.Writer
Expand Down Expand Up @@ -68,7 +68,7 @@ const headerTmpl = `-- Go SQL Dump {{ .DumpVersion }}
/*!40101 SET @OLD_CHARACTER_SET_CLIENT=@@CHARACTER_SET_CLIENT */;
/*!40101 SET @OLD_CHARACTER_SET_RESULTS=@@CHARACTER_SET_RESULTS */;
/*!40101 SET @OLD_COLLATION_CONNECTION=@@COLLATION_CONNECTION */;
SET NAMES utf8mb4 ;
/*!50503 SET NAMES UTF8 */;
/*!40103 SET @OLD_TIME_ZONE=@@TIME_ZONE */;
/*!40103 SET TIME_ZONE='+00:00' */;
/*!40014 SET @OLD_UNIQUE_CHECKS=@@UNIQUE_CHECKS, UNIQUE_CHECKS=0 */;
Expand Down Expand Up @@ -99,7 +99,7 @@ const tableTmpl = `

DROP TABLE IF EXISTS {{ .NameEsc }};
/*!40101 SET @saved_cs_client = @@character_set_client */;
SET character_set_client = utf8mb4 ;
/*!50503 SET character_set_client = utf8mb4 */;
{{ .CreateSQL }};
/*!40101 SET character_set_client = @saved_cs_client */;

Expand Down Expand Up @@ -296,7 +296,7 @@ func (table *table) CreateSQL() (string, error) {
}

if tableReturn.String != table.Name {
return "", errors.New("Returned table is not the same as requested table")
return "", errors.New("returned table is not the same as requested table")
}

return tableSQL.String, nil
Expand Down Expand Up @@ -389,29 +389,19 @@ func (table *table) Init() error {
}

func reflectColumnType(tp *sql.ColumnType) reflect.Type {
// reflect for scanable
switch tp.ScanType().Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return reflect.TypeOf(sql.NullInt64{})
case reflect.Float32, reflect.Float64:
return reflect.TypeOf(sql.NullFloat64{})
case reflect.String:
return reflect.TypeOf(sql.NullString{})
}

// determine by name
// workaround https://github.com/go-sql-driver/mysql/pull/1424 till it's released
nullable, _ := tp.Nullable()
switch tp.DatabaseTypeName() {
case "BLOB", "BINARY":
return reflect.TypeOf(sql.RawBytes{})
case "VARCHAR", "TEXT", "DECIMAL", "JSON":
return reflect.TypeOf(sql.NullString{})
case "BIGINT", "TINYINT", "INT":
return reflect.TypeOf(sql.NullInt64{})
case "DOUBLE":
return reflect.TypeOf(sql.NullFloat64{})
case "TINYBLOB", "MEDIUMBLOB", "LONGBLOB", "BLOB",
"VARBINARY", "BINARY", "BIT", "GEOMETRY":
return reflect.TypeOf([]byte{})
case "TINYTEXT", "MEDIUMTEXT", "LONGTEXT", "TEXT",
"VARCHAR", "CHAR", "DECIMAL", "ENUM", "SET", "JSON", "TIME":
if nullable {
return reflect.TypeOf(sql.NullString{})
}
return reflect.TypeOf("")
}

// unknown datatype
return tp.ScanType()
}

Expand Down Expand Up @@ -443,6 +433,30 @@ func (table *table) RowValues() string {
return table.RowBuffer().String()
}

func writeString(b *bytes.Buffer, s string) {
fmt.Fprintf(b, "'%s'", sanitize(s))
}

func writeBool(b *bytes.Buffer, s bool) {
if s {
fmt.Fprintf(b, "1")
} else {
fmt.Fprintf(b, "0")
}
}

func writeBinary(b *bytes.Buffer, s []byte) {
if len(s) == 0 {
b.WriteString(nullType)
} else {
fmt.Fprintf(b, "_binary '%s'", sanitize(string(s)))
}
}

func writeTime(b *bytes.Buffer, s time.Time) {
fmt.Fprintf(b, "'%s'", sanitize(s.UTC().Format(time.DateTime)))
}

func (table *table) RowBuffer() *bytes.Buffer {
var b bytes.Buffer
b.WriteString("(")
Expand All @@ -454,9 +468,51 @@ func (table *table) RowBuffer() *bytes.Buffer {
switch s := value.(type) {
case nil:
b.WriteString(nullType)
case *string:
writeString(&b, *s)
case *sql.NullString:
if s.Valid {
fmt.Fprintf(&b, "'%s'", sanitize(s.String))
writeString(&b, s.String)
} else {
b.WriteString(nullType)
}
case *bool:
writeBool(&b, *s)
case *sql.NullBool:
if s.Valid {
writeBool(&b, s.Bool)
} else {
b.WriteString(nullType)
}
case *uint:
fmt.Fprintf(&b, "%d", *s)
case *uint8:
fmt.Fprintf(&b, "%d", *s)
case *uint16:
fmt.Fprintf(&b, "%d", *s)
case *uint32:
fmt.Fprintf(&b, "%d", *s)
case *uint64:
fmt.Fprintf(&b, "%d", *s)
case *int:
fmt.Fprintf(&b, "%d", *s)
case *int8:
fmt.Fprintf(&b, "%d", *s)
case *int16:
fmt.Fprintf(&b, "%d", *s)
case *int32:
fmt.Fprintf(&b, "%d", *s)
case *int64:
fmt.Fprintf(&b, "%d", *s)
case *sql.NullInt16:
if s.Valid {
fmt.Fprintf(&b, "%d", s.Int16)
} else {
b.WriteString(nullType)
}
case *sql.NullInt32:
if s.Valid {
fmt.Fprintf(&b, "%d", s.Int32)
} else {
b.WriteString(nullType)
}
Expand All @@ -466,17 +522,27 @@ func (table *table) RowBuffer() *bytes.Buffer {
} else {
b.WriteString(nullType)
}
case *float32:
fmt.Fprintf(&b, "%f", *s)
case *float64:
fmt.Fprintf(&b, "%f", *s)
case *sql.NullFloat64:
if s.Valid {
fmt.Fprintf(&b, "%f", s.Float64)
} else {
b.WriteString(nullType)
}
case *[]byte:
writeBinary(&b, *s)
case *sql.RawBytes:
if len(*s) == 0 {
b.WriteString(nullType)
writeBinary(&b, *s)
case *time.Time:
writeTime(&b, *s)
case *sql.NullTime:
if s.Valid {
writeTime(&b, s.Time)
} else {
fmt.Fprintf(&b, "_binary '%s'", sanitize(string(*s)))
b.WriteString(nullType)
}
default:
fmt.Fprintf(&b, "'%s'", value)
Expand Down
10 changes: 5 additions & 5 deletions dump_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ func TestCreateTableAllValuesWithNil(t *testing.T) {
AddRow("email", "").
AddRow("name", "")

rows := sqlmock.NewRowsWithColumnDefinition(c("id", 0), c("email", ""), c("name", "")).
rows := sqlmock.NewRowsWithColumnDefinition(c("id", 0), c("email", sql.NullString{}), c("name", "")).
AddRow(1, nil, "Test Name 1").
AddRow(2, "[email protected]", "Test Name 2").
AddRow(3, "", "Test Name 3")
Expand Down Expand Up @@ -266,7 +266,7 @@ func TestCreateTableOk(t *testing.T) {
AddRow("email", "").
AddRow("name", "")

createTableValueRows := sqlmock.NewRowsWithColumnDefinition(c("id", 0), c("email", ""), c("name", "")).
createTableValueRows := sqlmock.NewRowsWithColumnDefinition(c("id", 0), c("email", sql.NullString{}), c("name", "")).
AddRow(1, nil, "Test Name 1").
AddRow(2, "[email protected]", "Test Name 2")

Expand Down Expand Up @@ -294,7 +294,7 @@ func TestCreateTableOk(t *testing.T) {

DROP TABLE IF EXISTS ~Test_Table~;
/*!40101 SET @saved_cs_client = @@character_set_client */;
SET character_set_client = utf8mb4 ;
/*!50503 SET character_set_client = utf8mb4 */;
CREATE TABLE 'Test_Table' (~id~ int(11) NOT NULL AUTO_INCREMENT,~s~ char(60) DEFAULT NULL, PRIMARY KEY (~id~))ENGINE=InnoDB DEFAULT CHARSET=latin1;
/*!40101 SET character_set_client = @saved_cs_client */;

Expand Down Expand Up @@ -325,7 +325,7 @@ func TestCreateTableOkSmallPackets(t *testing.T) {
AddRow("email", "").
AddRow("name", "")

createTableValueRows := sqlmock.NewRowsWithColumnDefinition(c("id", 0), c("email", ""), c("name", "")).
createTableValueRows := sqlmock.NewRowsWithColumnDefinition(c("id", 0), c("email", sql.NullString{}), c("name", "")).
AddRow(1, nil, "Test Name 1").
AddRow(2, "[email protected]", "Test Name 2")

Expand Down Expand Up @@ -353,7 +353,7 @@ func TestCreateTableOkSmallPackets(t *testing.T) {

DROP TABLE IF EXISTS ~Test_Table~;
/*!40101 SET @saved_cs_client = @@character_set_client */;
SET character_set_client = utf8mb4 ;
/*!50503 SET character_set_client = utf8mb4 */;
CREATE TABLE 'Test_Table' (~id~ int(11) NOT NULL AUTO_INCREMENT,~s~ char(60) DEFAULT NULL, PRIMARY KEY (~id~))ENGINE=InnoDB DEFAULT CHARSET=latin1;
/*!40101 SET character_set_client = @saved_cs_client */;

Expand Down
2 changes: 1 addition & 1 deletion mysqldump.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ Register a new dumper.
*/
func Register(db *sql.DB, dir, format string) (*Data, error) {
if !isDir(dir) {
return nil, errors.New("Invalid directory")
return nil, errors.New("invalid directory")
}

name := time.Now().Format(format)
Expand Down
Loading