diff --git a/cmd/sponge/commands/generate/handler-pb.go b/cmd/sponge/commands/generate/handler-pb.go index 99e038a6..5f1e283d 100644 --- a/cmd/sponge/commands/generate/handler-pb.go +++ b/cmd/sponge/commands/generate/handler-pb.go @@ -131,7 +131,7 @@ func runGenHandlerPbCommand(moduleName string, serverName string, codes map[stri "userExample.pb.go", "userExample.pb.validate.go", "userExample_grpc.pb.go", "userExample_router.pb.go", // api/serverNameExample "systemCode_http.go", "systemCode_rpc.go", "userExample_rpc.go", // internal/ecode "init.go", "init_test.go", // internal/model - "handler/userExample.go", "handler/userExample_test.go", // internal/handler + "handler/userExample.go", "handler/userExample_test.go", "handler/userExample_logic_test.go", // internal/handler "doc.go", "cacheNameExample.go", "cacheNameExample_test.go", // internal/cache } diff --git a/cmd/sponge/commands/generate/handler.go b/cmd/sponge/commands/generate/handler.go index c6ad2003..2b0dac6d 100644 --- a/cmd/sponge/commands/generate/handler.go +++ b/cmd/sponge/commands/generate/handler.go @@ -117,7 +117,7 @@ func runGenHandlerCommand(moduleName string, codes map[string]string, outPath st "routers.go", "routers_test.go", "routers_pbExample.go", "routers_pbExample_test.go", "userExample_router.go", // internal/routers "swagger_types.go", // internal/types "doc.go", "cacheNameExample.go", "cacheNameExample_test.go", // internal/cache - "handler/userExample_logic.go", // internal/handler + "handler/userExample_logic.go", "handler/userExample_logic_test.go", // internal/handler } r.SetSubDirsAndFiles(subDirs) diff --git a/cmd/sponge/commands/generate/http.go b/cmd/sponge/commands/generate/http.go index e740eb1a..e2335b97 100644 --- a/cmd/sponge/commands/generate/http.go +++ b/cmd/sponge/commands/generate/http.go @@ -150,7 +150,7 @@ func runGenHTTPCommand(moduleName string, serverName string, projectName string, "routers_pbExample.go", "routers_pbExample_test.go", "userExample_router.go", // internal/routers "grpc.go", "grpc_option.go", "grpc_test.go", // internal/server "doc.go", "cacheNameExample.go", "cacheNameExample_test.go", // internal/cache - "handler/userExample_logic.go", // internal/handler + "handler/userExample_logic.go", "handler/userExample_logic_test.go", // internal/handler } r.SetSubDirsAndFiles(subDirs, subFiles...) diff --git a/codecov.yml b/codecov.yml index 2a339c0a..83c0bb26 100644 --- a/codecov.yml +++ b/codecov.yml @@ -7,7 +7,7 @@ coverage: status: project: default: - target: '80' + target: '75' patch: default: target: '60' diff --git a/internal/handler/userExample.go b/internal/handler/userExample.go index 1402d863..191e175d 100644 --- a/internal/handler/userExample.go +++ b/internal/handler/userExample.go @@ -96,6 +96,7 @@ func (h *userExampleHandler) Create(c *gin.Context) { func (h *userExampleHandler) DeleteByID(c *gin.Context) { _, id, isAbort := getUserExampleIDFromPath(c) if isAbort { + response.Error(c, ecode.InvalidParams) return } @@ -152,6 +153,7 @@ func (h *userExampleHandler) DeleteByIDs(c *gin.Context) { func (h *userExampleHandler) UpdateByID(c *gin.Context) { _, id, isAbort := getUserExampleIDFromPath(c) if isAbort { + response.Error(c, ecode.InvalidParams) return } @@ -194,6 +196,7 @@ func (h *userExampleHandler) UpdateByID(c *gin.Context) { func (h *userExampleHandler) GetByID(c *gin.Context) { idStr, id, isAbort := getUserExampleIDFromPath(c) if isAbort { + response.Error(c, ecode.InvalidParams) return } @@ -358,7 +361,6 @@ func getUserExampleIDFromPath(c *gin.Context) (string, uint64, bool) { id, err := utils.StrToUint64E(idStr) if err != nil || id == 0 { logger.Warn("StrToUint64E error: ", logger.String("idStr", idStr), middleware.GCtxRequestIDField(c)) - response.Error(c, ecode.InvalidParams) return "", 0, true } diff --git a/internal/handler/userExample_logic_test.go b/internal/handler/userExample_logic_test.go new file mode 100644 index 00000000..594ea8cb --- /dev/null +++ b/internal/handler/userExample_logic_test.go @@ -0,0 +1,474 @@ +package handler + +import ( + "net/http" + "testing" + "time" + + serverNameExampleV1 "github.com/zhufuyi/sponge/api/serverNameExample/v1" + "github.com/zhufuyi/sponge/internal/cache" + "github.com/zhufuyi/sponge/internal/dao" + "github.com/zhufuyi/sponge/internal/ecode" + "github.com/zhufuyi/sponge/internal/model" + + "github.com/zhufuyi/sponge/pkg/gin/response" + "github.com/zhufuyi/sponge/pkg/gohttp" + "github.com/zhufuyi/sponge/pkg/gotest" + "github.com/zhufuyi/sponge/pkg/utils" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/gin-gonic/gin" + "github.com/jinzhu/copier" + "github.com/stretchr/testify/assert" + "github.com/zhufuyi/sponge/api/types" +) + +func newUserExamplePbHandler() *gotest.Handler { + // todo additional test field information + testData := &model.UserExample{} + testData.ID = 1 + testData.CreatedAt = time.Now() + testData.UpdatedAt = testData.CreatedAt + + // init mock cache + c := gotest.NewCache(map[string]interface{}{utils.Uint64ToStr(testData.ID): testData}) + c.ICache = cache.NewUserExampleCache(&model.CacheType{ + CType: "redis", + Rdb: c.RedisClient, + }) + + // init mock dao + d := gotest.NewDao(c, testData) + d.IDao = dao.NewUserExampleDao(d.DB, c.ICache.(cache.UserExampleCache)) + + // init mock handler + h := gotest.NewHandler(d, testData) + h.IHandler = &userExamplePbHandler{userExampleDao: d.IDao.(dao.UserExampleDao)} + iHandler := h.IHandler.(serverNameExampleV1.UserExampleLogicer) + + testFns := []gotest.RouterInfo{ + { + FuncName: "Create", + Method: http.MethodPost, + Path: "/userExample", + HandlerFunc: func(c *gin.Context) { + req := &serverNameExampleV1.CreateUserExampleRequest{} + _ = c.ShouldBindJSON(req) + _, err := iHandler.Create(c, req) + if err != nil { + response.Error(c, ecode.ErrCreateUserExample) + return + } + response.Success(c) + }, + }, + { + FuncName: "DeleteByID", + Method: http.MethodDelete, + Path: "/userExample/:id", + HandlerFunc: func(c *gin.Context) { + req := &serverNameExampleV1.DeleteUserExampleByIDRequest{ + Id: utils.StrToUint64(c.Param("id")), + } + _, err := iHandler.DeleteByID(c, req) + if err != nil { + response.Error(c, ecode.ErrDeleteByIDUserExample) + return + } + response.Success(c) + }, + }, + { + FuncName: "DeleteByIDs", + Method: http.MethodPost, + Path: "/userExample/delete/ids", + HandlerFunc: func(c *gin.Context) { + req := &serverNameExampleV1.DeleteUserExampleByIDsRequest{} + _ = c.ShouldBindJSON(req) + _, err := iHandler.DeleteByIDs(c, req) + if err != nil { + response.Error(c, ecode.ErrDeleteByIDsUserExample) + return + } + response.Success(c) + }, + }, + { + FuncName: "UpdateByID", + Method: http.MethodPut, + Path: "/userExample/:id", + HandlerFunc: func(c *gin.Context) { + req := &serverNameExampleV1.UpdateUserExampleByIDRequest{} + _ = c.ShouldBindJSON(req) + req.Id = utils.StrToUint64(c.Param("id")) + _, err := iHandler.UpdateByID(c, req) + if err != nil { + response.Error(c, ecode.ErrUpdateByIDUserExample) + return + } + response.Success(c) + }, + }, + { + FuncName: "GetByID", + Method: http.MethodGet, + Path: "/userExample/:id", + HandlerFunc: func(c *gin.Context) { + req := &serverNameExampleV1.GetUserExampleByIDRequest{ + Id: utils.StrToUint64(c.Param("id")), + } + _, err := iHandler.GetByID(c, req) + if err != nil { + response.Error(c, ecode.ErrGetByIDUserExample) + return + } + response.Success(c) + }, + }, + { + FuncName: "GetByCondition", + Method: http.MethodPost, + Path: "/userExample/condition", + HandlerFunc: func(c *gin.Context) { + req := &serverNameExampleV1.GetUserExampleByConditionRequest{} + _ = c.ShouldBindJSON(req) + _, err := iHandler.GetByCondition(c, req) + if err != nil { + response.Error(c, ecode.ErrGetByConditionUserExample) + return + } + response.Success(c) + }, + }, + { + FuncName: "ListByIDs", + Method: http.MethodPost, + Path: "/userExample/list/ids", + HandlerFunc: func(c *gin.Context) { + req := &serverNameExampleV1.ListUserExampleByIDsRequest{} + _ = c.ShouldBindJSON(req) + _, err := iHandler.ListByIDs(c, req) + if err != nil { + response.Error(c, ecode.ErrListByIDsUserExample) + return + } + response.Success(c) + }, + }, + { + FuncName: "List", + Method: http.MethodPost, + Path: "/userExample/list", + HandlerFunc: func(c *gin.Context) { + req := &serverNameExampleV1.ListUserExampleRequest{} + _ = c.ShouldBindJSON(req) + _, err := iHandler.List(c, req) + if err != nil { + response.Error(c, ecode.ErrListUserExample) + return + } + response.Success(c) + }, + }, + } + + h.GoRunHTTPServer(testFns) + + time.Sleep(time.Millisecond * 200) + return h +} + +func Test_userExamplePbHandler_Create(t *testing.T) { + h := newUserExamplePbHandler() + defer h.Close() + testData := &serverNameExampleV1.CreateUserExampleRequest{} + _ = copier.Copy(testData, h.TestData.(*model.UserExample)) + + h.MockDao.SQLMock.ExpectBegin() + args := h.MockDao.GetAnyArgs(h.TestData) + h.MockDao.SQLMock.ExpectExec("INSERT INTO .*"). + WithArgs(args[:len(args)-1]...). // adjusted for the amount of test data + WillReturnResult(sqlmock.NewResult(1, 1)) + h.MockDao.SQLMock.ExpectCommit() + + result := &gohttp.StdResult{} + err := gohttp.Post(result, h.GetRequestURL("Create"), testData) + if err != nil { + t.Fatal(err) + } + + t.Logf("%+v", result) + // delete the templates code start + result = &gohttp.StdResult{} + testData = &serverNameExampleV1.CreateUserExampleRequest{ + Name: "foo", + Password: "f447b20a7fcbf53a5d5be013ea0b15af", + Email: "foo@bar.com", + Phone: "16000000001", + Avatar: "http://foo/1.jpg", + Age: 10, + Gender: 1, + } + err = gohttp.Post(result, h.GetRequestURL("Create"), testData) + if err != nil { + t.Fatal(err) + } + t.Logf("%+v", result) + + h.MockDao.SQLMock.ExpectBegin() + h.MockDao.SQLMock.ExpectCommit() + // create error test + result = &gohttp.StdResult{} + err = gohttp.Post(result, h.GetRequestURL("Create"), testData) + assert.NoError(t, err) + // delete the templates code end +} + +func Test_userExamplePbHandler_DeleteByID(t *testing.T) { + h := newUserExamplePbHandler() + defer h.Close() + testData := h.TestData.(*model.UserExample) + + h.MockDao.SQLMock.ExpectBegin() + h.MockDao.SQLMock.ExpectExec("UPDATE .*"). + WithArgs(h.MockDao.AnyTime, testData.ID). // adjusted for the amount of test data + WillReturnResult(sqlmock.NewResult(int64(testData.ID), 1)) + h.MockDao.SQLMock.ExpectCommit() + + result := &gohttp.StdResult{} + err := gohttp.Delete(result, h.GetRequestURL("DeleteByID", testData.ID)) + if err != nil { + t.Fatal(err) + } + if result.Code != 0 { + t.Fatalf("%+v", result) + } + + // zero id error test + err = gohttp.Delete(result, h.GetRequestURL("DeleteByID", 0)) + assert.NoError(t, err) + + // delete error test + err = gohttp.Delete(result, h.GetRequestURL("DeleteByID", 111)) + assert.NoError(t, err) +} + +func Test_userExamplePbHandler_DeleteByIDs(t *testing.T) { + h := newUserExamplePbHandler() + defer h.Close() + testData := h.TestData.(*model.UserExample) + + h.MockDao.SQLMock.ExpectBegin() + h.MockDao.SQLMock.ExpectExec("UPDATE .*"). + WithArgs(h.MockDao.AnyTime, testData.ID). // adjusted for the amount of test data + WillReturnResult(sqlmock.NewResult(int64(testData.ID), 1)) + h.MockDao.SQLMock.ExpectCommit() + + result := &gohttp.StdResult{} + err := gohttp.Post(result, h.GetRequestURL("DeleteByIDs"), &serverNameExampleV1.DeleteUserExampleByIDsRequest{Ids: []uint64{testData.ID}}) + if err != nil { + t.Fatal(err) + } + if result.Code != 0 { + t.Fatalf("%+v", result) + } + + // zero id error test + err = gohttp.Post(result, h.GetRequestURL("DeleteByIDs"), &serverNameExampleV1.DeleteUserExampleByIDsRequest{}) + assert.NoError(t, err) + + // get error test + err = gohttp.Post(result, h.GetRequestURL("DeleteByIDs"), &serverNameExampleV1.DeleteUserExampleByIDsRequest{Ids: []uint64{111}}) + assert.NoError(t, err) +} + +func Test_userExamplePbHandler_UpdateByID(t *testing.T) { + h := newUserExamplePbHandler() + defer h.Close() + testData := &serverNameExampleV1.UpdateUserExampleByIDRequest{} + _ = copier.Copy(testData, h.TestData.(*model.UserExample)) + testData.Id = h.TestData.(*model.UserExample).ID + + h.MockDao.SQLMock.ExpectBegin() + h.MockDao.SQLMock.ExpectExec("UPDATE .*"). + WithArgs(h.MockDao.AnyTime, testData.Id). // adjusted for the amount of test data + WillReturnResult(sqlmock.NewResult(int64(testData.Id), 1)) + h.MockDao.SQLMock.ExpectCommit() + + result := &gohttp.StdResult{} + err := gohttp.Put(result, h.GetRequestURL("UpdateByID", testData.Id), testData) + if err != nil { + t.Fatal(err) + } + if result.Code != 0 { + t.Fatalf("%+v", result) + } + + // zero id error test + err = gohttp.Put(result, h.GetRequestURL("UpdateByID", 0), testData) + assert.NoError(t, err) + + // update error test + err = gohttp.Put(result, h.GetRequestURL("UpdateByID", 111), testData) + assert.NoError(t, err) +} + +func Test_userExamplePbHandler_GetByID(t *testing.T) { + h := newUserExamplePbHandler() + defer h.Close() + testData := h.TestData.(*model.UserExample) + + rows := sqlmock.NewRows([]string{"id", "created_at", "updated_at"}). + AddRow(testData.ID, testData.CreatedAt, testData.UpdatedAt) + + h.MockDao.SQLMock.ExpectQuery("SELECT .*"). + WithArgs(testData.ID). + WillReturnRows(rows) + + result := &gohttp.StdResult{} + err := gohttp.Get(result, h.GetRequestURL("GetByID", testData.ID)) + if err != nil { + t.Fatal(err) + } + if result.Code != 0 { + t.Fatalf("%+v", result) + } + + // zero id error test + err = gohttp.Get(result, h.GetRequestURL("GetByID", 0)) + assert.NoError(t, err) + + // get error test + err = gohttp.Get(result, h.GetRequestURL("GetByID", 111)) + assert.NoError(t, err) +} + +func Test_userExamplePbHandler_GetByCondition(t *testing.T) { + h := newUserExamplePbHandler() + defer h.Close() + testData := h.TestData.(*model.UserExample) + + rows := sqlmock.NewRows([]string{"id", "created_at", "updated_at"}). + AddRow(testData.ID, testData.CreatedAt, testData.UpdatedAt) + + h.MockDao.SQLMock.ExpectQuery("SELECT .*").WillReturnRows(rows) + + result := &gohttp.StdResult{} + err := gohttp.Post(result, h.GetRequestURL("GetByCondition"), &serverNameExampleV1.GetUserExampleByConditionRequest{ + Conditions: &types.Conditions{ + Columns: []*types.Column{ + { + Name: "id", + Value: utils.Uint64ToStr(testData.ID), + }, + }, + }, + }) + if err != nil { + t.Fatal(err) + } + if result.Code != 0 { + t.Fatalf("%+v", result) + } + + // zero error test + err = gohttp.Post(result, h.GetRequestURL("GetByCondition"), nil) + assert.NoError(t, err) + + // valid error test + err = gohttp.Post(result, h.GetRequestURL("GetByCondition"), &serverNameExampleV1.GetUserExampleByConditionRequest{ + Conditions: &types.Conditions{ + Columns: []*types.Column{ + { + Name: "id", + Value: "111", + Exp: "unknown", + }, + }, + }, + }) + + // get error test + err = gohttp.Post(result, h.GetRequestURL("GetByCondition"), &serverNameExampleV1.GetUserExampleByConditionRequest{ + Conditions: &types.Conditions{ + Columns: []*types.Column{ + { + Name: "id", + Value: "111", + }, + }, + }, + }) + assert.NoError(t, err) +} + +func Test_userExamplePbHandler_ListByIDs(t *testing.T) { + h := newUserExamplePbHandler() + defer h.Close() + testData := h.TestData.(*model.UserExample) + + rows := sqlmock.NewRows([]string{"id", "created_at", "updated_at"}). + AddRow(testData.ID, testData.CreatedAt, testData.UpdatedAt) + + h.MockDao.SQLMock.ExpectQuery("SELECT .*").WillReturnRows(rows) + + result := &gohttp.StdResult{} + err := gohttp.Post(result, h.GetRequestURL("ListByIDs"), &serverNameExampleV1.ListUserExampleByIDsRequest{Ids: []uint64{testData.ID}}) + if err != nil { + t.Fatal(err) + } + if result.Code != 0 { + t.Fatalf("%+v", result) + } + + // zero id error test + err = gohttp.Post(result, h.GetRequestURL("ListByIDs"), &serverNameExampleV1.ListUserExampleByIDsRequest{}) + assert.NoError(t, err) + + // get error test + err = gohttp.Post(result, h.GetRequestURL("ListByIDs"), &serverNameExampleV1.ListUserExampleByIDsRequest{Ids: []uint64{111}}) + assert.NoError(t, err) +} + +func Test_userExamplePbHandler_List(t *testing.T) { + h := newUserExamplePbHandler() + defer h.Close() + testData := h.TestData.(*model.UserExample) + + rows := sqlmock.NewRows([]string{"id", "created_at", "updated_at"}). + AddRow(testData.ID, testData.CreatedAt, testData.UpdatedAt) + + h.MockDao.SQLMock.ExpectQuery("SELECT .*").WillReturnRows(rows) + + result := &gohttp.StdResult{} + err := gohttp.Post(result, h.GetRequestURL("List"), &serverNameExampleV1.ListUserExampleRequest{ + Params: &types.Params{ + Page: 0, + Limit: 10, + Sort: "ignore count", // ignore test count + }}) + if err != nil { + t.Fatal(err) + } + if result.Code != 0 { + t.Fatalf("%+v", result) + } + + // nil params error test + err = gohttp.Post(result, h.GetRequestURL("List"), &serverNameExampleV1.ListUserExampleRequest{}) + assert.NoError(t, err) + + // get error test + err = gohttp.Post(result, h.GetRequestURL("List"), &serverNameExampleV1.ListUserExampleRequest{Params: &types.Params{ + Page: 0, + Limit: 10, + }}) + assert.NoError(t, err) +} + +func TestNewUserExamplePbHandler(t *testing.T) { + defer func() { + recover() + }() + _ = NewUserExamplePbHandler() +} diff --git a/internal/handler/userExample_test.go b/internal/handler/userExample_test.go index c8476680..459ee073 100644 --- a/internal/handler/userExample_test.go +++ b/internal/handler/userExample_test.go @@ -139,6 +139,7 @@ func Test_userExampleHandler_Create(t *testing.T) { h.MockDao.SQLMock.ExpectBegin() h.MockDao.SQLMock.ExpectCommit() + // create error test result = &gohttp.StdResult{} err = gohttp.Post(result, h.GetRequestURL("Create"), testData) assert.Error(t, err) diff --git a/internal/service/userExample_logic_test.go b/internal/service/userExample_logic_test.go index 99999bf9..40db65af 100644 --- a/internal/service/userExample_logic_test.go +++ b/internal/service/userExample_logic_test.go @@ -37,6 +37,11 @@ func TestNewUserExampleServiceClient(t *testing.T) { t.Log(reply, err) cancel() }) + utils.SafeRunWithTimeout(time.Second, func(cancel context.CancelFunc) { + reply, err := cli.DeleteByIDs(ctx, nil) + t.Log(reply, err) + cancel() + }) utils.SafeRunWithTimeout(time.Second, func(cancel context.CancelFunc) { reply, err := cli.UpdateByID(ctx, nil) t.Log(reply, err) diff --git a/pkg/errcode/rpc_error_test.go b/pkg/errcode/rpc_error_test.go index e713979b..ba7e77c4 100644 --- a/pkg/errcode/rpc_error_test.go +++ b/pkg/errcode/rpc_error_test.go @@ -1,7 +1,11 @@ package errcode import ( + "github.com/gin-gonic/gin" + "github.com/zhufuyi/sponge/pkg/utils" + "net/http" "testing" + "time" "github.com/stretchr/testify/assert" ) @@ -84,3 +88,25 @@ func TestRCode(t *testing.T) { }() code = RCode(101) } + +func TestHandlers(t *testing.T) { + serverAddr, requestAddr := utils.GetLocalHTTPAddrPairs() + + gin.SetMode(gin.ReleaseMode) + r := gin.New() + r.GET("/codes", gin.WrapF(ListGRPCErrCodes)) + r.GET("/config", gin.WrapF(ShowConfig([]byte(`{"foo": "bar"}`)))) + + go func() { + _ = r.Run(serverAddr) + }() + + time.Sleep(time.Millisecond * 200) + resp, err := http.Get(requestAddr + "/codes") + assert.NoError(t, err) + assert.NotNil(t, resp) + resp, err = http.Get(requestAddr + "/config") + assert.NoError(t, err) + assert.NotNil(t, resp) + time.Sleep(time.Second) +} diff --git a/pkg/gin/handlerfunc/common_test.go b/pkg/gin/handlerfunc/common_test.go index d5fb0de6..38417d04 100644 --- a/pkg/gin/handlerfunc/common_test.go +++ b/pkg/gin/handlerfunc/common_test.go @@ -6,7 +6,6 @@ import ( "testing" "time" - "github.com/zhufuyi/sponge/pkg/errcode" "github.com/zhufuyi/sponge/pkg/gohttp" "github.com/zhufuyi/sponge/pkg/utils" @@ -14,7 +13,7 @@ import ( "github.com/stretchr/testify/assert" ) -func TestCheckHealth(t *testing.T) { +func TestCommonHandlers(t *testing.T) { serverAddr, requestAddr := utils.GetLocalHTTPAddrPairs() gin.SetMode(gin.ReleaseMode) @@ -22,8 +21,6 @@ func TestCheckHealth(t *testing.T) { r.GET("/health", CheckHealth) r.GET("/ping", Ping) r.GET("/codes", ListCodes) - r.GET("/codes2", gin.WrapF(errcode.ListGRPCErrCodes)) - r.GET("/config", gin.WrapF(errcode.ShowConfig([]byte(`{"foo": "bar"}`)))) go func() { _ = r.Run(serverAddr) @@ -39,12 +36,7 @@ func TestCheckHealth(t *testing.T) { resp, err = http.Get(requestAddr + "/codes") assert.NoError(t, err) assert.NotNil(t, resp) - resp, err = http.Get(requestAddr + "/codes2") - assert.NoError(t, err) - assert.NotNil(t, resp) - resp, err = http.Get(requestAddr + "/config") - assert.NoError(t, err) - assert.NotNil(t, resp) + time.Sleep(time.Second) } func TestBrowserRefresh(t *testing.T) { diff --git a/pkg/rabbitmq/connect.go b/pkg/rabbitmq/connect.go index 41eedbb4..7cf93023 100644 --- a/pkg/rabbitmq/connect.go +++ b/pkg/rabbitmq/connect.go @@ -29,30 +29,6 @@ type Connection struct { IsConnected bool } -func connect(url string, tlsConfig *tls.Config) (*amqp.Connection, error) { - var ( - conn *amqp.Connection - err error - ) - - if strings.HasPrefix(url, "amqps://") { - if tlsConfig == nil { - return nil, errors.New("tls not set, e.g. NewConnection(url, WithTLSConfig(tlsConfig))") - } - conn, err = amqp.DialTLS(url, tlsConfig) - if err != nil { - return nil, err - } - } else { - conn, err = amqp.Dial(url) - if err != nil { - return nil, err - } - } - - return conn, nil -} - // NewConnection rabbitmq connection func NewConnection(url string, opts ...ConnectionOption) (*Connection, error) { if url == "" { @@ -85,13 +61,28 @@ func NewConnection(url string, opts ...ConnectionOption) (*Connection, error) { return c, nil } -// Close rabbitmq connection -func (c *Connection) Close() { - c.Mutex.Lock() - c.IsConnected = false - c.Mutex.Unlock() +func connect(url string, tlsConfig *tls.Config) (*amqp.Connection, error) { + var ( + conn *amqp.Connection + err error + ) - close(c.Exit) + if strings.HasPrefix(url, "amqps://") { + if tlsConfig == nil { + return nil, errors.New("tls not set, e.g. NewConnection(url, WithTLSConfig(tlsConfig))") + } + conn, err = amqp.DialTLS(url, tlsConfig) + if err != nil { + return nil, err + } + } else { + conn, err = amqp.Dial(url) + if err != nil { + return nil, err + } + } + + return conn, nil } // CheckConnected rabbitmq connection @@ -101,17 +92,6 @@ func (c *Connection) CheckConnected() bool { return c.IsConnected } -func (c *Connection) closeConn() error { - c.Mutex.Lock() - defer c.Mutex.Unlock() - - if c.Conn != nil { - return c.Conn.Close() - } - - return nil -} - func (c *Connection) monitor() { retryCount := 0 reconnectTip := fmt.Sprintf("[rabbitmq connection] lost connection, attempting reconnect in %s", c.reconnectTime) @@ -154,3 +134,23 @@ func (c *Connection) monitor() { } } } + +// Close rabbitmq connection +func (c *Connection) Close() { + c.Mutex.Lock() + c.IsConnected = false + c.Mutex.Unlock() + + close(c.Exit) +} + +func (c *Connection) closeConn() error { + c.Mutex.Lock() + defer c.Mutex.Unlock() + + if c.Conn != nil { + return c.Conn.Close() + } + + return nil +} diff --git a/pkg/rabbitmq/connect_test.go b/pkg/rabbitmq/connect_test.go index 2981875f..fda582ba 100644 --- a/pkg/rabbitmq/connect_test.go +++ b/pkg/rabbitmq/connect_test.go @@ -36,8 +36,10 @@ func TestConnectionOptions(t *testing.T) { func TestNewConnection1(t *testing.T) { utils.SafeRunWithTimeout(time.Second*2, func(cancel context.CancelFunc) { + c, err := NewConnection("") + assert.Error(t, err) - c, err := NewConnection(url) + c, err = NewConnection(url) if err != nil { t.Log(err) return @@ -76,6 +78,7 @@ func TestConnection_monitor(t *testing.T) { IsConnected: true, } + c.CheckConnected() go func() { defer func() { recover() }() c.monitor() diff --git a/pkg/rabbitmq/consumer/consumer.go b/pkg/rabbitmq/consumer/consumer.go index b8b91937..63f28e62 100644 --- a/pkg/rabbitmq/consumer/consumer.go +++ b/pkg/rabbitmq/consumer/consumer.go @@ -25,13 +25,6 @@ type Queue struct { zapLog *zap.Logger } -// Close queue -func (q *Queue) Close() { - if q.ch != nil { - _ = q.ch.Close() - } -} - // Handler message type Handler func(ctx context.Context, data []byte, tagID ...string) error @@ -161,3 +154,10 @@ func (q *Queue) Consume(handler Handler) { } }() } + +// Close queue +func (q *Queue) Close() { + if q.ch != nil { + _ = q.ch.Close() + } +} diff --git a/pkg/rabbitmq/consumer/consumer_test.go b/pkg/rabbitmq/consumer/consumer_test.go index d0cd5f78..5aa33cac 100644 --- a/pkg/rabbitmq/consumer/consumer_test.go +++ b/pkg/rabbitmq/consumer/consumer_test.go @@ -10,6 +10,9 @@ import ( "github.com/zhufuyi/sponge/pkg/rabbitmq" "github.com/zhufuyi/sponge/pkg/rabbitmq/producer" "github.com/zhufuyi/sponge/pkg/utils" + + amqp "github.com/rabbitmq/amqp091-go" + "go.uber.org/zap" ) var url = "amqp://guest:guest@192.168.3.37:5672/" @@ -242,3 +245,40 @@ func producerHeaders(queueName string) error { }) return producerErr } + +func TestNewQueue(t *testing.T) { + c := &rabbitmq.Connection{ + Exit: make(chan struct{}), + ZapLog: zap.NewNop(), + Conn: &amqp.Connection{}, + IsConnected: true, + } + + q, err := NewQueue(context.Background(), "test", c, WithConsumeQos(WithQosPrefetchCount(1))) + if err != nil { + t.Log(err) + return + } + q.ch = &amqp.Channel{} + amqp.NewConnectionProperties() + + utils.SafeRunWithTimeout(time.Second, func(cancel context.CancelFunc) { + err = q.newChannel() + if err != nil { + t.Log(err) + return + } + }) + utils.SafeRunWithTimeout(time.Second, func(cancel context.CancelFunc) { + _, err := q.consumeWithContext() + if err != nil { + t.Log(err) + return + } + }) + utils.SafeRunWithTimeout(time.Second*3, func(cancel context.CancelFunc) { + q.Consume(handler) + }) + time.Sleep(time.Millisecond * 2500) + close(q.c.Exit) +} diff --git a/pkg/rabbitmq/producer/producer.go b/pkg/rabbitmq/producer/producer.go index 91885226..557128ef 100644 --- a/pkg/rabbitmq/producer/producer.go +++ b/pkg/rabbitmq/producer/producer.go @@ -163,13 +163,6 @@ func NewQueue(queueName string, conn *amqp.Connection, exchange *Exchange, opts }, nil } -// Close the queue -func (q *Queue) Close() { - if q.ch != nil { - _ = q.ch.Close() - } -} - // Publish send direct or fanout type message func (q *Queue) Publish(ctx context.Context, body []byte) error { if q.exchange.eType != exchangeTypeDirect && q.exchange.eType != exchangeTypeFanout { @@ -224,3 +217,10 @@ func (q *Queue) PublishHeaders(ctx context.Context, headersKey map[string]interf }, ) } + +// Close the queue +func (q *Queue) Close() { + if q.ch != nil { + _ = q.ch.Close() + } +} diff --git a/pkg/sql2code/parser/parser_test.go b/pkg/sql2code/parser/parser_test.go index e160f15d..3b134aff 100644 --- a/pkg/sql2code/parser/parser_test.go +++ b/pkg/sql2code/parser/parser_test.go @@ -2,10 +2,10 @@ package parser import ( "fmt" - "github.com/blastrain/vitess-sqlparser/tidbparser/dependency/mysql" - "github.com/blastrain/vitess-sqlparser/tidbparser/dependency/types" "testing" + "github.com/blastrain/vitess-sqlparser/tidbparser/dependency/mysql" + "github.com/blastrain/vitess-sqlparser/tidbparser/dependency/types" "github.com/stretchr/testify/assert" ) @@ -154,6 +154,8 @@ func TestGetTableInfo(t *testing.T) { } func Test_initTemplate(t *testing.T) { + initTemplate() + defer func() { recover() }() modelStructTmplRaw = "{{if .foo}}" modelTmplRaw = "{{if .foo}}" diff --git a/pkg/sql2code/sql2code_test.go b/pkg/sql2code/sql2code_test.go index 55c1eae4..626e5234 100644 --- a/pkg/sql2code/sql2code_test.go +++ b/pkg/sql2code/sql2code_test.go @@ -1,8 +1,9 @@ package sql2code import ( - "github.com/stretchr/testify/assert" "testing" + + "github.com/stretchr/testify/assert" ) var sqlData = `