From 5e64b66b0afb1c2a2d09f6416df5ce1de27d5b7f Mon Sep 17 00:00:00 2001 From: xhd2015 Date: Sun, 7 Apr 2024 22:09:33 +0800 Subject: [PATCH] make patch consts work for go1.22 --- cmd/xgo/edit.go | 13 + cmd/xgo/exec_tool/debug.go | 2 +- cmd/xgo/main.go | 57 ++-- cmd/xgo/patch_compiler.go | 10 + cmd/xgo/patch_compiler_ast_type_check.go | 218 ++++++++++++++ cmd/xgo/patch_support.go | 61 ++++ cmd/xgo/version.go | 4 +- patch/pkgdata/pkgdata.go | 120 ++++++-- patch/syntax/call_expr_go1.17_18_19.go | 10 +- patch/syntax/call_expr_go1.20.go | 2 +- patch/syntax/syntax.go | 51 ++++ patch/syntax/vars.go | 355 +++++++++++++++++++++-- runtime/core/version.go | 4 +- runtime/test/debug/debug_test.go | 25 +- runtime/test/patch/patch_const_test.go | 231 +++++++++++++++ runtime/test/patch/patch_var_test.go | 84 +++--- runtime/test/patch/sub/sub.go | 2 + support/strutil/strutil.go | 52 ++-- 18 files changed, 1136 insertions(+), 165 deletions(-) create mode 100644 cmd/xgo/patch_compiler_ast_type_check.go create mode 100644 cmd/xgo/patch_support.go create mode 100644 runtime/test/patch/patch_const_test.go diff --git a/cmd/xgo/edit.go b/cmd/xgo/edit.go index 88bd5129..f5fe15c5 100644 --- a/cmd/xgo/edit.go +++ b/cmd/xgo/edit.go @@ -48,6 +48,18 @@ func addContentAt(content string, beginMark string, endMark string, seq []string return insertContentNoDuplicate(content, beginMark, endMark, idx, addContent) } +func addContentAtIndex(content string, beginMark string, endMark string, seq []string, i int, before bool, addContent string) string { + offset, endOffset := strutil.SeqenceOffset(content, seq, i, before) + if offset < 0 { + panic(fmt.Errorf("sequence missing: %v", seq)) + } + anotherOff, _ := strutil.SeqenceOffset(content[endOffset:], seq, i, false) + if anotherOff >= 0 { + panic(fmt.Errorf("sequence duplicate: %v", seq)) + } + return insertContentNoDuplicate(content, beginMark, endMark, offset, addContent) +} + func replaceContentAfter(content string, beginMark string, endMark string, seq []string, target string, replaceContent string) string { if replaceContent == "" { return content @@ -72,6 +84,7 @@ func replaceContentAfter(content string, beginMark string, endMark string, seq [ } // signature example: /**/ {content} /**/ +// insert content at index func insertContentNoDuplicate(content string, beginMark string, endMark string, idx int, insertContent string) string { if insertContent == "" { return content diff --git a/cmd/xgo/exec_tool/debug.go b/cmd/xgo/exec_tool/debug.go index 4799cbc8..5eef5144 100644 --- a/cmd/xgo/exec_tool/debug.go +++ b/cmd/xgo/exec_tool/debug.go @@ -102,7 +102,7 @@ func (c *VscodeDebugConfig) ToMap() (map[string]interface{}, error) { } const vscodeRemoteDebug = `{ - "name": "dlv remoe localhost:2345", + "name": "dlv remote localhost:2345", "type": "go", "request": "attach", "mode": "remote", diff --git a/cmd/xgo/main.go b/cmd/xgo/main.go index e795a073..31b3d018 100644 --- a/cmd/xgo/main.go +++ b/cmd/xgo/main.go @@ -160,7 +160,7 @@ func handleBuild(cmd string, args []string) error { logStartup() } - goroot, err := checkGoroot(withGoroot) + goroot, err := checkGoroot(projectDir, withGoroot) if err != nil { return err } @@ -526,29 +526,46 @@ func getXgoHome(xgoHome string) (string, error) { return absHome, nil } -func checkGoroot(goroot string) (string, error) { +func getGoEnvRoot(dir string) (string, error) { + goroot, err := cmd.Dir(dir).Output("go", "env", "GOROOT") + if err != nil { + return "", err + } + // remove special characters from output + goroot = strings.ReplaceAll(goroot, "\n", "") + goroot = strings.ReplaceAll(goroot, "\r", "") + return goroot, nil +} + +func checkGoroot(dir string, goroot string) (string, error) { if goroot == "" { + // use env first because dir will affect the actual + // go version used + // because with new go toolchain mechanism, even with: + // which go -> /Users/xhd2015/installed/go1.21.7/bin/go + // the actual go still points to go1.22.0 + // go version -> + // go tool chain: https://go.dev/doc/toolchain + // GOTOOLCHAIN=auto + envGoroot, envErr := getGoEnvRoot(dir) + if envErr == nil && envGoroot != "" { + return envGoroot, nil + } goroot = runtime.GOROOT() - if goroot == "" { - envGoroot, err := cmd.Output("go", "env", "GOROOT") - if err != nil { - var errMsg string - if e, ok := err.(*exec.ExitError); ok { - errMsg = string(e.Stderr) - } else { - errMsg = err.Error() - } - return "", fmt.Errorf("requires GOROOT or --with-goroot: go env GOROOT: %v", errMsg) - } - // remove special characters from output - envGoroot = strings.ReplaceAll(envGoroot, "\n", "") - envGoroot = strings.ReplaceAll(envGoroot, "\r", "") - goroot = envGoroot + if goroot != "" { + return goroot, nil } - if goroot == "" { - return "", fmt.Errorf("requires GOROOT or --with-goroot") + + if envErr != nil { + var errMsg string + if e, ok := envErr.(*exec.ExitError); ok { + errMsg = string(e.Stderr) + } else { + errMsg = envErr.Error() + } + return "", fmt.Errorf("requires GOROOT or --with-goroot: go env GOROOT: %v", errMsg) } - return goroot, nil + return "", fmt.Errorf("requires GOROOT or --with-goroot") } _, err := os.Stat(goroot) if err == nil { diff --git a/cmd/xgo/patch_compiler.go b/cmd/xgo/patch_compiler.go index a56fe912..5e921bd5 100644 --- a/cmd/xgo/patch_compiler.go +++ b/cmd/xgo/patch_compiler.go @@ -39,6 +39,12 @@ var compilerFiles = []_FilePath{ compilerRuntimeDefFile, compilerRuntimeDefFile18, compilerRuntimeDefFile16, + + type2ExprPatch.FilePath, + type2AssignmentsPatch.FilePath, + syntaxWalkPatch.FilePath, + noderWriterPatch.FilePath, + syntaxExtra, } func patchCompiler(origGoroot string, goroot string, goVersion *goinfo.GoVersion, xgoSrc string, forceReset bool, syncWithLink bool) error { @@ -94,6 +100,10 @@ func patchCompilerInternal(goroot string, goVersion *goinfo.GoVersion) error { if err != nil { return fmt.Errorf("patching gc main:%w", err) } + err = patchCompilerAstTypeCheck(goroot) + if err != nil { + return fmt.Errorf("patch ast type check:%w", err) + } return nil } diff --git a/cmd/xgo/patch_compiler_ast_type_check.go b/cmd/xgo/patch_compiler_ast_type_check.go new file mode 100644 index 00000000..94c77d2b --- /dev/null +++ b/cmd/xgo/patch_compiler_ast_type_check.go @@ -0,0 +1,218 @@ +package main + +import "os" + +const convertXY = ` +if xgoConv, ok := x.expr.(*syntax.XgoSimpleConvert); ok { + var isConst bool + switch y.expr.(type) { + case *syntax.XgoSimpleConvert,*syntax.BasicLit: + isConst=true + } + if !isConst{ + t := y.typ + + callExpr := xgoConv.X.(*syntax.CallExpr) + ct := callExpr.GetTypeInfo() + ct.Type = t + callExpr.SetTypeInfo(ct) + + name := callExpr.Fun.(*syntax.Name) + nt := name.GetTypeInfo() + nt.Type = t + name.SetTypeInfo(nt) + name.Value = t.String() + + xt := xgoConv.GetTypeInfo() + xt.Type = t + xgoConv.SetTypeInfo(xt) + + x.typ = t + } +}else if xgoConv,ok := y.expr.(*syntax.XgoSimpleConvert);ok { + var isConst bool + switch x.expr.(type) { + case *syntax.XgoSimpleConvert,*syntax.BasicLit: + isConst=true + } + if !isConst{ + t := x.typ + callExpr := xgoConv.X.(*syntax.CallExpr) + ct := callExpr.GetTypeInfo() + ct.Type = t + callExpr.SetTypeInfo(ct) + + name := callExpr.Fun.(*syntax.Name) + nt := name.GetTypeInfo() + nt.Type = t + name.SetTypeInfo(nt) + name.Value = t.String() + + xt := xgoConv.GetTypeInfo() + xt.Type = t + xgoConv.SetTypeInfo(xt) + + y.typ = t + } +} +` + +var type2ExprPatch = &FilePatch{ + FilePath: _FilePath{"src", "cmd", "compile", "internal", "types2", "expr.go"}, + Patches: []*Patch{ + { + Mark: "type2_check_xgo_simple_convert", + InsertIndex: 5, + InsertBefore: true, + Anchors: []string{ + `(check *Checker) exprInternal`, + "\n", + `default:`, + `case *syntax.Operation:`, + `case *syntax.KeyValueExpr:`, + `default:`, + "\n", + }, + Content: ` +case *syntax.XgoSimpleConvert: + kind := check.rawExpr(nil, x, e.X, nil, false) + x.expr = e + return kind +`, + }, + { + Mark: "type2_match_type_xgo_simple_convert", + InsertIndex: 1, + Anchors: []string{ + `func (check *Checker) matchTypes(x, y *operand) {`, + "\n", + }, + Content: convertXY, + }, + { + Mark: "type2_comparison_xgo_simple_convert", + InsertIndex: 2, + Anchors: []string{ + `func (check *Checker) comparison(x, y *operand`, + "{", + "\n", + }, + Content: convertXY, + }, + }, +} + +var type2AssignmentsPatch = &FilePatch{ + FilePath: _FilePath{"src", "cmd", "compile", "internal", "types2", "assignments.go"}, + Patches: []*Patch{ + { + Mark: "type2_assignment_rewrite_xgo_simple_convert", + InsertIndex: 1, + InsertBefore: true, + Anchors: []string{ + `func (check *Checker) assignment(`, + `switch x.mode {`, + }, + Content: ` + if xgoConv, ok := x.expr.(*syntax.XgoSimpleConvert); ok { + callExpr := xgoConv.X.(*syntax.CallExpr) + funName := callExpr.Fun.(*syntax.Name) + t := funName.GetTypeInfo() + t.Type = T + funName.SetTypeInfo(t) + funName.Value = T.String() + + ct := callExpr.GetTypeInfo() + ct.Type = T + callExpr.SetTypeInfo(ct) + + x.expr = callExpr + x.typ = T + xt := xgoConv.GetTypeInfo() + xt.Type = T + xgoConv.SetTypeInfo(xt) + } + `, + }, + }, +} + +var syntaxWalkPatch = &FilePatch{ + FilePath: _FilePath{"src", "cmd", "compile", "internal", "syntax", "walk.go"}, + Patches: []*Patch{ + { + Mark: "syntax_walk_xgo_simple_convert", + InsertIndex: 4, + InsertBefore: true, + Anchors: []string{ + `func (w walker) node(n Node) {`, + `case *RangeClause:`, + `case *CaseClause:`, + `case *CommClause:`, + `default`, + }, + Content: ` + case *XgoSimpleConvert: + w.node(n.X) + `, + }, + }, +} + +var noderWriterPatch = &FilePatch{ + FilePath: _FilePath{"src", "cmd", "compile", "internal", "noder", "writer.go"}, + Patches: []*Patch{ + { + Mark: "noder_write_xgo_simple_convert", + InsertIndex: 3, + InsertBefore: true, + Anchors: []string{ + `func (w *writer) expr(expr syntax.Expr) {`, + `switch expr := expr.(type) {`, + `case *syntax.Operation:`, + `case *syntax.CallExpr:`, + }, + Content: ` + case *syntax.XgoSimpleConvert: + w.expr(expr.X) + `, + }, + }, +} +var syntaxExtra = _FilePath{"src", "cmd", "compile", "internal", "syntax", "xgo_extra.go"} + +const syntaxExtraPatch = ` +package syntax + +// helper: convert anything to which +// the type is expected +type XgoSimpleConvert struct { + X Expr + expr +} +` + +func patchCompilerAstTypeCheck(goroot string) error { + err := type2ExprPatch.Apply(goroot) + if err != nil { + return err + } + err = type2AssignmentsPatch.Apply(goroot) + if err != nil { + return err + } + err = syntaxWalkPatch.Apply(goroot) + if err != nil { + return err + } + err = noderWriterPatch.Apply(goroot) + if err != nil { + return err + } + syntaxExtraFile := syntaxExtra.Join(goroot) + err = os.WriteFile(syntaxExtraFile, []byte(syntaxExtraPatch), 0755) + if err != nil { + return err + } + return nil +} diff --git a/cmd/xgo/patch_support.go b/cmd/xgo/patch_support.go new file mode 100644 index 00000000..9fc3edd9 --- /dev/null +++ b/cmd/xgo/patch_support.go @@ -0,0 +1,61 @@ +package main + +import "fmt" + +type FilePatch struct { + FilePath _FilePath + Patches []*Patch +} + +type Patch struct { + Mark string + + InsertIndex int // insert before which anchor, 0: insert at head, -1: tail + InsertBefore bool + // anchor should be unique + // appears exactly once + Anchors []string + + Content string +} + +func (c *FilePatch) Apply(goroot string) error { + if goroot == "" { + return fmt.Errorf("requires goroot") + } + if len(c.FilePath) == 0 { + return fmt.Errorf("invalid file path") + } + + if len(c.Patches) == 0 { + return nil + } + file := c.FilePath.Join(goroot) + + // validate patch mark + seenMark := make(map[string]bool, len(c.Patches)) + for i, patch := range c.Patches { + if patch.Content == "" { + return fmt.Errorf("empty content at: %d", i) + } + if patch.Mark == "" { + return fmt.Errorf("empty mark at %d", i) + } + if len(patch.Anchors) == 0 { + return fmt.Errorf("empty anchors at: %d", i) + } + if _, ok := seenMark[patch.Mark]; ok { + return fmt.Errorf("duplicate mark: %s", patch.Mark) + } + seenMark[patch.Mark] = true + } + + return editFile(file, func(content string) (string, error) { + for _, patch := range c.Patches { + beginMark := fmt.Sprintf("/**/", patch.Mark) + endMark := fmt.Sprintf("/**/", patch.Mark) + content = addContentAtIndex(content, beginMark, endMark, patch.Anchors, patch.InsertIndex, patch.InsertBefore, patch.Content) + } + return content, nil + }) +} diff --git a/cmd/xgo/version.go b/cmd/xgo/version.go index 78d9f0f6..83a09036 100644 --- a/cmd/xgo/version.go +++ b/cmd/xgo/version.go @@ -3,8 +3,8 @@ package main import "fmt" const VERSION = "1.0.19" -const REVISION = "7b94c97a1b438510f46aa86fca2b626396b9d2e1+1" -const NUMBER = 166 +const REVISION = "7fd3f2b160c52c890d248efceb23a5e070156e0c+1" +const NUMBER = 167 func getRevision() string { revSuffix := "" diff --git a/patch/pkgdata/pkgdata.go b/patch/pkgdata/pkgdata.go index ad7d2252..f2e701d4 100644 --- a/patch/pkgdata/pkgdata.go +++ b/patch/pkgdata/pkgdata.go @@ -12,10 +12,15 @@ import ( type PackageData struct { Vars map[string]bool - Consts map[string]bool + Consts map[string]*ConstInfo Funcs map[string]bool } +type ConstInfo struct { + Type string // t=xx + Untyped bool // no 'e='(explicit) flag +} + var pkgDataMapping map[string]*PackageData func GetPkgData(pkgPath string) *PackageData { @@ -47,11 +52,35 @@ func WritePkgData(pkgPath string, pkgData *PackageData) error { } defer w.Close() - writeSection := func(section string, m map[string]bool) error { - if len(m) == 0 { - return nil - } - _, err := io.WriteString(w, section) + err = writeConstSection(w, "[const]", pkgData.Consts) + if err != nil { + return err + } + writeSection(w, "[var]", pkgData.Vars) + if err != nil { + return err + } + writeSection(w, "[func]", pkgData.Funcs) + if err != nil { + return err + } + + return nil +} +func writeSection(w io.Writer, section string, m map[string]bool) error { + if len(m) == 0 { + return nil + } + _, err := io.WriteString(w, section) + if err != nil { + return err + } + _, err = io.WriteString(w, "\n") + if err != nil { + return err + } + for k := range m { + _, err := io.WriteString(w, k) if err != nil { return err } @@ -59,29 +88,62 @@ func WritePkgData(pkgPath string, pkgData *PackageData) error { if err != nil { return err } - for k := range m { - _, err := io.WriteString(w, k) - if err != nil { - return err - } - _, err = io.WriteString(w, "\n") - if err != nil { - return err - } - } + } + return nil +} + +func writeConstSection(w io.Writer, section string, m map[string]*ConstInfo) error { + if len(m) == 0 { return nil } - err = writeSection("[const]", pkgData.Consts) + _, err := io.WriteString(w, section) if err != nil { return err } - writeSection("[var]", pkgData.Vars) + _, err = io.WriteString(w, "\n") if err != nil { return err } - writeSection("[func]", pkgData.Funcs) - if err != nil { - return err + for k, v := range m { + _, err := io.WriteString(w, k) + if err != nil { + return err + } + err = writeConst(w, v) + if err != nil { + return err + } + _, err = io.WriteString(w, "\n") + if err != nil { + return err + } + } + return nil +} +func writeConst(w io.Writer, info *ConstInfo) error { + if !info.Untyped { + _, err := io.WriteString(w, " ") + if err != nil { + return err + } + _, err = io.WriteString(w, "e") + if err != nil { + return err + } + } else { + // only untyped requires type info + _, err := io.WriteString(w, " ") + if err != nil { + return err + } + _, err = io.WriteString(w, "t=") + if err != nil { + return err + } + _, err = io.WriteString(w, info.Type) + if err != nil { + return err + } } return nil @@ -144,6 +206,7 @@ func parsePkgData(content string) (*PackageData, error) { if idx >= 0 { name = line[:idx] } + extra := line[idx+1:] if name == "" { break } @@ -160,9 +223,20 @@ func parsePkgData(content string) (*PackageData, error) { p.Vars[name] = true case Section_Const: if p.Consts == nil { - p.Consts = make(map[string]bool, 1) + p.Consts = make(map[string]*ConstInfo, 1) + } + constInfo := &ConstInfo{Untyped: true} + kvs := strings.Split(extra, " ") + for _, v := range kvs { + if v == "e" { + constInfo.Untyped = false + continue + } + if strings.HasPrefix(v, "t=") { + constInfo.Type = v[len("t="):] + } } - p.Consts[name] = true + p.Consts[name] = constInfo default: // ignore others } diff --git a/patch/syntax/call_expr_go1.17_18_19.go b/patch/syntax/call_expr_go1.17_18_19.go index 1744f4ca..03cf3e0d 100644 --- a/patch/syntax/call_expr_go1.17_18_19.go +++ b/patch/syntax/call_expr_go1.17_18_19.go @@ -5,12 +5,6 @@ package syntax import "cmd/compile/internal/syntax" -func (ctx *BlockContext) traverseCallExpr(node *syntax.CallExpr, globaleNames map[string]*DeclInfo, imports map[string]string) *syntax.CallExpr { - if node == nil { - return nil - } - for i, arg := range node.ArgList { - node.ArgList[i] = ctx.traverseExpr(arg, globaleNames, imports) - } - return node +func (ctx *BlockContext) traverseCallStmtCallExpr(node *syntax.CallExpr, globaleNames map[string]*DeclInfo, imports map[string]string) *syntax.CallExpr { + return ctx.traverseCallExpr(node, globaleNames, imports) } diff --git a/patch/syntax/call_expr_go1.20.go b/patch/syntax/call_expr_go1.20.go index aba8968e..4ad71eb7 100644 --- a/patch/syntax/call_expr_go1.20.go +++ b/patch/syntax/call_expr_go1.20.go @@ -5,6 +5,6 @@ package syntax import "cmd/compile/internal/syntax" -func (ctx *BlockContext) traverseCallExpr(node syntax.Expr, globaleNames map[string]*DeclInfo, imports map[string]string) syntax.Expr { +func (ctx *BlockContext) traverseCallStmtCallExpr(node syntax.Expr, globaleNames map[string]*DeclInfo, imports map[string]string) syntax.Expr { return ctx.traverseExpr(node, globaleNames, imports) } diff --git a/patch/syntax/syntax.go b/patch/syntax/syntax.go index b4d3314e..f8ad43de 100644 --- a/patch/syntax/syntax.go +++ b/patch/syntax/syntax.go @@ -43,6 +43,55 @@ func AfterFilesParsed(fileList []*syntax.File, addFile func(name string, r io.Re registerFuncs(fileList, addFile) } +// typeinfo not used +// func AfterSyntaxTypeCheck(pkgPath string, files []*syntax.File, info *types2.Info) { +// if pkgPath != "github.com/xhd2015/xgo/runtime/test/debug" { +// return +// } +// if true { +// return +// } +// stmt := files[0].DeclList[2].(*syntax.FuncDecl).Body.List[0] +// call := stmt.(*syntax.ExprStmt).X.(*syntax.CallExpr) +// name := call.ArgList[0].(*syntax.Name) +// if false { +// v := &syntax.BasicLit{Value: "11", Kind: syntax.IntLit} +// t := syntax.TypeAndValue{ +// Type: name.GetTypeInfo().Type, +// Value: constant.MakeInt64(11), +// } +// t.SetIsValue() +// v.SetTypeInfo(t) +// call.ArgList[0] = v +// } + +// _ = name +// } + +func debugPkgSyntax(files []*syntax.File) { + if false { + return + } + pkgPath := xgo_ctxt.GetPkgPath() + if pkgPath != "github.com/xhd2015/xgo/runtime/test/debug" { + return + } + + stmt := files[0].DeclList[2].(*syntax.FuncDecl).Body.List[1] + call := stmt.(*syntax.ExprStmt).X.(*syntax.CallExpr) + name := call.ArgList[0].(*syntax.Name) + // if false { + call.ArgList[0] = &syntax.XgoSimpleConvert{ + X: &syntax.CallExpr{ + Fun: syntax.NewName(name.Pos(), "int"), + ArgList: []syntax.Expr{ + name, + }, + }, + } + // } +} + func GetSyntaxDeclMapping() map[string]map[LineCol]*DeclInfo { return getSyntaxDeclMapping() } @@ -139,6 +188,8 @@ func registerFuncs(fileList []*syntax.File, addFile func(name string, r io.Reade if len(fileList) > 0 { pkgName = fileList[0].PkgName.Value } + + // debugPkgSyntax(fileList) // if true { // return // } diff --git a/patch/syntax/vars.go b/patch/syntax/vars.go index 91fb94fd..550d5d07 100644 --- a/patch/syntax/vars.go +++ b/patch/syntax/vars.go @@ -49,14 +49,21 @@ func collectVarDecls(declKind DeclKind, names []*syntax.Name, typ syntax.Expr) [ func trapVariables(pkgPath string, fileList []*syntax.File, funcDelcs []*DeclInfo) { names := make(map[string]*DeclInfo, len(funcDelcs)) varNames := make(map[string]bool) - constNames := make(map[string]bool) + constNames := make(map[string]*pkgdata.ConstInfo) for _, funcDecl := range funcDelcs { identityName := funcDecl.IdentityName() names[identityName] = funcDecl if funcDecl.Kind == Kind_Var || funcDecl.Kind == Kind_VarPtr { varNames[identityName] = true } else if funcDecl.Kind == Kind_Const { - constNames[identityName] = true + constDecl := funcDecl.ConstDecl + constInfo := &pkgdata.ConstInfo{Untyped: true} + if constDecl.Type != nil { + constInfo.Untyped = false + } else { + constInfo.Type = getConstDeclValueType(constDecl.Values) + } + constNames[identityName] = constInfo } } err := pkgdata.WritePkgData(pkgPath, &pkgdata.PackageData{ @@ -121,13 +128,26 @@ type BlockContext struct { Names map[string]bool - OperationParent map[syntax.Node]*syntax.Operation + // node appears as RHS of var decl + ListExprParent map[syntax.Node]*syntax.ListExpr + RHSVarDeclParent map[syntax.Node]*syntax.VarDecl + OperationParent map[syntax.Node]*syntax.Operation + ArgCallExprParent map[syntax.Node]*syntax.CallExpr + RHSAssignNoDefParent map[syntax.Node]*syntax.AssignStmt + CaseClauseParent map[syntax.Node]*syntax.CaseClause + ReturnStmtParent map[syntax.Node]*syntax.ReturnStmt + + // const info + ConstInfo map[syntax.Node]*ConstInfo // to be inserted InsertList []syntax.Stmt TrapNames []*NameAndDecl } +type ConstInfo struct { + Type string +} type NameAndDecl struct { TakeAddr bool @@ -152,6 +172,9 @@ func (c *BlockContext) Has(name string) bool { return c.Parent.Has(name) } +// avoid unused warning +var _ = (*BlockContext).traverseNode + // imports: name -> pkgPath func (ctx *BlockContext) traverseNode(node syntax.Node, globaleNames map[string]*DeclInfo, imports map[string]string) syntax.Node { if node == nil { @@ -188,7 +211,7 @@ func (ctx *BlockContext) traverseStmt(node syntax.Stmt, globaleNames map[string] return ctx.traverseBlockStmt(node, globaleNames, imports) case *syntax.CallStmt: // defer, go - node.Call = ctx.traverseCallExpr(node.Call, globaleNames, imports) + node.Call = ctx.traverseCallStmtCallExpr(node.Call, globaleNames, imports) return node case *syntax.IfStmt: node.Init = ctx.traverseSimpleStmt(node.Init, globaleNames, imports) @@ -220,6 +243,12 @@ func (ctx *BlockContext) traverseStmt(node syntax.Stmt, globaleNames map[string] case *syntax.BranchStmt: // ignore continue or continue label case *syntax.ReturnStmt: + if node.Results != nil { + if ctx.ReturnStmtParent == nil { + ctx.ReturnStmtParent = make(map[syntax.Node]*syntax.ReturnStmt, 1) + } + ctx.ReturnStmtParent[node.Results] = node + } node.Results = ctx.traverseExpr(node.Results, globaleNames, imports) default: // unknown @@ -253,6 +282,11 @@ func (ctx *BlockContext) traverseSimpleStmt(node syntax.SimpleStmt, globaleNames } } } + } else { + if ctx.RHSAssignNoDefParent == nil { + ctx.RHSAssignNoDefParent = make(map[syntax.Node]*syntax.AssignStmt, 1) + } + ctx.RHSAssignNoDefParent[node.Rhs] = node } node.Rhs = ctx.traverseExpr(node.Rhs, globaleNames, imports) case *syntax.RangeClause: @@ -280,6 +314,7 @@ func (ctx *BlockContext) traverseBlockStmt(node *syntax.BlockStmt, globaleNames return nil } n := len(node.List) + base := len(ctx.Children) for i := 0; i < n; i++ { subCtx := &BlockContext{ Parent: ctx, @@ -290,7 +325,7 @@ func (ctx *BlockContext) traverseBlockStmt(node *syntax.BlockStmt, globaleNames node.List[i] = subCtx.traverseStmt(node.List[i], globaleNames, imports) } for i := n - 1; i >= 0; i-- { - node.List = insertBefore(node.List, i, ctx.Children[i].InsertList) + node.List = insertBefore(node.List, i, ctx.Children[i+base].InsertList) } return node } @@ -299,6 +334,13 @@ func (ctx *BlockContext) traverseCaseClause(node *syntax.CaseClause, globaleName if node == nil { return nil } + if node.Cases != nil { + if ctx.CaseClauseParent == nil { + ctx.CaseClauseParent = make(map[syntax.Node]*syntax.CaseClause, 1) + } + ctx.CaseClauseParent[node.Cases] = node + } + node.Cases = ctx.traverseExpr(node.Cases, globaleNames, imports) fakeBlock := &syntax.BlockStmt{ List: node.Body, @@ -345,6 +387,17 @@ func (ctx *BlockContext) traverseExpr(node syntax.Expr, globaleNames map[string] return node case *syntax.ParenExpr: node.X = ctx.traverseExpr(node.X, globaleNames, imports) + if xgoConv, ok := node.X.(*syntax.XgoSimpleConvert); ok { + constType := getConstType(xgoConv) + newNode := &syntax.XgoSimpleConvert{ + X: &syntax.CallExpr{ + Fun: syntax.NewName(node.Pos(), constType), + ArgList: []syntax.Expr{node}, + }, + } + ctx.recordConstType(newNode, constType) + return newNode + } case *syntax.SelectorExpr: newNode, selIsName := ctx.trapSelector(node, node, false, globaleNames, imports) if newNode != nil { @@ -364,11 +417,10 @@ func (ctx *BlockContext) traverseExpr(node syntax.Expr, globaleNames map[string] case *syntax.AssertExpr: node.X = ctx.traverseExpr(node.X, globaleNames, imports) case *syntax.TypeSwitchGuard: - res := ctx.traverseExpr(node.X, globaleNames, imports) + node.X = ctx.traverseExpr(node.X, globaleNames, imports) if node.Lhs != nil { ctx.Add(node.Lhs.Value) } - return res case *syntax.Operation: // take addr? if node.Op == syntax.And && node.Y == nil { @@ -396,14 +448,41 @@ func (ctx *BlockContext) traverseExpr(node syntax.Expr, globaleNames map[string] // x op y node.X = ctx.traverseExpr(node.X, globaleNames, imports) node.Y = ctx.traverseExpr(node.Y, globaleNames, imports) + // if both side are const, then the operation should also + // be wrapped in a const + if node.X != nil && node.Y != nil { + xConst := ctx.ConstInfo[node.X] + yConst := ctx.ConstInfo[node.Y] + if xConst != nil && yConst != nil { + newNode := &syntax.XgoSimpleConvert{ + X: &syntax.CallExpr{ + Fun: syntax.NewName(node.Pos(), xConst.Type), + ArgList: []syntax.Expr{node}, + }, + } + ctx.recordConstType(newNode, xConst.Type) + return newNode + } + } else if node.Y == nil && (node.Op == syntax.Add || node.Op == syntax.Sub) { + if xgoConv, ok := node.X.(*syntax.XgoSimpleConvert); ok { + constType := getConstType(xgoConv) + newNode := createConv(node, constType) + ctx.recordConstType(newNode, constType) + return newNode + } + } return node case *syntax.CallExpr: - // NOTE: we skip capturing a name as a function - // node.Fun = ctx.traverseExpr(node.Fun, globaleNames, imports) - for i, arg := range node.ArgList { - node.ArgList[i] = ctx.traverseExpr(arg, globaleNames, imports) - } + return ctx.traverseCallExpr(node, globaleNames, imports) case *syntax.ListExpr: + if len(node.ElemList) > 0 { + if ctx.ListExprParent == nil { + ctx.ListExprParent = make(map[syntax.Node]*syntax.ListExpr, len(node.ElemList)) + } + for _, elem := range node.ElemList { + ctx.ListExprParent[elem] = node + } + } for i, elem := range node.ElemList { node.ElemList[i] = ctx.traverseExpr(elem, globaleNames, imports) } @@ -417,6 +496,13 @@ func (ctx *BlockContext) traverseExpr(node syntax.Expr, globaleNames map[string] case *syntax.ChanType: case *syntax.MapType: case *syntax.BasicLit: + constType := getBasicLitConstType(node.Kind) + if constType != "" { + if ctx.ConstInfo == nil { + ctx.ConstInfo = make(map[syntax.Node]*ConstInfo, 1) + } + ctx.ConstInfo[node] = &ConstInfo{Type: constType} + } case *syntax.BadExpr: default: // unknown @@ -427,6 +513,45 @@ func (ctx *BlockContext) traverseExpr(node syntax.Expr, globaleNames map[string] return node } +func getConstType(xgoConv *syntax.XgoSimpleConvert) string { + return xgoConv.X.(*syntax.CallExpr).Fun.(*syntax.Name).Value +} + +func createConv(node syntax.Expr, constType string) *syntax.XgoSimpleConvert { + return &syntax.XgoSimpleConvert{ + X: &syntax.CallExpr{ + Fun: syntax.NewName(node.Pos(), constType), + ArgList: []syntax.Expr{node}, + }, + } +} + +func (ctx *BlockContext) recordConstType(node syntax.Node, constType string) { + if ctx.ConstInfo == nil { + ctx.ConstInfo = make(map[syntax.Node]*ConstInfo, 1) + } + ctx.ConstInfo[node] = &ConstInfo{Type: constType} +} + +func (ctx *BlockContext) traverseCallExpr(node *syntax.CallExpr, globaleNames map[string]*DeclInfo, imports map[string]string) *syntax.CallExpr { + if node == nil { + return nil + } + if ctx.ArgCallExprParent == nil { + ctx.ArgCallExprParent = make(map[syntax.Node]*syntax.CallExpr, len(node.ArgList)) + for _, arg := range node.ArgList { + ctx.ArgCallExprParent[arg] = node + } + } + + // NOTE: we skip capturing a name as a function + // node.Fun = ctx.traverseExpr(node.Fun, globaleNames, imports) + for i, arg := range node.ArgList { + node.ArgList[i] = ctx.traverseExpr(arg, globaleNames, imports) + } + return node +} + func (ctx *BlockContext) traverseDecl(node syntax.Decl, globaleNames map[string]*DeclInfo, imports map[string]string) syntax.Decl { if node == nil { return nil @@ -435,6 +560,14 @@ func (ctx *BlockContext) traverseDecl(node syntax.Decl, globaleNames map[string] case *syntax.ConstDecl: case *syntax.TypeDecl: case *syntax.VarDecl: + // var a int64 = N + if node.Values != nil { + if ctx.RHSVarDeclParent == nil { + ctx.RHSVarDeclParent = make(map[syntax.Node]*syntax.VarDecl, 1) + } + ctx.RHSVarDeclParent[node.Values] = node + node.Values = ctx.traverseExpr(node.Values, globaleNames, imports) + } default: // unknown if os.Getenv("XGO_DEBUG_VAR_TRAP_LOOSE") != "true" { @@ -454,19 +587,71 @@ func (c *BlockContext) trapValueNode(node *syntax.Name, globaleNames map[string] if decl == nil { return node } + var explicitType syntax.Expr + var rhsAssign *syntax.AssignStmt + var isCallArg bool + var untypedConstType string if decl.Kind == Kind_Var || decl.Kind == Kind_VarPtr { // good to go } else if decl.Kind == Kind_Const { - if _, ok := c.OperationParent[node]; ok { - // directly inside an operation - return node + // untyped const(most cases) should only be used in + // serveral cases because runtime type is unknown + if decl.ConstDecl.Type == nil { + untypedConstType = getConstDeclValueType(decl.ConstDecl.Values) + var ok bool + explicitType, ok = c.isConstOKToTrap(node) + if !ok { + // debug + if _, ok := c.ArgCallExprParent[node]; ok { + isCallArg = true + } + if !isCallArg { + return node + } + } } } else { return node } - preStmts, tmpVarName := trapVar(node, syntax.NewName(node.Pos(), XgoLocalPkgName), node.Value, false) + preStmts, varDefStmt, tmpVarName := trapVar(node, syntax.NewName(node.Pos(), XgoLocalPkgName), node.Value, false) + if rhsAssign != nil { + varDefStmt.Op = 0 + preStmts = append([]syntax.Stmt{ + &syntax.AssignStmt{ + Op: syntax.Def, + Lhs: syntax.NewName(node.Pos(), tmpVarName), + Rhs: rhsAssign.Lhs, + }, + }, preStmts...) + } + c.InsertList = append(c.InsertList, preStmts...) - return syntax.NewName(node.Pos(), tmpVarName) + newName := syntax.NewName(node.Pos(), tmpVarName) + if explicitType != nil { + return &syntax.CallExpr{ + Fun: explicitType, + ArgList: []syntax.Expr{ + newName, + }, + } + } + if untypedConstType != "" { + newNode := &syntax.XgoSimpleConvert{ + X: &syntax.CallExpr{ + Fun: syntax.NewName(node.Pos(), untypedConstType), + ArgList: []syntax.Expr{ + newName, + }, + }, + } + if c.ConstInfo == nil { + c.ConstInfo = make(map[syntax.Node]*ConstInfo, 1) + } + c.ConstInfo[node] = &ConstInfo{Type: untypedConstType} + c.ConstInfo[newNode] = &ConstInfo{Type: untypedConstType} + return newNode + } + return newName } func (ctx *BlockContext) trapSelector(node syntax.Expr, sel *syntax.SelectorExpr, takeAddr bool, globaleNames map[string]*DeclInfo, imports map[string]string) (newExpr syntax.Expr, selIsName bool) { @@ -489,16 +674,123 @@ func (ctx *BlockContext) trapSelector(node syntax.Expr, sel *syntax.SelectorExpr if !allowPkgVarTrap(pkgPath) { return nil, true } + var explicitType syntax.Expr pkgData := pkgdata.GetPkgData(pkgPath) - if pkgData.Consts[sel.Sel.Value] { - // is const and inside operation - if _, ok := ctx.OperationParent[node]; ok { - return nil, true + var isCallArg bool + var untypedConstType string + if constInfo, ok := pkgData.Consts[sel.Sel.Value]; ok { + if constInfo.Untyped { + untypedConstType = constInfo.Type + var ok bool + explicitType, ok = ctx.isConstOKToTrap(node) + if !ok { + // debug + if _, ok := ctx.ArgCallExprParent[node]; ok { + isCallArg = true + } + if !isCallArg { + return nil, true + } + + } } + } else if pkgData.Vars[sel.Sel.Value] { + // ok to go + } else { + return nil, true } - preStmts, tmpVarName := trapVar(node, newStringLit(pkgPath), sel.Sel.Value, takeAddr) + preStmts, _, tmpVarName := trapVar(node, newStringLit(pkgPath), sel.Sel.Value, takeAddr) ctx.InsertList = append(ctx.InsertList, preStmts...) - return syntax.NewName(node.Pos(), tmpVarName), true + newName := syntax.NewName(node.Pos(), tmpVarName) + if explicitType != nil { + return &syntax.CallExpr{ + Fun: explicitType, + ArgList: []syntax.Expr{ + newName, + }, + }, true + } + if untypedConstType != "" { + newNode := &syntax.XgoSimpleConvert{ + X: &syntax.CallExpr{ + Fun: syntax.NewName(node.Pos(), untypedConstType), + ArgList: []syntax.Expr{ + newName, + }, + }, + } + if ctx.ConstInfo == nil { + ctx.ConstInfo = make(map[syntax.Node]*ConstInfo, 1) + } + ctx.ConstInfo[sel] = &ConstInfo{Type: untypedConstType} + ctx.ConstInfo[node] = &ConstInfo{Type: untypedConstType} + return newNode, true + } + return newName, true +} + +func (ctx *BlockContext) isConstOKToTrap(node syntax.Node) (explicitType syntax.Expr, ok bool) { + if true { + return nil, true + } + // is const and inside operation + if _, ok := ctx.OperationParent[node]; ok { + return nil, false + } + + // NOTE: will this: int64(a) not work? maybe we + // can make it work + if _, ok := ctx.ArgCallExprParent[node]; ok { + // directly as argument to a call + return nil, false + } + if _, ok := ctx.CaseClauseParent[node]; ok { + return nil, false + } + if _, ok := ctx.ReturnStmtParent[node]; ok { + return nil, false + } + if varDecl, ok := ctx.RHSVarDeclParent[node]; ok { + return varDecl.Type, true + } + + if _, ok := ctx.RHSAssignNoDefParent[node]; ok { + // a=CONST -> tmp:=CONST,a=tmp + // not working + // rhsAssign = assign + return nil, false + } + listExprParent, ok := ctx.ListExprParent[node] + if !ok { + return nil, true + } + return ctx.isConstOKToTrap(listExprParent) +} + +func getConstDeclValueType(expr syntax.Expr) string { + switch expr := expr.(type) { + case *syntax.BasicLit: + return getBasicLitConstType(expr.Kind) + case *syntax.Name: + if expr.Value == "true" || expr.Value == "false" { + return "bool" + } + // NOTE: nil is not a constant + } + return "" +} +func getBasicLitConstType(kind syntax.LitKind) string { + switch kind { + case syntax.IntLit: + return "int" + case syntax.StringLit: + return "string" + case syntax.RuneLit: + return "rune" + case syntax.FloatLit: + return "float64" + } + return "" } func (c *BlockContext) trapAddrNode(node *syntax.Operation, nameNode *syntax.Name, globaleNames map[string]*DeclInfo) syntax.Expr { @@ -511,12 +803,12 @@ func (c *BlockContext) trapAddrNode(node *syntax.Operation, nameNode *syntax.Nam if decl == nil || !decl.Kind.IsVarOrConst() { return node } - preStmts, tmpVarName := trapVar(node, syntax.NewName(nameNode.Pos(), XgoLocalPkgName), name, true) + preStmts, _, tmpVarName := trapVar(node, syntax.NewName(nameNode.Pos(), XgoLocalPkgName), name, true) c.InsertList = append(c.InsertList, preStmts...) return syntax.NewName(node.Pos(), tmpVarName) } -func trapVar(expr syntax.Expr, pkgRef syntax.Expr, name string, takeAddr bool) (preStmts []syntax.Stmt, tmpVarName string) { +func trapVar(expr syntax.Expr, pkgRef syntax.Expr, name string, takeAddr bool) (preStmts []syntax.Stmt, varDefStmt *syntax.AssignStmt, tmpVarName string) { pos := expr.Pos() line := pos.Line() col := pos.Col() @@ -530,12 +822,14 @@ func trapVar(expr syntax.Expr, pkgRef syntax.Expr, name string, takeAddr bool) ( // &a -> __m varName := fmt.Sprintf("__xgo_%s_%d_%d", name, line, col) // a: - - preStmts = append(preStmts, &syntax.AssignStmt{ + varDefStmt = &syntax.AssignStmt{ Op: syntax.Def, Lhs: syntax.NewName(pos, varName), Rhs: expr, - }, + } + + preStmts = append(preStmts, + varDefStmt, &syntax.ExprStmt{ X: &syntax.CallExpr{ Fun: syntax.NewName(pos, "__xgo_link_trap_var_for_generated"), @@ -563,10 +857,13 @@ func trapVar(expr syntax.Expr, pkgRef syntax.Expr, name string, takeAddr bool) ( for _, preStmt := range preStmts { fillPos(pos, preStmt) } - return preStmts, varName + return preStmts, varDefStmt, varName } func insertBefore(list []syntax.Stmt, i int, add []syntax.Stmt) []syntax.Stmt { + if len(add) == 0 { + return list + } return append(append(list[:i:i], add...), list[i:]...) } diff --git a/runtime/core/version.go b/runtime/core/version.go index 2bfd1403..74296c39 100644 --- a/runtime/core/version.go +++ b/runtime/core/version.go @@ -7,8 +7,8 @@ import ( ) const VERSION = "1.0.19" -const REVISION = "7b94c97a1b438510f46aa86fca2b626396b9d2e1+1" -const NUMBER = 166 +const REVISION = "7fd3f2b160c52c890d248efceb23a5e070156e0c+1" +const NUMBER = 167 // these fields will be filled by compiler const XGO_VERSION = "" diff --git a/runtime/test/debug/debug_test.go b/runtime/test/debug/debug_test.go index d18d33d6..6fc0b04f 100644 --- a/runtime/test/debug/debug_test.go +++ b/runtime/test/debug/debug_test.go @@ -6,21 +6,26 @@ package debug import ( - "context" "testing" - "github.com/xhd2015/xgo/runtime/core" "github.com/xhd2015/xgo/runtime/mock" - "github.com/xhd2015/xgo/runtime/test/mock_var/sub" ) -func TestMockVarInOtherPkg(t *testing.T) { - mock.Mock(&sub.A, func(ctx context.Context, fn *core.FuncInfo, args, results core.Object) error { - results.GetFieldIndex(0).Set("mockA") - return nil +const N = 50 +const M = 20 + +func TestPatchConstOperationShouldCompileAndSkipMock(t *testing.T) { + // should have no effect + mock.PatchByName("github.com/xhd2015/xgo/runtime/test/debug", "N", func() int { + return 10 }) - b := sub.A - if b != "mockA" { - t.Fatalf("expect sub.A to be %s, actual: %s", "mockA", b) + // because N is used inside an operation + // it's type is not yet determined, so + // should not rewrite it + var size int64 = M + N + t.Logf("size=%d", size) + // size := (N + 1) * unsafe.Sizeof(int(0)) + if size != 11 { + t.Fatalf("expect N not patched and size to be %d, actual: %d\n", 11, size) } } diff --git a/runtime/test/patch/patch_const_test.go b/runtime/test/patch/patch_const_test.go new file mode 100644 index 00000000..98307736 --- /dev/null +++ b/runtime/test/patch/patch_const_test.go @@ -0,0 +1,231 @@ +package patch + +import ( + "fmt" + "os" + "testing" + "unsafe" + + "github.com/xhd2015/xgo/runtime/mock" + "github.com/xhd2015/xgo/runtime/test/patch/sub" +) + +const testVersion = "1.0" + +func TestPatchConstByNamePtrTest(t *testing.T) { + mock.PatchByName(pkgPath, "testVersion", func() string { + return "1.5" + }) + version := testVersion + if version != "1.5" { + t.Fatalf("expect patched version a to be %s, actual: %s", "1.5", version) + } +} + +func TestPatchConstByNameWrongTypeShouldFail(t *testing.T) { + var pe interface{} + func() { + defer func() { + pe = recover() + }() + mock.PatchByName(pkgPath, "a", func() string { + return "1.5" + }) + }() + expectMsg := "replacer should have type: func() int, actual: func() string" + if pe == nil { + t.Fatalf("expect panic: %q, actual nil", expectMsg) + } + msg := fmt.Sprint(pe) + if msg != expectMsg { + t.Fatalf("expect err %q, actual: %q", expectMsg, msg) + } +} + +const N = 50 + +func TestPatchConstOperationShouldCompileAndSkipMock(t *testing.T) { + // should have effect + mock.PatchByName(pkgPath, "N", func() int { + return 10 + }) + // because N is used inside an operation + // it's type is not yet determined, so + // should not rewrite it + size := N * unsafe.Sizeof(int(0)) + if size != 80 { + t.Fatalf("expect N not patched and size to be %d, actual: %d\n", 80, size) + } +} + +func TestPatchOtherPkgConstOperationShouldWork(t *testing.T) { + // should have effect + mock.PatchByName(subPkgPath, "N", func() int { + return 10 + }) + // because N is used inside an operation + // it's type is not yet determined, so + // should not rewrite it + size := sub.N * unsafe.Sizeof(int(0)) + if size != 80 { + t.Fatalf("expect N not patched and size to be %d, actual: %d\n", 80, size) + } +} + +func TestConstOperationNaked(t *testing.T) { + mock.PatchByName(pkgPath, "N", func() int { + return 10 + }) + var size int64 = N + 1 + if size != 11 { + t.Fatalf("expect N not patched and size to be %d, actual: %d\n", 11, size) + } +} + +const M = 10 + +func TestTwoConstAdd(t *testing.T) { + mock.PatchByName(pkgPath, "N", func() int { + return 10 + }) + var size int64 = (N + M) * 2 + if size != 40 { + t.Fatalf("expect N not patched and size to be %d, actual: %d\n", 40, size) + } +} + +func TestConstOperationParen(t *testing.T) { + mock.PatchByName(pkgPath, "N", func() int { + return 10 + }) + var size int64 = (N + 1) * 2 + if size != 22 { + t.Fatalf("expect N not patched and size to be %d, actual: %d\n", 22, size) + } +} + +// local const +func TestPatchConstInAssignmentShouldWork(t *testing.T) { + mock.PatchByName(pkgPath, "N", func() int { + return 10 + }) + var a int64 = N + + if a != 10 { + t.Fatalf("expect a to be %d, actual: %d\n", 10, a) + } +} + +func TestPatchConstInAssignmentNoDefShouldWork(t *testing.T) { + mock.PatchByName(pkgPath, "N", func() int { + return 10 + }) + var a int64 = 100 + if os.Getenv("nothing") == "" { + a = N + } + + if a != 10 { + t.Fatalf("expect a to be %d, actual: %d\n", 10, a) + } +} + +func TestPatchConstInFuncArgShouldSkip(t *testing.T) { + mock.PatchByName(pkgPath, "N", func() int { + return 10 + }) + a := f(N) + if a != 10 { + t.Fatalf("expect a to be %d, actual: %d\n", 10, a) + } +} + +func TestPatchConstInTypeConvertArgShouldWork(t *testing.T) { + mock.PatchByName(pkgPath, "N", func() int { + return 10 + }) + a := int64(N) + if a != 10 { + t.Fatalf("expect a to be %d, actual: %d\n", 10, a) + } +} +func f(a int64) int64 { + return a +} + +func TestCaseConstShouldSkip(t *testing.T) { + n := int64(50) + switch n { + case N: + case N + 1: + t.Fatalf("should not faill to N+1") + default: + t.Fatalf("should not fall to default") + } + + switch n { + case N, N + 1: + default: + t.Fatalf("should not fall to default") + } +} + +func TestReturnConstShouldWork(t *testing.T) { + mock.PatchByName(pkgPath, "N", func() int { + return 10 + }) + n := getN() + if n != 10 { + t.Fatalf("expect a to be %d, actual: %d\n", 50, n) + } +} + +func getN() int64 { + return N +} + +// other package +func TestPatchOtherPackageConstInAssignmentShouldWork(t *testing.T) { + mock.PatchByName(subPkgPath, "N", func() int { + return 10 + }) + var a int64 = sub.N + + if a != 10 { + t.Fatalf("expect a to be %d, actual: %d\n", 10, a) + } +} + +func TestPatchOtherPackageConstInAssignmentNoDefShouldWork(t *testing.T) { + mock.PatchByName(subPkgPath, "N", func() int { + return 10 + }) + var a int64 = 100 + if os.Getenv("nothing") == "" { + a = sub.N + } + + if a != 10 { + t.Fatalf("expect a to be %d, actual: %d\n", 10, a) + } +} + +func TestPatchOtherPackageConstInFuncArgShouldWork(t *testing.T) { + mock.PatchByName(subPkgPath, "N", func() int { + return 10 + }) + a := f(sub.N) + if a != 10 { + t.Fatalf("expect a to be %d, actual: %d\n", 10, a) + } +} + +func TestPatchOtherPackageConstInTypeConvertArgShouldWork(t *testing.T) { + mock.PatchByName(subPkgPath, "N", func() int { + return 10 + }) + a := int64(sub.N) + if a != 10 { + t.Fatalf("expect a to be %d, actual: %d\n", 10, a) + } +} diff --git a/runtime/test/patch/patch_var_test.go b/runtime/test/patch/patch_var_test.go index 82aec209..ec31a075 100644 --- a/runtime/test/patch/patch_var_test.go +++ b/runtime/test/patch/patch_var_test.go @@ -1,9 +1,10 @@ package patch import ( + "encoding/json" "fmt" + "os" "testing" - "unsafe" "github.com/xhd2015/xgo/runtime/mock" "github.com/xhd2015/xgo/runtime/test/patch/sub" @@ -72,64 +73,45 @@ func TestPatchVarByNamePtrTest(t *testing.T) { } } -const testVersion = "1.0" - -func TestPatchConstByNamePtrTest(t *testing.T) { - mock.PatchByName(pkgPath, "testVersion", func() string { - return "1.5" - }) - version := testVersion - if version != "1.5" { - t.Fatalf("expect patched version a to be %s, actual: %s", "1.5", version) - } +func TestPatchSwitchCaseShouldCompile(t *testing.T) { + toJSONRaw(10) } -func TestPatchConstByNameWrongTypeShouldFail(t *testing.T) { - var pe interface{} - func() { - defer func() { - pe = recover() - }() - mock.PatchByName(pkgPath, "a", func() string { - return "1.5" - }) - }() - expectMsg := "replacer should have type: func() int, actual: func() string" - if pe == nil { - t.Fatalf("expect panic: %q, actual nil", expectMsg) +func toJSONRaw(v interface{}) (json.RawMessage, error) { + if v == nil { + return nil, nil } - msg := fmt.Sprint(pe) - if msg != expectMsg { - t.Fatalf("expect err %q, actual: %q", expectMsg, msg) + switch v := v.(type) { + case []byte: + return v, nil + case json.RawMessage: + return v, nil + case string: + return json.RawMessage([]byte(v)), nil + default: + return json.Marshal(v) } } -const N = 50 +const a3 = 4 -func TestPatchConstOperationShouldCompileAndSkipMock(t *testing.T) { - // should have no effect - mock.PatchByName(pkgPath, "N", func() int { - return 10 - }) - // because N is used inside an operation - // it's type is not yet determined, so - // should not rewrite it - size := N * unsafe.Sizeof(int(0)) - if size != 400 { - t.Logf("expect N not patched and size to be %d, actual: %d\n", 400, size) +func TestPatchInElseShouldWork(t *testing.T) { + if os.Getenv("nothing") == "nothing" { + t.Fatalf("should go else") + } else { + mock.PatchByName(pkgPath, "a3", func() int { + return 5 + }) + b := a3 + + if b != 5 { + t.Fatalf("expect b to be %d,actual: %d", 5, b) + } } } -func TestPatchOtherPkgConstOperationShouldCompileAndSkipMock(t *testing.T) { - // should have no effect - mock.PatchByName(subPkgPath, "N", func() int { - return 10 - }) - // because N is used inside an operation - // it's type is not yet determined, so - // should not rewrite it - size := sub.N * unsafe.Sizeof(int(0)) - if size != 400 { - t.Logf("expect N not patched and size to be %d, actual: %d\n", 400, size) - } +func TestMakeInOtherPackageShouldCompile(t *testing.T) { + // previous error:sub.NameSet (type) is not an expression + set := make(sub.NameSet) + _ = set } diff --git a/runtime/test/patch/sub/sub.go b/runtime/test/patch/sub/sub.go index 8348c677..cbc28085 100644 --- a/runtime/test/patch/sub/sub.go +++ b/runtime/test/patch/sub/sub.go @@ -1,3 +1,5 @@ package sub const N = 50 + +type NameSet map[string]bool diff --git a/support/strutil/strutil.go b/support/strutil/strutil.go index b07be792..1628c3b1 100644 --- a/support/strutil/strutil.go +++ b/support/strutil/strutil.go @@ -6,41 +6,57 @@ import ( ) func IndexSequenceAt(s string, sequence []string, begin bool) int { - _, idx := indexSequence(s, sequence, begin) - return idx + idx := 0 + if !begin { + idx = -1 + } + off, _ := indexSequence(s, sequence, idx, begin) + return off +} + +func SeqenceOffset(s string, sequence []string, i int, begin bool) (offset int, endOffset int) { + return indexSequence(s, sequence, i, begin) } func IndexSequence(s string, sequence []string) int { - _, idx := indexSequence(s, sequence, false) - return idx + off, _ := indexSequence(s, sequence, -1, false) + return off } -func indexSequence(s string, sequence []string, begin bool) (int, int) { + +// [a,b,c] +// before -> +func indexSequence(s string, sequence []string, seqIdx int, begin bool) (offset int, endOffset int) { if len(sequence) == 0 { return 0, 0 } - firstIdx := -1 - base := 0 + if seqIdx == -1 { + seqIdx = len(sequence) - 1 + } else if seqIdx < 0 || seqIdx >= len(sequence) { + return -1, -1 + } + var recordOff int + cursor := 0 for i, seq := range sequence { idx := strings.Index(s, seq) if idx < 0 { - return i, -1 - } - if firstIdx < 0 { - firstIdx = idx + return -1, -1 } s = s[idx+len(seq):] - base += idx + len(seq) - } - if begin { - return -1, firstIdx + cursor += idx + len(seq) + if i == seqIdx { + recordOff = cursor + if begin { + recordOff -= len(seq) + } + } } - return -1, base + return recordOff, cursor } func CheckSequence(output string, sequence []string) error { - missing, idx := indexSequence(output, sequence, false) + idx, _ := indexSequence(output, sequence, -1, false) if idx < 0 { - return fmt.Errorf("sequence at %d: missing %q", missing, sequence[missing]) + return fmt.Errorf("sequence %q missing from %q", sequence, output) } return nil }