Skip to content

Commit

Permalink
Composite primary keys (#47)
Browse files Browse the repository at this point in the history
* WIP composite primary keys

* Add tests for t.PrimaryKey()

* Add composite primary key support in translators

* Fix remaining issues

* Add doc about composite primary key

* Change primaryKey field name to primaryKeys

This way it matches the getter name.
  • Loading branch information
stanislas-m authored Apr 13, 2019
1 parent b1746c6 commit 42b3487
Show file tree
Hide file tree
Showing 13 changed files with 269 additions and 4 deletions.
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,18 @@ For example for PostgreSQL you could pass `jsonb`and it will be supported, howev
* `after` - (MySQL Only) Add a column after another column in the table. `example: {"after":"created_at"}`
* `first` - (MySQL Only) Add a column to the first position in the table. `example: {"first": true}`

#### Composite primary key

```javascript
create_table("user_privileges") {
t.Column("user_id", "int")
t.Column("privilege_id", "int")
t.PrimaryKey("user_id", "privilege_id")
}
```

Please note that the `t.PrimaryKey` statement MUST be after the columns definitions.

## Drop a Table

``` javascript
Expand Down
45 changes: 44 additions & 1 deletion tables.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ type Table struct {
Columns []Column
Indexes []Index
ForeignKeys []ForeignKey
primaryKeys []string
Options map[string]interface{}
columnsCache map[string]struct{}
}
Expand Down Expand Up @@ -61,6 +62,14 @@ func (t Table) Fizz() string {
if timestampsOpt {
buff.WriteString("\tt.Timestamps()\n")
}
// Write primary key (single column pk will be written in inline form as the column opt)
if len(t.primaryKeys) > 1 {
pks := make([]string, len(t.primaryKeys))
for i, pk := range t.primaryKeys {
pks[i] = fmt.Sprintf("\"%s\"", pk)
}
buff.WriteString(fmt.Sprintf("\tt.PrimaryKey(%s)\n", strings.Join(pks, ", ")))
}
// Write indexes
for _, i := range t.Indexes {
buff.WriteString(fmt.Sprintf("\t%s\n", i.String()))
Expand Down Expand Up @@ -89,7 +98,11 @@ func (t *Table) Column(name string, colType string, options Options) error {
}
var primary bool
if _, ok := options["primary"]; ok {
if t.primaryKeys != nil {
return errors.New("could not define multiple primary keys")
}
primary = true
t.primaryKeys = []string{name}
}
c := Column{
Name: name,
Expand All @@ -114,7 +127,7 @@ func (t *Table) Column(name string, colType string, options Options) error {
func (t *Table) ForeignKey(column string, refs interface{}, options Options) error {
fkr, err := parseForeignKeyRef(refs)
if err != nil {
return errors.WithStack(err)
return errors.Wrap(err, "could not parse foreign key")
}
fk := ForeignKey{
Column: column,
Expand Down Expand Up @@ -178,6 +191,36 @@ func (t *Table) Timestamps() error {
return t.Timestamp("updated_at")
}

// PrimaryKey adds a primary key to the table. It's useful to define a composite
// primary key.
func (t *Table) PrimaryKey(pk ...string) error {
if len(pk) == 0 {
return errors.New("missing columns for primary key")
}
if t.primaryKeys != nil {
return errors.New("duplicate primary key")
}
if !t.HasColumns(pk...) {
return errors.New("columns must be declared before the primary key")
}
if len(pk) == 1 {
for i, c := range t.Columns {
if c.Name == pk[0] {
t.Columns[i].Primary = true
break
}
}
}
t.primaryKeys = make([]string, 0)
t.primaryKeys = append(t.primaryKeys, pk...)
return nil
}

// PrimaryKeys gets the list of registered primary key fields.
func (t *Table) PrimaryKeys() []string {
return t.primaryKeys
}

// ColumnNames returns the names of the Table's columns.
func (t *Table) ColumnNames() []string {
cols := make([]string, len(t.Columns))
Expand Down
64 changes: 64 additions & 0 deletions tables_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,3 +204,67 @@ func Test_Table_AddEmptyIndex(t *testing.T) {
r.NoError(table.Column("email", "string", nil))
r.Error(table.Index([]string{}, nil))
}

func Test_Table_AddPrimaryKey(t *testing.T) {
r := require.New(t)

// Add single primary key
expected :=
`create_table("users") {
t.Column("id", "int", {primary: true})
t.Column("name", "string")
t.Column("email", "string")
t.Timestamps()
}`
table := fizz.NewTable("users", nil)
r.NoError(table.Column("id", "int", fizz.Options{"primary": true}))
r.NoError(table.Column("name", "string", nil))
r.NoError(table.Column("email", "string", nil))
r.Equal(expected, table.String())

table = fizz.NewTable("users", nil)
r.NoError(table.Column("id", "int", nil))
r.NoError(table.Column("name", "string", nil))
r.NoError(table.Column("email", "string", nil))
r.NoError(table.PrimaryKey("id"))
r.Equal(expected, table.String())

// Add composite primary key
expected =
`create_table("user_privileges") {
t.Column("user_id", "int")
t.Column("privilege_id", "int")
t.Timestamps()
t.PrimaryKey("user_id", "privilege_id")
}`
table = fizz.NewTable("user_privileges", nil)
r.NoError(table.Column("user_id", "int", nil))
r.NoError(table.Column("privilege_id", "int", nil))
r.NoError(table.PrimaryKey("user_id", "privilege_id"))
r.Equal(expected, table.String())
}

func Test_Table_AddPrimaryKey_Errors(t *testing.T) {
r := require.New(t)

// Primary key on unknown column
table := fizz.NewTable("users", nil)
r.NoError(table.Column("id", "int", nil))
r.Error(table.PrimaryKey("id2"))

// Duplicate primary key
table = fizz.NewTable("users", nil)
r.NoError(table.Column("id", "int", nil))
r.NoError(table.PrimaryKey("id"))
r.Error(table.PrimaryKey("id"))

// Duplicate primary key
table = fizz.NewTable("users", nil)
r.NoError(table.Column("id", "int", fizz.Options{"primary": true}))
r.Error(table.PrimaryKey("id"))

// Duplicate inline primary key
table = fizz.NewTable("users", nil)
r.NoError(table.Column("id", "int", fizz.Options{"primary": true}))
r.Error(table.Column("id2", "int", fizz.Options{"primary": true}))
}
9 changes: 9 additions & 0 deletions translators/cockroach.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,15 @@ func (p *Cockroach) CreateTable(t fizz.Table) (string, error) {
cols = append(cols, p.buildForeignKey(t, fk, true))
}

primaryKeys := t.PrimaryKeys()
if len(primaryKeys) > 1 {
pks := make([]string, len(primaryKeys))
for i, pk := range primaryKeys {
pks[i] = fmt.Sprintf("\"%s\"", pk)
}
cols = append(cols, fmt.Sprintf("PRIMARY KEY(%s)", strings.Join(pks, ", ")))
}

s = fmt.Sprintf("CREATE TABLE \"%s\" (\n%s\n);COMMIT TRANSACTION;BEGIN TRANSACTION;", t.Name, strings.Join(cols, ",\n"))
sql = append(sql, s)

Expand Down
20 changes: 20 additions & 0 deletions translators/cockroach_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,26 @@ CONSTRAINT profiles_users_id_fk FOREIGN KEY (user_id) REFERENCES users (id)
r.Equal(ddl, res)
}

func (p *CockroachSuite) Test_Cockroach_CreateTables_WithCompositePrimaryKey() {
r := p.Require()
ddl := `CREATE TABLE "user_profiles" (
"user_id" INT NOT NULL,
"profile_id" INT NOT NULL,
"created_at" timestamp NOT NULL,
"updated_at" timestamp NOT NULL,
PRIMARY KEY("user_id", "profile_id")
);COMMIT TRANSACTION;BEGIN TRANSACTION;`

res, _ := fizz.AString(`
create_table("user_profiles") {
t.Column("user_id", "INT")
t.Column("profile_id", "INT")
t.PrimaryKey("user_id", "profile_id")
}
`, p.crdbt())
r.Equal(ddl, res)
}

func (p *CockroachSuite) Test_Cockroach_DropTable() {
r := p.Require()

Expand Down
14 changes: 12 additions & 2 deletions translators/mssqlserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@ import (
"github.com/pkg/errors"
)

type MsSqlServer struct {
}
// MsSqlServer is a MS SqlServer-specific translator.
type MsSqlServer struct{}

// NewMsSqlServer constructs a new MsSqlServer translator.
func NewMsSqlServer() *MsSqlServer {
return &MsSqlServer{}
}
Expand All @@ -31,6 +32,15 @@ func (p *MsSqlServer) CreateTable(t fizz.Table) (string, error) {
cols = append(cols, s)
}

primaryKeys := t.PrimaryKeys()
if len(primaryKeys) > 1 {
pks := make([]string, len(primaryKeys))
for i, pk := range primaryKeys {
pks[i] = fmt.Sprintf("[%s]", pk)
}
cols = append(cols, fmt.Sprintf("PRIMARY KEY(%s)", strings.Join(pks, ", ")))
}

s = fmt.Sprintf("CREATE TABLE %s (\n%s\n);", t.Name, strings.Join(cols, ",\n"))
sql = append(sql, s)

Expand Down
20 changes: 20 additions & 0 deletions translators/mssqlserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,26 @@ ALTER TABLE profiles ADD CONSTRAINT profiles_users_id_fk FOREIGN KEY (user_id) R
r.Equal(ddl, res)
}

func (p *MsSqlServerSQLSuite) Test_MsSqlServer_CreateTables_WithCompositePrimaryKey() {
r := p.Require()
ddl := `CREATE TABLE user_profiles (
user_id INT NOT NULL,
profile_id INT NOT NULL,
created_at DATETIME NOT NULL,
updated_at DATETIME NOT NULL
PRIMARY KEY([user_id], [profile_id])
);`

res, _ := fizz.AString(`
create_table("user_profiles") {
t.Column("user_id", "INT")
t.Column("profile_id", "INT")
t.PrimaryKey("user_id", "profile_id")
}
`, sqlsrv)
r.Equal(ddl, res)
}

func (p *MsSqlServerSQLSuite) Test_MsSqlServer_DropTable() {
r := p.Require()

Expand Down
9 changes: 9 additions & 0 deletions translators/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,15 @@ func (p *MySQL) CreateTable(t fizz.Table) (string, error) {
cols = append(cols, p.buildForeignKey(t, fk, true))
}

primaryKeys := t.PrimaryKeys()
if len(primaryKeys) > 1 {
pks := make([]string, len(primaryKeys))
for i, pk := range primaryKeys {
pks[i] = fmt.Sprintf("`%s`", pk)
}
cols = append(cols, fmt.Sprintf("PRIMARY KEY(%s)", strings.Join(pks, ", ")))
}

s := fmt.Sprintf("CREATE TABLE %s (\n%s\n) ENGINE=InnoDB;", p.escapeIdentifier(t.Name), strings.Join(cols, ",\n"))
sql = append(sql, s)

Expand Down
20 changes: 20 additions & 0 deletions translators/mysql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,26 @@ FOREIGN KEY (` + "`user_id`" + `) REFERENCES ` + "`users`" + ` (` + "`id`" + `)
r.Equal(ddl, res)
}

func (p *MySQLSuite) Test_MySQL_CreateTables_WithCompositePrimaryKey() {
r := p.Require()
ddl := `CREATE TABLE ` + "`user_profiles`" + ` (
` + "`user_id`" + ` INTEGER NOT NULL,
` + "`profile_id`" + ` INTEGER NOT NULL,
` + "`created_at`" + ` DATETIME NOT NULL,
` + "`updated_at`" + ` DATETIME NOT NULL,
PRIMARY KEY(` + "`user_id`" + `, ` + "`profile_id`" + `)
) ENGINE=InnoDB;`

res, _ := fizz.AString(`
create_table("user_profiles") {
t.Column("user_id", "INT")
t.Column("profile_id", "INT")
t.PrimaryKey("user_id", "profile_id")
}
`, myt)
r.Equal(ddl, res)
}

func (p *MySQLSuite) Test_MySQL_DropTable() {
r := p.Require()

Expand Down
9 changes: 9 additions & 0 deletions translators/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,15 @@ func (p *Postgres) CreateTable(t fizz.Table) (string, error) {
cols = append(cols, p.buildForeignKey(t, fk, true))
}

primaryKeys := t.PrimaryKeys()
if len(primaryKeys) > 1 {
pks := make([]string, len(primaryKeys))
for i, pk := range primaryKeys {
pks[i] = fmt.Sprintf("\"%s\"", pk)
}
cols = append(cols, fmt.Sprintf("PRIMARY KEY(%s)", strings.Join(pks, ", ")))
}

s = fmt.Sprintf("CREATE TABLE \"%s\" (\n%s\n);", t.Name, strings.Join(cols, ",\n"))
sql = append(sql, s)

Expand Down
22 changes: 21 additions & 1 deletion translators/postgres_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ PRIMARY KEY("uuid")
r.Equal(ddl, res)
}

func (p *PostgreSQLSuite) Test_Postgre_CreateTables_WithForeignKeys() {
func (p *PostgreSQLSuite) Test_Postgres_CreateTables_WithForeignKeys() {
r := p.Require()
ddl := `CREATE TABLE "users" (
"id" SERIAL NOT NULL,
Expand Down Expand Up @@ -153,6 +153,26 @@ FOREIGN KEY (user_id) REFERENCES users (id)
r.Equal(ddl, res)
}

func (p *PostgreSQLSuite) Test_Postgres_CreateTables_WithCompositePrimaryKey() {
r := p.Require()
ddl := `CREATE TABLE "user_profiles" (
"user_id" INT NOT NULL,
"profile_id" INT NOT NULL,
"created_at" timestamp NOT NULL,
"updated_at" timestamp NOT NULL,
PRIMARY KEY("user_id", "profile_id")
);`

res, _ := fizz.AString(`
create_table("user_profiles") {
t.Column("user_id", "INT")
t.Column("profile_id", "INT")
t.PrimaryKey("user_id", "profile_id")
}
`, pgt)
r.Equal(ddl, res)
}

func (p *PostgreSQLSuite) Test_Postgres_DropTable() {
r := p.Require()

Expand Down
9 changes: 9 additions & 0 deletions translators/sqlite.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,15 @@ func (p *SQLite) CreateTable(t fizz.Table) (string, error) {
cols = append(cols, p.buildForeignKey(t, fk, true))
}

primaryKeys := t.PrimaryKeys()
if len(primaryKeys) > 1 {
pks := make([]string, len(primaryKeys))
for i, pk := range primaryKeys {
pks[i] = fmt.Sprintf("\"%s\"", pk)
}
cols = append(cols, fmt.Sprintf("PRIMARY KEY(%s)", strings.Join(pks, ", ")))
}

s = fmt.Sprintf("CREATE TABLE \"%s\" (\n%s\n);", t.Name, strings.Join(cols, ",\n"))
sql = append(sql, s)

Expand Down
Loading

0 comments on commit 42b3487

Please sign in to comment.