Skip to content

Commit

Permalink
feat: mongo util pagination
Browse files Browse the repository at this point in the history
  • Loading branch information
withchao committed Nov 23, 2023
1 parent 68786ae commit f38d03e
Show file tree
Hide file tree
Showing 3 changed files with 231 additions and 12 deletions.
177 changes: 177 additions & 0 deletions mgoutil/mongo.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
package mgoutil

import (
"context"
"github.com/OpenIMSDK/tools/errs"
"github.com/OpenIMSDK/tools/pagination"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)

func basic[T any]() bool {
var t T
switch any(t).(type) {
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64, string, []byte:
return true
case *int, *int8, *int16, *int32, *int64, *uint, *uint8, *uint16, *uint32, *uint64, *float32, *float64, *string, *[]byte:
return true
default:
return false
}
}

func anes[T any](ts []T) []any {
val := make([]any, len(ts))
for i := range ts {
val[i] = ts[i]
}
return val
}

func findOptionToCountOption(opts []*options.FindOptions) *options.CountOptions {
countOpt := options.Count()
for _, opt := range opts {
if opt.Skip != nil {
countOpt.SetSkip(*opt.Skip)
}
if opt.Limit != nil {
countOpt.SetLimit(*opt.Limit)
}
}
return countOpt
}

func InsertMany[T any](ctx context.Context, coll *mongo.Collection, val []T, opts ...*options.InsertManyOptions) error {
_, err := coll.InsertMany(ctx, anes(val), opts...)
if err != nil {
return errs.Wrap(err)
}
return nil
}

func UpdateOne(ctx context.Context, coll *mongo.Collection, filter any, update any, notMatchedErr bool, opts ...*options.UpdateOptions) error {
res, err := coll.UpdateOne(ctx, filter, update, opts...)
if err != nil {
return errs.Wrap(err)
}
if notMatchedErr && res.MatchedCount == 0 {
return errs.Wrap(mongo.ErrNoDocuments)
}
return nil
}

func UpdateMany(ctx context.Context, coll *mongo.Collection, filter any, update any, opts ...*options.UpdateOptions) (*mongo.UpdateResult, error) {
res, err := coll.UpdateMany(ctx, filter, update, opts...)
if err != nil {
return nil, errs.Wrap(err)
}
return res, nil
}

func Find[T any](ctx context.Context, coll *mongo.Collection, filter any, opts ...*options.FindOptions) ([]T, error) {
cur, err := coll.Find(ctx, filter, opts...)
if err != nil {
return nil, errs.Wrap(err)
}
defer cur.Close(ctx)
var res []T
if basic[T]() {
var temp []map[string]T
if err := cur.All(ctx, &temp); err != nil {
return nil, errs.Wrap(err)
}
res = make([]T, 0, len(temp))
for _, m := range temp {
if len(m) != 1 {
return nil, errs.ErrInternalServer.Wrap("mongo find result len(m) != 1")
}
for _, t := range m {
res = append(res, t)
}
}
} else {
if err := cur.All(ctx, &res); err != nil {
return nil, errs.Wrap(err)
}
}
return res, nil
}

func FindOne[T any](ctx context.Context, coll *mongo.Collection, filter any, opts ...*options.FindOneOptions) (res T, err error) {
cur := coll.FindOne(ctx, filter, opts...)
if err := cur.Err(); err != nil {
return res, errs.Wrap(err)
}
if err := cur.Decode(&res); err != nil {
return res, errs.Wrap(err)
}
return res, nil
}

func FindPage[T any](ctx context.Context, coll *mongo.Collection, filter any, pagination pagination.Pagination, opts ...*options.FindOptions) (int64, []T, error) {
count, err := Count(ctx, coll, filter, findOptionToCountOption(opts))
if err != nil {
return 0, nil, err
}
if count == 0 || pagination == nil {
return count, nil, nil
}
skip := int64(pagination.GetPageNumber()-1) * int64(pagination.GetShowNumber())
if skip < 0 || skip >= count || pagination.GetShowNumber() <= 0 {
return count, nil, nil
}
opt := options.Find().SetSkip(skip).SetLimit(int64(pagination.GetShowNumber()))
res, err := Find[T](ctx, coll, filter, append(opts, opt)...)
if err != nil {
return 0, nil, err
}
return count, res, nil
}

