Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: 新增MySQL规则 #2753

Merged
merged 11 commits into from
Nov 25, 2024
2 changes: 1 addition & 1 deletion sqle/driver/mysql/audit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -614,7 +614,7 @@ ALTER TABLE exist_db.exist_tb_1 Add index idx_2 (id,id);
ALTER TABLE exist_db.exist_tb_1 Add index (id,id);
`,
newTestResult().add(driver.RuleLevelError, DuplicateIndexedColumnMessage, "(匿名)",
"id").addResult(rulepkg.DDLCheckIndexPrefix, "idx_"),
"id").addResult(rulepkg.DDLCheckIndexPrefix, "idx_").addResult(rulepkg.DDLCheckIndexNameExisted),
)
}

Expand Down
334 changes: 321 additions & 13 deletions sqle/driver/mysql/rule/rule.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ const (
DDLNotAllowRenaming = "ddl_not_allow_renaming"
DDLCheckObjectNameIsUpperAndLowerLetterMixed = "ddl_check_object_name_is_upper_and_lower_letter_mixed"
DDLCheckFieldNotNUllMustContainDefaultValue = "ddl_check_field_not_null_must_contain_default_value"
DDLCheckIndexNameExisted = "ddl_check_index_name_existed"
DDLCheckTableRowLength = "ddl_check_table_row_length"
)

// inspector DML rules
Expand Down Expand Up @@ -169,6 +171,8 @@ const (
ConfigOptimizeIndexEnabled = "optimize_index_enabled"
ConfigDMLExplainPreCheckEnable = "dml_enable_explain_pre_check"
ConfigSQLIsExecuted = "sql_is_executed"
ConfigAvoidSet = "config_avoid_set"
ConfigCheckEventScheduler = "config_check_event_scheduler"
)

type RuleHandlerInput struct {
Expand Down Expand Up @@ -1881,6 +1885,62 @@ var RuleHandlers = []RuleHandler{
Message: "禁止使用rename或change对表名字段名进行修改",
Func: ddlNotAllowRenaming,
},
{
Rule: driver.Rule{
Name: DDLCheckIndexNameExisted,
Desc: "索引必须设置索引名",
Annotation: "普通索引定义索引名,且名称遵循固定的命名规范、避免特殊字符的使用,可以提高代码的可读性、可维护性,并减少潜在的兼容性和语法问题。",
Level: driver.RuleLevelNormal,
Category: RuleTypeNamingConvention,
},
AllowOffline: true,
Message: "索引必须设置索引名",
Func: checkIndexNameExisted,
},
{
Rule: driver.Rule{
Name: DDLCheckTableRowLength,
Desc: "表设计做到行不跨页",
Annotation: "在表设计时,应该尽量确保每一行数据都不会跨越数据页(Page)的边界,以提高数据的读取和写入性能,减少物理I/O操作,并优化存储空间的利用率。",
Level: driver.RuleLevelWarn,
Category: RuleTypeDDLConvention,
Params: params.Params{
&params.Param{
Key: DefaultSingleParamKeyName,
Value: "65535",
Desc: "最大行长 (byte)",
Type: params.ParamTypeInt,
},
},
},
AllowOffline: true,
Message: "表设计做到行不跨页",
Func: checkTableRowLength,
},
{
Rule: driver.Rule{
Name: ConfigAvoidSet,
Desc: "不允许使用SET操作",
Annotation: "禁止使用SET命令来修改MySQL的系统参数,以确保数据库的稳定性、一致性和安全性。",
Level: driver.RuleLevelError,
Category: RuleTypeGlobalConfig,
},
AllowOffline: true,
Message: "不允许使用SET操作",
Func: avoidSet,
},
{
Rule: driver.Rule{
Name: ConfigCheckEventScheduler,
Desc: "禁止使用event scheduler",
Annotation: "禁用MySQL的事件调度器(event_scheduler),以提高数据库的安全性、稳定性和可控性,避免非预期的事件执行对系统造成影响。",
Level: driver.RuleLevelError,
Category: RuleTypeGlobalConfig,
},
AllowOffline: true,
Message: "禁止使用event schedule",
Func: checkEventScheduler,
},
}

func checkFieldNotNUllMustContainDefaultValue(input *RuleHandlerInput) error {
Expand Down Expand Up @@ -4352,12 +4412,13 @@ var createTriggerReg1 = regexp.MustCompile(`(?i)create[\s]+trigger[\s]+[\S\s]+be
var createTriggerReg2 = regexp.MustCompile(`(?i)create[\s]+[\s\S]+[\s]+trigger[\s]+[\S\s]+before|after`)

