Skip to content

Commit

Permalink
override table name (#5)
Browse files Browse the repository at this point in the history
  • Loading branch information
JunNishimura committed Mar 30, 2024
1 parent ce6d7c7 commit bc9f8b5
Showing 1 changed file with 21 additions and 18 deletions.
39 changes: 21 additions & 18 deletions adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ func connectDB(driverName, dataSourceName string) (*bun.DB, error) {
func (a *bunAdapter) createTalbe() error {
if _, err := a.db.NewCreateTable().
Model((*CasbinPolicy)(nil)).
Table(a.tableName).
ModelTableExpr(a.tableName).
IfNotExists().
Exec(context.Background()); err != nil {
return err
Expand All @@ -141,7 +141,7 @@ func (a *bunAdapter) LoadPolicy(model model.Model) error {
var policies []CasbinPolicy
err := a.db.NewSelect().
Model(&policies).
Table(a.tableName).
ModelTableExpr(a.tableName).
Scan(context.Background())
if err != nil {
return err
Expand Down Expand Up @@ -200,23 +200,26 @@ func (a *bunAdapter) savePolicyRecords(policies []CasbinPolicy) error {
// bulk insert new policies
if _, err := a.db.NewInsert().
Model(&policies).
Table(a.tableName).
ModelTableExpr(a.tableName).
Exec(context.Background()); err != nil {
return err
}

return nil
}

// delete all policy rules from the storage.
// drop and recreate the table
func (a *bunAdapter) refreshTable() error {
if _, err := a.db.NewTruncateTable().
Model((*CasbinPolicy)(nil)).
Table(a.tableName).
// just truncate the table could be a better choice
// but NewTruncateTable() does not support ModelTableExpr
// so we drop and recreate the table instead
if _, err := a.db.NewDropTable().
ModelTableExpr(a.tableName).
IfExists().
Exec(context.Background()); err != nil {
return err
}
return nil
return a.createTalbe()
}

// AddPolicy adds a policy rule to the storage.
Expand All @@ -225,7 +228,7 @@ func (a *bunAdapter) AddPolicy(sec string, ptype string, rule []string) error {
newPolicy := newCasbinPolicy(ptype, rule)
if _, err := a.db.NewInsert().
Model(&newPolicy).
Table(a.tableName).
ModelTableExpr(a.tableName).
Exec(context.Background()); err != nil {
return err
}
Expand All @@ -241,7 +244,7 @@ func (a *bunAdapter) AddPolicies(sec string, ptype string, rules [][]string) err
}
if _, err := a.db.NewInsert().
Model(&policies).
Table(a.tableName).
ModelTableExpr(a.tableName).
Exec(context.Background()); err != nil {
return err
}
Expand Down Expand Up @@ -274,7 +277,7 @@ func (a *bunAdapter) RemovePolicies(sec string, ptype string, rules [][]string)

func (a *bunAdapter) deleteRecord(existingPolicy CasbinPolicy) error {
query := a.db.NewDelete().
Table(a.tableName).
ModelTableExpr(a.tableName).
Where("ptype = ?", existingPolicy.PType)

values := existingPolicy.filterValuesWithKey()
Expand All @@ -284,7 +287,7 @@ func (a *bunAdapter) deleteRecord(existingPolicy CasbinPolicy) error {

func (a *bunAdapter) deleteRecordInTx(tx bun.Tx, existingPolicy CasbinPolicy) error {
query := tx.NewDelete().
Table(a.tableName).
ModelTableExpr(a.tableName).
Where("ptype = ?", existingPolicy.PType)

values := existingPolicy.filterValuesWithKey()
Expand Down Expand Up @@ -317,7 +320,7 @@ func (a *bunAdapter) RemoveFilteredPolicy(sec string, ptype string, fieldIndex i

func (a *bunAdapter) deleteFilteredPolicy(ptype string, fieldIndex int, fieldValues ...string) error {
query := a.db.NewDelete().
Table(a.tableName).
ModelTableExpr(a.tableName).
Where("ptype = ?", ptype)

// Note that empty string in fieldValues could be any word.
Expand Down Expand Up @@ -388,7 +391,7 @@ func (a *bunAdapter) UpdatePolicy(sec string, ptype string, oldRule, newRule []s
func (a *bunAdapter) updateRecord(oldPolicy, newPolicy CasbinPolicy) error {
query := a.db.NewUpdate().
Model(&newPolicy).
Table(a.tableName).
ModelTableExpr(a.tableName).
Where("ptype = ?", oldPolicy.PType)

values := oldPolicy.filterValuesWithKey()
Expand All @@ -399,7 +402,7 @@ func (a *bunAdapter) updateRecord(oldPolicy, newPolicy CasbinPolicy) error {
func (a *bunAdapter) updateRecordInTx(tx bun.Tx, oldPolicy, newPolicy CasbinPolicy) error {
query := tx.NewUpdate().
Model(&newPolicy).
Table(a.tableName).
ModelTableExpr(a.tableName).
Where("ptype = ?", oldPolicy.PType)

values := oldPolicy.filterValuesWithKey()
Expand Down Expand Up @@ -453,10 +456,10 @@ func (a *bunAdapter) UpdateFilteredPolicies(sec string, ptype string, newRules [
}

selectQuery := tx.NewSelect().
Table(a.tableName).
ModelTableExpr(a.tableName).
Where("ptype = ?", ptype)
deleteQuery := tx.NewDelete().
Table(a.tableName).
ModelTableExpr(a.tableName).
Where("ptype = ?", ptype)

// Note that empty string in fieldValues could be any word.
Expand Down Expand Up @@ -541,7 +544,7 @@ func (a *bunAdapter) UpdateFilteredPolicies(sec string, ptype string, newRules [
// create new policies
if _, err := tx.NewInsert().
Model(&newPolicies).
Table(a.tableName).
ModelTableExpr(a.tableName).
Exec(context.Background()); err != nil {
if err := tx.Rollback(); err != nil {
return nil, err
Expand Down

0 comments on commit bc9f8b5

Please sign in to comment.