diff --git a/README.md b/README.md index 8e13860a..3977013d 100644 --- a/README.md +++ b/README.md @@ -139,11 +139,11 @@ go test -v ./ Output: ```sh -WARNING: failed to link __xgo_link_on_init_finished.(xgo required) -WARNING: failed to link __xgo_link_on_goexit.(xgo required) +WARNING: failed to link __xgo_link_on_init_finished(requires xgo). +WARNING: failed to link __xgo_link_on_goexit(requires xgo). === RUN TestFuncMock -WARNING: failed to link __xgo_link_set_trap.(xgo required) -WARNING: failed to link __xgo_link_init_finished.(xgo required) +WARNING: failed to link __xgo_link_set_trap(requires xgo). +WARNING: failed to link __xgo_link_init_finished(requires xgo). demo_test.go:21: expect MyFunc() to be 'mock func', actual: my func --- FAIL: TestFuncMock (0.00s) FAIL diff --git a/README_zh_cn.md b/README_zh_cn.md index 99fbc941..384dec43 100644 --- a/README_zh_cn.md +++ b/README_zh_cn.md @@ -135,11 +135,11 @@ go test -v ./ 输出: ```sh -WARNING: failed to link __xgo_link_on_init_finished.(xgo required) -WARNING: failed to link __xgo_link_on_goexit.(xgo required) +WARNING: failed to link __xgo_link_on_init_finished(requires xgo). +WARNING: failed to link __xgo_link_on_goexit(requires xgo). === RUN TestFuncMock -WARNING: failed to link __xgo_link_set_trap.(xgo required) -WARNING: failed to link __xgo_link_init_finished.(xgo required) +WARNING: failed to link __xgo_link_set_trap(requires xgo). +WARNING: failed to link __xgo_link_init_finished(requires xgo). demo_test.go:21: expect MyFunc() to be 'mock func', actual: my func --- FAIL: TestFuncMock (0.00s) FAIL diff --git a/cmd/xgo/patch.go b/cmd/xgo/patch.go index e2f261ad..99b427d8 100644 --- a/cmd/xgo/patch.go +++ b/cmd/xgo/patch.go @@ -4,21 +4,15 @@ import ( "bytes" "errors" "fmt" - "go/ast" - "io/fs" "io/ioutil" "os" "os/exec" "path/filepath" - "runtime" - "sort" "strings" - "github.com/xhd2015/xgo/cmd/xgo/patch" "github.com/xhd2015/xgo/support/filecopy" "github.com/xhd2015/xgo/support/goinfo" "github.com/xhd2015/xgo/support/osinfo" - "github.com/xhd2015/xgo/support/transform" ) // assume go 1.20 @@ -50,62 +44,6 @@ func patchRuntimeAndCompiler(origGoroot string, goroot string, xgoSrc string, go return nil } -func patchRuntimeAndTesting(goroot string) error { - err := patchRuntimeProc(goroot) - if err != nil { - return err - } - err = patchRuntimeTesting(goroot) - if err != nil { - return err - } - return nil -} - -func patchRuntimeProc(goroot string) error { - anchors := []string{ - "func main() {", - "doInit(", "runtime_inittask", ")", // first doInit for runtime - "doInit(", // second init for main - "close(main_init_done)", - "\n", - } - procGo := filepath.Join(goroot, "src", "runtime", "proc.go") - err := editFile(procGo, func(content string) (string, error) { - content = addContentAfter(content, "/**/", "/**/", anchors, patch.RuntimeProcPatch) - - // goexit1() is called for every exited goroutine - content = addContentAfter(content, - "/**/", "/**/", - []string{"func goexit1() {", "\n"}, - patch.RuntimeProcGoroutineExitPatch, - ) - return content, nil - }) - if err != nil { - return err - } - return nil -} - -func patchRuntimeTesting(goroot string) error { - testingFile := filepath.Join(goroot, "src", "testing", "testing.go") - return editFile(testingFile, func(content string) (string, error) { - // func tRunner(t *T, fn func(t *T)) { - anchor := []string{"func tRunner(t *T", "{", "\n"} - content = addContentBefore(content, - "/**/", "/**/", - anchor, - patch.TestingCallbackDeclarations, - ) - content = addContentAfter(content, - "/**/", "/**/", - anchor, - patch.TestingStart, - ) - return content, nil - }) -} func getInternalPatch(goroot string, subDirs ...string) string { dir := filepath.Join(goroot, "src", "cmd", "compile", "internal", "xgo_rewrite_internal", "patch") if len(subDirs) > 0 { @@ -113,184 +51,6 @@ func getInternalPatch(goroot string, subDirs ...string) string { } return dir } -func importCompileInternalPatch(goroot string, xgoSrc string, forceReset bool, syncWithLink bool) error { - dstDir := getInternalPatch(goroot) - if isDevelopment { - symLink := syncWithLink - if osinfo.FORCE_COPY_UNSYM { - // windows: A required privilege is not held by the client. - symLink = false - } - // copy compiler internal dependencies - err := filecopy.CopyReplaceDir(filepath.Join(xgoSrc, "patch"), dstDir, symLink) - if err != nil { - return err - } - - // remove patch/go.mod - err = os.RemoveAll(filepath.Join(dstDir, "go.mod")) - if err != nil { - if errors.Is(err, os.ErrNotExist) { - return nil - } - return err - } - return nil - } - - if forceReset { - // -a causes repatch - err := os.RemoveAll(dstDir) - if err != nil { - return err - } - } else { - // check if already copied - _, statErr := os.Stat(dstDir) - if statErr == nil { - // skip copy if already exists - return nil - } - } - - // read from embed - err := fs.WalkDir(patchEmbed, "patch_compiler", func(path string, d fs.DirEntry, err error) error { - if err != nil { - return err - } - if path == "patch_compiler" { - return os.MkdirAll(dstDir, 0755) - } - // TODO: test on windows if "/" works - dstPath := filepath.Join(dstDir, strings.TrimPrefix(path, "patch_compiler/")) - if d.IsDir() { - return os.MkdirAll(dstPath, 0755) - } - - content, err := patchEmbed.ReadFile(path) - if err != nil { - return err - } - return os.WriteFile(dstPath, content, 0755) - }) - if err != nil { - return err - } - - return nil -} - -func patchRuntimeDef(origGoroot string, goroot string, goVersion *goinfo.GoVersion) error { - err := prepareRuntimeDefs(goroot, goVersion) - if err != nil { - return err - } - - // run mkbuiltin - cmd := exec.Command(filepath.Join(origGoroot, "bin", "go"), "run", "mkbuiltin.go") - cmd.Stderr = os.Stderr - cmd.Stdout = os.Stdout - - var dirs []string - if goVersion.Major > 1 || (goVersion.Major == 1 && goVersion.Minor > 16) { - dirs = []string{goroot, "src", "cmd", "compile", "internal", "typecheck"} - } else { - dirs = []string{goroot, "src", "cmd", "compile", "internal", "gc"} - } - cmd.Dir = filepath.Join(dirs...) - cmd.Env = os.Environ() - cmd.Env, err = patchEnvWithGoroot(cmd.Env, origGoroot) - if err != nil { - return err - } - - err = cmd.Run() - if err != nil { - return err - } - - return nil -} - -func prepareRuntimeDefs(goRoot string, goVersion *goinfo.GoVersion) error { - runtimeDefFiles := []string{"src", "cmd", "compile", "internal", "typecheck", "_builtin", "runtime.go"} - if goVersion.Major == 1 && goVersion.Minor <= 19 { - if goVersion.Minor > 16 { - // in go1.19 and below, builtin has no _ prefix - runtimeDefFiles = []string{"src", "cmd", "compile", "internal", "typecheck", "builtin", "runtime.go"} - } else { - runtimeDefFiles = []string{"src", "cmd", "compile", "internal", "gc", "builtin", "runtime.go"} - } - } - runtimeDefFile := filepath.Join(runtimeDefFiles...) - fullFile := filepath.Join(goRoot, runtimeDefFile) - - extraDef := patch.RuntimeExtraDef - return editFile(fullFile, func(content string) (string, error) { - content = addContentAfter(content, - `/**/`, `/**/`, - []string{`var x86HasFMA bool`, `var armHasVFPv4 bool`, `var arm64HasATOMICS bool`}, - extraDef, - ) - return content, nil - }) -} - -func patchCompiler(origGoroot string, goroot string, goVersion *goinfo.GoVersion, xgoSrc string, forceReset bool, syncWithLink bool) error { - // copy compiler internal dependencies - err := importCompileInternalPatch(goroot, xgoSrc, forceReset, syncWithLink) - if err != nil { - return err - } - runtimeDefUpdated, err := addRuntimeFunctions(goroot, goVersion, xgoSrc) - if err != nil { - return err - } - - if runtimeDefUpdated { - err = patchRuntimeDef(origGoroot, goroot, goVersion) - if err != nil { - return err - } - } - - // NOTE: not adding reflect to access any method - if false { - err = addReflectFunctions(goroot, goVersion, xgoSrc) - if err != nil { - return err - } - } - - err = patchCompilerInternal(goroot, goVersion) - if err != nil { - return err - } - return nil -} - -func patchCompilerInternal(goroot string, goVersion *goinfo.GoVersion) error { - // src/cmd/compile/internal/noder/noder.go - err := patchCompilerNoder(goroot, goVersion) - if err != nil { - return fmt.Errorf("patching noder: %w", err) - } - if goVersion.Major == 1 && (goVersion.Minor == 18 || goVersion.Minor == 19) { - err := poatchIRGenericGen(goroot, goVersion) - if err != nil { - return fmt.Errorf("patching generic trap: %w", err) - } - } - err = patchSynatxNode(goroot, goVersion) - if err != nil { - return fmt.Errorf("patching syntax node:%w", err) - } - err = patchGcMain(goroot, goVersion) - if err != nil { - return fmt.Errorf("patching gc main:%w", err) - } - return nil -} func readXgoSrc(xgoSrc string, paths []string) ([]byte, error) { if isDevelopment { @@ -307,302 +67,6 @@ func replaceBuildIgnore(content []byte) ([]byte, error) { return replaceMarkerNewline(content, []byte(buildIgnore)) } -// addRuntimeFunctions always copy file -func addRuntimeFunctions(goroot string, goVersion *goinfo.GoVersion, xgoSrc string) (updated bool, err error) { - if false { - // seems unnecessary - // TODO: needs to debug to see what will happen with auto generated files - // we need to skip when debugging - - // add debug file - // rational: when debugging, dlv will jump to __xgo_autogen_register_func_helper.go - // previousely this file does not exist, making the debugging blind - runtimeAutoGenFile := filepath.Join(goroot, "src", "runtime", "__xgo_autogen_register_func_helper.go") - srcAutoGen := getInternalPatch(goroot, "syntax", "helper_code.go") - err = filecopy.CopyFile(srcAutoGen, runtimeAutoGenFile) - if err != nil { - return false, err - } - } - - dstFile := filepath.Join(goroot, "src", "runtime", "xgo_trap.go") - content, err := readXgoSrc(xgoSrc, []string{"trap_runtime", "xgo_trap.go"}) - if err != nil { - return false, err - } - - content, err = replaceBuildIgnore(content) - if err != nil { - return false, fmt.Errorf("file %s: %w", filepath.Base(dstFile), err) - } - - // the func.entry is a field, not a function - if goVersion.Major == 1 && goVersion.Minor <= 17 { - entryPatch := "fn.entry() /*>=go1.18*/" - entryPatchBytes := []byte(entryPatch) - idx := bytes.Index(content, entryPatchBytes) - if idx < 0 { - return false, fmt.Errorf("expect %q in xgo_trap.go, actually not found", entryPatch) - } - content = bytes.ReplaceAll(content, entryPatchBytes, []byte("fn.entry")) - } - - // func name patch - if goVersion.Major > 1 || goVersion.Minor > 22 { - panic("should check the implementation of runtime.FuncForPC(pc).Name() to ensure __xgo_get_pc_name is not wrapped in print format above go1.22") - } - if goVersion.Major > 1 || goVersion.Minor >= 21 { - content = append(content, []byte(patch.RuntimeGetFuncName_Go121)...) - } else if goVersion.Major == 1 { - if goVersion.Minor >= 17 { - // go1.17,go1.18,go1.19 - content = append(content, []byte(patch.RuntimeGetFuncName_Go117_120)...) - } - } - - return true, os.WriteFile(dstFile, content, 0755) -} - -func addReflectFunctions(goroot string, goVersion *goinfo.GoVersion, xgoSrc string) error { - dstFile := filepath.Join(goroot, "src", "reflect", "xgo_reflect.go") - content, err := readXgoSrc(xgoSrc, []string{"trap_runtime", "xgo_reflect.go"}) - if err != nil { - return err - } - - content, err = replaceBuildIgnore(content) - if err != nil { - return fmt.Errorf("file %s: %w", filepath.Base(dstFile), err) - } - - valCode, err := transformReflectValue(filepath.Join(goroot, "src", "reflect", "value.go")) - if err != nil { - return fmt.Errorf("transforming reflect/value.go: %w", err) - } - typeCode, err := transformReflectType(filepath.Join(goroot, "src", "reflect", "type.go")) - if err != nil { - return fmt.Errorf("transforming reflect/type.go: %w", err) - } - - // fmt.Printf("typCode: %s\n", typeCode) - - // concat all code - content = bytes.Join([][]byte{content, []byte(valCode), []byte(typeCode)}, []byte("\n")) - return os.WriteFile(dstFile, content, 0755) -} - -const xgoGetAllMethodByName = "__xgo_get_all_method_by_name" - -func transformReflectValue(reflectValueFile string) (string, error) { - file, err := transform.Parse(reflectValueFile) - if err != nil { - return "", err - } - - fnDecl := file.GetMethodDecl("Value", "MethodByName") - if fnDecl == nil { - return "", fmt.Errorf("cannot find Value.MethodByName") - } - - code, err := replaceIdent(file, fnDecl, xgoGetAllMethodByName, func(n ast.Node) (*ast.Ident, string) { - sel, ok := n.(*ast.SelectorExpr) - if !ok { - return nil, "" - } - - idt := sel.Sel - switch idt.Name { - case "MethodByName": - return idt, xgoGetAllMethodByName - case "Method": // method by index - return idt, "__xgo_get_all_method_index" - } - - return nil, "" - }) - if err != nil { - return "", fmt.Errorf("replacing MethodByName: %w", err) - } - - methodDecl := file.GetMethodDecl("Value", "Method") // method by index - if methodDecl == nil { - return "", fmt.Errorf("cannot find Value.Method") - } - code2, err := replaceIdent(file, methodDecl, "__xgo_get_all_method_index", func(n ast.Node) (*ast.Ident, string) { - sel, ok := n.(*ast.SelectorExpr) - if !ok { - return nil, "" - } - - idt := sel.Sel - switch idt.Name { - case "NumMethod": // method by index - return idt, "__xgo_get_all_method_num" - } - return nil, "" - }) - if err != nil { - return "", fmt.Errorf("replacing Method: %w", err) - } - - codef := strings.Join([]string{code, code2}, "\n") - return codef, nil -} - -func transformReflectType(reflectTypeFile string) (string, error) { - file, err := transform.Parse(reflectTypeFile) - if err != nil { - return "", err - } - fnDecl := file.GetMethodDecl("rtype", "MethodByName") - if fnDecl == nil { - return "", fmt.Errorf("cannot find rtype.MethodByName") - } - m0, err := replaceIdent(file, fnDecl, xgoGetAllMethodByName, func(n ast.Node) (*ast.Ident, string) { - sel, ok := n.(*ast.SelectorExpr) - if !ok { - return nil, "" - } - - idt := sel.Sel - if idt.Name == "ExportedMethods" { - return idt, "Methods" - } else if idt.Name == "Method" { - return idt, "__xgo_get_all_method_index" - } - return nil, "" - }) - if err != nil { - return "", fmt.Errorf("replacing ExportedMethods: %w", err) - } - - fnDecl2 := file.GetMethodDecl("rtype", "exportedMethods") - if fnDecl2 == nil { - return "", fmt.Errorf("cannot find rtype.exportedMethods") - } - - m1, err := replaceIdent(file, fnDecl2, "__xgo_all_methods", func(n ast.Node) (*ast.Ident, string) { - sel, ok := n.(*ast.SelectorExpr) - if !ok { - return nil, "" - } - - idt := sel.Sel - if idt.Name == "ExportedMethods" { - return idt, "Methods" - } - return nil, "" - }) - if err != nil { - return "", err - } - - methodDecl := file.GetMethodDecl("rtype", "Method") - if methodDecl == nil { - return "", fmt.Errorf("cannot find rtype.Method") - } - m2, err := replaceIdent(file, methodDecl, "__xgo_get_all_method_index", func(n ast.Node) (*ast.Ident, string) { - sel, ok := n.(*ast.SelectorExpr) - if !ok { - return nil, "" - } - - idt := sel.Sel - if idt.Name == "exportedMethods" { - return idt, "__xgo_all_methods" - } - return nil, "" - }) - if err != nil { - return "", fmt.Errorf("replacing Method: %w", err) - } - - numA := file.GetMethodDecl("rtype", "NumMethod") - if numA == nil { - return "", fmt.Errorf("cannot find rtype.NumMethod") - } - m3, err := replaceIdent(file, numA, "__xgo_get_all_method_num", func(n ast.Node) (*ast.Ident, string) { - sel, ok := n.(*ast.SelectorExpr) - if !ok { - return nil, "" - } - - idt := sel.Sel - if idt.Name == "exportedMethods" { - return idt, "__xgo_all_methods" - } - return nil, "" - }) - if err != nil { - return "", fmt.Errorf("replacing Method: %w", err) - } - code := strings.Join([]string{m0, m1, m2, m3}, "\n") - return code, nil -} - -func replaceIdent(file *transform.File, fnDecl *ast.FuncDecl, replaceFuncName string, identReplacer func(n ast.Node) (*ast.Ident, string)) (string, error) { - type replaceIdent struct { - idt *ast.Ident - rep string - } - var replaceIdents []replaceIdent - ast.Inspect(fnDecl.Body, func(n ast.Node) bool { - if n == nil { - // post action - return false - } - idt, replace := identReplacer(n) - if idt != nil { - replaceIdents = append(replaceIdents, replaceIdent{ - idt: idt, - rep: replace, - }) - } - return true - }) - if len(replaceIdents) == 0 { - return "", errors.New("no replace found") - } - if replaceFuncName != "" { - // replace the name - replaceIdents = append(replaceIdents, replaceIdent{ - idt: fnDecl.Name, - rep: replaceFuncName, - }) - } - // find assignment to x - sort.Slice(replaceIdents, func(i, j int) bool { - a := replaceIdents[i].idt - b := replaceIdents[j].idt - return file.Fset.Position(a.Pos()).Offset < file.Fset.Position(b.Pos()).Offset - }) - - // replace - n := len(replaceIdents) - baseOffset := file.Fset.Position(fnDecl.Pos()).Offset - - code := file.GetCode(fnDecl) - for i := n - 1; i >= 0; i-- { - rp := replaceIdents[i] - offset := file.Fset.Position(rp.idt.Pos()).Offset - baseOffset - - var buf bytes.Buffer - buf.Grow(len(code)) - buf.Write(code[:offset]) - buf.WriteString(rp.rep) - buf.Write(code[offset+len(rp.idt.Name):]) - - code = buf.Bytes() - // NOTE: don't use slice append, content will be override - if false { - newCode := append(code[:offset:offset], []byte(rp.rep)...) - newCode = append(newCode, code[offset+len(rp.idt.Name):]...) - code = newCode - } - } - return string(code), nil -} - // content = bytes.Replace(content, []byte("//go:build ignore\n"), nil, 1) func replaceMarkerNewline(content []byte, marker []byte) ([]byte, error) { idx := bytes.Index(content, marker) @@ -618,234 +82,6 @@ func replaceMarkerNewline(content []byte, marker []byte) ([]byte, error) { } return content[idx:], nil } -func patchCompilerNoder(goroot string, goVersion *goinfo.GoVersion) error { - files := []string{"src", "cmd", "compile", "internal", "noder", "noder.go"} - var noderFiles string - if goVersion.Major == 1 { - minor := goVersion.Minor - if minor == 16 { - files = []string{"src", "cmd", "compile", "internal", "gc", "noder.go"} - noderFiles = patch.NoderFiles_1_17 - } else if minor == 17 { - noderFiles = patch.NoderFiles_1_17 - } else if minor == 18 { - noderFiles = patch.NoderFiles_1_17 - } else if minor == 19 { - noderFiles = patch.NoderFiles_1_17 - } else if minor == 20 { - noderFiles = patch.NoderFiles_1_20 - } else if minor == 21 { - noderFiles = patch.NoderFiles_1_21 - } else if minor == 22 { - noderFiles = patch.NoderFiles_1_21 - } - } - if noderFiles == "" { - return fmt.Errorf("unsupported: %v", goVersion) - } - file := filepath.Join(files...) - return editFile(filepath.Join(goroot, file), func(content string) (string, error) { - content = addCodeAfterImports(content, - "/**/", "/**/", - []string{ - `xgo_syntax "cmd/compile/internal/xgo_rewrite_internal/patch/syntax"`, - `"io"`, - }, - ) - var anchors []string - if goVersion.Major == 1 && goVersion.Minor <= 16 { - anchors = []string{ - "func parseFiles(filenames []string)", - "for _, p := range noders {", - "localpkg.Height = myheight", - "\n", - } - } else { - anchors = []string{ - `func LoadPackage`, - `for _, p := range noders {`, - `base.Timer.AddEvent(int64(lines), "lines")`, - "\n", - } - } - content = addContentAfter(content, "/**/", "/**/", anchors, - noderFiles) - return content, nil - }) -} - -func poatchIRGenericGen(goroot string, goVersion *goinfo.GoVersion) error { - file := filepath.Join(goroot, "src", "cmd", "compile", "internal", "noder", "irgen.go") - return editFile(file, func(content string) (string, error) { - imports := []string{ - `xgo_patch "cmd/compile/internal/xgo_rewrite_internal/patch"`, - } - if goVersion.Major == 1 && goVersion.Minor >= 19 { - imports = append(imports, `"os"`) - } - content = addCodeAfterImports(content, - "/**/", "/**/", - imports, - ) - content = addContentAfter(content, "/**/", "/**/", []string{ - `func (g *irgen) generate(noders []*noder) {`, - `types.DeferCheckSize()`, - `base.ExitIfErrors()`, - `typecheck.DeclareUniverse()`, - "\n", - }, - patch.GenericTrapForGo118And119) - return content, nil - }) -} - -func patchSynatxNode(goroot string, goVersion *goinfo.GoVersion) error { - if goVersion.Major > 1 || goVersion.Minor >= 22 { - return nil - } - var fragments []string - - if goVersion.Major == 1 { - if goVersion.Minor < 22 { - fragments = append(fragments, patch.NodesGen) - } - if goVersion.Minor <= 17 { - fragments = append(fragments, patch.Nodes_Inspect_117) - } - } - if len(fragments) == 0 { - return nil - } - file := filepath.Join(goroot, "src", "cmd", "compile", "internal", "syntax", "xgo_nodes.go") - return os.WriteFile(file, []byte("package syntax\n"+strings.Join(fragments, "\n")), 0755) -} - -func patchGcMain(goroot string, goVersion *goinfo.GoVersion) error { - file := filepath.Join(goroot, "src", "cmd", "compile", "internal", "gc", "main.go") - go116AndUnder := goVersion.Major == 1 && goVersion.Minor <= 16 - go117 := goVersion.Major == 1 && goVersion.Minor == 17 - go118 := goVersion.Major == 1 && goVersion.Minor == 18 - go119 := goVersion.Major == 1 && goVersion.Minor == 19 - go119AndUnder := goVersion.Major == 1 && goVersion.Minor <= 19 - go120 := goVersion.Major == 1 && goVersion.Minor == 20 - go121 := goVersion.Major == 1 && goVersion.Minor == 21 - go122 := goVersion.Major == 1 && goVersion.Minor == 22 - - return editFile(file, func(content string) (string, error) { - imports := []string{ - `xgo_patch "cmd/compile/internal/xgo_rewrite_internal/patch"`, - `xgo_record "cmd/compile/internal/xgo_rewrite_internal/patch/record"`, - } - content = addCodeAfterImports(content, - "/**/", "/**/", - imports, - ) - initRuntimeTypeCheckGo117 := `typecheck.InitRuntime()` - - var beforePatchContent string - var patchAnchors []string - - if go116AndUnder { - // go1.16 is pretty old - patchAnchors = []string{ - "loadsys()", - "parseFiles(flag.Args())", - "finishUniverse()", - "recordPackageName()", - } - } else { - patchAnchors = []string{`noder.LoadPackage(flag.Args())`, `dwarfgen.RecordPackageName()`} - if !go117 { - patchAnchors = append(patchAnchors, `ssagen.InitConfig()`) - } else { - // go 1.17 needs to call typecheck.InitRuntime() before patch - beforePatchContent = initRuntimeTypeCheckGo117 + "\n" - } - } - patchAnchors = append(patchAnchors, "\n") - content = addContentAfter(content, - "/**/", "/**/", - patchAnchors, - ` // insert trap points - if os.Getenv("XGO_COMPILER_ENABLE")=="true" { - `+beforePatchContent+`xgo_patch.Patch() - } -`) - - if go117 { - // go1.17 needs to adjust typecheck.InitRuntime before patch - content = replaceContentAfter(content, - "/**/", "/**/", - []string{`escape.Funcs(typecheck.Target.Decls)`, `if base.Flag.CompilingRuntime {`, "}", "\n"}, - initRuntimeTypeCheckGo117, - `if os.Getenv("XGO_COMPILER_ENABLE")!="true" { - `+initRuntimeTypeCheckGo117+` - }`, - ) - } - - // turn off inline when there is rewrite(gcflags=-l) - // windows: also turn off optimization(gcflags=-N) - var flagNSwitch = "" - if runtime.GOOS == "windows" { - flagNSwitch = "\n" + "base.Flag.N = 1" - } - - // there are two ways to turn off inline - // - 1. by not calling to inline.InlinePackage - // - 2. by override base.Flag.LowerL to 0 - // prefer 1 because it is more focused - if go116AndUnder { - inlineGuard := `if Debug.l != 0 {` - inlineAnchors := []string{ - `fninit(xtop)`, - `Curfn = nil`, - `// Phase 5: Inlining`, - `if Debug_typecheckinl != 0 {`, - "\n", - } - content = replaceContentAfter(content, - "/**/", "/**/", - inlineAnchors, - inlineGuard, - ` // NOTE: turn off inline if there is any rewrite - `+strings.TrimSuffix(inlineGuard, " {")+` && !xgo_record.HasRewritten() {`+flagNSwitch) - } else if go117 || go118 || go119 || go120 || go121 { - inlineCall := `inline.InlinePackage(profile)` - if go119AndUnder { - // go1.19 and under does not hae PGO - inlineCall = `inline.InlinePackage()` - } - // go1.20 does not respect rewritten content when inlined - content = replaceContentAfter(content, - "/**/", "/**/", - []string{`base.Timer.Start("fe", "inlining")`, `if base.Flag.LowerL != 0 {`, "\n"}, - inlineCall, - ` // NOTE: turn off inline if there is any rewrite - if !xgo_record.HasRewritten() { - `+inlineCall+` - }else{`+flagNSwitch+` - } -`) - } else if go122 { - // go1.22 also does not respect rewritten content when inlined - // NOTE: the override of LowerL is inserted after xgo_patch.Patch() - content = addContentAfter(content, - "/**/", "/**/", - []string{`if base.Flag.LowerL <= 1 {`, `base.Flag.LowerL = 1 - base.Flag.LowerL`, "}", "xgo_patch.Patch()", "}", "\n"}, - ` // NOTE: turn off inline if there is any rewrite - if xgo_record.HasRewritten() {`+flagNSwitch+` - base.Flag.LowerL = 0 - } - `) - } else { - return "", fmt.Errorf("inline for %v not defined", goVersion) - } - - return content, nil - }) -} - func checkRevisionChanged(revisionFile string, currentRevision string) (bool, error) { savedRevision, err := readOrEmpty(revisionFile) if err != nil { diff --git a/cmd/xgo/patch/runtime_def.go b/cmd/xgo/patch/runtime_def.go index 542dbf9d..9d9f2029 100644 --- a/cmd/xgo/patch/runtime_def.go +++ b/cmd/xgo/patch/runtime_def.go @@ -7,6 +7,12 @@ for _, fn := range __xgo_on_init_finished_callbacks { __xgo_on_init_finished_callbacks = nil ` +const RuntimeProcGoroutineCreatedPatch = `for _, fn := range __xgo_on_gonewproc_callbacks { + fn(uintptr(unsafe.Pointer(newg))) +} +return newg +` + // added after goroutine exit1 const RuntimeProcGoroutineExitPatch = `for _, fn := range __xgo_on_goexits { fn() diff --git a/cmd/xgo/patch/runtime_def_gen.go b/cmd/xgo/patch/runtime_def_gen.go index 3cf4cc54..50a23332 100644 --- a/cmd/xgo/patch/runtime_def_gen.go +++ b/cmd/xgo/patch/runtime_def_gen.go @@ -13,6 +13,7 @@ func __xgo_register_func(info interface{}) func __xgo_retrieve_all_funcs_and_clear(f func(info interface{})) func __xgo_init_finished() bool func __xgo_on_init_finished(fn func()) +func __xgo_on_gonewproc(fn func(g uintptr)) func __xgo_on_goexit(fn func()) func __xgo_on_test_start(fn interface{}) func __xgo_get_test_starts() []interface{} diff --git a/cmd/xgo/patch_compiler.go b/cmd/xgo/patch_compiler.go new file mode 100644 index 00000000..20e08313 --- /dev/null +++ b/cmd/xgo/patch_compiler.go @@ -0,0 +1,424 @@ +package main + +import ( + "errors" + "fmt" + "io/fs" + "os" + "os/exec" + "path/filepath" + "runtime" + "strings" + + "github.com/xhd2015/xgo/cmd/xgo/patch" + "github.com/xhd2015/xgo/support/filecopy" + "github.com/xhd2015/xgo/support/goinfo" + "github.com/xhd2015/xgo/support/osinfo" +) + +func patchCompiler(origGoroot string, goroot string, goVersion *goinfo.GoVersion, xgoSrc string, forceReset bool, syncWithLink bool) error { + // copy compiler internal dependencies + err := importCompileInternalPatch(goroot, xgoSrc, forceReset, syncWithLink) + if err != nil { + return err + } + runtimeDefUpdated, err := addRuntimeFunctions(goroot, goVersion, xgoSrc) + if err != nil { + return err + } + + if runtimeDefUpdated { + err = patchRuntimeDef(origGoroot, goroot, goVersion) + if err != nil { + return err + } + } + + // NOTE: not adding reflect to access any method + if false { + err = addReflectFunctions(goroot, goVersion, xgoSrc) + if err != nil { + return err + } + } + + err = patchCompilerInternal(goroot, goVersion) + if err != nil { + return err + } + return nil +} + +func patchSynatxNode(goroot string, goVersion *goinfo.GoVersion) error { + if goVersion.Major > 1 || goVersion.Minor >= 22 { + return nil + } + var fragments []string + + if goVersion.Major == 1 { + if goVersion.Minor < 22 { + fragments = append(fragments, patch.NodesGen) + } + if goVersion.Minor <= 17 { + fragments = append(fragments, patch.Nodes_Inspect_117) + } + } + if len(fragments) == 0 { + return nil + } + file := filepath.Join(goroot, "src", "cmd", "compile", "internal", "syntax", "xgo_nodes.go") + return os.WriteFile(file, []byte("package syntax\n"+strings.Join(fragments, "\n")), 0755) +} + +func patchGcMain(goroot string, goVersion *goinfo.GoVersion) error { + file := filepath.Join(goroot, "src", "cmd", "compile", "internal", "gc", "main.go") + go116AndUnder := goVersion.Major == 1 && goVersion.Minor <= 16 + go117 := goVersion.Major == 1 && goVersion.Minor == 17 + go118 := goVersion.Major == 1 && goVersion.Minor == 18 + go119 := goVersion.Major == 1 && goVersion.Minor == 19 + go119AndUnder := goVersion.Major == 1 && goVersion.Minor <= 19 + go120 := goVersion.Major == 1 && goVersion.Minor == 20 + go121 := goVersion.Major == 1 && goVersion.Minor == 21 + go122 := goVersion.Major == 1 && goVersion.Minor == 22 + + return editFile(file, func(content string) (string, error) { + imports := []string{ + `xgo_patch "cmd/compile/internal/xgo_rewrite_internal/patch"`, + `xgo_record "cmd/compile/internal/xgo_rewrite_internal/patch/record"`, + } + content = addCodeAfterImports(content, + "/**/", "/**/", + imports, + ) + initRuntimeTypeCheckGo117 := `typecheck.InitRuntime()` + + var beforePatchContent string + var patchAnchors []string + + if go116AndUnder { + // go1.16 is pretty old + patchAnchors = []string{ + "loadsys()", + "parseFiles(flag.Args())", + "finishUniverse()", + "recordPackageName()", + } + } else { + patchAnchors = []string{`noder.LoadPackage(flag.Args())`, `dwarfgen.RecordPackageName()`} + if !go117 { + patchAnchors = append(patchAnchors, `ssagen.InitConfig()`) + } else { + // go 1.17 needs to call typecheck.InitRuntime() before patch + beforePatchContent = initRuntimeTypeCheckGo117 + "\n" + } + } + patchAnchors = append(patchAnchors, "\n") + content = addContentAfter(content, + "/**/", "/**/", + patchAnchors, + ` // insert trap points + if os.Getenv("XGO_COMPILER_ENABLE")=="true" { + `+beforePatchContent+`xgo_patch.Patch() + } +`) + + if go117 { + // go1.17 needs to adjust typecheck.InitRuntime before patch + content = replaceContentAfter(content, + "/**/", "/**/", + []string{`escape.Funcs(typecheck.Target.Decls)`, `if base.Flag.CompilingRuntime {`, "}", "\n"}, + initRuntimeTypeCheckGo117, + `if os.Getenv("XGO_COMPILER_ENABLE")!="true" { + `+initRuntimeTypeCheckGo117+` + }`, + ) + } + + // turn off inline when there is rewrite(gcflags=-l) + // windows: also turn off optimization(gcflags=-N) + var flagNSwitch = "" + if runtime.GOOS == "windows" { + flagNSwitch = "\n" + "base.Flag.N = 1" + } + + // there are two ways to turn off inline + // - 1. by not calling to inline.InlinePackage + // - 2. by override base.Flag.LowerL to 0 + // prefer 1 because it is more focused + if go116AndUnder { + inlineGuard := `if Debug.l != 0 {` + inlineAnchors := []string{ + `fninit(xtop)`, + `Curfn = nil`, + `// Phase 5: Inlining`, + `if Debug_typecheckinl != 0 {`, + "\n", + } + content = replaceContentAfter(content, + "/**/", "/**/", + inlineAnchors, + inlineGuard, + ` // NOTE: turn off inline if there is any rewrite + `+strings.TrimSuffix(inlineGuard, " {")+` && !xgo_record.HasRewritten() {`+flagNSwitch) + } else if go117 || go118 || go119 || go120 || go121 { + inlineCall := `inline.InlinePackage(profile)` + if go119AndUnder { + // go1.19 and under does not hae PGO + inlineCall = `inline.InlinePackage()` + } + // go1.20 does not respect rewritten content when inlined + content = replaceContentAfter(content, + "/**/", "/**/", + []string{`base.Timer.Start("fe", "inlining")`, `if base.Flag.LowerL != 0 {`, "\n"}, + inlineCall, + ` // NOTE: turn off inline if there is any rewrite + if !xgo_record.HasRewritten() { + `+inlineCall+` + }else{`+flagNSwitch+` + } +`) + } else if go122 { + // go1.22 also does not respect rewritten content when inlined + // NOTE: the override of LowerL is inserted after xgo_patch.Patch() + content = addContentAfter(content, + "/**/", "/**/", + []string{`if base.Flag.LowerL <= 1 {`, `base.Flag.LowerL = 1 - base.Flag.LowerL`, "}", "xgo_patch.Patch()", "}", "\n"}, + ` // NOTE: turn off inline if there is any rewrite + if xgo_record.HasRewritten() {`+flagNSwitch+` + base.Flag.LowerL = 0 + } + `) + } else { + return "", fmt.Errorf("inline for %v not defined", goVersion) + } + + return content, nil + }) +} + +func patchCompilerNoder(goroot string, goVersion *goinfo.GoVersion) error { + files := []string{"src", "cmd", "compile", "internal", "noder", "noder.go"} + var noderFiles string + if goVersion.Major == 1 { + minor := goVersion.Minor + if minor == 16 { + files = []string{"src", "cmd", "compile", "internal", "gc", "noder.go"} + noderFiles = patch.NoderFiles_1_17 + } else if minor == 17 { + noderFiles = patch.NoderFiles_1_17 + } else if minor == 18 { + noderFiles = patch.NoderFiles_1_17 + } else if minor == 19 { + noderFiles = patch.NoderFiles_1_17 + } else if minor == 20 { + noderFiles = patch.NoderFiles_1_20 + } else if minor == 21 { + noderFiles = patch.NoderFiles_1_21 + } else if minor == 22 { + noderFiles = patch.NoderFiles_1_21 + } + } + if noderFiles == "" { + return fmt.Errorf("unsupported: %v", goVersion) + } + file := filepath.Join(files...) + return editFile(filepath.Join(goroot, file), func(content string) (string, error) { + content = addCodeAfterImports(content, + "/**/", "/**/", + []string{ + `xgo_syntax "cmd/compile/internal/xgo_rewrite_internal/patch/syntax"`, + `"io"`, + }, + ) + var anchors []string + if goVersion.Major == 1 && goVersion.Minor <= 16 { + anchors = []string{ + "func parseFiles(filenames []string)", + "for _, p := range noders {", + "localpkg.Height = myheight", + "\n", + } + } else { + anchors = []string{ + `func LoadPackage`, + `for _, p := range noders {`, + `base.Timer.AddEvent(int64(lines), "lines")`, + "\n", + } + } + content = addContentAfter(content, "/**/", "/**/", anchors, + noderFiles) + return content, nil + }) +} + +func poatchIRGenericGen(goroot string, goVersion *goinfo.GoVersion) error { + file := filepath.Join(goroot, "src", "cmd", "compile", "internal", "noder", "irgen.go") + return editFile(file, func(content string) (string, error) { + imports := []string{ + `xgo_patch "cmd/compile/internal/xgo_rewrite_internal/patch"`, + } + if goVersion.Major == 1 && goVersion.Minor >= 19 { + imports = append(imports, `"os"`) + } + content = addCodeAfterImports(content, + "/**/", "/**/", + imports, + ) + content = addContentAfter(content, "/**/", "/**/", []string{ + `func (g *irgen) generate(noders []*noder) {`, + `types.DeferCheckSize()`, + `base.ExitIfErrors()`, + `typecheck.DeclareUniverse()`, + "\n", + }, + patch.GenericTrapForGo118And119) + return content, nil + }) +} + +func importCompileInternalPatch(goroot string, xgoSrc string, forceReset bool, syncWithLink bool) error { + dstDir := getInternalPatch(goroot) + if isDevelopment { + symLink := syncWithLink + if osinfo.FORCE_COPY_UNSYM { + // windows: A required privilege is not held by the client. + symLink = false + } + // copy compiler internal dependencies + err := filecopy.CopyReplaceDir(filepath.Join(xgoSrc, "patch"), dstDir, symLink) + if err != nil { + return err + } + + // remove patch/go.mod + err = os.RemoveAll(filepath.Join(dstDir, "go.mod")) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return nil + } + return err + } + return nil + } + + if forceReset { + // -a causes repatch + err := os.RemoveAll(dstDir) + if err != nil { + return err + } + } else { + // check if already copied + _, statErr := os.Stat(dstDir) + if statErr == nil { + // skip copy if already exists + return nil + } + } + + // read from embed + err := fs.WalkDir(patchEmbed, "patch_compiler", func(path string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + if path == "patch_compiler" { + return os.MkdirAll(dstDir, 0755) + } + // TODO: test on windows if "/" works + dstPath := filepath.Join(dstDir, strings.TrimPrefix(path, "patch_compiler/")) + if d.IsDir() { + return os.MkdirAll(dstPath, 0755) + } + + content, err := patchEmbed.ReadFile(path) + if err != nil { + return err + } + return os.WriteFile(dstPath, content, 0755) + }) + if err != nil { + return err + } + + return nil +} + +func patchRuntimeDef(origGoroot string, goroot string, goVersion *goinfo.GoVersion) error { + err := prepareRuntimeDefs(goroot, goVersion) + if err != nil { + return err + } + + // run mkbuiltin + cmd := exec.Command(filepath.Join(origGoroot, "bin", "go"), "run", "mkbuiltin.go") + cmd.Stderr = os.Stderr + cmd.Stdout = os.Stdout + + var dirs []string + if goVersion.Major > 1 || (goVersion.Major == 1 && goVersion.Minor > 16) { + dirs = []string{goroot, "src", "cmd", "compile", "internal", "typecheck"} + } else { + dirs = []string{goroot, "src", "cmd", "compile", "internal", "gc"} + } + cmd.Dir = filepath.Join(dirs...) + cmd.Env = os.Environ() + cmd.Env, err = patchEnvWithGoroot(cmd.Env, origGoroot) + if err != nil { + return err + } + + err = cmd.Run() + if err != nil { + return err + } + + return nil +} + +func prepareRuntimeDefs(goRoot string, goVersion *goinfo.GoVersion) error { + runtimeDefFiles := []string{"src", "cmd", "compile", "internal", "typecheck", "_builtin", "runtime.go"} + if goVersion.Major == 1 && goVersion.Minor <= 19 { + if goVersion.Minor > 16 { + // in go1.19 and below, builtin has no _ prefix + runtimeDefFiles = []string{"src", "cmd", "compile", "internal", "typecheck", "builtin", "runtime.go"} + } else { + runtimeDefFiles = []string{"src", "cmd", "compile", "internal", "gc", "builtin", "runtime.go"} + } + } + runtimeDefFile := filepath.Join(runtimeDefFiles...) + fullFile := filepath.Join(goRoot, runtimeDefFile) + + extraDef := patch.RuntimeExtraDef + return editFile(fullFile, func(content string) (string, error) { + content = addContentAfter(content, + `/**/`, `/**/`, + []string{`var x86HasFMA bool`, `var armHasVFPv4 bool`, `var arm64HasATOMICS bool`}, + extraDef, + ) + return content, nil + }) +} + +func patchCompilerInternal(goroot string, goVersion *goinfo.GoVersion) error { + // src/cmd/compile/internal/noder/noder.go + err := patchCompilerNoder(goroot, goVersion) + if err != nil { + return fmt.Errorf("patching noder: %w", err) + } + if goVersion.Major == 1 && (goVersion.Minor == 18 || goVersion.Minor == 19) { + err := poatchIRGenericGen(goroot, goVersion) + if err != nil { + return fmt.Errorf("patching generic trap: %w", err) + } + } + err = patchSynatxNode(goroot, goVersion) + if err != nil { + return fmt.Errorf("patching syntax node:%w", err) + } + err = patchGcMain(goroot, goVersion) + if err != nil { + return fmt.Errorf("patching gc main:%w", err) + } + return nil +} diff --git a/cmd/xgo/patch_reflect.go b/cmd/xgo/patch_reflect.go new file mode 100644 index 00000000..1455d3a3 --- /dev/null +++ b/cmd/xgo/patch_reflect.go @@ -0,0 +1,257 @@ +// patch reflect package +// NOTE: not used currently +package main + +import ( + "bytes" + "errors" + "fmt" + "go/ast" + "os" + "path/filepath" + "sort" + "strings" + + "github.com/xhd2015/xgo/support/goinfo" + "github.com/xhd2015/xgo/support/transform" +) + +func addReflectFunctions(goroot string, goVersion *goinfo.GoVersion, xgoSrc string) error { + dstFile := filepath.Join(goroot, "src", "reflect", "xgo_reflect.go") + content, err := readXgoSrc(xgoSrc, []string{"trap_runtime", "xgo_reflect.go"}) + if err != nil { + return err + } + + content, err = replaceBuildIgnore(content) + if err != nil { + return fmt.Errorf("file %s: %w", filepath.Base(dstFile), err) + } + + valCode, err := transformReflectValue(filepath.Join(goroot, "src", "reflect", "value.go")) + if err != nil { + return fmt.Errorf("transforming reflect/value.go: %w", err) + } + typeCode, err := transformReflectType(filepath.Join(goroot, "src", "reflect", "type.go")) + if err != nil { + return fmt.Errorf("transforming reflect/type.go: %w", err) + } + + // fmt.Printf("typCode: %s\n", typeCode) + + // concat all code + content = bytes.Join([][]byte{content, []byte(valCode), []byte(typeCode)}, []byte("\n")) + return os.WriteFile(dstFile, content, 0755) +} + +const xgoGetAllMethodByName = "__xgo_get_all_method_by_name" + +func transformReflectValue(reflectValueFile string) (string, error) { + file, err := transform.Parse(reflectValueFile) + if err != nil { + return "", err + } + + fnDecl := file.GetMethodDecl("Value", "MethodByName") + if fnDecl == nil { + return "", fmt.Errorf("cannot find Value.MethodByName") + } + + code, err := replaceIdent(file, fnDecl, xgoGetAllMethodByName, func(n ast.Node) (*ast.Ident, string) { + sel, ok := n.(*ast.SelectorExpr) + if !ok { + return nil, "" + } + + idt := sel.Sel + switch idt.Name { + case "MethodByName": + return idt, xgoGetAllMethodByName + case "Method": // method by index + return idt, "__xgo_get_all_method_index" + } + + return nil, "" + }) + if err != nil { + return "", fmt.Errorf("replacing MethodByName: %w", err) + } + + methodDecl := file.GetMethodDecl("Value", "Method") // method by index + if methodDecl == nil { + return "", fmt.Errorf("cannot find Value.Method") + } + code2, err := replaceIdent(file, methodDecl, "__xgo_get_all_method_index", func(n ast.Node) (*ast.Ident, string) { + sel, ok := n.(*ast.SelectorExpr) + if !ok { + return nil, "" + } + + idt := sel.Sel + switch idt.Name { + case "NumMethod": // method by index + return idt, "__xgo_get_all_method_num" + } + return nil, "" + }) + if err != nil { + return "", fmt.Errorf("replacing Method: %w", err) + } + + codef := strings.Join([]string{code, code2}, "\n") + return codef, nil +} + +func transformReflectType(reflectTypeFile string) (string, error) { + file, err := transform.Parse(reflectTypeFile) + if err != nil { + return "", err + } + fnDecl := file.GetMethodDecl("rtype", "MethodByName") + if fnDecl == nil { + return "", fmt.Errorf("cannot find rtype.MethodByName") + } + m0, err := replaceIdent(file, fnDecl, xgoGetAllMethodByName, func(n ast.Node) (*ast.Ident, string) { + sel, ok := n.(*ast.SelectorExpr) + if !ok { + return nil, "" + } + + idt := sel.Sel + if idt.Name == "ExportedMethods" { + return idt, "Methods" + } else if idt.Name == "Method" { + return idt, "__xgo_get_all_method_index" + } + return nil, "" + }) + if err != nil { + return "", fmt.Errorf("replacing ExportedMethods: %w", err) + } + + fnDecl2 := file.GetMethodDecl("rtype", "exportedMethods") + if fnDecl2 == nil { + return "", fmt.Errorf("cannot find rtype.exportedMethods") + } + + m1, err := replaceIdent(file, fnDecl2, "__xgo_all_methods", func(n ast.Node) (*ast.Ident, string) { + sel, ok := n.(*ast.SelectorExpr) + if !ok { + return nil, "" + } + + idt := sel.Sel + if idt.Name == "ExportedMethods" { + return idt, "Methods" + } + return nil, "" + }) + if err != nil { + return "", err + } + + methodDecl := file.GetMethodDecl("rtype", "Method") + if methodDecl == nil { + return "", fmt.Errorf("cannot find rtype.Method") + } + m2, err := replaceIdent(file, methodDecl, "__xgo_get_all_method_index", func(n ast.Node) (*ast.Ident, string) { + sel, ok := n.(*ast.SelectorExpr) + if !ok { + return nil, "" + } + + idt := sel.Sel + if idt.Name == "exportedMethods" { + return idt, "__xgo_all_methods" + } + return nil, "" + }) + if err != nil { + return "", fmt.Errorf("replacing Method: %w", err) + } + + numA := file.GetMethodDecl("rtype", "NumMethod") + if numA == nil { + return "", fmt.Errorf("cannot find rtype.NumMethod") + } + m3, err := replaceIdent(file, numA, "__xgo_get_all_method_num", func(n ast.Node) (*ast.Ident, string) { + sel, ok := n.(*ast.SelectorExpr) + if !ok { + return nil, "" + } + + idt := sel.Sel + if idt.Name == "exportedMethods" { + return idt, "__xgo_all_methods" + } + return nil, "" + }) + if err != nil { + return "", fmt.Errorf("replacing Method: %w", err) + } + code := strings.Join([]string{m0, m1, m2, m3}, "\n") + return code, nil +} + +func replaceIdent(file *transform.File, fnDecl *ast.FuncDecl, replaceFuncName string, identReplacer func(n ast.Node) (*ast.Ident, string)) (string, error) { + type replaceIdent struct { + idt *ast.Ident + rep string + } + var replaceIdents []replaceIdent + ast.Inspect(fnDecl.Body, func(n ast.Node) bool { + if n == nil { + // post action + return false + } + idt, replace := identReplacer(n) + if idt != nil { + replaceIdents = append(replaceIdents, replaceIdent{ + idt: idt, + rep: replace, + }) + } + return true + }) + if len(replaceIdents) == 0 { + return "", errors.New("no replace found") + } + if replaceFuncName != "" { + // replace the name + replaceIdents = append(replaceIdents, replaceIdent{ + idt: fnDecl.Name, + rep: replaceFuncName, + }) + } + // find assignment to x + sort.Slice(replaceIdents, func(i, j int) bool { + a := replaceIdents[i].idt + b := replaceIdents[j].idt + return file.Fset.Position(a.Pos()).Offset < file.Fset.Position(b.Pos()).Offset + }) + + // replace + n := len(replaceIdents) + baseOffset := file.Fset.Position(fnDecl.Pos()).Offset + + code := file.GetCode(fnDecl) + for i := n - 1; i >= 0; i-- { + rp := replaceIdents[i] + offset := file.Fset.Position(rp.idt.Pos()).Offset - baseOffset + + var buf bytes.Buffer + buf.Grow(len(code)) + buf.Write(code[:offset]) + buf.WriteString(rp.rep) + buf.Write(code[offset+len(rp.idt.Name):]) + + code = buf.Bytes() + // NOTE: don't use slice append, content will be override + if false { + newCode := append(code[:offset:offset], []byte(rp.rep)...) + newCode = append(newCode, code[offset+len(rp.idt.Name):]...) + code = newCode + } + } + return string(code), nil +} diff --git a/cmd/xgo/patch_runtime.go b/cmd/xgo/patch_runtime.go new file mode 100644 index 00000000..f3a1f2ab --- /dev/null +++ b/cmd/xgo/patch_runtime.go @@ -0,0 +1,134 @@ +package main + +import ( + "bytes" + "fmt" + "os" + "path/filepath" + + "github.com/xhd2015/xgo/cmd/xgo/patch" + "github.com/xhd2015/xgo/support/filecopy" + "github.com/xhd2015/xgo/support/goinfo" +) + +func patchRuntimeAndTesting(goroot string) error { + err := patchRuntimeProc(goroot) + if err != nil { + return err + } + err = patchRuntimeTesting(goroot) + if err != nil { + return err + } + return nil +} + +// addRuntimeFunctions always copy file +func addRuntimeFunctions(goroot string, goVersion *goinfo.GoVersion, xgoSrc string) (updated bool, err error) { + if false { + // seems unnecessary + // TODO: needs to debug to see what will happen with auto generated files + // we need to skip when debugging + + // add debug file + // rational: when debugging, dlv will jump to __xgo_autogen_register_func_helper.go + // previousely this file does not exist, making the debugging blind + runtimeAutoGenFile := filepath.Join(goroot, "src", "runtime", "__xgo_autogen_register_func_helper.go") + srcAutoGen := getInternalPatch(goroot, "syntax", "helper_code.go") + err = filecopy.CopyFile(srcAutoGen, runtimeAutoGenFile) + if err != nil { + return false, err + } + } + + dstFile := filepath.Join(goroot, "src", "runtime", "xgo_trap.go") + content, err := readXgoSrc(xgoSrc, []string{"trap_runtime", "xgo_trap.go"}) + if err != nil { + return false, err + } + + content, err = replaceBuildIgnore(content) + if err != nil { + return false, fmt.Errorf("file %s: %w", filepath.Base(dstFile), err) + } + + // the func.entry is a field, not a function + if goVersion.Major == 1 && goVersion.Minor <= 17 { + entryPatch := "fn.entry() /*>=go1.18*/" + entryPatchBytes := []byte(entryPatch) + idx := bytes.Index(content, entryPatchBytes) + if idx < 0 { + return false, fmt.Errorf("expect %q in xgo_trap.go, actually not found", entryPatch) + } + content = bytes.ReplaceAll(content, entryPatchBytes, []byte("fn.entry")) + } + + // func name patch + if goVersion.Major > 1 || goVersion.Minor > 22 { + panic("should check the implementation of runtime.FuncForPC(pc).Name() to ensure __xgo_get_pc_name is not wrapped in print format above go1.22") + } + if goVersion.Major > 1 || goVersion.Minor >= 21 { + content = append(content, []byte(patch.RuntimeGetFuncName_Go121)...) + } else if goVersion.Major == 1 { + if goVersion.Minor >= 17 { + // go1.17,go1.18,go1.19 + content = append(content, []byte(patch.RuntimeGetFuncName_Go117_120)...) + } + } + + return true, os.WriteFile(dstFile, content, 0755) +} + +func patchRuntimeProc(goroot string) error { + anchors := []string{ + "func main() {", + "doInit(", "runtime_inittask", ")", // first doInit for runtime + "doInit(", // second init for main + "close(main_init_done)", + "\n", + } + procGo := filepath.Join(goroot, "src", "runtime", "proc.go") + err := editFile(procGo, func(content string) (string, error) { + content = addContentAfter(content, "/**/", "/**/", anchors, patch.RuntimeProcPatch) + + // goexit1() is called for every exited goroutine + content = addContentAfter(content, + "/**/", "/**/", + []string{"func goexit1() {", "\n"}, + patch.RuntimeProcGoroutineExitPatch, + ) + + content = replaceContentAfter(content, + "/**/", "/**/", + []string{ + "func newproc1(", "*g {", + }, + "return newg", + patch.RuntimeProcGoroutineCreatedPatch, + ) + return content, nil + }) + if err != nil { + return err + } + return nil +} + +func patchRuntimeTesting(goroot string) error { + testingFile := filepath.Join(goroot, "src", "testing", "testing.go") + return editFile(testingFile, func(content string) (string, error) { + // func tRunner(t *T, fn func(t *T)) { + anchor := []string{"func tRunner(t *T", "{", "\n"} + content = addContentBefore(content, + "/**/", "/**/", + anchor, + patch.TestingCallbackDeclarations, + ) + content = addContentAfter(content, + "/**/", "/**/", + anchor, + patch.TestingStart, + ) + return content, nil + }) +} diff --git a/cmd/xgo/version.go b/cmd/xgo/version.go index f37d8089..44d59df4 100644 --- a/cmd/xgo/version.go +++ b/cmd/xgo/version.go @@ -2,9 +2,9 @@ package main import "fmt" -const VERSION = "1.0.11" -const REVISION = "43756010e13cabfae008c1de9d72f98b946b0a09+1" -const NUMBER = 144 +const VERSION = "1.0.12" +const REVISION = "f3d7271450fef6b7575368a82d6fe254c894a97e+1" +const NUMBER = 147 func getRevision() string { return fmt.Sprintf("%s %s BUILD_%d", VERSION, REVISION, NUMBER) diff --git a/patch/link_name.go b/patch/link_name.go index 22fcc59d..6d781260 100644 --- a/patch/link_name.go +++ b/patch/link_name.go @@ -29,6 +29,7 @@ var linkMap = map[string]string{ xgo_syntax.XgoLinkTrapForGenerated: XgoTrapForGenerated, "__xgo_link_init_finished": "__xgo_init_finished", "__xgo_link_on_init_finished": "__xgo_on_init_finished", + "__xgo_link_on_gonewproc": "__xgo_on_gonewproc", "__xgo_link_on_goexit": "__xgo_on_goexit", "__xgo_link_on_test_start": xgoOnTestStart, "__xgo_link_get_test_starts": "__xgo_get_test_starts", diff --git a/patch/syntax/func_stub_def.go b/patch/syntax/func_stub_def.go index fcbaff04..82b952cb 100644 --- a/patch/syntax/func_stub_def.go +++ b/patch/syntax/func_stub_def.go @@ -16,9 +16,11 @@ const expected__xgo_stub_def = `struct { ArgNames []string ResNames []string - // can be retrieved at runtime + // Deprecated + // these two fields can be retrieved at runtime FirstArgCtx bool // first argument is context.Context or sub type? - LastResErr bool // last res is error or sub type? + // Deprecated + LastResErr bool // last res is error or sub type? File string Line int diff --git a/patch/syntax/helper_code.go b/patch/syntax/helper_code.go index e23f52b5..f49a83df 100644 --- a/patch/syntax/helper_code.go +++ b/patch/syntax/helper_code.go @@ -19,9 +19,11 @@ type __xgo_local_func_stub struct { ArgNames []string ResNames []string - // can be retrieved at runtime + // Deprecated + // these two fields can be retrieved at runtime FirstArgCtx bool // first argument is context.Context or sub type? - LastResErr bool // last res is error or sub type? + // Deprecated + LastResErr bool // last res is error or sub type? File string Line int diff --git a/patch/syntax/helper_code_gen.go b/patch/syntax/helper_code_gen.go index 531fec0a..4a3ab607 100755 --- a/patch/syntax/helper_code_gen.go +++ b/patch/syntax/helper_code_gen.go @@ -18,9 +18,11 @@ const __xgo_stub_def = `struct { ArgNames []string ResNames []string - // can be retrieved at runtime + // Deprecated + // these two fields can be retrieved at runtime FirstArgCtx bool // first argument is context.Context or sub type? - LastResErr bool // last res is error or sub type? + // Deprecated + LastResErr bool // last res is error or sub type? File string Line int @@ -44,9 +46,11 @@ type __xgo_local_func_stub struct { ArgNames []string ResNames []string - // can be retrieved at runtime + // Deprecated + // these two fields can be retrieved at runtime FirstArgCtx bool // first argument is context.Context or sub type? - LastResErr bool // last res is error or sub type? + // Deprecated + LastResErr bool // last res is error or sub type? File string Line int diff --git a/patch/syntax/rewrite.go b/patch/syntax/rewrite.go index 4655be9a..8598c5ec 100644 --- a/patch/syntax/rewrite.go +++ b/patch/syntax/rewrite.go @@ -11,6 +11,28 @@ import ( const XgoLinkTrapForGenerated = "__xgo_link_trap_for_generated" +func fillFuncArgResNames(fileList []*syntax.File) { + if base.Flag.Std { + return + } + for _, file := range fileList { + syntax.Inspect(file, func(n syntax.Node) bool { + if decl, ok := n.(*syntax.FuncDecl); ok { + if decl.Body == nil { + return true + } + preset := getPresetNames(decl) + fillNames(decl.Pos(), decl.Recv, decl.Type, preset) + } else if funcLit, ok := n.(*syntax.FuncLit); ok { + preset := getPresetNames(funcLit.Type) + fillNames(funcLit.Pos(), nil, funcLit.Type, preset) + } + + return true + }) + } +} + func rewriteStdAndGenericFuncs(funcDecls []*DeclInfo, pkgPath string) { for _, fn := range funcDecls { if fn.Interface { @@ -42,7 +64,7 @@ func rewriteStdAndGenericFuncs(funcDecls []*DeclInfo, pkgPath string) { preset := getPresetNames(newDecl) - fillNames(newDecl, preset) + fillNames(pos, newDecl.Recv, newDecl.Type, preset) if preset[XgoLinkTrapForGenerated] { // cannot trap continue @@ -216,12 +238,12 @@ func fillPos(pos syntax.Pos, node syntax.Node) { }) } -func fillNames(decl *syntax.FuncDecl, preset map[string]bool) { - if decl.Recv != nil { - fillFieldNames([]*syntax.Field{decl.Recv}, preset, "_x") +func fillNames(pos syntax.Pos, recv *syntax.Field, funcType *syntax.FuncType, preset map[string]bool) { + if recv != nil { + fillFieldNames(pos, []*syntax.Field{recv}, preset, "_x") } - fillFieldNames(decl.Type.ParamList, preset, "_a") - fillFieldNames(decl.Type.ResultList, preset, "_r") + fillFieldNames(pos, funcType.ParamList, preset, "_a") + fillFieldNames(pos, funcType.ResultList, preset, "_r") } func getRefSlice(pos syntax.Pos, fields []*syntax.Field) []syntax.Expr { @@ -254,10 +276,12 @@ func doGetRefAddrSlice(pos syntax.Pos, fields []*syntax.Field, addr bool) []synt } // _a0,_a1 -func fillFieldNames(fields []*syntax.Field, preset map[string]bool, prefix string) { +func fillFieldNames(pos syntax.Pos, fields []*syntax.Field, preset map[string]bool, prefix string) { for i, f := range fields { if f.Name == nil { - f.Name = &syntax.Name{} + name := &syntax.Name{} + (ISetPos)(name).SetPos(pos) + f.Name = name } else if f.Name.Value != "" && f.Name.Value != "_" { continue } @@ -433,6 +457,11 @@ func copyExpr(expr syntax.Expr) syntax.Expr { x := *expr x.ElemList = copyExprs(expr.ElemList) return &x + case *syntax.MapType: + x := *expr + x.Key = copyExpr(expr.Key) + x.Value = copyExpr(expr.Value) + return &x default: panic(fmt.Errorf("unrecognized expr while copying: %T", expr)) } diff --git a/patch/syntax/syntax.go b/patch/syntax/syntax.go index b58c45bb..7117cf64 100644 --- a/patch/syntax/syntax.go +++ b/patch/syntax/syntax.go @@ -24,6 +24,7 @@ func init() { func AfterFilesParsed(fileList []*syntax.File, addFile func(name string, r io.Reader)) { debugSyntax(fileList) patchVersions(fileList) + fillFuncArgResNames(fileList) afterFilesParsed(fileList, addFile) } @@ -476,11 +477,14 @@ func getFuncDeclInfo(fileIndex int, f *syntax.File, file string, fn *syntax.Func } var firstArgCtx bool var lastResErr bool - if len(fn.Type.ParamList) > 0 && hasQualifiedName(fn.Type.ParamList[0].Type, "context", "Context") { - firstArgCtx = true - } - if len(fn.Type.ResultList) > 0 && isName(fn.Type.ResultList[len(fn.Type.ResultList)-1].Type, "error") { - lastResErr = true + if false { + // NOTE: these fields will be retrieved at runtime dynamically + if len(fn.Type.ParamList) > 0 && hasQualifiedName(fn.Type.ParamList[0].Type, "context", "Context") { + firstArgCtx = true + } + if len(fn.Type.ResultList) > 0 && isName(fn.Type.ResultList[len(fn.Type.ResultList)-1].Type, "error") { + lastResErr = true + } } return &DeclInfo{ diff --git a/runtime/core/version.go b/runtime/core/version.go index 50aa03e3..3bda771d 100644 --- a/runtime/core/version.go +++ b/runtime/core/version.go @@ -6,9 +6,9 @@ import ( "os" ) -const VERSION = "1.0.11" -const REVISION = "43756010e13cabfae008c1de9d72f98b946b0a09+1" -const NUMBER = 144 +const VERSION = "1.0.12" +const REVISION = "f3d7271450fef6b7575368a82d6fe254c894a97e+1" +const NUMBER = 147 // these fields will be filled by compiler const XGO_VERSION = "" diff --git a/runtime/functab/functab.go b/runtime/functab/functab.go index 80aff57f..2bc1bde5 100644 --- a/runtime/functab/functab.go +++ b/runtime/functab/functab.go @@ -22,11 +22,11 @@ func init() { // a call to runtime.__xgo_for_each_func func __xgo_link_retrieve_all_funcs_and_clear(f func(fn interface{})) { // linked at runtime - fmt.Fprintln(os.Stderr, "WARNING: failed to link __xgo_link_retrieve_all_funcs_and_clear.(xgo required)") + fmt.Fprintln(os.Stderr, "WARNING: failed to link __xgo_link_retrieve_all_funcs_and_clear(requires xgo).") } func __xgo_link_on_init_finished(f func()) { - fmt.Fprintln(os.Stderr, "WARNING: failed to link __xgo_link_on_init_finished.(xgo required)") + fmt.Fprintln(os.Stderr, "WARNING: failed to link __xgo_link_on_init_finished(requires xgo).") } func __xgo_link_get_pc_name(pc uintptr) string { @@ -166,21 +166,24 @@ func ensureMapping() { generic := rv.FieldByName("Generic").Bool() f := rv.FieldByName("Fn").Interface() - firstArgCtx := rv.FieldByName("FirstArgCtx").Bool() - lastResErr := rv.FieldByName("LastResErr").Bool() + var firstArgCtx bool + var lastResErr bool var pc uintptr var fullName string if !generic && !interface_ { if f != nil { - if closure { - // TODO: move all ctx, err check logic here - ft := reflect.TypeOf(f) - if ft.NumIn() > 0 && ft.In(0).Implements(ctxType) { - firstArgCtx = true - } - if ft.NumOut() > 0 && ft.Out(ft.NumOut()-1).Implements(errType) { - lastResErr = true - } + // TODO: move all ctx, err check logic here + ft := reflect.TypeOf(f) + off := 0 + if recvTypeName != "" { + off = 1 + } + if ft.NumIn() > off && ft.In(off).Implements(ctxType) { + firstArgCtx = true + } + // NOTE: use == instead of implements + if ft.NumOut() > 0 && ft.Out(ft.NumOut()-1) == errType { + lastResErr = true } pc = getFuncPC(f) fullName = __xgo_link_get_pc_name(pc) diff --git a/runtime/mock/patch.go b/runtime/mock/patch.go new file mode 100644 index 00000000..27d4822c --- /dev/null +++ b/runtime/mock/patch.go @@ -0,0 +1,69 @@ +package mock + +import ( + "context" + "fmt" + "reflect" + + "github.com/xhd2015/xgo/runtime/core" +) + +func PatchByName(pkgPath string, funcName string, replacer interface{}) func() { + return MockByName(pkgPath, funcName, buildInterceptorFromPatch(replacer)) +} + +func PatchMethodByName(instance interface{}, method string, replacer interface{}) func() { + return MockMethodByName(instance, method, buildInterceptorFromPatch(replacer)) +} + +func buildInterceptorFromPatch(replacer interface{}) func(ctx context.Context, fn *core.FuncInfo, args, results core.Object) error { + v := reflect.ValueOf(replacer) + t := v.Type() + if t.Kind() != reflect.Func { + panic(fmt.Errorf("requires func, given %T", replacer)) + } + if v.IsNil() { + panic("replacer is nil") + } + nIn := t.NumIn() + return func(ctx context.Context, fn *core.FuncInfo, args, results core.Object) error { + // assemble arguments + callArgs := make([]reflect.Value, nIn) + src := 0 + dst := 0 + if fn.RecvType != "" { + src++ + } + if fn.FirstArgCtx { + callArgs[0] = reflect.ValueOf(ctx) + dst++ + } + for i := 0; i < nIn-dst; i++ { + callArgs[dst+i] = reflect.ValueOf(args.GetFieldIndex(src + i).Value()) + } + + // call the function + var res []reflect.Value + if !t.IsVariadic() { + res = v.Call(callArgs) + } else { + res = v.CallSlice(callArgs) + } + + // assign result + nOut := len(res) + resLen := nOut + if fn.LastResultErr { + resLen-- + } + for i := 0; i < resLen; i++ { + results.GetFieldIndex(i).Set(res[i].Interface()) + } + + if fn.LastResultErr { + results.(core.ObjectWithErr).GetErr().Set(res[nOut-1].Interface()) + } + + return nil + } +} diff --git a/runtime/mock/patch_go1.17.go b/runtime/mock/patch_go1.17.go new file mode 100644 index 00000000..4e39ac50 --- /dev/null +++ b/runtime/mock/patch_go1.17.go @@ -0,0 +1,8 @@ +//go:build !go1.18 +// +build !go1.18 + +package mock + +func Patch(fn interface{}, replacer interface{}) func() { + return Mock(fn, buildInterceptorFromPatch(replacer)) +} diff --git a/runtime/mock/patch_go1.18.go b/runtime/mock/patch_go1.18.go new file mode 100644 index 00000000..d921cad7 --- /dev/null +++ b/runtime/mock/patch_go1.18.go @@ -0,0 +1,10 @@ +//go:build go1.18 +// +build go1.18 + +package mock + +// TODO: what if `fn` is a Type function +// instead of an instance method? +func Patch[T any](fn T, replacer T) func() { + return Mock(fn, buildInterceptorFromPatch(replacer)) +} diff --git a/runtime/test/debug/debug_test.go b/runtime/test/debug/debug_test.go index 67674a65..b691915e 100644 --- a/runtime/test/debug/debug_test.go +++ b/runtime/test/debug/debug_test.go @@ -7,37 +7,37 @@ package debug import ( "context" + "fmt" "testing" "github.com/xhd2015/xgo/runtime/core" "github.com/xhd2015/xgo/runtime/trap" ) -func TestTrapAbortInTheMiddle(t *testing.T) { +func TestNewGoroutineShouldInheritInterceptor(t *testing.T) { trap.AddInterceptor(&trap.Interceptor{ - Pre: func(ctx context.Context, f *core.FuncInfo, args, result core.Object) (interface{}, error) { - if f.IdentityName == "double" { - panic("should be aborted") - } - return nil, nil - }, - }) - trap.AddInterceptor(&trap.Interceptor{ - Pre: func(ctx context.Context, f *core.FuncInfo, args, result core.Object) (interface{}, error) { - if f.IdentityName == "double" { - args.GetField("i").Set(20) + Pre: func(ctx context.Context, f *core.FuncInfo, args, result core.Object) (data interface{}, err error) { + if f.IdentityName == "greet" { + result.GetFieldIndex(0).Set("mock " + args.GetFieldIndex(0).Value().(string)) return nil, trap.ErrAbort } - return nil, nil + return }, }) - d := double(1) - if d != 0 { - t.Fatalf("expect double(1) to be aborted so return %d actual: %d", 0, d) - } + done := make(chan struct{}) + go func() { + s := greet("world") + if s != "mock world" { + panic(fmt.Errorf("expect greet returns %q, actual: %q", "mock world", s)) + } + + close(done) + }() + + <-done } -func double(i int) int { - return 2 * i +func greet(s string) string { + return "hello " + s } diff --git a/runtime/test/patch/err_test.go b/runtime/test/patch/err_test.go new file mode 100644 index 00000000..a57ef3d1 --- /dev/null +++ b/runtime/test/patch/err_test.go @@ -0,0 +1,32 @@ +package patch + +import ( + "fmt" + "testing" + + "github.com/xhd2015/xgo/runtime/mock" +) + +func toErr(s string) error { + return fmt.Errorf("err: %v", s) +} + +func TestPatchShouldWorkWithErrReturn(t *testing.T) { + mock.Patch(toErr, func(s string) error { + return fmt.Errorf("mock: %v", s) + }) + err := toErr("test") + if err.Error() != "mock: test" { + t.Fatalf("expect toErr() patched to be %q, actual: %q", "mock: test", err.Error()) + } +} + +func TestPatchShouldWorkWithErrNilReturn(t *testing.T) { + mock.Patch(toErr, func(s string) error { + return nil + }) + err := toErr("test") + if err != nil { + t.Fatalf("expect toErr() patched to be nil, actual: %v", err) + } +} diff --git a/runtime/test/patch/patch_test.go b/runtime/test/patch/patch_test.go new file mode 100644 index 00000000..b73c4d4b --- /dev/null +++ b/runtime/test/patch/patch_test.go @@ -0,0 +1,60 @@ +package patch + +import ( + "strings" + "testing" + + "github.com/xhd2015/xgo/runtime/mock" +) + +func greet(s string) string { + return "hello " + s +} + +func greetVaradic(s ...string) string { + return "hello " + strings.Join(s, ",") +} + +func TestPatchSimpleFunc(t *testing.T) { + mock.Patch(greet, func(s string) string { + return "mock " + s + }) + + res := greet("world") + if res != "mock world" { + t.Fatalf("expect patched result to be %q, actual: %q", "mock world", res) + } +} + +func TestPatchVaradicFunc(t *testing.T) { + mock.Patch(greetVaradic, func(s ...string) string { + return "mock " + strings.Join(s, ",") + }) + + res := greetVaradic("earth", "moon") + if res != "mock earth,moon" { + t.Fatalf("expect patched result to be %q, actual: %q", "mock earth,moon", res) + } +} + +type struct_ struct { + s string +} + +func (c *struct_) greet() string { + return "hello " + c.s +} + +func TestPatchMethod(t *testing.T) { + ins := &struct_{ + s: "world", + } + mock.Patch(ins.greet, func() string { + return "mock " + ins.s + }) + + res := ins.greet() + if res != "mock world" { + t.Fatalf("expect patched result to be %q, actual: %q", "mock world", res) + } +} diff --git a/runtime/test/testing_callback/main_test.go b/runtime/test/testing_callback/main_test.go index 386bdff2..8aae032a 100644 --- a/runtime/test/testing_callback/main_test.go +++ b/runtime/test/testing_callback/main_test.go @@ -6,7 +6,7 @@ import ( ) func __xgo_link_on_test_start(fn func(t *testing.T, fn func(t *testing.T))) { - panic("WARNING: failed to link __xgo_link_on_test_start.(xgo required)") + panic("WARNING: failed to link __xgo_link_on_test_start(requires xgo).") // link by compiler } func init() { diff --git a/runtime/test/trap/trap_inherit_goroutine_test.go b/runtime/test/trap/trap_inherit_goroutine_test.go new file mode 100644 index 00000000..1fb78157 --- /dev/null +++ b/runtime/test/trap/trap_inherit_goroutine_test.go @@ -0,0 +1,38 @@ +package trap + +import ( + "context" + "fmt" + "testing" + + "github.com/xhd2015/xgo/runtime/core" + "github.com/xhd2015/xgo/runtime/trap" +) + +func TestNewGoroutineShouldInheritInterceptor(t *testing.T) { + trap.AddInterceptor(&trap.Interceptor{ + Pre: func(ctx context.Context, f *core.FuncInfo, args, result core.Object) (data interface{}, err error) { + if f.IdentityName == "greet" { + result.GetFieldIndex(0).Set("mock " + args.GetFieldIndex(0).Value().(string)) + return nil, trap.ErrAbort + } + return + }, + }) + + done := make(chan struct{}) + go func() { + s := greet("world") + if s != "mock world" { + panic(fmt.Errorf("expect greet returns %q, actual: %q", "mock world", s)) + } + + close(done) + }() + + <-done +} + +func greet(s string) string { + return "hello " + s +} diff --git a/runtime/test/trap/trap_overlay_test.go b/runtime/test/trap/trap_overlay_test.go new file mode 100644 index 00000000..832391b2 --- /dev/null +++ b/runtime/test/trap/trap_overlay_test.go @@ -0,0 +1,42 @@ +// when multiple interceptors added, the order is reversed +package trap + +import ( + "context" + "testing" + + "github.com/xhd2015/xgo/runtime/core" + "github.com/xhd2015/xgo/runtime/trap" +) + +func TestTrapOverlay(t *testing.T) { + trap.AddInterceptor(&trap.Interceptor{ + Pre: func(ctx context.Context, f *core.FuncInfo, args, result core.Object) (data interface{}, err error) { + if f.IdentityName == "testOverlay" { + panic("first trap should not be called") + } + return nil, nil + }, + }) + + // overlay + var trapCalled bool + trap.AddInterceptor(&trap.Interceptor{ + Pre: func(ctx context.Context, f *core.FuncInfo, args, result core.Object) (data interface{}, err error) { + if f.IdentityName == "testOverlay" { + trapCalled = true + return nil, trap.ErrAbort + } + return + }, + }) + testOverlay() + + if !trapCalled { + t.Fatalf("expect trap to have been called, actually not called") + } +} + +func testOverlay() { + +} diff --git a/runtime/test/trap_args/closure_test.go b/runtime/test/trap_args/closure_test.go new file mode 100644 index 00000000..41f5a2a4 --- /dev/null +++ b/runtime/test/trap_args/closure_test.go @@ -0,0 +1,54 @@ +package trap_args + +import ( + "context" + "testing" + + "github.com/xhd2015/xgo/runtime/core" +) + +var gc = func(ctx context.Context) { + panic("gc should be trapped") +} + +var gcUnnamed = func(context.Context) { + panic("gcUnnamed should be trapped") +} + +func TestClosureShouldRetrieveCtxInfoAtTrapTime(t *testing.T) { + ctx := context.Background() + ctx = context.WithValue(ctx, "test", "mock") + callAndCheck(func() { + gc(ctx) + }, func(trapCtx context.Context, f *core.FuncInfo, args, result core.Object) error { + if !f.FirstArgCtx { + t.Fatalf("expect closure also mark firstArgCtx, actually not marked") + } + if trapCtx == nil { + t.Fatalf("expect trapCtx to be non nil, atcual nil") + } + if trapCtx != ctx { + t.Fatalf("expect trapCtx to be the same with ctx, actully different") + } + return nil + }) +} + +func TestClosureUnnamedArgShouldRetrieveCtxInfo(t *testing.T) { + ctx := context.Background() + ctx = context.WithValue(ctx, "test", "mock") + callAndCheck(func() { + gcUnnamed(ctx) + }, func(trapCtx context.Context, f *core.FuncInfo, args, result core.Object) error { + if !f.FirstArgCtx { + t.Fatalf("expect closure also mark firstArgCtx, actually not marked") + } + if trapCtx == nil { + t.Fatalf("expect trapCtx to be non nil, atcual nil") + } + if trapCtx != ctx { + t.Fatalf("expect trapCtx to be the same with ctx, actully different") + } + return nil + }) +} diff --git a/runtime/test/trap_args/ctx_test.go b/runtime/test/trap_args/ctx_test.go new file mode 100644 index 00000000..af4be039 --- /dev/null +++ b/runtime/test/trap_args/ctx_test.go @@ -0,0 +1,51 @@ +package trap_args + +import ( + "context" + "testing" + + "github.com/xhd2015/xgo/runtime/core" +) + +func TestPlainCtxArgCanBeRecognized(t *testing.T) { + ctx := context.Background() + ctx = context.WithValue(ctx, "test", "mock") + callAndCheck(func() { + f2(ctx) + }, func(trapCtx context.Context, f *core.FuncInfo, args, result core.Object) error { + if !f.FirstArgCtx { + t.Fatalf("expect first arg to be context") + } + if trapCtx != ctx { + t.Fatalf("expect context passed unchanged, actually different") + } + ctxVal := trapCtx.Value("test").(string) + if ctxVal != "mock" { + t.Fatalf("expect context value to be %q, actual: %q", "mock", ctxVal) + } + return nil + }) +} + +func TestCtxVariantCanBeRecognized(t *testing.T) { + ctx := context.Background() + ctx = context.WithValue(ctx, "test", "mock") + + myCtx := &MyContext{Context: ctx} + + callAndCheck(func() { + f3(myCtx) + }, func(trapCtx context.Context, f *core.FuncInfo, args, result core.Object) error { + if !f.FirstArgCtx { + t.Fatalf("expect first arg to be context") + } + if trapCtx != myCtx { + t.Fatalf("expect context passed unchanged, actually different") + } + ctxVal := trapCtx.Value("test").(string) + if ctxVal != "mock" { + t.Fatalf("expect context value to be %q, actual: %q", "mock", ctxVal) + } + return nil + }) +} diff --git a/runtime/test/trap_args/err_test.go b/runtime/test/trap_args/err_test.go new file mode 100644 index 00000000..988e3f96 --- /dev/null +++ b/runtime/test/trap_args/err_test.go @@ -0,0 +1,69 @@ +package trap_args + +import ( + "context" + "errors" + "testing" + + "github.com/xhd2015/xgo/runtime/core" + "github.com/xhd2015/xgo/runtime/trap" +) + +func plainErr() error { + panic("plainErr should be mocked") +} + +func subErr() *Error { + return &Error{"sub error"} +} + +func TestPlainErrShouldSetErrRes(t *testing.T) { + mockErr := errors.New("mock err") + var err error + callAndCheck(func() { + err = plainErr() + }, func(trapCtx context.Context, f *core.FuncInfo, args, result core.Object) error { + if !f.LastResultErr { + t.Fatalf("expect f.LastResultErr to be true, actual: false") + } + return mockErr + }) + + if err != mockErr { + t.Fatalf("expect return err %v, actual %v", mockErr, err) + } +} + +func TestSubErrShouldNotSetErrRes(t *testing.T) { + mockErr := errors.New("mock err") + var err *Error + var recoverErr interface{} + func() { + defer func() { + // NOTE: this may have impact + trap.Skip() + recoverErr = recover() + }() + callAndCheck(func() { + err = subErr() + }, func(trapCtx context.Context, f *core.FuncInfo, args, result core.Object) error { + if f.LastResultErr { + t.Fatalf("expect f.LastResultErr to be false, actual: true") + } + // even not pl should fail + return mockErr + }) + }() + + if recoverErr == nil { + t.Fatalf("expect error via panic, actually no panic") + } + + if err != nil { + t.Fatalf("expect return error not set, actual: %v", err) + } + + if recoverErr != mockErr { + t.Fatalf("expect panic err to be %v, actual: %v", mockErr, recoverErr) + } +} diff --git a/runtime/test/trap_args/method_test.go b/runtime/test/trap_args/method_test.go new file mode 100644 index 00000000..a702777e --- /dev/null +++ b/runtime/test/trap_args/method_test.go @@ -0,0 +1,7 @@ +package trap_args + +import "testing" + +func TestWhenRecvIsCtxShouldNotRecognize(t *testing.T) { + +} diff --git a/runtime/test/trap_args/trap_args_test.go b/runtime/test/trap_args/trap_args_test.go new file mode 100644 index 00000000..8ab4fb93 --- /dev/null +++ b/runtime/test/trap_args/trap_args_test.go @@ -0,0 +1,104 @@ +package trap_args + +import ( + "context" + "fmt" + "reflect" + "runtime" + "testing" + + "github.com/xhd2015/xgo/runtime/core" + "github.com/xhd2015/xgo/runtime/trap" +) + +func f1() { + +} +func f2(ctx context.Context) error { + panic("f2 should be mocked") +} + +func f3(ctx *MyContext) *Error { + panic("f3 should be mocked") +} + +type MyContext struct { + context.Context +} + +type Error struct { + msg string +} + +func (c *Error) Error() string { + return c.msg +} + +type struct_ struct { +} + +func (c *struct_) f2(ctx context.Context) error { + panic(fmt.Errorf("struct_.f2 should be mocked")) +} +func TestCtxArgWithRecvCanBeRecognized(t *testing.T) { + st := &struct_{} + ctx := context.Background() + ctx = context.WithValue(ctx, "test", "mock") + + var callCount int + trap.AddInterceptor(&trap.Interceptor{ + Pre: func(trapCtx context.Context, f *core.FuncInfo, args, result core.Object) (data interface{}, err error) { + callCount++ + if !f.FirstArgCtx { + t.Fatalf("expect first arg to be context") + } + if trapCtx != ctx { + t.Fatalf("expect context passed unchanged, actually different") + } + ctxVal := trapCtx.Value("test").(string) + if ctxVal != "mock" { + t.Fatalf("expect context value to be %q, actual: %q", "mock", ctxVal) + } + return nil, trap.ErrAbort + }, + }) + st.f2(ctx) + + if callCount != 1 { + t.Fatalf("expect call trap once, actual: %d", callCount) + } +} + +func callAndCheck(fn func(), check func(trapCtx context.Context, f *core.FuncInfo, args, result core.Object) error) error { + pc := reflect.ValueOf(fn).Pointer() + + pcName := runtime.FuncForPC(pc).Name() + if pcName == "" { + return fmt.Errorf("cannot get pc name") + } + + var callCount int + trap.AddInterceptor(&trap.Interceptor{ + Pre: func(trapCtx context.Context, f *core.FuncInfo, args, result core.Object) (data interface{}, err error) { + if callCount == 0 { + callCount++ + if pcName != f.FullName { + panic(fmt.Errorf("expect first call hit %s, actual: %s", pcName, f.FullName)) + } + return + } + callCount++ + err = check(trapCtx, f, args, result) + if err != nil { + return nil, err + } + return nil, trap.ErrAbort + }, + }) + fn() + + if callCount != 2 { + fmt.Errorf("expect call trap twice, actual: %d", callCount) + } + return nil +} diff --git a/runtime/trace/trace.go b/runtime/trace/trace.go index 916282d9..deaf43ab 100644 --- a/runtime/trace/trace.go +++ b/runtime/trace/trace.go @@ -47,26 +47,26 @@ func init() { // link by compiler func __xgo_link_on_test_start(fn func(t *testing.T, fn func(t *testing.T))) { - fmt.Fprintln(os.Stderr, "WARNING: failed to link __xgo_link_on_test_start.(xgo required)") + fmt.Fprintln(os.Stderr, "WARNING: failed to link __xgo_link_on_test_start(requires xgo).") } // link by compiler func __xgo_link_getcurg() unsafe.Pointer { - fmt.Fprintln(os.Stderr, "WARNING: failed to link __xgo_link_getcurg.(xgo required)") + fmt.Fprintln(os.Stderr, "WARNING: failed to link __xgo_link_getcurg(requires xgo).") return nil } func __xgo_link_on_goexit(fn func()) { - fmt.Fprintln(os.Stderr, "WARNING: failed to link __xgo_link_on_goexit.(xgo required)") + fmt.Fprintln(os.Stderr, "WARNING: failed to link __xgo_link_on_goexit(requires xgo).") } func __xgo_link_init_finished() bool { - fmt.Fprintln(os.Stderr, "WARNING: failed to link __xgo_link_init_finished.(xgo required)") + fmt.Fprintln(os.Stderr, "WARNING: failed to link __xgo_link_init_finished(requires xgo).") return false } // linked by compiler func __xgo_link_peek_panic() interface{} { - fmt.Fprintln(os.Stderr, "WARNING: failed to link __xgo_link_peek_panic.(xgo required)") + fmt.Fprintln(os.Stderr, "WARNING: failed to link __xgo_link_peek_panic(requires xgo).") return nil } diff --git a/runtime/trap/interceptor.go b/runtime/trap/interceptor.go index 25be6822..c28a598a 100644 --- a/runtime/trap/interceptor.go +++ b/runtime/trap/interceptor.go @@ -17,17 +17,17 @@ var ErrAbort error = errors.New("abort trap interceptor") // link by compiler func __xgo_link_getcurg() unsafe.Pointer { - fmt.Fprintln(os.Stderr, "WARNING: failed to link __xgo_link_getcurg.(xgo required)") + fmt.Fprintln(os.Stderr, "WARNING: failed to link __xgo_link_getcurg(requires xgo).") return nil } func __xgo_link_init_finished() bool { - fmt.Fprintln(os.Stderr, "WARNING: failed to link __xgo_link_init_finished.(xgo required)") + fmt.Fprintln(os.Stderr, "WARNING: failed to link __xgo_link_init_finished(requires xgo).") return false } func __xgo_link_on_goexit(fn func()) { - fmt.Fprintln(os.Stderr, "WARNING: failed to link __xgo_link_on_goexit.(xgo required)") + fmt.Fprintln(os.Stderr, "WARNING: failed to link __xgo_link_on_goexit(requires xgo).") } func __xgo_link_get_pc_name(pc uintptr) string { fmt.Fprintln(os.Stderr, "WARNING: failed to link __xgo_link_get_pc_name(requires xgo)") @@ -78,7 +78,7 @@ func GetInterceptors() []*Interceptor { } func GetLocalInterceptors() []*Interceptor { - key := __xgo_link_getcurg() + key := uintptr(__xgo_link_getcurg()) val, ok := localInterceptors.Load(key) if !ok { return nil @@ -107,7 +107,7 @@ func GetAllInterceptors() []*Interceptor { // NOTE: if not called correctly,there might be memory leak func addLocalInterceptor(interceptor *Interceptor) func() { ensureTrapInstall() - key := __xgo_link_getcurg() + key := uintptr(__xgo_link_getcurg()) list := &interceptorList{} val, loaded := localInterceptors.LoadOrStore(key, list) if loaded { @@ -121,6 +121,10 @@ func addLocalInterceptor(interceptor *Interceptor) func() { if removed { panic(fmt.Errorf("remove interceptor more than once")) } + curKey := uintptr(__xgo_link_getcurg()) + if key != curKey { + panic(fmt.Errorf("remove interceptor from another goroutine")) + } removed = true var idx int = -1 for i, intc := range list.interceptors { @@ -149,7 +153,7 @@ type interceptorList struct { } func clearLocalInterceptorsAndMark() { - key := __xgo_link_getcurg() + key := uintptr(__xgo_link_getcurg()) localInterceptors.Delete(key) clearTrappingMark() diff --git a/runtime/trap/object.go b/runtime/trap/object.go index 2f078f98..23455f86 100644 --- a/runtime/trap/object.go +++ b/runtime/trap/object.go @@ -68,6 +68,13 @@ func (c field) Name() string { } func (c field) Set(val interface{}) { + // if val is nil, then reflect.ValueOf(val) + // is invalid + if val == nil { + // clear + reflect.ValueOf(c.valPtr).Elem().Set(reflect.Zero(reflect.TypeOf(c.valPtr).Elem())) + return + } reflect.ValueOf(c.valPtr).Elem().Set(reflect.ValueOf(val)) } func (c field) Ptr() interface{} { diff --git a/runtime/trap/trap.go b/runtime/trap/trap.go index cb7f77fc..bc757f9a 100644 --- a/runtime/trap/trap.go +++ b/runtime/trap/trap.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "os" + "reflect" "sync" "github.com/xhd2015/xgo/runtime/core" @@ -17,9 +18,27 @@ func ensureTrapInstall() { __xgo_link_set_trap(trapImpl) }) } +func init() { + __xgo_link_on_gonewproc(func(g uintptr) { + interceptors := GetLocalInterceptors() + if len(interceptors) == 0 { + return + } + copyInterceptors := make([]*Interceptor, len(interceptors)) + copy(copyInterceptors, interceptors) + + // inherit interceptors + localInterceptors.Store(g, &interceptorList{ + interceptors: copyInterceptors, + }) + }) +} func __xgo_link_set_trap(trapImpl func(pkgPath string, identityName string, generic bool, pc uintptr, recv interface{}, args []interface{}, results []interface{}) (func(), bool)) { - fmt.Fprintln(os.Stderr, "WARNING: failed to link __xgo_link_set_trap.(xgo required)") + fmt.Fprintln(os.Stderr, "WARNING: failed to link __xgo_link_set_trap(requires xgo).") +} +func __xgo_link_on_gonewproc(f func(g uintptr)) { + fmt.Fprintln(os.Stderr, "WARNING: failed to link __xgo_link_on_gonewproc(requires xgo).") } // Skip serves as mark to tell xgo not insert @@ -30,15 +49,6 @@ func __xgo_link_set_trap(trapImpl func(pkgPath string, identityName string, gene // sense at compile time. func Skip() {} -func GetTrappingPC() uintptr { - key := uintptr(__xgo_link_getcurg()) - val, ok := trappingPC.Load(key) - if !ok { - return 0 - } - return val.(uintptr) -} - var trappingMark sync.Map // -> struct{}{} var trappingPC sync.Map // -> PC @@ -82,6 +92,35 @@ func trapImpl(pkgPath string, identityName string, generic bool, pc uintptr, rec return nil, false } + // retrieve context + var ctx context.Context + if f.FirstArgCtx { + // TODO: is *HttpRequest a *Context? + ctx = reflect.ValueOf(args[0]).Elem().Interface().(context.Context) + // ctx = *(args[0].(*context.Context)) + } else if f.Closure { + if len(args) > 0 { + argCtx, ok := reflect.ValueOf(args[0]).Elem().Interface().(context.Context) + if ok { + // modify on the fly + f.FirstArgCtx = true + ctx = argCtx + } + } + } + var perr *error + if f.LastResultErr { + perr = results[len(results)-1].(*error) + } else if f.Closure { + if len(results) > 0 { + resErr, ok := reflect.ValueOf(results[0]).Interface().(*error) + if ok { + f.LastResultErr = true + perr = resErr + } + } + } + // TODO: set FirstArgCtx and LastResultErr req := make(object, 0, len(args)) result := make(object, 0, len(results)) @@ -130,18 +169,6 @@ func trapImpl(pkgPath string, identityName string, generic bool, pc uintptr, rec // Results: results, // } - // TODO: will results always have names? - var perr *error - if f.LastResultErr { - perr = results[len(results)-1].(*error) - } - - // NOTE: ctx - var ctx context.Context - if f.FirstArgCtx { - // TODO: is *HttpRequest a *Context? - ctx = *(args[0].(*context.Context)) - } // NOTE: context.TODO() is a constant if ctx == nil { ctx = context.TODO() @@ -222,6 +249,14 @@ func trapImpl(pkgPath string, identityName string, generic bool, pc uintptr, rec }, false } +func GetTrappingPC() uintptr { + key := uintptr(__xgo_link_getcurg()) + val, ok := trappingPC.Load(key) + if !ok { + return 0 + } + return val.(uintptr) +} func setTrappingMark() func() { key := uintptr(__xgo_link_getcurg()) _, trapping := trappingMark.LoadOrStore(key, struct{}{}) diff --git a/runtime/trap_runtime/xgo_trap.go b/runtime/trap_runtime/xgo_trap.go index 8645509a..46973e67 100644 --- a/runtime/trap_runtime/xgo_trap.go +++ b/runtime/trap_runtime/xgo_trap.go @@ -77,8 +77,14 @@ func __xgo_on_init_finished(fn func()) { __xgo_on_init_finished_callbacks = append(__xgo_on_init_finished_callbacks, fn) } +// goroutine creates and exits callbacks +var __xgo_on_gonewproc_callbacks []func(g uintptr) var __xgo_on_goexits []func() +func __xgo_on_gonewproc(fn func(g uintptr)) { + __xgo_on_gonewproc_callbacks = append(__xgo_on_gonewproc_callbacks, fn) +} + func __xgo_on_goexit(fn func()) { __xgo_on_goexits = append(__xgo_on_goexits, fn) } diff --git a/script/run-test/main.go b/script/run-test/main.go index 515ccf53..4f5329a7 100644 --- a/script/run-test/main.go +++ b/script/run-test/main.go @@ -44,6 +44,8 @@ var runtimeTests = []string{ "mock_closure", "mock_stdlib", "mock_generic", + "trap_args", + "patch", } func main() { diff --git a/test/trap_test.go b/test/trap_test.go index 9a9947f3..4dde9c3d 100644 --- a/test/trap_test.go +++ b/test/trap_test.go @@ -24,7 +24,7 @@ func TestTrap(t *testing.T) { // go test -run TestTrapNormalBuildShouldWarn -v ./test func TestTrapNormalBuildShouldWarn(t *testing.T) { t.Parallel() - expectOrigStderr := "WARNING: failed to link __xgo_link_set_trap.(xgo required)" + expectOrigStderr := "WARNING: failed to link __xgo_link_set_trap(requires xgo)." var origStderr bytes.Buffer runAndCheckInstrumentOutput(t, "./testdata/trap", func(output string) error { diff --git a/test/xgo_test/link_on_go_new_proc/link_on_go_new_proc_test.go b/test/xgo_test/link_on_go_new_proc/link_on_go_new_proc_test.go new file mode 100644 index 00000000..a44b24bd --- /dev/null +++ b/test/xgo_test/link_on_go_new_proc/link_on_go_new_proc_test.go @@ -0,0 +1,38 @@ +package trap_set + +import ( + "testing" + "unsafe" +) + +func __xgo_link_on_gonewproc(f func(g uintptr)) { + panic("failed to link __xgo_link_on_gonewproc(requires xgo).") +} + +func __xgo_link_getcurg() unsafe.Pointer { + panic("failed to link __xgo_link_getcurg(requires xgo).") +} + +func TestLinkOnGoNewPrco(t *testing.T) { + var newg uintptr + __xgo_link_on_gonewproc(func(g uintptr) { + newg = g + }) + + var rang uintptr + done := make(chan struct{}) + go func() { + rang = uintptr(__xgo_link_getcurg()) + close(done) + }() + + <-done + + if newg == 0 { + t.Fatalf("expect newg captured, actually not") + } + if newg != rang { + t.Fatalf("expect newg to be 0x%x, actual: 0x%x", rang, newg) + } + +} diff --git a/test/xgo_test/link_on_init_finished/link_on_init_finished_test.go b/test/xgo_test/link_on_init_finished/link_on_init_finished_test.go index d5b75776..1086ada8 100644 --- a/test/xgo_test/link_on_init_finished/link_on_init_finished_test.go +++ b/test/xgo_test/link_on_init_finished/link_on_init_finished_test.go @@ -7,7 +7,7 @@ import ( ) func __xgo_link_on_init_finished(f func()) { - fmt.Fprintln(os.Stderr, "WARNING: failed to link __xgo_link_on_init_finished.(xgo required)") + fmt.Fprintln(os.Stderr, "WARNING: failed to link __xgo_link_on_init_finished(requires xgo).") } var ran bool diff --git a/test/xgo_test/trap_set/trap_set_test.go b/test/xgo_test/trap_set/trap_set_test.go index a2588dff..07898a84 100644 --- a/test/xgo_test/trap_set/trap_set_test.go +++ b/test/xgo_test/trap_set/trap_set_test.go @@ -5,7 +5,7 @@ import ( ) func __xgo_link_set_trap(trapImpl func(pkgPath string, identityName string, generic bool, pc uintptr, recv interface{}, args []interface{}, results []interface{}) (func(), bool)) { - panic("WARNING: failed to link __xgo_link_set_trap.(xgo required)") + panic("WARNING: failed to link __xgo_link_set_trap(requires xgo).") } var haveCalledTrap bool