func FindPageOnly[T any](ctx context.Context, coll *mongo.Collection, filter any, pagination pagination.Pagination, opts ...*options.FindOptions) ([]T, error) {
skip := int64(pagination.GetPageNumber()-1) * int64(pagination.GetShowNumber())
if skip < 0 || pagination.GetShowNumber() <= 0 {
return nil, nil
}
opt := options.Find().SetSkip(skip).SetLimit(int64(pagination.GetShowNumber()))
return Find[T](ctx, coll, filter, append(opts, opt)...)
}

func Count(ctx context.Context, coll *mongo.Collection, filter any, opts ...*options.CountOptions) (int64, error) {
return coll.CountDocuments(ctx, filter, opts...)
}

func Exist(ctx context.Context, coll *mongo.Collection, filter any, opts ...*options.CountOptions) (bool, error) {
opts = append(opts, options.Count().SetLimit(1))
count, err := Count(ctx, coll, filter, opts...)
if err != nil {
return false, err
}
return count > 0, nil
}

func DeleteOne(ctx context.Context, coll *mongo.Collection, filter any, opts ...*options.DeleteOptions) error {
if _, err := coll.DeleteOne(ctx, filter, opts...); err != nil {
return errs.Wrap(err)
}
return nil
}

func DeleteMany(ctx context.Context, coll *mongo.Collection, filter any, opts ...*options.DeleteOptions) error {
if _, err := coll.DeleteMany(ctx, filter, opts...); err != nil {
return errs.Wrap(err)
}
return nil
}

func Aggregate[T any](ctx context.Context, coll *mongo.Collection, pipeline any, opts ...*options.AggregateOptions) ([]T, error) {
cur, err := coll.Aggregate(ctx, pipeline, opts...)
if err != nil {
return nil, err
}
var ts []T
if err := cur.All(ctx, &ts); err != nil {
return nil, err
}
return ts, nil
}
6 changes: 6 additions & 0 deletions pagination/pagination.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package pagination

type Pagination interface {
GetPageNumber() int32
GetShowNumber() int32
}
60 changes: 48 additions & 12 deletions tx/mongo.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,32 +16,68 @@ package tx

import (
"context"
"go.mongodb.org/mongo-driver/bson"
"sync"

"go.mongodb.org/mongo-driver/mongo"

"github.com/OpenIMSDK/tools/utils"
)

func NewMongo(client *mongo.Client) CtxTx {
return &_Mongo{
client: client,
initialized: false,
lock: &sync.Mutex{},
client: client,
}
}

type _Mongo struct {
client *mongo.Client
initialized bool
lock sync.Locker
client *mongo.Client
tx func(func(ctx context.Context) error) error
}

func (m *_Mongo) Transaction(ctx context.Context, fn func(ctx context.Context) error) error {
sess, err := m.client.StartSession()
if err != nil {
func (m *_Mongo) init(ctx context.Context) (err error) {
m.lock.Lock()
defer func() {
if err == nil {
m.initialized = true
}
m.lock.Unlock()
}()
if m.initialized {
return nil
}
var res map[string]any
if err := m.client.Database("admin").RunCommand(ctx, bson.M{"isMaster": 1}).Decode(&res); err != nil {
return err
}
sCtx := mongo.NewSessionContext(ctx, sess)
defer sess.EndSession(sCtx)
if err := fn(sCtx); err != nil {
_ = sess.AbortTransaction(sCtx)
_, allowTx := res["setName"]
if !allowTx {
return nil
}
m.tx = func(fn func(ctx context.Context) error) error {
sess, err := m.client.StartSession()
if err != nil {
return err
}
defer sess.EndSession(ctx)
_, err = sess.WithTransaction(ctx, func(sessCtx mongo.SessionContext) (interface{}, error) {
return nil, fn(sessCtx)
})
return err
}
return utils.Wrap(sess.CommitTransaction(sCtx), "")
return nil
}

func (m *_Mongo) Transaction(ctx context.Context, fn func(ctx context.Context) error) error {
if !m.initialized {
if err := m.init(ctx); err != nil {
return err
}
}
if m.tx == nil {
return fn(ctx)
}
return m.tx(fn)
}

0 comments on commit f38d03e

Please sign in to comment.