// CREATE
// [DEFINER = user]
// TRIGGER trigger_name
// trigger_time trigger_event
// ON tbl_name FOR EACH ROW
// [trigger_order]
// trigger_body
//
// [DEFINER = user]
// TRIGGER trigger_name
// trigger_time trigger_event
// ON tbl_name FOR EACH ROW
// [trigger_order]
// trigger_body
//
// ref:https://dev.mysql.com/doc/refman/8.0/en/create-trigger.html
//
Expand All @@ -4378,10 +4439,11 @@ var createFunctionReg1 = regexp.MustCompile(`(?i)create[\s]+function[\s]+[\S\s]+
var createFunctionReg2 = regexp.MustCompile(`(?i)create[\s]+[\s\S]+[\s]+function[\s]+[\S\s]+returns`)

// CREATE
// [DEFINER = user]
// FUNCTION sp_name ([func_parameter[,...]])
// RETURNS type
// [characteristic ...] routine_body
//
// [DEFINER = user]
// FUNCTION sp_name ([func_parameter[,...]])
// RETURNS type
// [characteristic ...] routine_body
//
// ref: https://dev.mysql.com/doc/refman/5.7/en/create-procedure.html
// For now, we do character matching for CREATE FUNCTION Statement. Maybe we need
Expand All @@ -4401,9 +4463,10 @@ var createProcedureReg1 = regexp.MustCompile(`(?i)create[\s]+procedure[\s]+[\S\s
var createProcedureReg2 = regexp.MustCompile(`(?i)create[\s]+[\s\S]+[\s]+procedure[\s]+[\S\s]+`)

// CREATE
// [DEFINER = user]
// PROCEDURE sp_name ([proc_parameter[,...]])
// [characteristic ...] routine_body
//
// [DEFINER = user]
// PROCEDURE sp_name ([proc_parameter[,...]])
// [characteristic ...] routine_body
//
// ref: https://dev.mysql.com/doc/refman/8.0/en/create-procedure.html
// For now, we do character matching for CREATE PROCEDURE Statement. Maybe we need
Expand Down Expand Up @@ -5097,3 +5160,248 @@ func ddlNotAllowRenaming(input *RuleHandlerInput) error {
}
return nil
}

func checkIndexNameExisted(input *RuleHandlerInput) error {
indexNameNotExisted := false
switch stmt := input.Node.(type) {
case *ast.CreateTableStmt:
for _, constraint := range stmt.Constraints {
switch constraint.Tp {
case ast.ConstraintIndex, ast.ConstraintUniqIndex, ast.ConstraintKey, ast.ConstraintUniqKey:
if constraint.Name == "" {
indexNameNotExisted = true
winfredLIN marked this conversation as resolved.
Show resolved Hide resolved
break
}
default:
return nil
}
}
case *ast.AlterTableStmt:
for _, spec := range stmt.Specs {
if spec.Tp == ast.AlterTableAddConstraint && IsIndexConstraint(spec.Constraint.Tp) {
// 遍历Keys
if spec.Constraint.Name == "" {
indexNameNotExisted = true
winfredLIN marked this conversation as resolved.
Show resolved Hide resolved
break
}
}
}
default:
return nil
}
if indexNameNotExisted {
addResult(input.Res, input.Rule, DDLCheckIndexNameExisted)
}
return nil
}

func IsIndexConstraint(constraintType ast.ConstraintType) bool {
return constraintType == ast.ConstraintIndex || constraintType == ast.ConstraintUniqIndex || constraintType == ast.ConstraintKey || constraintType == ast.ConstraintUniqKey
}

func checkTableRowLength(input *RuleHandlerInput) error {
var rowLengthLimit = input.Rule.Params.GetParam(DefaultSingleParamKeyName).Int()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

若表实际的页长小于人工设置的页长,则该规则失效

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

补充进文档

rowLength := 0
switch stmt := input.Node.(type) {
case *ast.CreateTableStmt:
charsetNum := GetTableCharsetNum(stmt.Options)
for _, col := range stmt.Cols {
colCharsetNum := MappingCharsetLength(col.Tp.Charset)
// 可能会设置列级别的字符串
if charsetNum != colCharsetNum {
charsetNum = colCharsetNum
}
oneColumnLength := ComputeOneColumnLength(col, charsetNum)
rowLength += oneColumnLength
}
case *ast.AlterTableStmt:
// 获取在线表信息
tableStmt, tableExist, err := input.Ctx.GetCreateTableStmt(stmt.Table)
if !tableExist || err != nil {
return err
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果table不存在,那error会不等于nil吗?
这是两种情况,都用返回error处理可以吗

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果表不存在,返回error,会不会导致审核阻塞?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

嗯 这里有问题

}
charsetNum := GetTableCharsetNum(tableStmt.Options)
columnLengthMap := make(map[string]int, len(tableStmt.Cols))
// 计算原表的长度
for _, col := range tableStmt.Cols {
colCharsetNum := MappingCharsetLength(col.Tp.Charset)
if charsetNum != colCharsetNum {
charsetNum = colCharsetNum
}
oneColumnLength := ComputeOneColumnLength(col, charsetNum)
rowLength += oneColumnLength
columnLengthMap[col.Name.String()] = oneColumnLength
}
// 计算alter语句修改列之后的长度
for _, alteredSpec := range stmt.Specs {
for _, alterCol := range alteredSpec.NewColumns {
if alterCol.Tp == nil {
// 不是对于列类型相关的变更
continue
}
// 可能会设置列级别的字符串
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不理解这个注释的含义
注释+代码,需要能够说明这里为什么这么做,做了什么

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

增加注释

colCharsetNum := MappingCharsetLength(alterCol.Tp.Charset)
if charsetNum != colCharsetNum {
charsetNum = colCharsetNum
}
if alteredSpec.Tp == ast.AlterTableAddColumns {
rowLength += ComputeOneColumnLength(alterCol, charsetNum)
}
if alteredSpec.Tp == ast.AlterTableModifyColumn {
// 如果是修改某个字段,减去原来字段的长度,使用新的字段长度
rowLength -= columnLengthMap[alterCol.Name.String()]
rowLength += ComputeOneColumnLength(alterCol, charsetNum)
}
}
}
default:
return nil
}
if rowLength > rowLengthLimit {
addResult(input.Res, input.Rule, DDLCheckTableRowLength)
}
return nil
}

func GetTableCharsetNum(options []*ast.TableOption) int {
charsetNum := 4
for _, opt := range options {
if opt.Tp == ast.TableOptionCharset {
charsetNum = MappingCharsetLength(opt.StrValue)
}
}
return charsetNum
}

// ComputeOneColumnLength 计算一个列的长度
Copy link
Collaborator

@winfredLIN winfredLIN Nov 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

贴一下长度计算的依据:

  1. 官方文档
  2. 用ai使用mysql的风格格式绘制一个表格放在这里

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

增加了注释

func ComputeOneColumnLength(columnDef *ast.ColumnDef, charsetNum int) int {
oneColumnLength := 0
switch columnDef.Tp.Tp {
case mysql.TypeVarchar:
// 0~255 长度需要一个字节存储长度
lLength := 1
if columnDef.Tp.Flen > 255 {
// > 255 需要两个字节来存储长度
lLength = 2
}
// length * charsetNum + notNull + lLength
oneColumnLength = columnDef.Tp.Flen*charsetNum + OptionNotNullLength(columnDef.Options) + lLength
case mysql.TypeString:
oneColumnLength = columnDef.Tp.Flen*charsetNum + OptionNotNullLength(columnDef.Options)
case mysql.TypeYear, mysql.TypeTiny:
oneColumnLength = 1 + OptionNotNullLength(columnDef.Options)
case mysql.TypeDate, mysql.TypeInt24:
// DATE MEDIUMINT
oneColumnLength = 3 + OptionNotNullLength(columnDef.Options)
case mysql.TypeDuration:
// TIME
oneColumnLength = 3 + OptionNotNullLength(columnDef.Options) + typeTimePrecisionLength(columnDef.Tp.Decimal)
case mysql.TypeDatetime:
oneColumnLength = 5 + OptionNotNullLength(columnDef.Options) + typeTimePrecisionLength(columnDef.Tp.Decimal)
case mysql.TypeTimestamp:
oneColumnLength = 4 + OptionNotNullLength(columnDef.Options) + typeTimePrecisionLength(columnDef.Tp.Decimal)
case mysql.TypeShort:
// SMALLINT
oneColumnLength = 2 + OptionNotNullLength(columnDef.Options)
case mysql.TypeLong, mysql.TypeFloat:
// INT FLOAT
oneColumnLength = 4 + OptionNotNullLength(columnDef.Options)
case mysql.TypeLonglong:
// BIGINT
oneColumnLength = 8 + OptionNotNullLength(columnDef.Options)
case mysql.TypeDouble:
// BIGINT DOUBLE REAL
oneColumnLength = 8 + OptionNotNullLength(columnDef.Options)
case mysql.TypeNewDecimal:
// 整数部分
partition := (columnDef.Tp.Flen - columnDef.Tp.Decimal) / 9
oneColumnLength += partition * 4
oneColumnLength += decimalLeftoverLength((columnDef.Tp.Flen - columnDef.Tp.Decimal) % 9)
// 小数部分
decimalPartition := columnDef.Tp.Decimal / 9
oneColumnLength += decimalPartition * 4
oneColumnLength += decimalLeftoverLength((columnDef.Tp.Decimal) % 9)
}
return oneColumnLength
}

// typeTimePrecisionLength 时间类型会根据精度的不同有不同的存储大小
// decimal bytes
// 0 0
// 1,2 1
// 3,4 2
// 5,6 3
func typeTimePrecisionLength(decimal int) int {
if decimal < 0 {
return 0
} else if decimal < 3 {
return 1
} else if decimal < 5 {
return 2
} else if decimal < 7 {
return 3
}
return 0
}

// decimalLeftoverLength decimal被9整除后的部分,根据位数使用相印字节数
// leftover bytes
// 1-2 1
// 3-4 2
// 5-6 3
// 7-9 4
func decimalLeftoverLength(leftover int) int {
if leftover < 0 {
return 0
} else if leftover < 3 {
return 1
} else if leftover < 5 {
return 2
} else if leftover < 7 {
return 3
} else if leftover < 10 {
return 4
}
return 0
}

// OptionNotNullLength 当有not null 约束时会占用一个字节
func OptionNotNullLength(columnOptions []*ast.ColumnOption) int {
for _, option := range columnOptions {
if option.Tp == ast.ColumnOptionNotNull {
return 0
}
}
return 1
}

// MappingCharsetLength 不同的字符集会用不同数量表示一个字符
func MappingCharsetLength(charset string) int {
charNum := 4
switch charset {
case "utf8mb4", "utf16", "utf16le", "utf32":
charNum = 4
case "utf8":
charNum = 3
default:
charNum = 4
}
return charNum
}

func avoidSet(input *RuleHandlerInput) error {
switch input.Node.(type) {
case *ast.SetStmt:
addResult(input.Res, input.Rule, ConfigAvoidSet)
default:
return nil
}
return nil
}

func checkEventScheduler(input *RuleHandlerInput) error {
if utils.IsOpenEventScheduler(input.Node.Text()) {
addResult(input.Res, input.Rule, input.Rule.Name)
}
return nil
}
Loading
Loading