Skip to content

Commit

Permalink
Merge pull request #802 from gobuffalo/add-eager-after-hook
Browse files Browse the repository at this point in the history
Add `AfterEagerFind()`
  • Loading branch information
sio4 authored Jan 13, 2023
2 parents 5dfc37d + 3596719 commit e2fe45b
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 16 deletions.
22 changes: 19 additions & 3 deletions callbacks.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,24 @@ type AfterFindable interface {
AfterFind(*Connection) error
}

func (m *Model) afterFind(c *Connection) error {
if x, ok := m.Value.(AfterFindable); ok {
// AfterEagerFindable callback will be called after a record, or records,
// has been retrieved from the database and their associations have been
// eagerly loaded.
type AfterEagerFindable interface {
AfterEagerFind(*Connection) error
}

func (m *Model) afterFind(c *Connection, eager bool) error {
if x, ok := m.Value.(AfterFindable); ok && !eager {
if err := x.AfterFind(c); err != nil {
return err
}
}
if x, ok := m.Value.(AfterEagerFindable); ok && eager {
if err := x.AfterEagerFind(c); err != nil {
return err
}
}

// if the "model" is a slice/array we want
// to loop through each of the elements in the collection
Expand All @@ -34,9 +46,13 @@ func (m *Model) afterFind(c *Connection) error {
wg.Go(func() error {
y := rv.Index(i)
y = y.Addr()
if x, ok := y.Interface().(AfterFindable); ok {
if x, ok := y.Interface().(AfterFindable); ok && !eager {
return x.AfterFind(c)
}

if x, ok := y.Interface().(AfterEagerFindable); ok && eager {
return x.AfterEagerFind(c)
}
return nil
})
}(i)
Expand Down
13 changes: 11 additions & 2 deletions callbacks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ func Test_Callbacks(t *testing.T) {
r.Equal("AF", user.AfterF)
r.NoError(tx.Find(user, user.ID))
r.Equal("AfterFind", user.AfterF)
r.Empty(user.AfterEF)

r.NoError(tx.Eager().Find(user, user.ID))
r.Equal("AfterEagerFind", user.AfterEF)

r.NoError(tx.Destroy(user))

Expand All @@ -70,11 +74,16 @@ func Test_Callbacks_on_Slice(t *testing.T) {

users := CallbacksUsers{}
r.NoError(tx.All(&users))

r.Len(users, 2)

for _, u := range users {
r.Equal("AfterFind", u.AfterF)
r.Empty(u.AfterEF)
}

r.NoError(tx.Eager().All(&users))
r.Len(users, 2)
for _, u := range users {
r.Equal("AfterEagerFind", u.AfterEF)
}
})
}
37 changes: 26 additions & 11 deletions finders.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,24 +66,29 @@ func (c *Connection) First(model interface{}) error {
//
// q.Where("name = ?", "mark").First(&User{})
func (q *Query) First(model interface{}) error {
var m *Model
err := q.Connection.timeFunc("First", func() error {
q.Limit(1)
m := NewModel(model, q.Connection.Context())
m = NewModel(model, q.Connection.Context())
if err := q.Connection.Dialect.SelectOne(q.Connection, m, *q); err != nil {
return err
}
return m.afterFind(q.Connection)
return m.afterFind(q.Connection, false)
})

if err != nil {
return err
}

if q.eager {
err = q.eagerAssociations(model)
err := q.eagerAssociations(model)
q.disableEager()
return err
if err != nil {
return err
}
return m.afterFind(q.Connection, true)
}

return nil
}

Expand All @@ -98,14 +103,15 @@ func (c *Connection) Last(model interface{}) error {
//
// q.Where("name = ?", "mark").Last(&User{})
func (q *Query) Last(model interface{}) error {
var m *Model
err := q.Connection.timeFunc("Last", func() error {
q.Limit(1)
q.Order("created_at DESC, id DESC")
m := NewModel(model, q.Connection.Context())
m = NewModel(model, q.Connection.Context())
if err := q.Connection.Dialect.SelectOne(q.Connection, m, *q); err != nil {
return err
}
return m.afterFind(q.Connection)
return m.afterFind(q.Connection, false)
})

if err != nil {
Expand All @@ -115,7 +121,10 @@ func (q *Query) Last(model interface{}) error {
if q.eager {
err = q.eagerAssociations(model)
q.disableEager()
return err
if err != nil {
return err
}
return m.afterFind(q.Connection, true)
}

return nil
Expand All @@ -132,17 +141,20 @@ func (c *Connection) All(models interface{}) error {
//
// q.Where("name = ?", "mark").All(&[]User{})
func (q *Query) All(models interface{}) error {
var m *Model
err := q.Connection.timeFunc("All", func() error {
m := NewModel(models, q.Connection.Context())
m = NewModel(models, q.Connection.Context())
err := q.Connection.Dialect.SelectMany(q.Connection, m, *q)
if err != nil {
return err
}

err = q.paginateModel(models)
if err != nil {
return err
}
return m.afterFind(q.Connection)

return m.afterFind(q.Connection, false)
})

if err != nil {
Expand All @@ -152,7 +164,10 @@ func (q *Query) All(models interface{}) error {
if q.eager {
err = q.eagerAssociations(models)
q.disableEager()
return err
if err != nil {
return err
}
return m.afterFind(q.Connection, true)
}

return nil
Expand Down Expand Up @@ -301,7 +316,7 @@ func (q *Query) eagerDefaultAssociations(model interface{}) error {
// Exists returns true/false if a record exists in the database that matches
// the query.
//
// q.Where("name = ?", "mark").Exists(&User{})
// q.Where("name = ?", "mark").Exists(&User{})
func (q *Query) Exists(model interface{}) (bool, error) {
tmpQuery := Q(q.Connection)
q.Clone(tmpQuery) // avoid meddling with original query
Expand Down
6 changes: 6 additions & 0 deletions pop_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,7 @@ type CallbacksUser struct {
AfterU string `db:"after_u"`
AfterD string `db:"after_d"`
AfterF string `db:"after_f"`
AfterEF string `db:"after_ef"`
CreatedAt time.Time `json:"created_at" db:"created_at"`
UpdatedAt time.Time `json:"updated_at" db:"updated_at"`
}
Expand Down Expand Up @@ -420,6 +421,11 @@ func (u *CallbacksUser) AfterFind(tx *Connection) error {
return nil
}

func (u *CallbacksUser) AfterEagerFind(tx *Connection) error {
u.AfterEF = "AfterEagerFind"
return nil
}

type Label struct {
ID string `db:"id"`
}
Expand Down
1 change: 1 addition & 0 deletions testdata/migrations/20181104135800_callbacks_users.up.fizz
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ create_table("callbacks_users") {
t.Column("after_u", "string", {})
t.Column("after_d", "string", {})
t.Column("after_f", "string", {})
t.Column("after_ef", "string", {})
t.Column("before_v", "string", {})
t.Timestamps()
}

0 comments on commit e2fe45b

Please sign in to comment.