Skip to content

Commit

Permalink
Generate DefaultCount function (#282)
Browse files Browse the repository at this point in the history
* Generate DefaultCount function

* dependent function changes

* added demo example
  • Loading branch information
Ramky-Infoblox authored Dec 10, 2024
1 parent 1938937 commit 3f7c3c4
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 0 deletions.
36 changes: 36 additions & 0 deletions example/feature_demo/demo_service.pb.gorm.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

91 changes: 91 additions & 0 deletions plugin/handlergen.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ func (p *OrmPlugin) generateDefaultHandlers(file *generator.FileDescriptor) {

p.generateApplyFieldMask(message)
p.generateListHandler(message)

typeName := p.TypeName(message)
ormable := p.getOrmable(typeName)
if p.listHasPagination(ormable) {
p.generateListCountHandler(message)
}
}
}
}
Expand Down Expand Up @@ -632,6 +638,65 @@ func (p *OrmPlugin) generateListHandler(message *generator.Descriptor) {
p.generateAfterListHookDef(ormable)
}

func (p *OrmPlugin) generateListCountHandler(message *generator.Descriptor) {
typeName := p.TypeName(message)
ormable := p.getOrmable(typeName)

p.P(`// DefaultCount`, typeName, ` executes a gorm total record count call`)
listSign := fmt.Sprint(`func DefaultCount`, typeName, `(ctx context.Context, db *`, p.Import(gormImport), `.DB`)
var f string
if p.listHasFiltering(ormable) {
listSign += fmt.Sprint(`, f `, `*`, p.Import(queryImport), `.Filtering`)
f = "f"
} else {
f = "nil"
}
listSign += fmt.Sprint(`) (int64`, `, error) {`)
p.P(listSign)
p.P(`in := `, typeName, `{}`)
p.P(`ormObj, err := in.ToORM(ctx)`)
p.P(`if err != nil {`)
p.P(`return 0, err`)
p.P(`}`)
p.generateBeforeCountHookCall(ormable, "ApplyQuery")
p.P(`db, err = `, p.Import(tkgormImport), `.ApplyCollectionOperators(ctx, db, &`, ormable.Name, `{}, &`, typeName, `{}, `, f, `,nil,nil,nil`, `)`)
p.P(`if err != nil {`)
p.P(`return 0, err`)
p.P(`}`)
p.generateAfterCountHookCall(ormable, "ApplyQuery")
p.P(`db = db.Where(&ormObj)`)
p.P(`var total int64`)
p.P(`if err = db.Model(&ormObj).Count(&total).Error; err != nil {`)
p.P(`return 0, err`)
p.P(`}`)
p.P(`return total, nil`)
p.P(`}`)
p.generateBeforeCountHookDef(ormable, "ApplyQuery")
p.generateAfterCountHookDef(ormable, "ApplyQuery")
}

func (p *OrmPlugin) generateBeforeCountHookDef(orm *OrmableType, suffix string) {
p.P(`type `, orm.Name, `WithBeforeCount`, suffix, ` interface {`)
hookSign := fmt.Sprint(`BeforeCount`, suffix, `(context.Context, *`, p.Import(gormImport), `.DB`)
if p.listHasFiltering(orm) {
hookSign += fmt.Sprint(`, *`, p.Import(queryImport), `.Filtering`)
}
hookSign += fmt.Sprint(`) (*`, p.Import(gormImport), `.DB, error)`)
p.P(hookSign)
p.P(`}`)
}

func (p *OrmPlugin) generateAfterCountHookDef(orm *OrmableType, suffix string) {
p.P(`type `, orm.Name, `WithAfterCount`, suffix, ` interface {`)
hookSign := fmt.Sprint(`AfterCount`, suffix, `(context.Context, *`, p.Import(gormImport), `.DB`)
if p.listHasFiltering(orm) {
hookSign += fmt.Sprint(`, *`, p.Import(queryImport), `.Filtering`)
}
hookSign += fmt.Sprint(`) (*`, p.Import(gormImport), `.DB, error)`)
p.P(hookSign)
p.P(`}`)
}

func (p *OrmPlugin) generateBeforeListHookDef(orm *OrmableType, suffix string) {
p.P(`type `, orm.Name, `WithBeforeList`, suffix, ` interface {`)
hookSign := fmt.Sprint(`BeforeList`, suffix, `(context.Context, *`, p.Import(gormImport), `.DB`)
Expand Down Expand Up @@ -694,6 +759,32 @@ func (p *OrmPlugin) generateBeforeListHookCall(orm *OrmableType, suffix string)
p.P(`}`)
}

func (p *OrmPlugin) generateBeforeCountHookCall(orm *OrmableType, suffix string) {
p.P(`if hook, ok := interface{}(&ormObj).(`, orm.Name, `WithBeforeCount`, suffix, `); ok {`)
hookCall := fmt.Sprint(`if db, err = hook.BeforeCount`, suffix, `(ctx, db`)
if p.listHasFiltering(orm) {
hookCall += `,f`
}
hookCall += `); err != nil {`
p.P(hookCall)
p.P(`return 0, err`)
p.P(`}`)
p.P(`}`)
}

func (p *OrmPlugin) generateAfterCountHookCall(orm *OrmableType, suffix string) {
p.P(`if hook, ok := interface{}(&ormObj).(`, orm.Name, `WithAfterCount`, suffix, `); ok {`)
hookCall := fmt.Sprint(`if db, err = hook.AfterCount`, suffix, `(ctx, db`)
if p.listHasFiltering(orm) {
hookCall += `,f`
}
hookCall += `); err != nil {`
p.P(hookCall)
p.P(`return 0, err`)
p.P(`}`)
p.P(`}`)
}

func (p *OrmPlugin) generateAfterListHookCall(orm *OrmableType) {
p.P(`if hook, ok := interface{}(&ormObj).(`, orm.Name, `WithAfterListFind); ok {`)
hookCall := fmt.Sprint(`if err = hook.AfterListFind(ctx, db, &ormResponse`)
Expand Down

0 comments on commit 3f7c3c4

Please sign in to comment.