diff --git a/gin/admin/logic.go b/gin/admin/logic.go index 83454ca..f82d468 100644 --- a/gin/admin/logic.go +++ b/gin/admin/logic.go @@ -67,6 +67,9 @@ func FindByUserName(scopes ...func(db *gorm.DB) *gorm.DB) (*Response, error) { } func FindPasswordByUserName(db *gorm.DB, username string, scopes ...func(db *gorm.DB) *gorm.DB) (*LoginResponse, error) { + if db == nil { + return nil, gorm.ErrInvalidDB + } admin := &LoginResponse{} db = db.Model(&Admin{}).Select("id,password").Where("username = ?", username) @@ -169,6 +172,9 @@ func AddRoleForUser(admin *Admin) error { } func UpdateAvatar(db *gorm.DB, id uint, avatar string) error { + if db == nil { + return gorm.ErrInvalidDB + } err := db.Model(&Admin{}).Where("id = ?", id).Update("header_img", avatar).Error if err != nil { zap_server.ZAPLOG.Error(err.Error()) diff --git a/gin/admin/model.go b/gin/admin/model.go index 9fe7d70..8c13f64 100644 --- a/gin/admin/model.go +++ b/gin/admin/model.go @@ -28,6 +28,9 @@ type Avatar struct { // Create 添加 func (item *Admin) Create(db *gorm.DB) (uint, error) { + if db == nil { + return 0, gorm.ErrInvalidDB + } err := db.Model(item).Create(item).Error if err != nil { zap_server.ZAPLOG.Error(err.Error()) @@ -53,6 +56,9 @@ func (item *Admin) mc() map[string]interface{} { // Update 更新 func (item *Admin) Update(db *gorm.DB, scopes ...func(db *gorm.DB) *gorm.DB) error { + if db == nil { + return gorm.ErrInvalidDB + } err := db.Model(item).Scopes(scopes...).Updates(item.mc()).Error if err != nil { zap_server.ZAPLOG.Error(err.Error()) @@ -63,6 +69,9 @@ func (item *Admin) Update(db *gorm.DB, scopes ...func(db *gorm.DB) *gorm.DB) err // Delete 删除 func (item *Admin) Delete(db *gorm.DB, scopes ...func(db *gorm.DB) *gorm.DB) error { + if db == nil { + return gorm.ErrInvalidDB + } err := db.Model(item).Unscoped().Scopes(scopes...).Delete(item).Error if err != nil { zap_server.ZAPLOG.Error(err.Error()) diff --git a/gin/admin/response.go b/gin/admin/response.go index 86ba5da..ca16f8b 100644 --- a/gin/admin/response.go +++ b/gin/admin/response.go @@ -38,6 +38,9 @@ type LoginResponse struct { } func (res *Response) First(db *gorm.DB, scopes ...func(db *gorm.DB) *gorm.DB) error { + if db == nil { + return gorm.ErrInvalidDB + } err := db.Model(&Admin{}).Scopes(scopes...).First(res).Error if err != nil { zap_server.ZAPLOG.Error(err.Error()) @@ -54,6 +57,9 @@ type PageResponse struct { } func (res *PageResponse) Paginate(db *gorm.DB, pageScope func(db *gorm.DB) *gorm.DB, scopes ...func(db *gorm.DB) *gorm.DB) (int64, error) { + if db == nil { + return 0, gorm.ErrInvalidDB + } db = db.Model(&Admin{}) var count int64 if len(scopes) > 0 { @@ -76,6 +82,9 @@ func (res *PageResponse) Paginate(db *gorm.DB, pageScope func(db *gorm.DB) *gorm } func (res *PageResponse) Find(db *gorm.DB, scopes ...func(db *gorm.DB) *gorm.DB) error { + if db == nil { + return gorm.ErrInvalidDB + } db = db.Model(&Admin{}) err := db.Scopes(scopes...).Find(&res.Item).Error if err != nil { diff --git a/gin/api/logic.go b/gin/api/logic.go index c45c0ec..011d583 100644 --- a/gin/api/logic.go +++ b/gin/api/logic.go @@ -11,6 +11,9 @@ import ( // CreatenInBatches 批量加入 func CreatenInBatches(db *gorm.DB, apis ApiCollection) error { + if db == nil { + return gorm.ErrInvalidDB + } err := db.Model(&Api{}).CreateInBatches(&apis, 500).Error if err != nil { zap_server.ZAPLOG.Error(err.Error()) diff --git a/gin/api/model.go b/gin/api/model.go index 0504796..e912547 100644 --- a/gin/api/model.go +++ b/gin/api/model.go @@ -38,6 +38,9 @@ func (item *Api) mc() map[string]interface{} { // Create 添加 func (item *Api) Create(db *gorm.DB) (uint, error) { + if db == nil { + return 0, gorm.ErrInvalidDB + } err := db.Model(item).Create(item).Error if err != nil { zap_server.ZAPLOG.Error(err.Error()) @@ -48,6 +51,9 @@ func (item *Api) Create(db *gorm.DB) (uint, error) { // Update 更新 func (item *Api) Update(db *gorm.DB, scopes ...func(db *gorm.DB) *gorm.DB) error { + if db == nil { + return gorm.ErrInvalidDB + } err := db.Model(item).Scopes(scopes...).Updates(item.mc()).Error if err != nil { zap_server.ZAPLOG.Error(err.Error()) @@ -58,6 +64,9 @@ func (item *Api) Update(db *gorm.DB, scopes ...func(db *gorm.DB) *gorm.DB) error // Delete 删除 func (item *Api) Delete(db *gorm.DB, scopes ...func(db *gorm.DB) *gorm.DB) error { + if db == nil { + return gorm.ErrInvalidDB + } err := db.Model(item).Unscoped().Scopes(scopes...).Delete(item).Error if err != nil { zap_server.ZAPLOG.Error(err.Error()) diff --git a/gin/api/response.go b/gin/api/response.go index a01eb58..710608e 100644 --- a/gin/api/response.go +++ b/gin/api/response.go @@ -15,6 +15,9 @@ type Response struct { } func (res *Response) First(db *gorm.DB, scopes ...func(db *gorm.DB) *gorm.DB) error { + if db == nil { + return gorm.ErrInvalidDB + } err := db.Model(&Api{}).Scopes(scopes...).First(res).Error if err != nil { zap_server.ZAPLOG.Error(err.Error()) @@ -29,6 +32,9 @@ type PageResponse struct { } func (res *PageResponse) Paginate(db *gorm.DB, pageScope func(db *gorm.DB) *gorm.DB, scopes ...func(db *gorm.DB) *gorm.DB) (int64, error) { + if db == nil { + return 0, gorm.ErrInvalidDB + } db = db.Model(&Api{}) var count int64 if len(scopes) > 0 { @@ -49,6 +55,9 @@ func (res *PageResponse) Paginate(db *gorm.DB, pageScope func(db *gorm.DB) *gorm } func (res *PageResponse) Find(db *gorm.DB, scopes ...func(db *gorm.DB) *gorm.DB) error { + if db == nil { + return gorm.ErrInvalidDB + } db = db.Model(&Api{}) err := db.Scopes(scopes...).Find(&res.Item).Error if err != nil { diff --git a/gin/authority/model.go b/gin/authority/model.go index bfb4fb2..4ae030b 100644 --- a/gin/authority/model.go +++ b/gin/authority/model.go @@ -34,6 +34,9 @@ func (item *Authority) mc() map[string]interface{} { // Create 添加 func (item *Authority) Create(db *gorm.DB) (uint, error) { + if db == nil { + return 0, gorm.ErrInvalidDB + } err := db.Model(item).Create(item).Error if err != nil { zap_server.ZAPLOG.Error(err.Error()) @@ -44,6 +47,9 @@ func (item *Authority) Create(db *gorm.DB) (uint, error) { // Update 更新 func (item *Authority) Update(db *gorm.DB, scopes ...func(db *gorm.DB) *gorm.DB) error { + if db == nil { + return gorm.ErrInvalidDB + } err := db.Model(item).Scopes(scopes...).Updates(item.mc()).Error if err != nil { zap_server.ZAPLOG.Error(err.Error()) @@ -54,6 +60,9 @@ func (item *Authority) Update(db *gorm.DB, scopes ...func(db *gorm.DB) *gorm.DB) // Delete 删除 func (item *Authority) Delete(db *gorm.DB, scopes ...func(db *gorm.DB) *gorm.DB) error { + if db == nil { + return gorm.ErrInvalidDB + } err := db.Model(item).Unscoped().Scopes(scopes...).Delete(item).Error if err != nil { zap_server.ZAPLOG.Error(err.Error()) diff --git a/gin/authority/response.go b/gin/authority/response.go index 45b62f7..2054ee5 100644 --- a/gin/authority/response.go +++ b/gin/authority/response.go @@ -17,6 +17,9 @@ type Response struct { } func (res *Response) First(db *gorm.DB, scopes ...func(db *gorm.DB) *gorm.DB) error { + if db == nil { + return gorm.ErrInvalidDB + } err := db.Model(&Authority{}).Scopes(scopes...).First(res).Error if err != nil { zap_server.ZAPLOG.Error(err.Error()) @@ -32,6 +35,9 @@ type PageResponse struct { } func (res *PageResponse) Paginate(db *gorm.DB, pageScope func(db *gorm.DB) *gorm.DB, scopes ...func(db *gorm.DB) *gorm.DB) (int64, error) { + if db == nil { + return 0, gorm.ErrInvalidDB + } db = db.Model(&Authority{}) var count int64 if len(scopes) > 0 { @@ -61,6 +67,9 @@ func (res *PageResponse) Paginate(db *gorm.DB, pageScope func(db *gorm.DB) *gorm } func (res *PageResponse) Find(db *gorm.DB, scopes ...func(db *gorm.DB) *gorm.DB) error { + if db == nil { + return gorm.ErrInvalidDB + } db = db.Model(&Authority{}) err := db.Scopes(scopes...).Find(&res.Item).Error if err != nil { @@ -83,7 +92,11 @@ func (res *PageResponse) Find(db *gorm.DB, scopes ...func(db *gorm.DB) *gorm.DB) // findChildrenAuthority func findChildrenAuthority(item *Response) error { - err := database.Instance().Where("parent_id = ?", item.Id).Find(&item.Children).Error + db := database.Instance() + if db == nil { + return gorm.ErrInvalidDB + } + err := db.Where("parent_id = ?", item.Id).Find(&item.Children).Error if len(item.Children) > 0 { for k := range item.Children { err = findChildrenAuthority(&item.Children[k]) @@ -95,7 +108,11 @@ func findChildrenAuthority(item *Response) error { // getPermsForRoleMap func getPermsForRoleMap(uuid string) []map[string]string { apisForRoles := []map[string]string{} - perms := casbin.Instance().GetPermissionsForUser(uuid) + ca := casbin.Instance() + if ca == nil { + return nil + } + perms := ca.GetPermissionsForUser(uuid) for _, perm := range perms { if len(perm) < 3 { continue diff --git a/gin/oplog/response.go b/gin/oplog/response.go index cacf460..065ed3f 100644 --- a/gin/oplog/response.go +++ b/gin/oplog/response.go @@ -11,6 +11,9 @@ type Response struct { } func (res *Response) First(db *gorm.DB, scopes ...func(db *gorm.DB) *gorm.DB) error { + if db == nil { + return gorm.ErrInvalidDB + } err := db.Model(&operation.Oplog{}).Scopes(scopes...).First(res).Error if err != nil { zap_server.ZAPLOG.Error(err.Error()) @@ -25,6 +28,9 @@ type PageResponse struct { } func (res *PageResponse) Paginate(db *gorm.DB, pageScope func(db *gorm.DB) *gorm.DB, scopes ...func(db *gorm.DB) *gorm.DB) (int64, error) { + if db == nil { + return 0, gorm.ErrInvalidDB + } db = db.Model(&operation.Oplog{}) var count int64 err := db.Scopes(scopes...).Count(&count).Error @@ -42,6 +48,9 @@ func (res *PageResponse) Paginate(db *gorm.DB, pageScope func(db *gorm.DB) *gorm } func (res *PageResponse) Find(db *gorm.DB, scopes ...func(db *gorm.DB) *gorm.DB) error { + if db == nil { + return gorm.ErrInvalidDB + } db = db.Model(&operation.Oplog{}) err := db.Scopes(scopes...).Find(&res.Item).Error if err != nil { diff --git a/go.json b/go.json index b00b21e..1f77fa0 100644 --- a/go.json +++ b/go.json @@ -64,9 +64,12 @@ "import (", " \"os\"", " \"testing\"", + " \"strconv\"", " \"github.com/snowlyg/httptest\"", " \"github.com/snowlyg/iris-admin/server/web/common\"", + " \"github.com/snowlyg/helper/str\"", " \"github.com/snowlyg/iris-admin/server/web/web_gin\"", + " \"github.com/snowlyg/iris-admin/server/cache\"", ")", "var TestServer *web_gin.WebServer", "var TestClient *httptest.Client", @@ -78,9 +81,21 @@ " $4", " code := m.Run()", " common.AfterTestMain(uuid, true)", + " cache.Remove()", " $5", " os.Exit(code)", - "}" + "}", + "func Create(client *httptest.Client, data map[string]interface{}) uint {", + " pageKeys := httptest.IdKeys()", + " client.POST(str.Join(Uri, \"create\"), httptest.NewResponses(http.StatusOK, response.ResponseOkMessage, pageKeys), httptest.NewWithJsonParamFunc(data))", + " return pageKeys.GetId()", + "}", + "func Delete(client *httptest.Client, id uint) {", + " client.DELETE(str.Join(Uri, \"delete/\", strconv.FormatUint(uint64(id), 10)), httptest.SuccessResponse)", + "}", + "func Detail(client *httptest.Client, id uint, detail httptest.Responses) {", + " client.GET(str.Join(Uri, \"detail/\", strconv.FormatUint(uint64(id), 10)), httptest.NewResponses(http.StatusOK, response.ResponseOkMessage, detail))", + "}", ], "description": "Print iris-admin test main" }, @@ -122,9 +137,25 @@ ], "description": "Print iris-admin test func" }, + "Print iris-admin test create func": { + "prefix": "iris-admintcf", + "body": [ + "data := map[string]interface{}{", + " $1", + "}", + "id := Create(TestClient, data)", + "if id == 0 {", + " t.Error(\"create is failed\")", + " return", + "}", + "defer Delete(TestClient, id)", + ], + "description": "Print iris-admin test create func" + }, "Print iris-admin controller resource function": { "prefix": "iris-admincrf", "body": [ + "package ${TM_DIRECTORY/^.+[\\/\\\\]+(.*)$/$1/}", "// List list", "func List(ctx *gin.Context) {", " req := &ReqPaginate{}", @@ -159,7 +190,7 @@ " return", " }", " item := new($1)", - " item.Base$1 = req.Base$1", + " item.Base$1 = data.Base$1", " err := item.Update(database.Instance(), scope.IdScope(req.Id))", " if err != nil {", " response.FailWithMessage(err.Error(), ctx)", @@ -229,18 +260,20 @@ "description": "Print iris-admin scope" }, "Print iris-admin route group": { - "prefix": "iris-adming", + "prefix": "iris-adminrg", "body": [ + "package ${TM_DIRECTORY/^.+[\\/\\\\]+(.*)$/$1/}", "import (", " \"github.com/gin-gonic/gin\"", " \"github.com/snowlyg/iris-admin-rbac/gin/middleware\"", ")", "func Group(app *gin.RouterGroup) {", - " router := app.Group(\"$1\", middleware.Auth(), middleware.CasbinHandler(), middleware.OperationRecord())", + " router := app.Group(\"${TM_DIRECTORY/^.+[\\/\\\\]+(.*)$/$1/}\", middleware.Auth(), middleware.CasbinHandler(), middleware.OperationRecord())", " {", " router.GET(\"/list\", List)", " router.POST(\"/create\", Create)", " router.PUT(\"/update/:id\", Update)", + " router.GET(\"/detail/:id\", Detail)", " router.DELETE(\"/delete/:id\", Delete)", " }", "}", @@ -248,31 +281,38 @@ "description": "Print iris-admin route group" }, "Print iris-admin migrate": { - "prefix": "iris-adminmi", + "prefix": "iris-adminmig", "body": [ + "package ${TM_DIRECTORY/^.+[\\/\\\\]+(.*)$/$1/}", + "import (", + " \"github.com/go-gormigrate/gormigrate/v2\"", + " \"gorm.io/gorm\"", + ")", "func GetMigration() *gormigrate.Migration {", " return &gormigrate.Migration{", - " // 20211215120700_create_xxxxs_table", - " ID: \"$CURRENT_YEAR$CURRENT_MONTH$CURRENT_DATE$CURRENT_HOUR$CURRENT_MINUTE$CURRENT_SECOND$1\",", + " // 20230314170919_create_xxxx_table", + " ID: \"$CURRENT_YEAR$CURRENT_MONTH$CURRENT_DATE$CURRENT_HOUR$CURRENT_MINUTE$CURRENT_SECOND$1_table\",", " Migrate: func(tx *gorm.DB) error {", - " return tx.AutoMigrate(&$2{})", + " return tx.AutoMigrate(&$3{})", " },", " Rollback: func(tx *gorm.DB) error {", - " return tx.Migrator().DropTable(\"$3\")", + " return tx.Migrator().DropTable(\"$2\")", " },", " }", "}" ], - "description": "Print iris-admin route group" + "description": "Print iris-admin route migrate" }, "Print iris-admin model": { "prefix": "iris-adminm", "body": [ + "package ${TM_DIRECTORY/^.+[\\/\\\\]+(.*)$/$1/}", "import (", " \"time\"", " \"gorm.io/gorm\"", " \"github.com/snowlyg/iris-admin/server/zap_server\"", ")", + "// gorm:\"column:xxxx_xxxx\" json:\"xxxx_xxxx\" form:\"xxxx_xxxx\" uri:\"xxxx_xxxx\" param:\"xxxx_xxxx\"", "type $1 struct {", " gorm.Model", " Base$1", @@ -280,8 +320,16 @@ "type Base$1 struct {", " $2", "}", + "func (item *$1) mc() map[string]interface{} {", + " return map[string]interface{}{", + " $2", + " }", + "}", "// Create create", "func (item *$1) Create(db *gorm.DB) (uint, error) {", + " if db == nil {", + " return 0, gorm.ErrInvalidDB", + " }", " err := db.Model(&$1{}).Create(item).Error", " if err != nil {", " zap_server.ZAPLOG.Error(err.Error())", @@ -291,10 +339,10 @@ "}", "// Update update", "func (item *$1) Update(db *gorm.DB, scopes ...func(db *gorm.DB) *gorm.DB) error {", - " data := map[string]interface{}{", - " $3", - " }", - " err := db.Model(&$1{}).Scopes(scopes...).Updates(data).Error", + " if db == nil {", + " return gorm.ErrInvalidDB", + " }", + " err := db.Model(&$1{}).Scopes(scopes...).Updates(item.mc()).Error", " if err != nil {", " zap_server.ZAPLOG.Error(err.Error())", " return err", @@ -303,6 +351,9 @@ "}", "// Delete delete", "func (item *$1) Delete(db *gorm.DB, scopes ...func(db *gorm.DB) *gorm.DB) error {", + " if db == nil {", + " return gorm.ErrInvalidDB", + " }", " err := db.Unscoped().Scopes(scopes...).Delete(item).Error", " if err != nil {", " zap_server.ZAPLOG.Error(err.Error())", @@ -316,6 +367,7 @@ "Print iris-admin request": { "prefix": "iris-adminreq", "body": [ + "package ${TM_DIRECTORY/^.+[\\/\\\\]+(.*)$/$1/}", "import (", " \"github.com/gin-gonic/gin\"", " \"github.com/snowlyg/iris-admin/server/database/orm\"", @@ -341,6 +393,7 @@ "Print iris-admin response": { "prefix": "iris-adminres", "body": [ + "package ${TM_DIRECTORY/^.+[\\/\\\\]+(.*)$/$1/}", "import (", " \"github.com/gin-gonic/gin\"", " \"github.com/snowlyg/iris-admin/server/database/orm\"", @@ -351,6 +404,9 @@ " Base$1", "}", "func (res *Response) First(db *gorm.DB, scopes ...func(db *gorm.DB) *gorm.DB) error {", + " if db == nil {", + " return gorm.ErrInvalidDB", + " }", " db = db.Model(&$1{})", " if len(scopes) > 0 {", " db.Scopes(scopes...)", @@ -367,6 +423,9 @@ " Item []*Response", "}", "func (res *PageResponse) Paginate(db *gorm.DB, pageScope func(db *gorm.DB) *gorm.DB, scopes ...func(db *gorm.DB) *gorm.DB) (int64, error) {", + " if db == nil {", + " return 0,gorm.ErrInvalidDB", + " }", " db = db.Model(&$1{})", " if len(scopes) > 0 {", " db.Scopes(scopes...)", @@ -386,6 +445,9 @@ " return count, nil", "}", "func (res *PageResponse) Find(db *gorm.DB, scopes ...func(db *gorm.DB) *gorm.DB) error {", + " if db == nil {", + " return gorm.ErrInvalidDB", + " }", " db = db.Model(&$1{})", " if len(scopes) > 0 {", " db.Scopes(scopes...)", diff --git a/iris/oplog/response.go b/iris/oplog/response.go index 9e268d4..3cf14ec 100644 --- a/iris/oplog/response.go +++ b/iris/oplog/response.go @@ -11,6 +11,9 @@ type Response struct { } func (res *Response) First(db *gorm.DB, scopes ...func(db *gorm.DB) *gorm.DB) error { + if db == nil { + return gorm.ErrInvalidDB + } err := db.Model(&operation.Oplog{}).Scopes(scopes...).First(res).Error if err != nil { zap_server.ZAPLOG.Error(err.Error()) @@ -25,6 +28,9 @@ type PageResponse struct { } func (res *PageResponse) Paginate(db *gorm.DB, pageScope func(db *gorm.DB) *gorm.DB, scopes ...func(db *gorm.DB) *gorm.DB) (int64, error) { + if db == nil { + return 0, gorm.ErrInvalidDB + } db = db.Model(&operation.Oplog{}) var count int64 err := db.Scopes(scopes...).Count(&count).Error @@ -42,6 +48,9 @@ func (res *PageResponse) Paginate(db *gorm.DB, pageScope func(db *gorm.DB) *gorm } func (res *PageResponse) Find(db *gorm.DB, scopes ...func(db *gorm.DB) *gorm.DB) error { + if db == nil { + return gorm.ErrInvalidDB + } db = db.Model(&operation.Oplog{}) err := db.Scopes(scopes...).Find(&res.Item).Error if err != nil { diff --git a/iris/perm/logic.go b/iris/perm/logic.go index 40172a7..5f067a3 100644 --- a/iris/perm/logic.go +++ b/iris/perm/logic.go @@ -10,6 +10,9 @@ import ( // CreatenInBatches 批量加入 func CreatenInBatches(db *gorm.DB, perms PermCollection) error { + if db == nil { + return gorm.ErrInvalidDB + } err := db.Model(&Permission{}).CreateInBatches(&perms, 500).Error if err != nil { zap_server.ZAPLOG.Error(err.Error()) diff --git a/iris/perm/model.go b/iris/perm/model.go index db04323..297bfb1 100644 --- a/iris/perm/model.go +++ b/iris/perm/model.go @@ -28,6 +28,9 @@ type BasePermission struct { // Create 添加 func (item *Permission) Create(db *gorm.DB) (uint, error) { + if db == nil { + return 0, gorm.ErrInvalidDB + } if !CheckNameAndAct(NameScope(item.Name), ActScope(item.Act)) { return item.ID, errors.New(str.Join("权限[", item.Name, "-", item.Act, "]已存在")) } @@ -41,6 +44,9 @@ func (item *Permission) Create(db *gorm.DB) (uint, error) { // Update 更新 func (item *Permission) Update(db *gorm.DB, scopes ...func(db *gorm.DB) *gorm.DB) error { + if db == nil { + return gorm.ErrInvalidDB + } err := db.Model(item).Scopes(scopes...).Updates(item).Error if err != nil { zap_server.ZAPLOG.Error(err.Error()) @@ -51,6 +57,9 @@ func (item *Permission) Update(db *gorm.DB, scopes ...func(db *gorm.DB) *gorm.DB // Delete 删除 func (item *Permission) Delete(db *gorm.DB, scopes ...func(db *gorm.DB) *gorm.DB) error { + if db == nil { + return gorm.ErrInvalidDB + } err := db.Model(item).Unscoped().Scopes(scopes...).Delete(item).Error if err != nil { zap_server.ZAPLOG.Error(err.Error()) diff --git a/iris/perm/response.go b/iris/perm/response.go index 29d5207..703f8a7 100644 --- a/iris/perm/response.go +++ b/iris/perm/response.go @@ -14,6 +14,9 @@ type Response struct { } func (res *Response) First(db *gorm.DB, scopes ...func(db *gorm.DB) *gorm.DB) error { + if db == nil { + return gorm.ErrInvalidDB + } err := db.Model(&Permission{}).Scopes(scopes...).First(res).Error if err != nil { zap_server.ZAPLOG.Error("获取权限失败", zap.String("First()", err.Error())) @@ -28,6 +31,9 @@ type PageResponse struct { } func (res *PageResponse) Paginate(db *gorm.DB, pageScope func(db *gorm.DB) *gorm.DB, scopes ...func(db *gorm.DB) *gorm.DB) (int64, error) { + if db == nil { + return 0, gorm.ErrInvalidDB + } db = db.Model(&Permission{}) var count int64 err := db.Scopes(scopes...).Count(&count).Error @@ -45,6 +51,9 @@ func (res *PageResponse) Paginate(db *gorm.DB, pageScope func(db *gorm.DB) *gorm } func (res *PageResponse) Find(db *gorm.DB, scopes ...func(db *gorm.DB) *gorm.DB) error { + if db == nil { + return gorm.ErrInvalidDB + } db = db.Model(&Permission{}) err := db.Scopes(scopes...).Find(&res.Item).Error if err != nil { diff --git a/iris/role/model.go b/iris/role/model.go index cbdbc69..70ce978 100644 --- a/iris/role/model.go +++ b/iris/role/model.go @@ -22,6 +22,9 @@ type BaseRole struct { // Create 添加 func (item *Role) Create(db *gorm.DB) (uint, error) { + if db == nil { + return 0, gorm.ErrInvalidDB + } err := db.Model(item).Create(item).Error if err != nil { zap_server.ZAPLOG.Error(err.Error()) @@ -32,6 +35,9 @@ func (item *Role) Create(db *gorm.DB) (uint, error) { // Update 更新 func (item *Role) Update(db *gorm.DB, scopes ...func(db *gorm.DB) *gorm.DB) error { + if db == nil { + return gorm.ErrInvalidDB + } err := db.Model(item).Scopes(scopes...).Updates(item).Error if err != nil { zap_server.ZAPLOG.Error(err.Error()) @@ -42,6 +48,9 @@ func (item *Role) Update(db *gorm.DB, scopes ...func(db *gorm.DB) *gorm.DB) erro // Delete 删除 func (item *Role) Delete(db *gorm.DB, scopes ...func(db *gorm.DB) *gorm.DB) error { + if db == nil { + return gorm.ErrInvalidDB + } err := db.Model(item).Unscoped().Scopes(scopes...).Delete(item).Error if err != nil { zap_server.ZAPLOG.Error(err.Error()) diff --git a/iris/role/response.go b/iris/role/response.go index 3ef3057..1251842 100644 --- a/iris/role/response.go +++ b/iris/role/response.go @@ -12,6 +12,9 @@ type Response struct { } func (res *Response) First(db *gorm.DB, scopes ...func(db *gorm.DB) *gorm.DB) error { + if db == nil { + return gorm.ErrInvalidDB + } err := db.Model(&Role{}).Scopes(scopes...).First(res).Error if err != nil { zap_server.ZAPLOG.Error(err.Error()) @@ -26,6 +29,9 @@ type PageResponse struct { } func (res *PageResponse) Paginate(db *gorm.DB, pageScope func(db *gorm.DB) *gorm.DB, scopes ...func(db *gorm.DB) *gorm.DB) (int64, error) { + if db == nil { + return 0, gorm.ErrInvalidDB + } db = db.Model(&Role{}) var count int64 err := db.Scopes(scopes...).Count(&count).Error @@ -43,6 +49,9 @@ func (res *PageResponse) Paginate(db *gorm.DB, pageScope func(db *gorm.DB) *gorm } func (res *PageResponse) Find(db *gorm.DB, scopes ...func(db *gorm.DB) *gorm.DB) error { + if db == nil { + return gorm.ErrInvalidDB + } db = db.Model(&Role{}) err := db.Scopes(scopes...).Find(&res.Item).Error if err != nil { diff --git a/iris/user/logic.go b/iris/user/logic.go index 24b10a1..dcdf890 100644 --- a/iris/user/logic.go +++ b/iris/user/logic.go @@ -64,6 +64,9 @@ func FindByUserName(scopes ...func(db *gorm.DB) *gorm.DB) (*Response, error) { func FindPasswordByUserName(db *gorm.DB, username string, scopes ...func(db *gorm.DB) *gorm.DB) (*LoginResponse, error) { user := &LoginResponse{} + if db == nil { + return nil, gorm.ErrInvalidDB + } db = db.Model(&User{}).Select("id,password"). Where("username = ?", username) @@ -138,6 +141,9 @@ func IsAdminUser(id uint) error { func FindById(db *gorm.DB, id uint) (Response, error) { user := Response{} + if db == nil { + return user, gorm.ErrInvalidDB + } err := db.Model(&User{}).Where("id = ?", id).First(&user).Error if err != nil { zap_server.ZAPLOG.Error(err.Error()) @@ -199,6 +205,9 @@ func CleanToken(authorityType int, userId string) error { } func UpdateAvatar(db *gorm.DB, id uint, avatar string) error { + if db == nil { + return gorm.ErrInvalidDB + } err := db.Model(&User{}).Where("id = ?", id).Update("avatar", avatar).Error if err != nil { zap_server.ZAPLOG.Error(err.Error()) diff --git a/iris/user/model.go b/iris/user/model.go index 2c831e5..faed7bf 100644 --- a/iris/user/model.go +++ b/iris/user/model.go @@ -25,6 +25,9 @@ type Avatar struct { // Create 添加 func (item *User) Create(db *gorm.DB) (uint, error) { + if db == nil { + return 0, gorm.ErrInvalidDB + } err := db.Model(item).Create(item).Error if err != nil { zap_server.ZAPLOG.Error(err.Error()) @@ -35,6 +38,9 @@ func (item *User) Create(db *gorm.DB) (uint, error) { // Update 更新 func (item *User) Update(db *gorm.DB, scopes ...func(db *gorm.DB) *gorm.DB) error { + if db == nil { + return gorm.ErrInvalidDB + } err := db.Model(item).Scopes(scopes...).Updates(item).Error if err != nil { zap_server.ZAPLOG.Error(err.Error()) @@ -45,6 +51,9 @@ func (item *User) Update(db *gorm.DB, scopes ...func(db *gorm.DB) *gorm.DB) erro // Delete 删除 func (item *User) Delete(db *gorm.DB, scopes ...func(db *gorm.DB) *gorm.DB) error { + if db == nil { + return gorm.ErrInvalidDB + } err := db.Model(item).Unscoped().Scopes(scopes...).Delete(item).Error if err != nil { zap_server.ZAPLOG.Error(err.Error()) diff --git a/iris/user/response.go b/iris/user/response.go index dc176cd..99662c1 100644 --- a/iris/user/response.go +++ b/iris/user/response.go @@ -33,6 +33,9 @@ type LoginResponse struct { } func (res *Response) First(db *gorm.DB, scopes ...func(db *gorm.DB) *gorm.DB) error { + if db == nil { + return gorm.ErrInvalidDB + } err := db.Model(&User{}).Scopes(scopes...).First(res).Error if err != nil { zap_server.ZAPLOG.Error(err.Error()) @@ -48,6 +51,9 @@ type PageResponse struct { } func (res *PageResponse) Paginate(db *gorm.DB, pageScope func(db *gorm.DB) *gorm.DB, scopes ...func(db *gorm.DB) *gorm.DB) (int64, error) { + if db == nil { + return 0, gorm.ErrInvalidDB + } db = db.Model(&User{}) var count int64 err := db.Scopes(scopes...).Count(&count).Error @@ -65,6 +71,9 @@ func (res *PageResponse) Paginate(db *gorm.DB, pageScope func(db *gorm.DB) *gorm } func (res *PageResponse) Find(db *gorm.DB, scopes ...func(db *gorm.DB) *gorm.DB) error { + if db == nil { + return gorm.ErrInvalidDB + } err := db.Model(&User{}).Scopes(scopes...).Find(&res.Item).Error if err != nil { zap_server.ZAPLOG.Error(err.Error())