Skip to content

Commit

Permalink
stage save
Browse files Browse the repository at this point in the history
  • Loading branch information
ymakedaq committed Oct 31, 2024
1 parent 6b4e370 commit d8c853a
Show file tree
Hide file tree
Showing 6 changed files with 203 additions and 211 deletions.
18 changes: 9 additions & 9 deletions dbm-services/mysql/db-simulation/app/syntax/syntax.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ type TmysqlParse struct {
bkRepoClient *bkrepo.BkRepoClient
TmysqlParseBinPath string
BaseWorkdir string
DbType string
mu sync.Mutex
}

Expand Down Expand Up @@ -111,7 +112,7 @@ type RiskInfo struct {
const DdlMapFileSubffix = ".tbl.map"

// Do 运行语法检查 For SQL 文件
func (tf *TmysqlParseFile) Do(dbtype string, versions []string) (result map[string]*CheckInfo, err error) {
func (tf *TmysqlParseFile) Do(versions []string) (result map[string]*CheckInfo, err error) {
logger.Info("doing....")
tf.result = make(map[string]*CheckInfo)
tf.tmpWorkdir = tf.BaseWorkdir
Expand All @@ -132,7 +133,7 @@ func (tf *TmysqlParseFile) Do(dbtype string, versions []string) (result map[stri

var errs []error
for _, version := range versions {
if err = tf.doSingleVersion(dbtype, version); err != nil {
if err = tf.doSingleVersion(version); err != nil {
logger.Error("when do [%s],syntax check,failed:%s", version, err.Error())
errs = append(errs, err)
}
Expand All @@ -141,7 +142,7 @@ func (tf *TmysqlParseFile) Do(dbtype string, versions []string) (result map[stri
return tf.result, errors.Join(errs...)
}

func (tf *TmysqlParseFile) doSingleVersion(dbtype string, mysqlVersion string) (err error) {
func (tf *TmysqlParseFile) doSingleVersion(mysqlVersion string) (err error) {
errChan := make(chan error, 1)
alreadExecutedSqlfileChan := make(chan string, len(tf.Param.FileNames))
signalChan := make(chan struct{})
Expand All @@ -157,7 +158,7 @@ func (tf *TmysqlParseFile) doSingleVersion(dbtype string, mysqlVersion string) (
// 对tmysqlparse的处理结果进行分析,为json文件,后面用到了rule
go func() {
logger.Info("start to analyze the parsing result")
if err = tf.AnalyzeParseResult(alreadExecutedSqlfileChan, mysqlVersion, dbtype); err != nil {
if err = tf.AnalyzeParseResult(alreadExecutedSqlfileChan, mysqlVersion); err != nil {
logger.Error("failed to analyze the parsing result:%s", err.Error())
errChan <- err
}
Expand Down Expand Up @@ -376,8 +377,7 @@ func (t *TmysqlParse) getAbsoutputfilePath(sqlFile, version string) string {
}

// AnalyzeParseResult 分析tmysqlparse 解析的结果
func (t *TmysqlParse) AnalyzeParseResult(alreadExecutedSqlfileCh chan string, mysqlVersion string,
dbtype string) (err error) {
func (t *TmysqlParse) AnalyzeParseResult(alreadExecutedSqlfileCh chan string, mysqlVersion string) (err error) {
var errs []error
c := make(chan struct{}, 10)
errChan := make(chan error, 5)
Expand All @@ -388,7 +388,7 @@ func (t *TmysqlParse) AnalyzeParseResult(alreadExecutedSqlfileCh chan string, my
c <- struct{}{}
go func(fileName string) {
defer wg.Done()
err = t.AnalyzeOne(fileName, mysqlVersion, dbtype)
err = t.AnalyzeOne(fileName, mysqlVersion)
if err != nil {
errChan <- err
}
Expand Down Expand Up @@ -512,7 +512,7 @@ func (t *TmysqlParse) getSyntaxErrorResult(res ParseLineQueryBase, mysqlVersion
}

// AnalyzeOne 分析单个文件
func (t *TmysqlParse) AnalyzeOne(inputfileName, mysqlVersion, dbtype string) (err error) {
func (t *TmysqlParse) AnalyzeOne(inputfileName, mysqlVersion string) (err error) {
var idx int
var syntaxFailInfos []FailedInfo
var buf []byte
Expand Down Expand Up @@ -570,7 +570,7 @@ func (t *TmysqlParse) AnalyzeOne(inputfileName, mysqlVersion, dbtype string) (er
continue
}
// tmysqlparse检查结果全部正确,开始判断语句是否符合定义的规则(即虽然语法正确,但语句可能是高危语句或禁用的命令)
switch dbtype {
switch t.DbType {
case app.MySQL:
checkResult.parseResult(R.CommandRule.HighRiskCommandRule, res, mysqlVersion)
checkResult.parseResult(R.CommandRule.BanCommandRule, res, mysqlVersion)
Expand Down
110 changes: 78 additions & 32 deletions dbm-services/mysql/db-simulation/handler/dbsimulation.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ package handler

import (
"fmt"
"net/http"
"strings"

"github.com/gin-gonic/gin"
Expand All @@ -22,58 +23,54 @@ import (
"dbm-services/mysql/db-simulation/model"
)

// SimulationHandler TODO
type SimulationHandler struct {
BaseHandler
}

// QueryFileResultParam 获取模拟执行文件的结果
type QueryFileResultParam struct {
RootID string `json:"root_id" binding:"required" `
VersionID string `json:"version_id" binding:"required"`
}

// QuerySimulationFileResult 查询模拟执行每个文件的执行结果
func QuerySimulationFileResult(r *gin.Context) {
func (s *SimulationHandler) QuerySimulationFileResult(r *gin.Context) {
var param QueryFileResultParam
if err := r.ShouldBindJSON(&param); err != nil {
logger.Error("ShouldBind failed %s", err)
SendResponse(r, err, "failed to deserialize parameters", "")
if s.Prepare(r, &param) != nil {
return
}
task_id := fmt.Sprintf("%s_%s", param.RootID, param.VersionID)
var data []model.TbSqlFileSimulationInfo
err := model.DB.Where("task_id = ? ", task_id).Find(&data).Error
if err != nil {
logger.Error("query file task result failed %v", err)
SendResponse(r, err, err.Error(), "")
s.SendResponse(r, err, err.Error())
return
}
SendResponse(r, nil, data, "")
s.SendResponse(r, nil, data)
}

// TendbSimulation Tendb simulation handler
func TendbSimulation(r *gin.Context) {
func (s *SimulationHandler) TendbSimulation(r *gin.Context) {
var param service.BaseParam
requestID := r.GetString("request_id")
if err := r.ShouldBindJSON(&param); err != nil {
logger.Error("ShouldBind failed %s", err)
SendResponse(r, err, "failed to deserialize parameters", requestID)
return
}
if requestID == "" {
SendResponse(r, fmt.Errorf("create request id failed"), nil, requestID)
if s.Prepare(r, &param) != nil {
return
}
version := param.MySQLVersion
img, err := getImgFromMySQLVersion(version)
if err != nil {
logger.Error("GetImgFromMySQLVersion %s failed:%s", version, err.Error())
SendResponse(r, err, nil, requestID)
s.SendResponse(r, err, nil)
return
}
if err := model.CreateTask(param.TaskId, requestID, version, param.Uid); err != nil {
if err := model.CreateTask(param.TaskId, s.RequestId, version, param.Uid); err != nil {
logger.Error("create task db record error %s", err.Error())
SendResponse(r, err, nil, requestID)
s.SendResponse(r, err, nil)
return
}
tsk := service.SimulationTask{
RequestId: requestID,
RequestId: s.RequestId,
DbPodSets: service.NewDbPodSets(),
BaseParam: &param,
Version: version,
Expand All @@ -83,40 +80,37 @@ func TendbSimulation(r *gin.Context) {
PodName: fmt.Sprintf("tendb-%s-%s", strings.ToLower(version),
replaceUnderSource(param.TaskId)),
Lables: map[string]string{"task_id": replaceUnderSource(param.TaskId),
"request_id": requestID},
"request_id": s.RequestId},
RootPwd: param.TaskId,
Args: param.BuildStartArgs(),
Charset: param.MySQLCharSet,
}
service.TaskChan <- tsk

SendResponse(r, nil, "request successful", requestID)
s.SendResponse(r, nil, "request successful")
}

// TendbClusterSimulation TendbCluster simulation handler
func TendbClusterSimulation(r *gin.Context) {
func (s *SimulationHandler) TendbClusterSimulation(r *gin.Context) {
var param service.SpiderSimulationExecParam
RequestID := r.GetString("request_id")
if err := r.ShouldBindJSON(&param); err != nil {
logger.Error("ShouldBind failed %s", err)
SendResponse(r, err, "failed to deserialize parameters", RequestID)
if s.Prepare(r, &param) != nil {
return
}
version := param.MySQLVersion
img, err := getImgFromMySQLVersion(version)
if err != nil {
logger.Error("GetImgFromMySQLVersion %s failed:%s", version, err.Error())
SendResponse(r, err, nil, RequestID)
s.SendResponse(r, err, nil)
return
}

if err := model.CreateTask(param.TaskId, RequestID, version, param.Uid); err != nil {
if err := model.CreateTask(param.TaskId, s.RequestId, version, param.Uid); err != nil {
logger.Error("create task db record error %s", err.Error())
SendResponse(r, err, nil, RequestID)
s.SendResponse(r, err, nil)
return
}
tsk := service.SimulationTask{
RequestId: RequestID,
RequestId: s.RequestId,
DbPodSets: service.NewDbPodSets(),
BaseParam: &param.BaseParam,
Version: version,
Expand All @@ -131,10 +125,62 @@ func TendbClusterSimulation(r *gin.Context) {
PodName: fmt.Sprintf("spider-%s-%s", strings.ToLower(version),
replaceUnderSource(param.TaskId)),
Lables: map[string]string{"task_id": replaceUnderSource(param.TaskId),
"request_id": RequestID},
"request_id": s.RequestId},
RootPwd: rootPwd,
Charset: param.MySQLCharSet,
}
service.SpiderTaskChan <- tsk
SendResponse(r, nil, "request successful", RequestID)
s.SendResponse(r, nil, "request successful")
}

// T 请求查询模拟执行整体任务的执行状态参数
type T struct {
TaskID string `json:"task_id"`
}

// QueryTask 查询模拟执行整体任务的执行状态
func (s *SimulationHandler) QueryTask(r *gin.Context) {
var param T
if s.Prepare(r, &param) != nil {
return
}
logger.Info("get task_id is %s", param.TaskID)
var tasks []model.TbSimulationTask
if err := model.DB.Where(&model.TbSimulationTask{TaskId: param.TaskID}).Find(&tasks).Error; err != nil {
logger.Error("query task failed %s", err.Error())
s.SendResponse(r, err, map[string]interface{}{"stderr": err.Error()})
return
}
allSuccessful := false
for _, task := range tasks {
if task.Phase != model.PhaseDone {
r.JSON(http.StatusOK, Response{
Code: 2,
Message: fmt.Sprintf("task current phase is %s", task.Phase),
Data: "",
})
return
}
switch task.Status {
case model.TaskFailed:
allSuccessful = false
s.SendResponse(r, fmt.Errorf("%s", task.SysErrMsg), map[string]interface{}{
"simulation_version": task.MySQLVersion,
"stdout": task.Stdout,
"stderr": task.Stderr,
"errmsg": fmt.Sprintf("the program has been run with abnormal status:%s", task.Status)})

case model.TaskSuccess:
allSuccessful = true
default:
allSuccessful = false
s.SendResponse(r, fmt.Errorf("unknown transition state"), map[string]interface{}{
"stdout": task.Stdout,
"stderr": task.Stderr,
"errmsg": fmt.Sprintf("the program has been run with abnormal status:%s", task.Status)})
}
}
if allSuccessful {
s.SendResponse(r, nil, map[string]interface{}{"stdout": "all ok", "stderr": "all ok"})
}
}
Loading

0 comments on commit d8c853a

Please sign in to comment.