From d8c853ad13d50d4ede25d10cbb976e3887fd7851 Mon Sep 17 00:00:00 2001 From: yuanruji Date: Thu, 31 Oct 2024 17:40:30 +0800 Subject: [PATCH] stage save --- .../mysql/db-simulation/app/syntax/syntax.go | 18 +-- .../db-simulation/handler/dbsimulation.go | 110 ++++++++++----- .../mysql/db-simulation/handler/handler.go | 126 +++++++----------- .../db-simulation/handler/syntax_check.go | 87 +++++------- .../db-simulation/handler/syntax_rule.go | 46 +++---- .../mysql/db-simulation/router/router.go | 27 ++-- 6 files changed, 203 insertions(+), 211 deletions(-) diff --git a/dbm-services/mysql/db-simulation/app/syntax/syntax.go b/dbm-services/mysql/db-simulation/app/syntax/syntax.go index ff699e68b2..35320233af 100644 --- a/dbm-services/mysql/db-simulation/app/syntax/syntax.go +++ b/dbm-services/mysql/db-simulation/app/syntax/syntax.go @@ -67,6 +67,7 @@ type TmysqlParse struct { bkRepoClient *bkrepo.BkRepoClient TmysqlParseBinPath string BaseWorkdir string + DbType string mu sync.Mutex } @@ -111,7 +112,7 @@ type RiskInfo struct { const DdlMapFileSubffix = ".tbl.map" // Do 运行语法检查 For SQL 文件 -func (tf *TmysqlParseFile) Do(dbtype string, versions []string) (result map[string]*CheckInfo, err error) { +func (tf *TmysqlParseFile) Do(versions []string) (result map[string]*CheckInfo, err error) { logger.Info("doing....") tf.result = make(map[string]*CheckInfo) tf.tmpWorkdir = tf.BaseWorkdir @@ -132,7 +133,7 @@ func (tf *TmysqlParseFile) Do(dbtype string, versions []string) (result map[stri var errs []error for _, version := range versions { - if err = tf.doSingleVersion(dbtype, version); err != nil { + if err = tf.doSingleVersion(version); err != nil { logger.Error("when do [%s],syntax check,failed:%s", version, err.Error()) errs = append(errs, err) } @@ -141,7 +142,7 @@ func (tf *TmysqlParseFile) Do(dbtype string, versions []string) (result map[stri return tf.result, errors.Join(errs...) } -func (tf *TmysqlParseFile) doSingleVersion(dbtype string, mysqlVersion string) (err error) { +func (tf *TmysqlParseFile) doSingleVersion(mysqlVersion string) (err error) { errChan := make(chan error, 1) alreadExecutedSqlfileChan := make(chan string, len(tf.Param.FileNames)) signalChan := make(chan struct{}) @@ -157,7 +158,7 @@ func (tf *TmysqlParseFile) doSingleVersion(dbtype string, mysqlVersion string) ( // 对tmysqlparse的处理结果进行分析,为json文件,后面用到了rule go func() { logger.Info("start to analyze the parsing result") - if err = tf.AnalyzeParseResult(alreadExecutedSqlfileChan, mysqlVersion, dbtype); err != nil { + if err = tf.AnalyzeParseResult(alreadExecutedSqlfileChan, mysqlVersion); err != nil { logger.Error("failed to analyze the parsing result:%s", err.Error()) errChan <- err } @@ -376,8 +377,7 @@ func (t *TmysqlParse) getAbsoutputfilePath(sqlFile, version string) string { } // AnalyzeParseResult 分析tmysqlparse 解析的结果 -func (t *TmysqlParse) AnalyzeParseResult(alreadExecutedSqlfileCh chan string, mysqlVersion string, - dbtype string) (err error) { +func (t *TmysqlParse) AnalyzeParseResult(alreadExecutedSqlfileCh chan string, mysqlVersion string) (err error) { var errs []error c := make(chan struct{}, 10) errChan := make(chan error, 5) @@ -388,7 +388,7 @@ func (t *TmysqlParse) AnalyzeParseResult(alreadExecutedSqlfileCh chan string, my c <- struct{}{} go func(fileName string) { defer wg.Done() - err = t.AnalyzeOne(fileName, mysqlVersion, dbtype) + err = t.AnalyzeOne(fileName, mysqlVersion) if err != nil { errChan <- err } @@ -512,7 +512,7 @@ func (t *TmysqlParse) getSyntaxErrorResult(res ParseLineQueryBase, mysqlVersion } // AnalyzeOne 分析单个文件 -func (t *TmysqlParse) AnalyzeOne(inputfileName, mysqlVersion, dbtype string) (err error) { +func (t *TmysqlParse) AnalyzeOne(inputfileName, mysqlVersion string) (err error) { var idx int var syntaxFailInfos []FailedInfo var buf []byte @@ -570,7 +570,7 @@ func (t *TmysqlParse) AnalyzeOne(inputfileName, mysqlVersion, dbtype string) (er continue } // tmysqlparse检查结果全部正确,开始判断语句是否符合定义的规则(即虽然语法正确,但语句可能是高危语句或禁用的命令) - switch dbtype { + switch t.DbType { case app.MySQL: checkResult.parseResult(R.CommandRule.HighRiskCommandRule, res, mysqlVersion) checkResult.parseResult(R.CommandRule.BanCommandRule, res, mysqlVersion) diff --git a/dbm-services/mysql/db-simulation/handler/dbsimulation.go b/dbm-services/mysql/db-simulation/handler/dbsimulation.go index 610b4d0268..1e306f4b47 100644 --- a/dbm-services/mysql/db-simulation/handler/dbsimulation.go +++ b/dbm-services/mysql/db-simulation/handler/dbsimulation.go @@ -12,6 +12,7 @@ package handler import ( "fmt" + "net/http" "strings" "github.com/gin-gonic/gin" @@ -22,6 +23,11 @@ import ( "dbm-services/mysql/db-simulation/model" ) +// SimulationHandler TODO +type SimulationHandler struct { + BaseHandler +} + // QueryFileResultParam 获取模拟执行文件的结果 type QueryFileResultParam struct { RootID string `json:"root_id" binding:"required" ` @@ -29,11 +35,9 @@ type QueryFileResultParam struct { } // QuerySimulationFileResult 查询模拟执行每个文件的执行结果 -func QuerySimulationFileResult(r *gin.Context) { +func (s *SimulationHandler) QuerySimulationFileResult(r *gin.Context) { var param QueryFileResultParam - if err := r.ShouldBindJSON(¶m); err != nil { - logger.Error("ShouldBind failed %s", err) - SendResponse(r, err, "failed to deserialize parameters", "") + if s.Prepare(r, ¶m) != nil { return } task_id := fmt.Sprintf("%s_%s", param.RootID, param.VersionID) @@ -41,39 +45,32 @@ func QuerySimulationFileResult(r *gin.Context) { err := model.DB.Where("task_id = ? ", task_id).Find(&data).Error if err != nil { logger.Error("query file task result failed %v", err) - SendResponse(r, err, err.Error(), "") + s.SendResponse(r, err, err.Error()) return } - SendResponse(r, nil, data, "") + s.SendResponse(r, nil, data) } // TendbSimulation Tendb simulation handler -func TendbSimulation(r *gin.Context) { +func (s *SimulationHandler) TendbSimulation(r *gin.Context) { var param service.BaseParam - requestID := r.GetString("request_id") - if err := r.ShouldBindJSON(¶m); err != nil { - logger.Error("ShouldBind failed %s", err) - SendResponse(r, err, "failed to deserialize parameters", requestID) - return - } - if requestID == "" { - SendResponse(r, fmt.Errorf("create request id failed"), nil, requestID) + if s.Prepare(r, ¶m) != nil { return } version := param.MySQLVersion img, err := getImgFromMySQLVersion(version) if err != nil { logger.Error("GetImgFromMySQLVersion %s failed:%s", version, err.Error()) - SendResponse(r, err, nil, requestID) + s.SendResponse(r, err, nil) return } - if err := model.CreateTask(param.TaskId, requestID, version, param.Uid); err != nil { + if err := model.CreateTask(param.TaskId, s.RequestId, version, param.Uid); err != nil { logger.Error("create task db record error %s", err.Error()) - SendResponse(r, err, nil, requestID) + s.SendResponse(r, err, nil) return } tsk := service.SimulationTask{ - RequestId: requestID, + RequestId: s.RequestId, DbPodSets: service.NewDbPodSets(), BaseParam: ¶m, Version: version, @@ -83,40 +80,37 @@ func TendbSimulation(r *gin.Context) { PodName: fmt.Sprintf("tendb-%s-%s", strings.ToLower(version), replaceUnderSource(param.TaskId)), Lables: map[string]string{"task_id": replaceUnderSource(param.TaskId), - "request_id": requestID}, + "request_id": s.RequestId}, RootPwd: param.TaskId, Args: param.BuildStartArgs(), Charset: param.MySQLCharSet, } service.TaskChan <- tsk - SendResponse(r, nil, "request successful", requestID) + s.SendResponse(r, nil, "request successful") } // TendbClusterSimulation TendbCluster simulation handler -func TendbClusterSimulation(r *gin.Context) { +func (s *SimulationHandler) TendbClusterSimulation(r *gin.Context) { var param service.SpiderSimulationExecParam - RequestID := r.GetString("request_id") - if err := r.ShouldBindJSON(¶m); err != nil { - logger.Error("ShouldBind failed %s", err) - SendResponse(r, err, "failed to deserialize parameters", RequestID) + if s.Prepare(r, ¶m) != nil { return } version := param.MySQLVersion img, err := getImgFromMySQLVersion(version) if err != nil { logger.Error("GetImgFromMySQLVersion %s failed:%s", version, err.Error()) - SendResponse(r, err, nil, RequestID) + s.SendResponse(r, err, nil) return } - if err := model.CreateTask(param.TaskId, RequestID, version, param.Uid); err != nil { + if err := model.CreateTask(param.TaskId, s.RequestId, version, param.Uid); err != nil { logger.Error("create task db record error %s", err.Error()) - SendResponse(r, err, nil, RequestID) + s.SendResponse(r, err, nil) return } tsk := service.SimulationTask{ - RequestId: RequestID, + RequestId: s.RequestId, DbPodSets: service.NewDbPodSets(), BaseParam: ¶m.BaseParam, Version: version, @@ -131,10 +125,62 @@ func TendbClusterSimulation(r *gin.Context) { PodName: fmt.Sprintf("spider-%s-%s", strings.ToLower(version), replaceUnderSource(param.TaskId)), Lables: map[string]string{"task_id": replaceUnderSource(param.TaskId), - "request_id": RequestID}, + "request_id": s.RequestId}, RootPwd: rootPwd, Charset: param.MySQLCharSet, } service.SpiderTaskChan <- tsk - SendResponse(r, nil, "request successful", RequestID) + s.SendResponse(r, nil, "request successful") +} + +// T 请求查询模拟执行整体任务的执行状态参数 +type T struct { + TaskID string `json:"task_id"` +} + +// QueryTask 查询模拟执行整体任务的执行状态 +func (s *SimulationHandler) QueryTask(r *gin.Context) { + var param T + if s.Prepare(r, ¶m) != nil { + return + } + logger.Info("get task_id is %s", param.TaskID) + var tasks []model.TbSimulationTask + if err := model.DB.Where(&model.TbSimulationTask{TaskId: param.TaskID}).Find(&tasks).Error; err != nil { + logger.Error("query task failed %s", err.Error()) + s.SendResponse(r, err, map[string]interface{}{"stderr": err.Error()}) + return + } + allSuccessful := false + for _, task := range tasks { + if task.Phase != model.PhaseDone { + r.JSON(http.StatusOK, Response{ + Code: 2, + Message: fmt.Sprintf("task current phase is %s", task.Phase), + Data: "", + }) + return + } + switch task.Status { + case model.TaskFailed: + allSuccessful = false + s.SendResponse(r, fmt.Errorf("%s", task.SysErrMsg), map[string]interface{}{ + "simulation_version": task.MySQLVersion, + "stdout": task.Stdout, + "stderr": task.Stderr, + "errmsg": fmt.Sprintf("the program has been run with abnormal status:%s", task.Status)}) + + case model.TaskSuccess: + allSuccessful = true + default: + allSuccessful = false + s.SendResponse(r, fmt.Errorf("unknown transition state"), map[string]interface{}{ + "stdout": task.Stdout, + "stderr": task.Stderr, + "errmsg": fmt.Sprintf("the program has been run with abnormal status:%s", task.Status)}) + } + } + if allSuccessful { + s.SendResponse(r, nil, map[string]interface{}{"stdout": "all ok", "stderr": "all ok"}) + } } diff --git a/dbm-services/mysql/db-simulation/handler/handler.go b/dbm-services/mysql/db-simulation/handler/handler.go index 869952e08f..3d49f9f632 100644 --- a/dbm-services/mysql/db-simulation/handler/handler.go +++ b/dbm-services/mysql/db-simulation/handler/handler.go @@ -20,12 +20,55 @@ import ( "github.com/gin-gonic/gin" "github.com/samber/lo" + "dbm-services/common/go-pubpkg/cmutil" "dbm-services/common/go-pubpkg/logger" "dbm-services/mysql/db-simulation/app/config" "dbm-services/mysql/db-simulation/app/service" "dbm-services/mysql/db-simulation/model" ) +// BaseHandler base handler +type BaseHandler struct { + RequestId string +} + +// Prepare prepare request +func (b *BaseHandler) Prepare(r *gin.Context, schema interface{}) error { + requestId := r.GetString("request_id") + if cmutil.IsEmpty(requestId) { + err := fmt.Errorf("get request id error ~") + b.SendResponse(r, err, nil) + return err + } + b.RequestId = requestId + if err := r.ShouldBind(&schema); err != nil { + logger.Error("ShouldBind Failed %s", err.Error()) + b.SendResponse(r, err, nil) + return err + } + logger.Info("param is %v", schema) + return nil +} + +// SendResponse send response to client +func (b *BaseHandler) SendResponse(r *gin.Context, err error, data interface{}) { + if err != nil { + r.JSON(http.StatusOK, Response{ + Code: 1, + Message: err.Error(), + Data: data, + RequestID: b.RequestId, + }) + return + } + r.JSON(http.StatusOK, Response{ + Code: 0, + Message: "successfully", + Data: data, + RequestID: b.RequestId, + }) +} + // Response response data define type Response struct { Data interface{} `json:"data"` @@ -42,11 +85,9 @@ type CreateClusterParam struct { } // CreateTmpSpiderPodCluster 创建临时的spider的集群,多用于测试,debug -func CreateTmpSpiderPodCluster(r *gin.Context) { +func (b *BaseHandler) CreateTmpSpiderPodCluster(r *gin.Context) { var param CreateClusterParam - if err := r.ShouldBindJSON(¶m); err != nil { - logger.Error("ShouldBind failed %s", err) - SendResponse(r, err, "failed to deserialize parameters", "") + if err := b.Prepare(r, param); err != nil { return } ps := service.NewDbPodSets() @@ -61,88 +102,13 @@ func CreateTmpSpiderPodCluster(r *gin.Context) { logger.Error(err.Error()) return } - SendResponse(r, nil, "ok", "") + b.SendResponse(r, nil, "ok") } func replaceUnderSource(str string) string { return strings.ReplaceAll(str, "_", "-") } -// T 请求查询模拟执行整体任务的执行状态参数 -type T struct { - TaskID string `json:"task_id"` -} - -// QueryTask 查询模拟执行整体任务的执行状态 -func QueryTask(c *gin.Context) { - var param T - if err := c.ShouldBindJSON(¶m); err != nil { - logger.Error("ShouldBind failed %s", err) - SendResponse(c, err, map[string]interface{}{"stderr": "failed to deserialize parameters"}, "") - return - } - logger.Info("get task_id is %s", param.TaskID) - var tasks []model.TbSimulationTask - if err := model.DB.Where(&model.TbSimulationTask{TaskId: param.TaskID}).Find(&tasks).Error; err != nil { - logger.Error("query task failed %s", err.Error()) - SendResponse(c, err, map[string]interface{}{"stderr": err.Error()}, "") - return - } - allSuccessful := false - for _, task := range tasks { - if task.Phase != model.PhaseDone { - c.JSON(http.StatusOK, Response{ - Code: 2, - Message: fmt.Sprintf("task current phase is %s", task.Phase), - Data: "", - }) - return - } - switch task.Status { - case model.TaskFailed: - allSuccessful = false - SendResponse(c, fmt.Errorf("%s", task.SysErrMsg), map[string]interface{}{ - "simulation_version": task.MySQLVersion, - "stdout": task.Stdout, - "stderr": task.Stderr, - "errmsg": fmt.Sprintf("the program has been run with abnormal status:%s", task.Status)}, - "") - - case model.TaskSuccess: - allSuccessful = true - default: - allSuccessful = false - SendResponse(c, fmt.Errorf("unknown transition state"), map[string]interface{}{ - "stdout": task.Stdout, - "stderr": task.Stderr, - "errmsg": fmt.Sprintf("the program has been run with abnormal status:%s", task.Status)}, - "") - } - } - if allSuccessful { - SendResponse(c, nil, map[string]interface{}{"stdout": "all ok", "stderr": "all ok"}, "") - } -} - -// SendResponse return response data to http client -func SendResponse(r *gin.Context, err error, data interface{}, requestid string) { - if err != nil { - r.JSON(http.StatusOK, Response{ - Code: 1, - Message: err.Error(), - Data: data, - RequestID: requestid, - }) - return - } - r.JSON(http.StatusOK, Response{ - Code: 0, - Message: "successfully", - Data: data, - RequestID: requestid, - }) -} - // getImgFromMySQLVersion 根据版本获取模拟执行运行的镜像配置 func getImgFromMySQLVersion(version string) (img string, err error) { img, errx := model.GetImageName("mysql", version) diff --git a/dbm-services/mysql/db-simulation/handler/syntax_check.go b/dbm-services/mysql/db-simulation/handler/syntax_check.go index 2060aadb46..a7fac30789 100644 --- a/dbm-services/mysql/db-simulation/handler/syntax_check.go +++ b/dbm-services/mysql/db-simulation/handler/syntax_check.go @@ -46,7 +46,9 @@ func init() { } // SyntaxHandler 语法检查 handler -type SyntaxHandler struct{} +type SyntaxHandler struct { + BaseHandler +} // CheckSQLStringParam sql string 语法检查参数 type CheckSQLStringParam struct { @@ -56,31 +58,22 @@ type CheckSQLStringParam struct { } // SyntaxCheckSQL 语法检查入参SQL string -func SyntaxCheckSQL(r *gin.Context) { - requestID := r.GetString("request_id") +func (c *SyntaxHandler) SyntaxCheckSQL(r *gin.Context) { var param CheckSQLStringParam var data map[string]*syntax.CheckInfo var versions []string - // 将request中的数据按照json格式直接解析到结构体中 - if err := r.ShouldBindJSON(¶m); err != nil { - logger.Error("ShouldBind failed %s", err) - SendResponse(r, err, nil, requestID) + if c.Prepare(r, ¶m) != nil { return } - logger.Info("versions: %v", param.Versions) - if len(param.Versions) == 0 { - versions = []string{""} - } else { - versions = rebuildVersion(param.Versions) - } + versions = rebuildVersion(param.Versions) sqlContext := strings.Join(param.Sqls, "\n") fileName := "ce_" + cmutil.RandStr(10) + ".sql" f := path.Join(workdir, fileName) err := os.WriteFile(f, []byte(sqlContext), 0600) if err != nil { - SendResponse(r, err, err.Error(), requestID) + c.SendResponse(r, err, err.Error()) return } @@ -88,6 +81,7 @@ func SyntaxCheckSQL(r *gin.Context) { TmysqlParse: syntax.TmysqlParse{ TmysqlParseBinPath: tmysqlParserBin, BaseWorkdir: workdir, + DbType: getTmysqlParseDbtype(param.ClusterType), }, IsLocalFile: true, Param: syntax.CheckSQLFileParam{ @@ -97,21 +91,23 @@ func SyntaxCheckSQL(r *gin.Context) { } logger.Info("cluster type :%s,versions:%v", param.ClusterType, versions) + data, err = check.Do(versions) + if err != nil { + c.SendResponse(r, err, data) + return + } + c.SendResponse(r, nil, data) +} - switch strings.ToLower(param.ClusterType) { +func getTmysqlParseDbtype(clusterType string) string { + switch strings.ToLower(clusterType) { case app.Spider, app.TendbCluster: - data, err = check.Do(app.Spider, []string{""}) + return app.Spider case app.MySQL: - data, err = check.Do(app.MySQL, versions) + return app.MySQL default: - data, err = check.Do(app.MySQL, versions) + return app.MySQL } - - if err != nil { - SendResponse(r, err, data, requestID) - return - } - SendResponse(r, nil, data, requestID) } // CheckFileParam 语法检查请求参数 @@ -123,29 +119,21 @@ type CheckFileParam struct { } // SyntaxCheckFile 运行语法检查 -func SyntaxCheckFile(r *gin.Context) { - requestID := r.GetString("request_id") +func (c *SyntaxHandler) SyntaxCheckFile(r *gin.Context) { var param CheckFileParam var data map[string]*syntax.CheckInfo var err error var versions []string // 将request中的数据按照json格式直接解析到结构体中 - if err = r.ShouldBindJSON(¶m); err != nil { - logger.Error("ShouldBind failed %s", err) - SendResponse(r, err, nil, requestID) + if c.Prepare(r, ¶m) != nil { return } - - if len(param.Versions) == 0 { - versions = []string{""} - } else { - versions = rebuildVersion(param.Versions) - } - + versions = rebuildVersion(param.Versions) check := &syntax.TmysqlParseFile{ TmysqlParse: syntax.TmysqlParse{ TmysqlParseBinPath: tmysqlParserBin, BaseWorkdir: workdir, + DbType: getTmysqlParseDbtype(param.ClusterType), }, Param: syntax.CheckSQLFileParam{ BkRepoBasePath: param.Path, @@ -154,30 +142,19 @@ func SyntaxCheckFile(r *gin.Context) { } logger.Info("cluster type :%s", param.ClusterType) - switch strings.ToLower(param.ClusterType) { - case app.Spider, app.TendbCluster: - data, err = check.Do(app.Spider, []string{""}) - case app.MySQL: - data, err = check.Do(app.MySQL, versions) - default: - data, err = check.Do(app.MySQL, versions) - } - + data, err = check.Do(versions) if err != nil { - SendResponse(r, err, data, requestID) + c.SendResponse(r, err, data) return } - SendResponse(r, nil, data, requestID) + c.SendResponse(r, nil, data) } // CreateAndUploadDDLTblListFile 分析变更SQL DDL操作的表,并将文件上传到制品库 -func CreateAndUploadDDLTblListFile(r *gin.Context) { - requestID := r.GetString("request_id") +func (c *SyntaxHandler) CreateAndUploadDDLTblListFile(r *gin.Context) { var param CheckFileParam // 将request中的数据按照json格式直接解析到结构体中 - if err := r.ShouldBindJSON(¶m); err != nil { - logger.Error("ShouldBind failed %s", err) - SendResponse(r, err, nil, requestID) + if c.Prepare(r, ¶m) != nil { return } check := &syntax.TmysqlParseFile{ @@ -191,16 +168,16 @@ func CreateAndUploadDDLTblListFile(r *gin.Context) { }, } if err := check.CreateAndUploadDDLTblFile(); err != nil { - SendResponse(r, err, nil, requestID) + c.SendResponse(r, err, nil) return } - SendResponse(r, nil, "ok", requestID) + c.SendResponse(r, nil, "ok") } // rebuildVersion tmysql 需要指定特殊的version func rebuildVersion(versions []string) (rebuildVers []string) { if len(versions) == 0 { - return + return []string{""} } rebuildVers = make([]string, 0) for _, bVer := range versions { diff --git a/dbm-services/mysql/db-simulation/handler/syntax_rule.go b/dbm-services/mysql/db-simulation/handler/syntax_rule.go index 10b963c64d..dde76c25a7 100644 --- a/dbm-services/mysql/db-simulation/handler/syntax_rule.go +++ b/dbm-services/mysql/db-simulation/handler/syntax_rule.go @@ -20,6 +20,11 @@ import ( "dbm-services/mysql/db-simulation/model" ) +// ManageRuleHandler operation syntax rule +type ManageRuleHandler struct { + BaseHandler +} + // OptRuleParam 语法规则管理参数 type OptRuleParam struct { RuleID int `json:"rule_id" binding:"required"` @@ -27,32 +32,30 @@ type OptRuleParam struct { } // ManageRule 语法规则管理 -func ManageRule(c *gin.Context) { +func (m *ManageRuleHandler) ManageRule(r *gin.Context) { var param OptRuleParam - if err := c.ShouldBindJSON(¶m); err != nil { - logger.Error("ShouldBind failed %s", err) - SendResponse(c, err, "failed to deserialize parameters", "") + if m.Prepare(r, ¶m) != nil { return } result := model.DB.Model(&model.TbSyntaxRule{}).Where(&model.TbSyntaxRule{ID: param.RuleID}).Update("status", param.Status).Limit(1) if result.Error != nil { logger.Error("update rule status failed %s,affect rows %d", result.Error.Error(), result.RowsAffected) - SendResponse(c, result.Error, result.Error, "") + m.SendResponse(r, result.Error, result.Error) return } - SendResponse(c, nil, "ok", "") + m.SendResponse(r, nil, "ok") } // GetAllRule 获取所有权限规则 -func GetAllRule(c *gin.Context) { +func (m *ManageRuleHandler) GetAllRule(r *gin.Context) { var rs []model.TbSyntaxRule if err := model.DB.Find(&rs).Error; err != nil { logger.Error("query rules failed %s", err.Error()) - SendResponse(c, err, err.Error(), "") + m.SendResponse(r, err, err.Error()) return } - SendResponse(c, nil, rs, "") + m.SendResponse(r, nil, rs) } // UpdateRuleParam 更新语法规则参数 @@ -62,13 +65,10 @@ type UpdateRuleParam struct { } // UpdateRule update syntax rule -func UpdateRule(r *gin.Context) { - logger.Info("UpdateRule...") +func (m *ManageRuleHandler) UpdateRule(r *gin.Context) { var param UpdateRuleParam // 将request中的数据按照json格式直接解析到结构体中 - if err := r.ShouldBindJSON(¶m); err != nil { - logger.Error("ShouldBind failed %s", err) - SendResponse(r, err, nil, "") + if m.Prepare(r, ¶m) != nil { return } var tsr model.TbSyntaxRule @@ -80,52 +80,52 @@ func UpdateRule(r *gin.Context) { // 判断float64存的是整数 if v == float64(int64(v)) { if !(tsr.ItemType == "int") { - errReturn(r, &tsr) + m.errReturn(r, &tsr) return } updateTable(param.ID, int(v)) } else { err = errors.New("not int") logger.Error("Type of error: %s", err) - SendResponse(r, err, nil, "") + m.SendResponse(r, err, nil) return } case bool: if tsr.ItemType == "bool" { updateTable(param.ID, fmt.Sprintf("%t", v)) } else { - errReturn(r, &tsr) + m.errReturn(r, &tsr) return } case string: if tsr.ItemType == "string" { updateTable(param.ID, fmt.Sprintf("%+q", v)) } else { - errReturn(r, &tsr) + m.errReturn(r, &tsr) return } case []interface{}: if tsr.ItemType == "arry" { updateTable(param.ID, fmt.Sprintf("%+q", v)) } else { - errReturn(r, &tsr) + m.errReturn(r, &tsr) return } default: err = errors.New("illegal type") logger.Error("%s", err) - SendResponse(r, err, nil, "") + m.SendResponse(r, err, nil) return } - SendResponse(r, nil, "sucessed", "") + m.SendResponse(r, nil, "sucessed") } func updateTable(id int, item interface{}) { model.DB.Model(&model.TbSyntaxRule{}).Where("id", id).Update("item", item) } -func errReturn(r *gin.Context, tsr *model.TbSyntaxRule) { +func (m *ManageRuleHandler) errReturn(r *gin.Context, tsr *model.TbSyntaxRule) { err := fmt.Errorf("%s type required", tsr.ItemType) logger.Error("Item type error: %s", err) - SendResponse(r, err, nil, "") + m.SendResponse(r, err, nil) } diff --git a/dbm-services/mysql/db-simulation/router/router.go b/dbm-services/mysql/db-simulation/router/router.go index 4adbaccf28..c2b6adbfbc 100644 --- a/dbm-services/mysql/db-simulation/router/router.go +++ b/dbm-services/mysql/db-simulation/router/router.go @@ -28,28 +28,31 @@ func RegisterRouter(engine *gin.Engine) { }) engine.POST("/app/debug", TurnOnDebug) + simulationHandler := handler.SimulationHandler{} // query simulation task status info t := engine.Group("/simulation") - t.POST("/task/file", handler.QuerySimulationFileResult) - t.POST("/task", handler.QueryTask) + t.POST("/task/file", simulationHandler.QuerySimulationFileResult) + t.POST("/task", simulationHandler.QueryTask) // mysql g := engine.Group("/mysql") - g.POST("/simulation", handler.TendbSimulation) - g.POST("/task", handler.QueryTask) + g.POST("/simulation", simulationHandler.TendbSimulation) + g.POST("/task", simulationHandler.QueryTask) // syntax + syntaxHandler := handler.SyntaxHandler{} s := engine.Group("/syntax") - s.POST("/check/file", handler.SyntaxCheckFile) - s.POST("/check/sql", handler.SyntaxCheckSQL) - s.POST("/upload/ddl/tbls", handler.CreateAndUploadDDLTblListFile) + s.POST("/check/file", syntaxHandler.SyntaxCheckFile) + s.POST("/check/sql", syntaxHandler.SyntaxCheckSQL) + s.POST("/upload/ddl/tbls", syntaxHandler.CreateAndUploadDDLTblListFile) // rule + manageRuleHandler := handler.ManageRuleHandler{} r := engine.Group("/rule") - r.POST("/manage", handler.ManageRule) - r.GET("/getall", handler.GetAllRule) - r.POST("/update", handler.UpdateRule) + r.POST("/manage", manageRuleHandler.ManageRule) + r.GET("/getall", manageRuleHandler.GetAllRule) + r.POST("/update", manageRuleHandler.UpdateRule) // spider sp := engine.Group("/spider") - sp.POST("/simulation", handler.TendbClusterSimulation) - sp.POST("/create", handler.CreateTmpSpiderPodCluster) + sp.POST("/simulation", simulationHandler.TendbClusterSimulation) + sp.POST("/create", simulationHandler.CreateTmpSpiderPodCluster) } // TurnOnDebug turn on debug,not del simulation pod