Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue1811 #1813

Merged
merged 6 commits into from
Sep 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion sqle/api/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,9 @@ func StartApi(net *gracenet.Net, exitChan chan struct{}, config config.SqleConfi
v2Router.GET("/projects/:project_name/audit_plans/:audit_plan_name/reports/:audit_plan_report_id/sqls", v2.GetAuditPlanReportSQLs)
v1Router.GET("/projects/:project_name/audit_plans/:audit_plan_name/reports/:audit_plan_report_id/export", v1.ExportAuditPlanReportV1)

// sql audit record
v1Router.POST("/projects/:project_name/sql_audit_record", v1.CreateSQLAuditRecord)

// sql query
if err := cloudbeaver_wrapper.StartApp(e); err != nil {
log.Logger().Errorf("CloudBeaver wrapper configuration failed: %v", err)
Expand All @@ -369,7 +372,6 @@ func StartApi(net *gracenet.Net, exitChan chan struct{}, config config.SqleConfi

// sql audit
v1Router.POST("/sql_audit", v1.DirectAudit)
v2Router.POST("/sql_audit", v2.DirectAudit)
v1Router.POST("/audit_files", v1.DirectAuditFiles)
v1Router.GET("/sql_analysis", v1.DirectGetSQLAnalysis)

Expand Down
273 changes: 272 additions & 1 deletion sqle/api/controller/v1/sql_audit_record.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,25 @@
package v1

import (
"archive/zip"
"context"
"fmt"
"io"
"net/http"
"strings"
"time"

xmlParser "github.com/actiontech/mybatis-mapper-2-sql"
"github.com/actiontech/sqle/sqle/api/controller"
"github.com/actiontech/sqle/sqle/common"
driverV2 "github.com/actiontech/sqle/sqle/driver/v2"
"github.com/actiontech/sqle/sqle/log"
"github.com/actiontech/sqle/sqle/model"
"github.com/actiontech/sqle/sqle/server"
"github.com/actiontech/sqle/sqle/utils"

"github.com/labstack/echo/v4"
"github.com/pkg/errors"
)

type CreateSQLAuditRecordReqV1 struct {
Expand All @@ -23,6 +39,9 @@ type SQLAuditRecordResData struct {
Task *AuditTaskResV1 `json:"task"`
}

// 10M
var maxZipFileSize int64 = 1024 * 1024 * 10

// CreateSQLAuditRecord
// @Summary SQL审核
// @Id CreateSQLAuditRecordV1
Expand All @@ -46,7 +65,258 @@ type SQLAuditRecordResData struct {
// @Success 200 {object} v1.CreateSQLAuditRecordResV1
// @router /v1/projects/{project_name}/sql_audit_record [post]
func CreateSQLAuditRecord(c echo.Context) error {
return nil
req := new(CreateSQLAuditRecordReqV1)
if err := controller.BindAndValidateReq(c, req); err != nil {
return controller.JSONBaseErrorReq(c, err)
}
if req.DbType == "" && req.InstanceName == "" {
return controller.JSONBaseErrorReq(c, errors.New("db_type and instance_name can't both be empty"))
}
projectName := c.Param("project_name")

s := model.GetStorage()
project, exist, err := s.GetProjectByName(projectName)
if err != nil {
return controller.JSONBaseErrorReq(c, err)
}
if !exist {
return controller.JSONBaseErrorReq(c, ErrProjectNotExist(projectName))
}
if project.IsArchived() {
return controller.JSONBaseErrorReq(c, ErrProjectArchived)
}

user, err := controller.GetCurrentUser(c)
if err != nil {
return controller.JSONBaseErrorReq(c, err)
}
if err := CheckIsProjectMember(user.Name, project.Name); err != nil {
return controller.JSONBaseErrorReq(c, err)
}

var sqls string
var source string
if req.Sqls != "" {
sqls, source = req.Sqls, model.TaskSQLSourceFromFormData
} else {
sqls, source, err = getSQLFromFile(c)
if err != nil {
return controller.JSONBaseErrorReq(c, err)
}
}

var task *model.Task
if req.InstanceName != "" {
task, err = buildOnlineTaskForAudit(c, s, user.ID, req.InstanceName, req.InstanceSchema, projectName, source, sqls)
if err != nil {
return controller.JSONBaseErrorReq(c, err)
}
} else {
task, err = buildOfflineTaskForAudit(user.ID, req.DbType, source, sqls)
if err != nil {
return controller.JSONBaseErrorReq(c, err)
}
}

recordId, err := utils.GenUid()
if err != nil {
return controller.JSONBaseErrorReq(c, fmt.Errorf("generate audit record id failed: %v", err))
}
record := model.SQLAuditRecord{
ProjectId: project.ID,
CreatorID: user.ID,
AuditRecordID: recordId,
TaskId: task.ID,
Task: task,
}
if err := s.Save(&record); err != nil {
return controller.JSONBaseErrorReq(c, fmt.Errorf("save sql audit record failed: %v", err))
}

task, err = server.GetSqled().AddTaskWaitResult(fmt.Sprintf("%d", task.ID), server.ActionTypeAudit)
if err != nil {
return controller.JSONBaseErrorReq(c, err)
}
return c.JSON(http.StatusOK, &CreateSQLAuditRecordResV1{
BaseRes: controller.NewBaseReq(nil),
Data: &SQLAuditRecordResData{
Id: record.AuditRecordID,
Task: &AuditTaskResV1{
Id: task.ID,
InstanceName: task.InstanceName(),
InstanceDbType: task.DBType,
InstanceSchema: req.InstanceSchema,
AuditLevel: task.AuditLevel,
Score: task.Score,
PassRate: task.PassRate,
Status: task.Status,
SQLSource: task.SQLSource,
ExecStartTime: task.ExecStartAt,
ExecEndTime: task.ExecEndAt,
},
},
})
}

func buildOnlineTaskForAudit(c echo.Context, s *model.Storage, userId uint, instanceName, instanceSchema, projectName, sourceType, sqls string) (*model.Task, error) {
instance, exist, err := s.GetInstanceByNameAndProjectName(instanceName, projectName)
if err != nil {
return nil, err
}
if !exist {
return nil, ErrInstanceNoAccess
}
can, err := checkCurrentUserCanAccessInstance(c, instance)
if err != nil {
return nil, err
}
if !can {
return nil, ErrInstanceNoAccess
}

plugin, err := common.NewDriverManagerWithoutAudit(log.NewEntry(), instance, "")
if err != nil {
return nil, err
}
defer plugin.Close(context.TODO())

if err := plugin.Ping(context.TODO()); err != nil {
return nil, err
}

task := &model.Task{
Schema: instanceSchema,
InstanceId: instance.ID,
Instance: instance,
CreateUserId: userId,
ExecuteSQLs: []*model.ExecuteSQL{},
SQLSource: sourceType,
DBType: instance.DbType,
}
createAt := time.Now()
task.CreatedAt = createAt

nodes, err := plugin.Parse(context.TODO(), sqls)
if err != nil {
return nil, err
}
for n, node := range nodes {
task.ExecuteSQLs = append(task.ExecuteSQLs, &model.ExecuteSQL{
BaseSQL: model.BaseSQL{
Number: uint(n + 1),
Content: node.Text,
},
})
}
return task, nil
}

func buildOfflineTaskForAudit(userId uint, dbType, sourceType, sqls string) (*model.Task, error) {
task := &model.Task{
CreateUserId: userId,
ExecuteSQLs: []*model.ExecuteSQL{},
SQLSource: sourceType,
DBType: dbType,
}
var err error
var nodes []driverV2.Node
plugin, err := common.NewDriverManagerWithoutCfg(log.NewEntry(), dbType)
if err != nil {
return nil, fmt.Errorf("open plugin failed: %v", err)
}
defer plugin.Close(context.TODO())

nodes, err = plugin.Parse(context.TODO(), sqls)
if err != nil {
return nil, fmt.Errorf("parse sqls failed: %v", err)
}

createAt := time.Now()
task.CreatedAt = createAt

for n, node := range nodes {
task.ExecuteSQLs = append(task.ExecuteSQLs, &model.ExecuteSQL{
BaseSQL: model.BaseSQL{
Number: uint(n + 1),
Content: node.Text,
},
})
}
return task, nil
}

func getSqlsFromZip(c echo.Context) (sqls string, exist bool, err error) {
file, err := c.FormFile(InputZipFileName)
if err == http.ErrMissingFile {
return "", false, nil
}
if err != nil {
return "", false, err
}
f, err := file.Open()
if err != nil {
return "", false, err
}
defer f.Close()

currentPos, err := f.Seek(0, io.SeekEnd) // get size of zip file
if err != nil {
return "", false, err
}
size := currentPos + 1
if size > maxZipFileSize {
return "", false, fmt.Errorf("file can't be bigger than %vM", maxZipFileSize/1024/1024)
}
r, err := zip.NewReader(f, size)
if err != nil {
return "", false, err
}
var sqlBuffer strings.Builder
xmlContents := make([]string, len(r.File))
for i := range r.File {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

压缩包里有文件夹如何处理

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

r.File是已经遍历了所有的文件夹

srcFile := r.File[i]
if srcFile == nil {
continue
}
if !strings.HasSuffix(srcFile.Name, ".xml") && !strings.HasSuffix(srcFile.Name, ".sql") {
continue
}

r, err := srcFile.Open()
if err != nil {
return "", false, fmt.Errorf("open src file failed: %v", err)
}
content, err := io.ReadAll(r)
if err != nil {
return "", false, fmt.Errorf("read src file failed: %v", err)
}

if strings.HasSuffix(srcFile.Name, ".xml") {
xmlContents[i] = string(content)
} else if strings.HasSuffix(srcFile.Name, ".sql") {
if _, err = sqlBuffer.Write(content); err != nil {
return "", false, fmt.Errorf("gather sqls from sql file failed: %v", err)
}
}
}

// parse xml content
ss, err := xmlParser.ParseXMLs(xmlContents, false)
if err != nil {
return "", false, fmt.Errorf("parse sqls from xml failed: %v", err)
}
for i := range ss {
if !strings.HasSuffix(sqlBuffer.String(), ";") {
if _, err = sqlBuffer.WriteString(";"); err != nil {
return "", false, fmt.Errorf("gather sqls from xml file failed: %v", err)
}
}
if _, err = sqlBuffer.WriteString(ss[i]); err != nil {
return "", false, fmt.Errorf("gather sqls from xml file failed: %v", err)
}
}

return sqlBuffer.String(), true, nil
}

type UpdateSQLAuditRecordReqV1 struct {
Expand All @@ -61,6 +331,7 @@ type UpdateSQLAuditRecordReqV1 struct {
// @Id updateSQLAuditRecordV1
// @Tags sql_audit_record
// @Security ApiKeyAuth
// @Param project_name path string true "project name"
// @Param param body v1.UpdateSQLAuditRecordReqV1 true "update SQL audit record"
// @Success 200 {object} controller.BaseRes
// @router /v1/projects/{project_name}/sql_audit_record/{sql_audit_record_id} [patch]
Expand Down
Loading
Loading