diff --git a/README.md b/README.md index 344a918d..18b536fd 100644 --- a/README.md +++ b/README.md @@ -60,6 +60,18 @@ For example for PostgreSQL you could pass `jsonb`and it will be supported, howev * `after` - (MySQL Only) Add a column after another column in the table. `example: {"after":"created_at"}` * `first` - (MySQL Only) Add a column to the first position in the table. `example: {"first": true}` +#### Composite primary key + +```javascript +create_table("user_privileges") { + t.Column("user_id", "int") + t.Column("privilege_id", "int") + t.PrimaryKey("user_id", "privilege_id") +} +``` + +Please note that the `t.PrimaryKey` statement MUST be after the columns definitions. + ## Drop a Table ``` javascript diff --git a/tables.go b/tables.go index 44fdc9f3..f7a4f03d 100644 --- a/tables.go +++ b/tables.go @@ -17,6 +17,7 @@ type Table struct { Columns []Column Indexes []Index ForeignKeys []ForeignKey + primaryKeys []string Options map[string]interface{} columnsCache map[string]struct{} } @@ -61,6 +62,14 @@ func (t Table) Fizz() string { if timestampsOpt { buff.WriteString("\tt.Timestamps()\n") } + // Write primary key (single column pk will be written in inline form as the column opt) + if len(t.primaryKeys) > 1 { + pks := make([]string, len(t.primaryKeys)) + for i, pk := range t.primaryKeys { + pks[i] = fmt.Sprintf("\"%s\"", pk) + } + buff.WriteString(fmt.Sprintf("\tt.PrimaryKey(%s)\n", strings.Join(pks, ", "))) + } // Write indexes for _, i := range t.Indexes { buff.WriteString(fmt.Sprintf("\t%s\n", i.String())) @@ -89,7 +98,11 @@ func (t *Table) Column(name string, colType string, options Options) error { } var primary bool if _, ok := options["primary"]; ok { + if t.primaryKeys != nil { + return errors.New("could not define multiple primary keys") + } primary = true + t.primaryKeys = []string{name} } c := Column{ Name: name, @@ -114,7 +127,7 @@ func (t *Table) Column(name string, colType string, options Options) error { func (t *Table) ForeignKey(column string, refs interface{}, options Options) error { fkr, err := parseForeignKeyRef(refs) if err != nil { - return errors.WithStack(err) + return errors.Wrap(err, "could not parse foreign key") } fk := ForeignKey{ Column: column, @@ -178,6 +191,36 @@ func (t *Table) Timestamps() error { return t.Timestamp("updated_at") } +// PrimaryKey adds a primary key to the table. It's useful to define a composite +// primary key. +func (t *Table) PrimaryKey(pk ...string) error { + if len(pk) == 0 { + return errors.New("missing columns for primary key") + } + if t.primaryKeys != nil { + return errors.New("duplicate primary key") + } + if !t.HasColumns(pk...) { + return errors.New("columns must be declared before the primary key") + } + if len(pk) == 1 { + for i, c := range t.Columns { + if c.Name == pk[0] { + t.Columns[i].Primary = true + break + } + } + } + t.primaryKeys = make([]string, 0) + t.primaryKeys = append(t.primaryKeys, pk...) + return nil +} + +// PrimaryKeys gets the list of registered primary key fields. +func (t *Table) PrimaryKeys() []string { + return t.primaryKeys +} + // ColumnNames returns the names of the Table's columns. func (t *Table) ColumnNames() []string { cols := make([]string, len(t.Columns)) diff --git a/tables_test.go b/tables_test.go index 8a26703a..3a9d4059 100644 --- a/tables_test.go +++ b/tables_test.go @@ -204,3 +204,67 @@ func Test_Table_AddEmptyIndex(t *testing.T) { r.NoError(table.Column("email", "string", nil)) r.Error(table.Index([]string{}, nil)) } + +func Test_Table_AddPrimaryKey(t *testing.T) { + r := require.New(t) + + // Add single primary key + expected := + `create_table("users") { + t.Column("id", "int", {primary: true}) + t.Column("name", "string") + t.Column("email", "string") + t.Timestamps() +}` + table := fizz.NewTable("users", nil) + r.NoError(table.Column("id", "int", fizz.Options{"primary": true})) + r.NoError(table.Column("name", "string", nil)) + r.NoError(table.Column("email", "string", nil)) + r.Equal(expected, table.String()) + + table = fizz.NewTable("users", nil) + r.NoError(table.Column("id", "int", nil)) + r.NoError(table.Column("name", "string", nil)) + r.NoError(table.Column("email", "string", nil)) + r.NoError(table.PrimaryKey("id")) + r.Equal(expected, table.String()) + + // Add composite primary key + expected = + `create_table("user_privileges") { + t.Column("user_id", "int") + t.Column("privilege_id", "int") + t.Timestamps() + t.PrimaryKey("user_id", "privilege_id") +}` + table = fizz.NewTable("user_privileges", nil) + r.NoError(table.Column("user_id", "int", nil)) + r.NoError(table.Column("privilege_id", "int", nil)) + r.NoError(table.PrimaryKey("user_id", "privilege_id")) + r.Equal(expected, table.String()) +} + +func Test_Table_AddPrimaryKey_Errors(t *testing.T) { + r := require.New(t) + + // Primary key on unknown column + table := fizz.NewTable("users", nil) + r.NoError(table.Column("id", "int", nil)) + r.Error(table.PrimaryKey("id2")) + + // Duplicate primary key + table = fizz.NewTable("users", nil) + r.NoError(table.Column("id", "int", nil)) + r.NoError(table.PrimaryKey("id")) + r.Error(table.PrimaryKey("id")) + + // Duplicate primary key + table = fizz.NewTable("users", nil) + r.NoError(table.Column("id", "int", fizz.Options{"primary": true})) + r.Error(table.PrimaryKey("id")) + + // Duplicate inline primary key + table = fizz.NewTable("users", nil) + r.NoError(table.Column("id", "int", fizz.Options{"primary": true})) + r.Error(table.Column("id2", "int", fizz.Options{"primary": true})) +} diff --git a/translators/cockroach.go b/translators/cockroach.go index 15ef275d..9190bef6 100644 --- a/translators/cockroach.go +++ b/translators/cockroach.go @@ -49,6 +49,15 @@ func (p *Cockroach) CreateTable(t fizz.Table) (string, error) { cols = append(cols, p.buildForeignKey(t, fk, true)) } + primaryKeys := t.PrimaryKeys() + if len(primaryKeys) > 1 { + pks := make([]string, len(primaryKeys)) + for i, pk := range primaryKeys { + pks[i] = fmt.Sprintf("\"%s\"", pk) + } + cols = append(cols, fmt.Sprintf("PRIMARY KEY(%s)", strings.Join(pks, ", "))) + } + s = fmt.Sprintf("CREATE TABLE \"%s\" (\n%s\n);COMMIT TRANSACTION;BEGIN TRANSACTION;", t.Name, strings.Join(cols, ",\n")) sql = append(sql, s) diff --git a/translators/cockroach_test.go b/translators/cockroach_test.go index 99f96e44..753b1f50 100644 --- a/translators/cockroach_test.go +++ b/translators/cockroach_test.go @@ -122,6 +122,26 @@ CONSTRAINT profiles_users_id_fk FOREIGN KEY (user_id) REFERENCES users (id) r.Equal(ddl, res) } +func (p *CockroachSuite) Test_Cockroach_CreateTables_WithCompositePrimaryKey() { + r := p.Require() + ddl := `CREATE TABLE "user_profiles" ( +"user_id" INT NOT NULL, +"profile_id" INT NOT NULL, +"created_at" timestamp NOT NULL, +"updated_at" timestamp NOT NULL, +PRIMARY KEY("user_id", "profile_id") +);COMMIT TRANSACTION;BEGIN TRANSACTION;` + + res, _ := fizz.AString(` + create_table("user_profiles") { + t.Column("user_id", "INT") + t.Column("profile_id", "INT") + t.PrimaryKey("user_id", "profile_id") + } + `, p.crdbt()) + r.Equal(ddl, res) +} + func (p *CockroachSuite) Test_Cockroach_DropTable() { r := p.Require() diff --git a/translators/mssqlserver.go b/translators/mssqlserver.go index 894c8c83..5c603526 100644 --- a/translators/mssqlserver.go +++ b/translators/mssqlserver.go @@ -8,9 +8,10 @@ import ( "github.com/pkg/errors" ) -type MsSqlServer struct { -} +// MsSqlServer is a MS SqlServer-specific translator. +type MsSqlServer struct{} +// NewMsSqlServer constructs a new MsSqlServer translator. func NewMsSqlServer() *MsSqlServer { return &MsSqlServer{} } @@ -31,6 +32,15 @@ func (p *MsSqlServer) CreateTable(t fizz.Table) (string, error) { cols = append(cols, s) } + primaryKeys := t.PrimaryKeys() + if len(primaryKeys) > 1 { + pks := make([]string, len(primaryKeys)) + for i, pk := range primaryKeys { + pks[i] = fmt.Sprintf("[%s]", pk) + } + cols = append(cols, fmt.Sprintf("PRIMARY KEY(%s)", strings.Join(pks, ", "))) + } + s = fmt.Sprintf("CREATE TABLE %s (\n%s\n);", t.Name, strings.Join(cols, ",\n")) sql = append(sql, s) diff --git a/translators/mssqlserver_test.go b/translators/mssqlserver_test.go index e69e9cd9..896e88a6 100644 --- a/translators/mssqlserver_test.go +++ b/translators/mssqlserver_test.go @@ -116,6 +116,26 @@ ALTER TABLE profiles ADD CONSTRAINT profiles_users_id_fk FOREIGN KEY (user_id) R r.Equal(ddl, res) } +func (p *MsSqlServerSQLSuite) Test_MsSqlServer_CreateTables_WithCompositePrimaryKey() { + r := p.Require() + ddl := `CREATE TABLE user_profiles ( +user_id INT NOT NULL, +profile_id INT NOT NULL, +created_at DATETIME NOT NULL, +updated_at DATETIME NOT NULL +PRIMARY KEY([user_id], [profile_id]) +);` + + res, _ := fizz.AString(` + create_table("user_profiles") { + t.Column("user_id", "INT") + t.Column("profile_id", "INT") + t.PrimaryKey("user_id", "profile_id") + } + `, sqlsrv) + r.Equal(ddl, res) +} + func (p *MsSqlServerSQLSuite) Test_MsSqlServer_DropTable() { r := p.Require() diff --git a/translators/mysql.go b/translators/mysql.go index 1dba588a..3aa25411 100644 --- a/translators/mysql.go +++ b/translators/mysql.go @@ -40,6 +40,15 @@ func (p *MySQL) CreateTable(t fizz.Table) (string, error) { cols = append(cols, p.buildForeignKey(t, fk, true)) } + primaryKeys := t.PrimaryKeys() + if len(primaryKeys) > 1 { + pks := make([]string, len(primaryKeys)) + for i, pk := range primaryKeys { + pks[i] = fmt.Sprintf("`%s`", pk) + } + cols = append(cols, fmt.Sprintf("PRIMARY KEY(%s)", strings.Join(pks, ", "))) + } + s := fmt.Sprintf("CREATE TABLE %s (\n%s\n) ENGINE=InnoDB;", p.escapeIdentifier(t.Name), strings.Join(cols, ",\n")) sql = append(sql, s) diff --git a/translators/mysql_test.go b/translators/mysql_test.go index 2bc69ab2..9b9efaa7 100644 --- a/translators/mysql_test.go +++ b/translators/mysql_test.go @@ -142,6 +142,26 @@ FOREIGN KEY (` + "`user_id`" + `) REFERENCES ` + "`users`" + ` (` + "`id`" + `) r.Equal(ddl, res) } +func (p *MySQLSuite) Test_MySQL_CreateTables_WithCompositePrimaryKey() { + r := p.Require() + ddl := `CREATE TABLE ` + "`user_profiles`" + ` ( +` + "`user_id`" + ` INTEGER NOT NULL, +` + "`profile_id`" + ` INTEGER NOT NULL, +` + "`created_at`" + ` DATETIME NOT NULL, +` + "`updated_at`" + ` DATETIME NOT NULL, +PRIMARY KEY(` + "`user_id`" + `, ` + "`profile_id`" + `) +) ENGINE=InnoDB;` + + res, _ := fizz.AString(` + create_table("user_profiles") { + t.Column("user_id", "INT") + t.Column("profile_id", "INT") + t.PrimaryKey("user_id", "profile_id") + } + `, myt) + r.Equal(ddl, res) +} + func (p *MySQLSuite) Test_MySQL_DropTable() { r := p.Require() diff --git a/translators/postgres.go b/translators/postgres.go index 4c9346c1..3f442bc0 100644 --- a/translators/postgres.go +++ b/translators/postgres.go @@ -41,6 +41,15 @@ func (p *Postgres) CreateTable(t fizz.Table) (string, error) { cols = append(cols, p.buildForeignKey(t, fk, true)) } + primaryKeys := t.PrimaryKeys() + if len(primaryKeys) > 1 { + pks := make([]string, len(primaryKeys)) + for i, pk := range primaryKeys { + pks[i] = fmt.Sprintf("\"%s\"", pk) + } + cols = append(cols, fmt.Sprintf("PRIMARY KEY(%s)", strings.Join(pks, ", "))) + } + s = fmt.Sprintf("CREATE TABLE \"%s\" (\n%s\n);", t.Name, strings.Join(cols, ",\n")) sql = append(sql, s) diff --git a/translators/postgres_test.go b/translators/postgres_test.go index 5594d0e1..1c6d2ea7 100644 --- a/translators/postgres_test.go +++ b/translators/postgres_test.go @@ -116,7 +116,7 @@ PRIMARY KEY("uuid") r.Equal(ddl, res) } -func (p *PostgreSQLSuite) Test_Postgre_CreateTables_WithForeignKeys() { +func (p *PostgreSQLSuite) Test_Postgres_CreateTables_WithForeignKeys() { r := p.Require() ddl := `CREATE TABLE "users" ( "id" SERIAL NOT NULL, @@ -153,6 +153,26 @@ FOREIGN KEY (user_id) REFERENCES users (id) r.Equal(ddl, res) } +func (p *PostgreSQLSuite) Test_Postgres_CreateTables_WithCompositePrimaryKey() { + r := p.Require() + ddl := `CREATE TABLE "user_profiles" ( +"user_id" INT NOT NULL, +"profile_id" INT NOT NULL, +"created_at" timestamp NOT NULL, +"updated_at" timestamp NOT NULL, +PRIMARY KEY("user_id", "profile_id") +);` + + res, _ := fizz.AString(` + create_table("user_profiles") { + t.Column("user_id", "INT") + t.Column("profile_id", "INT") + t.PrimaryKey("user_id", "profile_id") + } + `, pgt) + r.Equal(ddl, res) +} + func (p *PostgreSQLSuite) Test_Postgres_DropTable() { r := p.Require() diff --git a/translators/sqlite.go b/translators/sqlite.go index 8a00a6ef..71f2317c 100644 --- a/translators/sqlite.go +++ b/translators/sqlite.go @@ -51,6 +51,15 @@ func (p *SQLite) CreateTable(t fizz.Table) (string, error) { cols = append(cols, p.buildForeignKey(t, fk, true)) } + primaryKeys := t.PrimaryKeys() + if len(primaryKeys) > 1 { + pks := make([]string, len(primaryKeys)) + for i, pk := range primaryKeys { + pks[i] = fmt.Sprintf("\"%s\"", pk) + } + cols = append(cols, fmt.Sprintf("PRIMARY KEY(%s)", strings.Join(pks, ", "))) + } + s = fmt.Sprintf("CREATE TABLE \"%s\" (\n%s\n);", t.Name, strings.Join(cols, ",\n")) sql = append(sql, s) diff --git a/translators/sqlite_test.go b/translators/sqlite_test.go index 555804e8..4e022a4e 100644 --- a/translators/sqlite_test.go +++ b/translators/sqlite_test.go @@ -125,6 +125,26 @@ func (p *SQLiteSuite) Test_SQLite_CreateTable_UUID() { r.Equal(ddl, res) } +func (p *SQLiteSuite) Test_SQLite_CreateTables_WithCompositePrimaryKey() { + r := p.Require() + ddl := `CREATE TABLE "user_profiles" ( +"user_id" INTEGER NOT NULL, +"profile_id" INTEGER NOT NULL, +"created_at" DATETIME NOT NULL, +"updated_at" DATETIME NOT NULL, +PRIMARY KEY("user_id", "profile_id") +);` + + res, _ := fizz.AString(` + create_table("user_profiles") { + t.Column("user_id", "INT") + t.Column("profile_id", "INT") + t.PrimaryKey("user_id", "profile_id") + } + `, sqt) + r.Equal(ddl, res) +} + func (p *SQLiteSuite) Test_SQLite_DropTable() { r := p.Require()