diff --git a/README.md b/README.md
index 179709e7..1566ee54 100644
--- a/README.md
+++ b/README.md
@@ -1,5 +1,6 @@
# xgo
+
[![Go Reference](https://pkg.go.dev/badge/github.com/xhd2015/xgo.svg)](https://pkg.go.dev/github.com/xhd2015/xgo)
[![Go Report Card](https://goreportcard.com/badge/github.com/xhd2015/xgo)](https://goreportcard.com/report/github.com/xhd2015/xgo)
[![Go Coverage](https://img.shields.io/badge/Coverage-82.6%25-brightgreen)](https://github.com/xhd2015/xgo/actions)
@@ -10,7 +11,7 @@
Enable function Trap in `go`, provide tools like Mock and Trace to help go developers write unit test and debug both easier and faster.
-`xgo` works as a preprocessor for `go run`,`go build`, and `go test`(see our [blog](https://blog.xhd2015.xyz/posts/xgo-monkey-patching-in-go-using-toolexec/)).
+`xgo` works as a preprocessor for `go run`,`go build`, and `go test`(see our [blog](https://blog.xhd2015.xyz/posts/xgo-monkey-patching-in-go-using-toolexec)).
It **preprocess** the source code and IR(Intermediate Representation) before invoking `go`, adding missing abilities to go program by cooperating with(or hacking) the go compiler.
@@ -21,6 +22,8 @@ These abilities include:
See [Quick Start](#quick-start) and [Documentation](./doc) for more details.
+> *By the way, I promise you this is an interesting project.*
+
# Installation
```sh
go install github.com/xhd2015/xgo/cmd/xgo@latest
@@ -155,7 +158,6 @@ import (
func init() {
trap.AddInterceptor(&trap.Interceptor{
Pre: func(ctx context.Context, f *core.FuncInfo, args core.Object, results core.Object) (interface{}, error) {
- trap.Skip()
if f.Name == "A" {
fmt.Printf("trap A\n")
return nil, nil
@@ -421,11 +423,22 @@ By default, Trace will write traces to a temp directory under current working di
- `XGO_TRACE_OUTPUT=
`: traces will be written to ``,
- `XGO_TRACE_OUTPUT=off`: turn off trace.
+# Concurrent safety
+I know you guys from other monkey patching library suffer from the unsafety implied by these frameworks.
+
+But I guarantee you mocking in xgo is builtin concurrent safe. That means, you can run multiple tests concurrently as long as you like.
+
+Why? when you run a test, you setup some mock, these mocks will only affect the goroutine test running the test. And these mocks get cleared when the goroutine ends, no matter the test passed or failed.
+
+Want to know why? Stay tuned, we are working on internal documentation.
+
# Implementation Details
> Working in progress...
See [Issue#7](https://github.com/xhd2015/xgo/issues/7) for more details.
+This blog has a basic explanation: https://blog.xhd2015.xyz/posts/xgo-monkey-patching-in-go-using-toolexec
+
# Why `xgo`?
The reason is simple: **NO** interface.
@@ -435,7 +448,7 @@ Extracting interface just for mocking is never an option to me. To the domain of
Monkey patching simply does the right thing for the problem. But existing library are bad at compatibility.
-So I created `xgo`, so I hope `xgo` will also take over other solutions to the mocking problem.
+So I created `xgo`, and hope it will finally take over other solutions to the mocking problem.
# Comparing `xgo` with `monkey`
The project [bouk/monkey](https://github.com/bouk/monkey), was initially created by bouk, as described in his blog https://bou.ke/blog/monkey-patching-in-go.
diff --git a/README_zh_cn.md b/README_zh_cn.md
index 737b62ab..9e35bdd4 100644
--- a/README_zh_cn.md
+++ b/README_zh_cn.md
@@ -9,7 +9,7 @@
允许对`go`的函数进行拦截, 并提供Mock和Trace等工具帮助开发者编写测试和快速调试。
-`xgo`作为一个预处理器工作在`go run`,`go build`,和`go test`之上(查看[blog](https://blog.xhd2015.xyz/zh/posts/xgo-monkey-patching-in-go-using-toolexec/))。
+`xgo`作为一个预处理器工作在`go run`,`go build`,和`go test`之上(查看[blog](https://blog.xhd2015.xyz/zh/posts/xgo-monkey-patching-in-go-using-toolexec))。
`xgo`对源代码和IR(中间码)进行预处理之后, 再调用`go`进行后续的编译工作。通过这种方式, `xgo`实现了一些在`go`中缺乏的能力。
@@ -20,6 +20,8 @@
更多细节, 参见[快速开始](#快速开始)和[文档](./doc)。
+> *顺便说一下, 我可以向你保证这是一个有趣的项目。*
+
# 安装
```sh
go install github.com/xhd2015/xgo/cmd/xgo@latest
@@ -150,7 +152,6 @@ import (
func init() {
trap.AddInterceptor(&trap.Interceptor{
Pre: func(ctx context.Context, f *core.FuncInfo, args core.Object, results core.Object) (interface{}, error) {
- trap.Skip()
if f.Name == "A" {
fmt.Printf("trap A\n")
return nil, nil
@@ -414,12 +415,23 @@ XGO_TRACE_OUTPUT=stdout xgo run ./
- `XGO_TRACE_OUTPUT=`: 堆栈记录被写入到``目录下,
- `XGO_TRACE_OUTPUT=off`: 关闭堆栈记录收集。
+# 并发安全
+我知道大部分人认为Monkey Patching不是并发安全的,但那是现有的库的实现方式决定的。
+
+我可以向你保证,在xgo中进行Monkey Patching是并发安全的,也就意味着,你可以同时并行跑所有的测试用例。
+
+为什么? 因为当你设置mock时,只有当前的goroutine受影响,并且在goroutine退出后清除,不管当前测试失败还是成功。
+
+想知道真正的原因吗? 我们正在整理内部实现的文档,尽请期待。
+
# 实现原理
> 仍在整理中...
参见[Issue#7](https://github.com/xhd2015/xgo/issues/7)
+这个博客作了一些简单的解释: https://blog.xhd2015.xyz/zh/posts/xgo-monkey-patching-in-go-using-toolexec
+
# 为何使用`xgo`?
原因很简单: **避免**interface.
diff --git a/cmd/xgo/main.go b/cmd/xgo/main.go
index 31b3d018..1820a755 100644
--- a/cmd/xgo/main.go
+++ b/cmd/xgo/main.go
@@ -287,7 +287,7 @@ func handleBuild(cmd string, args []string) error {
return err
}
- if resetInstrument || revisionChanged {
+ if resetInstrument {
logDebug("revision changed, reset %s", instrumentDir)
err := os.RemoveAll(instrumentDir)
if err != nil {
diff --git a/cmd/xgo/patch.go b/cmd/xgo/patch.go
index d029d451..ae3bbcb1 100644
--- a/cmd/xgo/patch.go
+++ b/cmd/xgo/patch.go
@@ -167,7 +167,7 @@ func syncGoroot(goroot string, instrumentGoroot string, fullSyncRecordFile strin
// need copy, delete target dst dir first
// TODO: use git worktree add if .git exists
err = filecopy.NewOptions().
- Concurrent(10).
+ Concurrent(2). // 10 is too much
CopyReplaceDir(goroot, instrumentGoroot)
if err != nil {
return err
diff --git a/cmd/xgo/patch_runtime.go b/cmd/xgo/patch_runtime.go
index d95b7ff9..c520162e 100644
--- a/cmd/xgo/patch_runtime.go
+++ b/cmd/xgo/patch_runtime.go
@@ -15,15 +15,44 @@ import (
var xgoAutoGenRegisterFuncHelper = _FilePath{"src", "runtime", "__xgo_autogen_register_func_helper.go"}
var xgoTrap = _FilePath{"src", "runtime", "xgo_trap.go"}
var runtimeProc = _FilePath{"src", "runtime", "proc.go"}
-var testingFile = _FilePath{"src", "testing", "testing.go"}
var runtimeTime _FilePath = _FilePath{"src", "runtime", "time.go"}
var timeSleep _FilePath = _FilePath{"src", "time", "sleep.go"}
+var testingFilePatch = &FilePatch{
+ FilePath: _FilePath{"src", "testing", "testing.go"},
+ Patches: []*Patch{
+ {
+ Mark: "declare_testing_callback_v2",
+ InsertIndex: 0,
+ InsertBefore: true,
+ Anchors: []string{
+ "func tRunner(t *T, fn func",
+ "{",
+ "\n",
+ },
+ Content: patch.TestingCallbackDeclarations,
+ },
+ {
+ Mark: "call_testing_callback_v2",
+ InsertIndex: 4,
+ InsertBefore: true,
+ Anchors: []string{
+ "func tRunner(t *T, fn func",
+ "{",
+ "\n",
+ `t.start = time.Now()`,
+ "fn(t",
+ },
+ Content: patch.TestingStart,
+ },
+ },
+}
+
var runtimeFiles = []_FilePath{
xgoAutoGenRegisterFuncHelper,
xgoTrap,
runtimeProc,
- testingFile,
+ testingFilePatch.FilePath,
runtimeTime,
timeSleep,
}
@@ -136,22 +165,7 @@ func patchRuntimeProc(goroot string) error {
}
func patchRuntimeTesting(goroot string) error {
- testingFile := filepath.Join(goroot, filepath.Join(testingFile...))
- return editFile(testingFile, func(content string) (string, error) {
- // func tRunner(t *T, fn func(t *T)) {
- anchor := []string{"func tRunner(t *T", "{", "\n"}
- content = addContentBefore(content,
- "/**/", "/**/",
- anchor,
- patch.TestingCallbackDeclarations,
- )
- content = addContentAfter(content,
- "/**/", "/**/",
- anchor,
- patch.TestingStart,
- )
- return content, nil
- })
+ return testingFilePatch.Apply(goroot, nil)
}
// only required if need to mock time.Sleep
diff --git a/cmd/xgo/version.go b/cmd/xgo/version.go
index 6e5ac1b4..5bc5517b 100644
--- a/cmd/xgo/version.go
+++ b/cmd/xgo/version.go
@@ -2,9 +2,9 @@ package main
import "fmt"
-const VERSION = "1.0.19"
-const REVISION = "75c04e25cd9ccc811d6893fc0c0c02df889cad66+1"
-const NUMBER = 171
+const VERSION = "1.0.20"
+const REVISION = "7ab84d83e8d847f0c2307a44866705581ba5cbbe+1"
+const NUMBER = 172
func getRevision() string {
revSuffix := ""
diff --git a/patch/ctxt/ctx.go b/patch/ctxt/ctx.go
index ee1c744e..8f301407 100644
--- a/patch/ctxt/ctx.go
+++ b/patch/ctxt/ctx.go
@@ -13,6 +13,8 @@ const XgoRuntimeCorePkg = XgoModule + "/runtime/core"
var XgoMainModule = os.Getenv("XGO_MAIN_MODULE")
var XgoCompilePkgDataDir = os.Getenv("XGO_COMPILE_PKG_DATA_DIR")
+const XgoLinkTrapVarForGenerated = "__xgo_link_trap_var_for_generated"
+
func SkipPackageTrap() bool {
pkgPath := GetPkgPath()
if pkgPath == "" {
@@ -56,68 +58,9 @@ func SkipPackageTrap() bool {
return false
}
-var stdWhitelist = map[string]map[string]bool{
- // "runtime": map[string]bool{
- // "timeSleep": true,
- // },
- "os": map[string]bool{
- // starts with Get
- "OpenFile": true,
- "ReadFile": true,
- "WriteFile": true,
- },
- "io": map[string]bool{
- "ReadAll": true,
- },
- "io/ioutil": map[string]bool{
- "ReadAll": true,
- "ReadFile": true,
- "ReadDir": true,
- },
- "time": map[string]bool{
- "Now": true,
- // time.Sleep is special:
- // if trapped like normal functions
- // runtime/time.go:178:6: ns escapes to heap, not allowed in runtime
- // there are special handling of this, see cmd/xgo/patch_runtime patchRuntimeTime
- "Sleep": true, // NOTE: time.Sleep links to runtime.timeSleep
- "NewTicker": true,
- "Time.Format": true,
- },
- "os/exec": map[string]bool{
- "Command": true,
- "(*Cmd).Run": true,
- "(*Cmd).Output": true,
- "(*Cmd).Start": true,
- },
- "net/http": map[string]bool{
- "Get": true,
- "Head": true,
- "Post": true,
- // Sever
- "Serve": true,
- "Handle": true,
- "(*Client).Do": true,
- "(*Server).Close": true,
- },
- "net": map[string]bool{
- // starts with Dial
- },
-}
-
func AllowPkgFuncTrap(pkgPath string, isStd bool, funcName string) bool {
if isStd {
- if stdWhitelist[pkgPath][funcName] {
- return true
- }
- switch pkgPath {
- case "os":
- return strings.HasPrefix(funcName, "Get")
- case "net":
- return strings.HasPrefix(funcName, "Dial")
- }
- // by default block all
- return false
+ return allowStdFunc(pkgPath, funcName)
}
return true
diff --git a/patch/ctxt/stdlib.go b/patch/ctxt/stdlib.go
new file mode 100644
index 00000000..09a5393e
--- /dev/null
+++ b/patch/ctxt/stdlib.go
@@ -0,0 +1,69 @@
+package ctxt
+
+import "strings"
+
+var stdWhitelist = map[string]map[string]bool{
+ // "runtime": map[string]bool{
+ // "timeSleep": true,
+ // },
+ "os": map[string]bool{
+ // starts with Get
+ "OpenFile": true,
+ "ReadFile": true,
+ "WriteFile": true,
+ },
+ "io": map[string]bool{
+ "ReadAll": true,
+ },
+ "io/ioutil": map[string]bool{
+ "ReadAll": true,
+ "ReadFile": true,
+ "ReadDir": true,
+ },
+ "time": map[string]bool{
+ "Now": true,
+ // time.Sleep is special:
+ // if trapped like normal functions
+ // runtime/time.go:178:6: ns escapes to heap, not allowed in runtime
+ // there are special handling of this, see cmd/xgo/patch_runtime patchRuntimeTime
+ "Sleep": true, // NOTE: time.Sleep links to runtime.timeSleep
+ "NewTicker": true,
+ "Time.Format": true,
+ },
+ "os/exec": map[string]bool{
+ "Command": true,
+ "(*Cmd).Run": true,
+ "(*Cmd).Output": true,
+ "(*Cmd).Start": true,
+ },
+ "net/http": map[string]bool{
+ "Get": true,
+ "Head": true,
+ "Post": true,
+ // Sever
+ "Serve": true,
+ "Handle": true,
+ "(*Client).Do": true,
+ "(*Server).Close": true,
+ },
+ "net": map[string]bool{
+ // starts with Dial
+ },
+ "encoding/json": map[string]bool{
+ "newTypeEncoder": true,
+ },
+}
+
+func allowStdFunc(pkgPath string, funcName string) bool {
+ if stdWhitelist[pkgPath][funcName] {
+ return true
+ }
+ switch pkgPath {
+ case "os":
+ return strings.HasPrefix(funcName, "Get")
+ case "net":
+ return strings.HasPrefix(funcName, "Dial")
+ }
+ // by default block all
+ return false
+}
diff --git a/patch/link_name.go b/patch/link_name.go
index 4de3f275..caa44a80 100644
--- a/patch/link_name.go
+++ b/patch/link_name.go
@@ -58,7 +58,7 @@ func isLinkValid(fnName string, targetName string, pkgPath string) bool {
if disableXgoLink {
return false
}
- safeGenerated := (fnName == xgo_syntax.XgoLinkGeneratedRegisterFunc || fnName == xgo_syntax.XgoLinkTrapForGenerated)
+ safeGenerated := (fnName == xgo_syntax.XgoLinkGeneratedRegisterFunc || fnName == xgo_syntax.XgoLinkTrapForGenerated || fnName == xgo_ctxt.XgoLinkTrapVarForGenerated)
if safeGenerated {
// generated by xgo on the fly for every instrumented package
return true
diff --git a/patch/trap.go b/patch/trap.go
index a6327cf3..279b5f25 100644
--- a/patch/trap.go
+++ b/patch/trap.go
@@ -98,13 +98,6 @@ func trapOrLink(fn *ir.Func) {
typeCheckBody(fn)
xgo_record.SetRewrittenBody(fn, fn.Body)
-
- // debug
- if false {
- if fnName == "Now" {
- ir.Dump("after:", fn)
- }
- }
}
/*
@@ -266,17 +259,16 @@ func InsertTrapForFunc(fn *ir.Func, forGeneric bool) bool {
}, nil)
origBody := fn.Body
- newBody := make([]ir.Node, 1+len(origBody))
- newBody[0] = callAfter
+ newBody := make([]ir.Node, len(origBody))
for i := 0; i < len(origBody); i++ {
- newBody[i+1] = origBody[i]
+ newBody[i] = origBody[i]
}
ifStmt := ir.NewIfStmt(fnPos, stopV, nil, newBody)
if isClosure {
trappedClosures = append(trappedClosures, fn)
}
- fn.Body = []ir.Node{assignStmt, ifStmt}
+ fn.Body = []ir.Node{assignStmt, callAfter, ifStmt}
return true
}
diff --git a/runtime/core/version.go b/runtime/core/version.go
index bd84a5bc..b6ac297e 100644
--- a/runtime/core/version.go
+++ b/runtime/core/version.go
@@ -6,9 +6,9 @@ import (
"os"
)
-const VERSION = "1.0.19"
-const REVISION = "75c04e25cd9ccc811d6893fc0c0c02df889cad66+1"
-const NUMBER = 171
+const VERSION = "1.0.20"
+const REVISION = "7ab84d83e8d847f0c2307a44866705581ba5cbbe+1"
+const NUMBER = 172
// these fields will be filled by compiler
const XGO_VERSION = ""
diff --git a/runtime/functab/functab.go b/runtime/functab/functab.go
index c57126d8..43b544d2 100644
--- a/runtime/functab/functab.go
+++ b/runtime/functab/functab.go
@@ -12,10 +12,9 @@ import (
"github.com/xhd2015/xgo/runtime/core"
)
-const __XGO_SKIP_TRAP = true
-
func init() {
__xgo_link_on_init_finished(ensureMapping)
+ __xgo_link_on_init_finished(ensureTypeMapping)
}
// rewrite at compile time by compiler, the body will be replaced with
@@ -29,6 +28,11 @@ func __xgo_link_on_init_finished(f func()) {
fmt.Fprintln(os.Stderr, "WARNING: failed to link __xgo_link_on_init_finished(requires xgo).")
}
+func __xgo_link_init_finished() bool {
+ fmt.Fprintln(os.Stderr, "WARNING: failed to link __xgo_link_init_finished(requires xgo).")
+ return false
+}
+
func __xgo_link_get_pc_name(pc uintptr) string {
fmt.Fprintf(os.Stderr, "WARNING: failed to link __xgo_link_get_pc_name(requires xgo).\n")
return ""
@@ -43,12 +47,10 @@ var interfaceMapping map[string]map[string]*core.FuncInfo // pkg -> inter
var typeMethodMapping map[reflect.Type]map[string]*core.FuncInfo // reflect.Type -> interfaceName -> FuncInfo
func GetFuncs() []*core.FuncInfo {
- ensureMapping()
return funcInfos
}
func InfoFunc(fn interface{}) *core.FuncInfo {
- ensureMapping()
v := reflect.ValueOf(fn)
if v.Kind() != reflect.Func {
panic(fmt.Errorf("given type is not a func: %T", fn))
@@ -58,7 +60,6 @@ func InfoFunc(fn interface{}) *core.FuncInfo {
return funcPCMapping[pc]
}
func InfoVar(addr interface{}) *core.FuncInfo {
- ensureMapping()
v := reflect.ValueOf(addr)
if v.Kind() != reflect.Ptr {
panic(fmt.Errorf("given type is not a pointer: %T", addr))
@@ -69,7 +70,6 @@ func InfoVar(addr interface{}) *core.FuncInfo {
// maybe rename to FuncForPC
func InfoPC(pc uintptr) *core.FuncInfo {
- ensureMapping()
return funcPCMapping[pc]
}
@@ -119,7 +119,6 @@ func GetFuncByFullName(fullName string) *core.FuncInfo {
}
func GetTypeMethods(typ reflect.Type) map[string]*core.FuncInfo {
- ensureTypeMapping()
return typeMethodMapping[typ]
}
@@ -288,7 +287,6 @@ func ensureMapping() {
var mappingTypeOnce sync.Once
func ensureTypeMapping() {
- ensureMapping()
mappingTypeOnce.Do(func() {
typeMethodMapping = make(map[reflect.Type]map[string]*core.FuncInfo)
for _, funcInfo := range funcInfos {
diff --git a/runtime/mock/mock.go b/runtime/mock/mock.go
index 296e79af..efc4d2e2 100644
--- a/runtime/mock/mock.go
+++ b/runtime/mock/mock.go
@@ -84,11 +84,6 @@ func getMethodByName(instance interface{}, method string) (recvPtr interface{},
return addr.Interface(), fn, 0, 0
}
-// Deprecated: use Mock instead
-func AddFuncInterceptor(fn interface{}, interceptor Interceptor) func() {
- return Mock(fn, interceptor)
-}
-
// TODO: ensure them run in last?
// no abort, run mocks
// mocks are special in that they on run in pre stage
diff --git a/runtime/test/debug/debug_test.go b/runtime/test/debug/debug_test.go
index 468ba9f4..365b2b3b 100644
--- a/runtime/test/debug/debug_test.go
+++ b/runtime/test/debug/debug_test.go
@@ -6,31 +6,18 @@
package debug
import (
+ "fmt"
"testing"
-)
-
-const good = 2
-const reason = "test"
-func TestPatchConstOperationShouldCompileAndSkipMock(t *testing.T) {
- reasons := getReasons("good")
- if len(reasons) != 2 || reasons[0] != "ok" || reasons[1] != "good" {
- t.Fatalf("bad reason: %v", reasons)
- }
+ "github.com/xhd2015/xgo/runtime/trap"
+)
- getReasons2 := func(good string) (reason []string) {
- reason = append(reason, "ok")
- reason = append(reason, good)
- return
- }
- reasons2 := getReasons2("good")
- if len(reasons2) != 2 || reasons2[0] != "ok" || reasons2[1] != "good" {
- t.Fatalf("bad reason2: %v", reasons2)
- }
+func ToString[T any](v T) string {
+ return fmt.Sprint(v)
}
-func getReasons(good string) (reason []string) {
- reason = append(reason, "ok")
- reason = append(reason, good)
- return
+func TestNakedTrapShouldAvoidRecursive(t *testing.T) {
+ trap.InspectPC(ToString[int])
+ // _, fnInfo, funcPC, trappingPC := trap.InspectPC(ToString[int])
+ // _, fnInfoStr, funcPCStr, trappingPCStr := trap.InspectPC(ToString[string])
}
diff --git a/runtime/test/debug/sub/sub.go b/runtime/test/debug/sub/sub.go
deleted file mode 100644
index e13aafbc..00000000
--- a/runtime/test/debug/sub/sub.go
+++ /dev/null
@@ -1,10 +0,0 @@
-package sub
-
-const LabelPrefix = "label:"
-
-const Version = "v2"
-
-const (
- LabelPrefix2 = "label2:"
- LabelPrefix3 = "label3:"
-)
diff --git a/runtime/test/func_list/func_list_stdlib_test.go b/runtime/test/func_list/func_list_stdlib_test.go
index 5a2da93d..c9d42da0 100644
--- a/runtime/test/func_list/func_list_stdlib_test.go
+++ b/runtime/test/func_list/func_list_stdlib_test.go
@@ -1,6 +1,7 @@
package func_list
import (
+ "encoding/json"
"io"
"io/ioutil"
"net"
@@ -20,6 +21,7 @@ var _ = ioutil.ReadAll
var _ = ioutil.ReadFile
var _ = ioutil.ReadDir
var _ = io.ReadAll
+var _ json.Encoder
// go run ./cmd/xgo test --project-dir runtime -run TestListStdlib -v ./test/func_list
func TestListStdlib(t *testing.T) {
@@ -71,6 +73,9 @@ func TestListStdlib(t *testing.T) {
"net.DialUDP": true,
"net.DialUnix": true,
"net.DialTimeout": true,
+
+ //json
+ "encoding/json.newTypeEncoder": true,
}
found, missing := getMissing(funcs, stdPkgs, false)
if len(missing) > 0 {
diff --git a/runtime/test/mock_res/main.go b/runtime/test/mock_res/main.go
index 14008368..921dde61 100644
--- a/runtime/test/mock_res/main.go
+++ b/runtime/test/mock_res/main.go
@@ -11,7 +11,7 @@ import (
func main() {
before := add(5, 2)
fmt.Printf("before mock: add(5,2)=%d\n", before)
- mock.AddFuncInterceptor(add, func(ctx context.Context, fn *core.FuncInfo, args core.Object, results core.Object) error {
+ mock.Mock(add, func(ctx context.Context, fn *core.FuncInfo, args core.Object, results core.Object) error {
a := args.GetField("a").Value().(int)
b := args.GetField("b").Value().(int)
res := a - b
diff --git a/runtime/test/trace_marshal/marshal_test.go b/runtime/test/trace_marshal/marshal_test.go
new file mode 100644
index 00000000..bdd15adb
--- /dev/null
+++ b/runtime/test/trace_marshal/marshal_test.go
@@ -0,0 +1,88 @@
+package trace_marshal
+
+import (
+ "context"
+ "fmt"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/xhd2015/xgo/runtime/core"
+ "github.com/xhd2015/xgo/runtime/functab"
+ "github.com/xhd2015/xgo/runtime/trace"
+ "github.com/xhd2015/xgo/runtime/trap"
+)
+
+func TestMarshalAnyJSON(t *testing.T) {
+ var nilChan chan int
+ tests := []struct {
+ v interface{}
+ want string
+ err string
+ }{
+ {
+ v: nil,
+ want: "null",
+ },
+ {
+ v: struct{}{},
+ want: "{}",
+ },
+ {
+ v: func() {},
+ want: "{}",
+ },
+ {
+ v: make(chan int),
+ want: "{}",
+ },
+ {
+ v: nilChan,
+ want: "{}",
+ },
+ {
+ v: struct{ A int }{A: 123},
+ want: `{"A":123}`,
+ },
+ {
+ v: getObject(),
+ want: `{"_r0":{}}`,
+ },
+ }
+
+ for i, tt := range tests {
+ t.Run(fmt.Sprintf("case_%d", i), func(t *testing.T) {
+ got, err := trace.MarshalAnyJSON(tt.v)
+ var errMsg string
+ if err != nil {
+ errMsg = err.Error()
+ }
+ if (errMsg != "" && tt.err == "") || !strings.Contains(errMsg, tt.err) {
+ t.Fatalf("expect err msg: %s, actual: %s", tt.err, errMsg)
+ }
+ if tt.want != string(got) {
+ t.Fatalf("expect result: %s, actual: %s", tt.want, got)
+ }
+ })
+ }
+}
+
+func exampleReturnFunc() context.CancelFunc {
+ _, f := context.WithTimeout(context.TODO(), 10*time.Millisecond)
+ return f
+}
+
+func getObject() core.Object {
+ var recordedResult core.Object
+ fnInfo := functab.InfoFunc(exampleReturnFunc)
+ trap.WithInterceptor(&trap.Interceptor{Pre: func(ctx context.Context, f *core.FuncInfo, args, result core.Object) (data interface{}, err error) {
+ if fnInfo != f {
+ return nil, nil
+ }
+ recordedResult = result
+ return nil, nil
+ }}, func() {
+ exampleReturnFunc()
+ })
+ return recordedResult
+}
diff --git a/runtime/test/trace_marshal/with_trace/marshal_with_trace_test.go b/runtime/test/trace_marshal/with_trace/marshal_with_trace_test.go
new file mode 100644
index 00000000..363d6758
--- /dev/null
+++ b/runtime/test/trace_marshal/with_trace/marshal_with_trace_test.go
@@ -0,0 +1,67 @@
+package trace_marshal_with_trace
+
+import (
+ "context"
+ "fmt"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/xhd2015/xgo/runtime/core"
+ "github.com/xhd2015/xgo/runtime/functab"
+ "github.com/xhd2015/xgo/runtime/trace"
+ "github.com/xhd2015/xgo/runtime/trap"
+)
+
+func init() {
+ trace.Enable()
+}
+
+func TestMarshalWithTrace(t *testing.T) {
+ tests := []struct {
+ v interface{}
+ want string
+ err string
+ }{
+ {
+ v: getObject(),
+ want: `{"_r0":{}}`,
+ },
+ }
+
+ for i, tt := range tests {
+ t.Run(fmt.Sprintf("case_%d", i), func(t *testing.T) {
+ got, err := trace.MarshalAnyJSON(tt.v)
+ var errMsg string
+ if err != nil {
+ errMsg = err.Error()
+ }
+ if (errMsg != "" && tt.err == "") || !strings.Contains(errMsg, tt.err) {
+ t.Fatalf("expect err msg: %s, actual: %s", tt.err, errMsg)
+ }
+ if tt.want != string(got) {
+ t.Fatalf("expect result: %s, actual: %s", tt.want, got)
+ }
+ })
+ }
+}
+
+func exampleReturnFunc() context.CancelFunc {
+ _, f := context.WithTimeout(context.TODO(), 10*time.Millisecond)
+ return f
+}
+
+func getObject() core.Object {
+ var recordedResult core.Object
+ fnInfo := functab.InfoFunc(exampleReturnFunc)
+ trap.WithInterceptor(&trap.Interceptor{Pre: func(ctx context.Context, f *core.FuncInfo, args, result core.Object) (data interface{}, err error) {
+ if fnInfo != f {
+ return nil, nil
+ }
+ recordedResult = result
+ return nil, nil
+ }}, func() {
+ exampleReturnFunc()
+ })
+ return recordedResult
+}
diff --git a/runtime/test/trace_panic_peek/trace_panic_peek_test.go b/runtime/test/trace_panic_peek/trace_panic_peek_test.go
index e4034ff1..b267a0ed 100644
--- a/runtime/test/trace_panic_peek/trace_panic_peek_test.go
+++ b/runtime/test/trace_panic_peek/trace_panic_peek_test.go
@@ -2,7 +2,6 @@ package main
import (
"bytes"
- "encoding/json"
"fmt"
"io"
"strings"
@@ -21,7 +20,7 @@ func TestTracePanicPeek(t *testing.T) {
var traceData []byte
trace.Options().OnComplete(func(root *trace.Root) {
var err error
- traceData, err = json.Marshal(root.Export())
+ traceData, err = trace.MarshalAnyJSON(root.Export())
if err != nil {
t.Fatal(err)
}
diff --git a/runtime/test/trap/trap_avoid_recursive_test.go b/runtime/test/trap/trap_avoid_recursive_test.go
new file mode 100644
index 00000000..f2e4a13d
--- /dev/null
+++ b/runtime/test/trap/trap_avoid_recursive_test.go
@@ -0,0 +1,59 @@
+package trap
+
+import (
+ "bytes"
+ "context"
+ "fmt"
+ "io"
+ "testing"
+
+ "github.com/xhd2015/xgo/runtime/core"
+ "github.com/xhd2015/xgo/runtime/trap"
+)
+
+// prints: pre->call_f->post
+// no repeation
+func TestNakedTrapShouldAvoidRecursive(t *testing.T) {
+ var recurseBuf bytes.Buffer
+ trap.AddInterceptor(&trap.Interceptor{
+ Pre: func(ctx context.Context, f *core.FuncInfo, args, result core.Object) (data interface{}, err error) {
+ fmt.Fprintf(&recurseBuf, "pre\n")
+ return nil, nil
+ },
+ Post: func(ctx context.Context, f *core.FuncInfo, args, result core.Object, data interface{}) (err error) {
+ fmt.Fprintf(&recurseBuf, "post\n")
+ return nil
+ },
+ })
+ f(&recurseBuf)
+ output := recurseBuf.String()
+ expect := "pre\ncall_f\npost\n"
+ if output != expect {
+ t.Fatalf("expect no recursive trap, output to be %q, actual: %q", expect, output)
+ }
+}
+
+func f(recurseBuf io.Writer) {
+ fmt.Fprintf(recurseBuf, "call_f\n")
+}
+
+func TestDeferredFuncShouldBeExecutedWhenAbort(t *testing.T) {
+ var recurseBuf bytes.Buffer
+ trap.AddInterceptor(&trap.Interceptor{
+ Pre: func(ctx context.Context, f *core.FuncInfo, args, result core.Object) (data interface{}, err error) {
+ fmt.Fprintf(&recurseBuf, "pre\n")
+ return nil, trap.ErrAbort
+ },
+ Post: func(ctx context.Context, f *core.FuncInfo, args, result core.Object, data interface{}) (err error) {
+ fmt.Fprintf(&recurseBuf, "post\n")
+ return nil
+ },
+ })
+ // NOTE: f's body is skipped
+ f(&recurseBuf)
+ output := recurseBuf.String()
+ expect := "pre\npost\n"
+ if output != expect {
+ t.Fatalf("expect no recursive trap, output to be %q, actual: %q", expect, output)
+ }
+}
diff --git a/runtime/test/trap_args/trap_args_test.go b/runtime/test/trap_args/trap_args_test.go
index 8ab4fb93..12a119d1 100644
--- a/runtime/test/trap_args/trap_args_test.go
+++ b/runtime/test/trap_args/trap_args_test.go
@@ -98,7 +98,7 @@ func callAndCheck(fn func(), check func(trapCtx context.Context, f *core.FuncInf
fn()
if callCount != 2 {
- fmt.Errorf("expect call trap twice, actual: %d", callCount)
+ return fmt.Errorf("expect call trap twice, actual: %d", callCount)
}
return nil
}
diff --git a/runtime/trace/marshal.go b/runtime/trace/marshal.go
new file mode 100644
index 00000000..bc699493
--- /dev/null
+++ b/runtime/trace/marshal.go
@@ -0,0 +1,62 @@
+package trace
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "os"
+ "reflect"
+
+ "github.com/xhd2015/xgo/runtime/core"
+ "github.com/xhd2015/xgo/runtime/functab"
+ "github.com/xhd2015/xgo/runtime/trap"
+)
+
+// MarshalAnyJSON marshals aribitray go value `v`
+// to JSON, when it encounters unmarshalable values
+// like func, chan, it will bypass these values.
+func MarshalAnyJSON(v interface{}) ([]byte, error) {
+ fn := functab.Info("encoding/json", "newTypeEncoder")
+ if fn == nil {
+ fmt.Fprintf(os.Stderr, "WARNING: encoding/json.newTypeEncoder not trapped(requires xgo).\n")
+ return json.Marshal(v)
+ }
+
+ // get the unmarshalable function
+ unmarshalable := getMarshaler(fn.Func, reflect.TypeOf(unmarshalableFunc))
+ var data []byte
+ var err error
+ // mock the encoding json
+ trap.WithOverride(&trap.Interceptor{
+ Post: func(ctx context.Context, f *core.FuncInfo, args, result core.Object, data interface{}) error {
+ if f != fn {
+ return nil
+ }
+ resField := result.GetFieldIndex(0)
+
+ // if unmarshalable, replace with an empty struct
+ if reflect.ValueOf(resField.Value()).Pointer() == reflect.ValueOf(unmarshalable).Pointer() {
+ resField.Set(getMarshaler(fn.Func, reflect.TypeOf(struct{}{})))
+ }
+ return nil
+ },
+ }, func() {
+ data, err = json.Marshal(v)
+ })
+ return data, err
+}
+
+func unmarshalableFunc() {}
+
+// newTypeEncoder signature: func(x Type,allowAddr bool) func()
+func getMarshaler(newTypeEncoder interface{}, v reflect.Type) interface{} {
+ var res interface{}
+ trap.Direct(func() {
+ results := reflect.ValueOf(newTypeEncoder).Call([]reflect.Value{
+ reflect.ValueOf(v),
+ reflect.ValueOf(false),
+ })
+ res = results[0].Interface()
+ })
+ return res
+}
diff --git a/runtime/trace/trace.go b/runtime/trace/trace.go
index 4902e811..cb884a64 100644
--- a/runtime/trace/trace.go
+++ b/runtime/trace/trace.go
@@ -2,7 +2,6 @@ package trace
import (
"context"
- "encoding/json"
"fmt"
"os"
"path/filepath"
@@ -17,8 +16,6 @@ import (
"github.com/xhd2015/xgo/runtime/trap"
)
-const __XGO_SKIP_TRAP = true
-
// hold goroutine stacks, keyed by goroutine ptr
var stackMap sync.Map // uintptr(goroutine) -> *Root
var testInfoMapping sync.Map // uintptr(goroutine) -> *testInfo
@@ -70,19 +67,27 @@ func __xgo_link_peek_panic() interface{} {
return nil
}
-var enabledGlobal int32
+var enabledGlobally bool
+var interceptorSet int32
-func Enable() {
- if getTraceOutput() == "off" {
- return
- }
+// Enable setup the trace interceptor
+// if called from init, the interceptor is enabled
+// globally. Otherwise locally
+func Enable() func() {
if __xgo_link_init_finished() {
- panic("Enable must be called from init")
- }
- if !atomic.CompareAndSwapInt32(&enabledGlobal, 0, 1) {
- return
+ var name string
+ key := uintptr(__xgo_link_getcurg())
+ tinfo, ok := testInfoMapping.Load(key)
+ if ok {
+ name = tinfo.(*testInfo).name
+ }
+ return enableLocal(&collectOpts{name: name})
}
+ enabledGlobally = true
setupInterceptor()
+ return func() {
+ panic("global trace cannot be turned off")
+ }
}
// executes f and collect its trace
@@ -119,9 +124,12 @@ func (c *collectOpts) Collect(f func()) {
collect(f, c)
}
-func setupInterceptor() func() {
+func setupInterceptor() {
+ if !atomic.CompareAndSwapInt32(&interceptorSet, 0, 1) {
+ return
+ }
// collect trace
- return trap.AddInterceptor(&trap.Interceptor{
+ trap.AddInterceptor(&trap.Interceptor{
Pre: func(ctx context.Context, f *core.FuncInfo, args core.Object, results core.Object) (interface{}, error) {
key := uintptr(__xgo_link_getcurg())
localOptStack, ok := collectingMap.Load(key)
@@ -131,6 +139,8 @@ func setupInterceptor() func() {
if len(l.list) > 0 {
localOpts = l.list[len(l.list)-1]
}
+ } else if !enabledGlobally {
+ return nil, nil
}
stack := &Stack{
FuncInfo: f,
@@ -168,6 +178,7 @@ func setupInterceptor() func() {
localOpts.root = root
}
// NOTE: for initial stack, the data is nil
+ // this will signal Post to emit a trace
return nil, nil
}
var root *Root
@@ -183,7 +194,6 @@ func setupInterceptor() func() {
return prevTop, nil
},
Post: func(ctx context.Context, f *core.FuncInfo, args core.Object, results core.Object, data interface{}) error {
- trap.Skip()
key := uintptr(__xgo_link_getcurg())
localOptStack, ok := collectingMap.Load(key)
@@ -193,6 +203,8 @@ func setupInterceptor() func() {
if len(l.list) > 0 {
localOpts = l.list[len(l.list)-1]
}
+ } else if !enabledGlobally {
+ return nil
}
var root *Root
if localOpts != nil {
@@ -227,25 +239,16 @@ func setupInterceptor() func() {
}
root.Top.End = int64(time.Since(root.Begin))
if data == nil {
+ root.Top = nil
// stack finished
if localOpts != nil {
- if localOpts.onComplete != nil {
- localOpts.onComplete(root)
- return nil
- }
- err := emitTrace(localOpts.name, root)
- if err != nil {
- return err
- }
+ // handled by local options
return nil
}
// global
stackMap.Delete(key)
- err := emitTrace("", root)
- if err != nil {
- return err
- }
+ emitTraceNoErr("", root)
return nil
}
// pop stack
@@ -262,28 +265,51 @@ type optStack struct {
}
func collect(f func(), collOpts *collectOpts) {
- if atomic.LoadInt32(&enabledGlobal) == 0 {
- cancel := setupInterceptor()
- defer cancel()
+ cancel := enableLocal(collOpts)
+ defer cancel()
+ f()
+}
+
+func enableLocal(collOpts *collectOpts) func() {
+ if collOpts == nil {
+ collOpts = &collectOpts{}
}
+ setupInterceptor()
key := uintptr(__xgo_link_getcurg())
if collOpts.name == "" {
collOpts.name = fmt.Sprintf("g_%x", uint(key))
}
+ if collOpts.root == nil {
+ collOpts.root = &Root{
+ Top: &Stack{},
+ Begin: time.Now(),
+ }
+ }
+ top := collOpts.root.Top
act, _ := collectingMap.LoadOrStore(key, &optStack{})
opts := act.(*optStack)
// push
opts.list = append(opts.list, collOpts)
- defer func() {
+ return func() {
// pop
opts.list = opts.list[:len(opts.list)-1]
if len(opts.list) == 0 {
collectingMap.Delete(key)
}
- }()
- f()
+
+ root := collOpts.root
+ root.Children = top.Children
+ root.Top = nil
+ // root.Children =
+ // call complete
+ if collOpts.onComplete != nil {
+ collOpts.onComplete(root)
+ } else {
+ emitTraceNoErr(collOpts.name, root)
+ }
+ }
}
func getTraceOutput() string {
@@ -310,21 +336,20 @@ func fmtStack(root *Root) (data []byte, err error) {
if marshalStack != nil {
return marshalStack(root)
}
- return json.Marshal(root.Export())
+ return MarshalAnyJSON(root.Export())
+}
+
+func emitTraceNoErr(name string, root *Root) {
+ emitTrace(name, root)
}
// this should also be marked as trap.Skip()
// TODO: may add callback for this
func emitTrace(name string, root *Root) error {
- if name == "" {
- key := uintptr(__xgo_link_getcurg())
- tinfo, ok := testInfoMapping.Load(key)
- if ok {
- name = tinfo.(*testInfo).name
- }
- }
-
xgoTraceOutput := getTraceOutput()
+ if xgoTraceOutput == "off" {
+ return nil
+ }
useStdout := xgoTraceOutput == "stdout"
subName := name
if name == "" {
diff --git a/runtime/trap/inspect.go b/runtime/trap/inspect.go
index 3766a59f..a9e782e7 100644
--- a/runtime/trap/inspect.go
+++ b/runtime/trap/inspect.go
@@ -1,7 +1,6 @@
package trap
import (
- "context"
"fmt"
"reflect"
"strings"
@@ -22,6 +21,8 @@ func Inspect(f interface{}) (recvPtr interface{}, funcInfo *core.FuncInfo) {
return
}
+type inspectingFunc func(f *core.FuncInfo, recv interface{}, pc uintptr)
+
func InspectPC(f interface{}) (recvPtr interface{}, funcInfo *core.FuncInfo, funcPC uintptr, trappingPC uintptr) {
fn := reflect.ValueOf(f)
// try as a variable
@@ -49,13 +50,13 @@ func InspectPC(f interface{}) (recvPtr interface{}, funcInfo *core.FuncInfo, fun
// is implemented as closure, so we need to distinguish
// between true closure and generic function
var origFullName string
+ var needRecv bool
if !maybeClosureGeneric {
origFullName = __xgo_link_get_pc_name(funcPC)
fullName := origFullName
- var isMethod bool
if strings.HasSuffix(fullName, methodSuffix) {
- isMethod = true
+ needRecv = true
fullName = fullName[:len(fullName)-len(methodSuffix)]
}
@@ -66,23 +67,24 @@ func InspectPC(f interface{}) (recvPtr interface{}, funcInfo *core.FuncInfo, fun
if funcInfo.Closure && GenericImplIsClosure {
// maybe closure generic
maybeClosureGeneric = true
- } else if !isMethod && !funcInfo.Generic {
+ } else if !needRecv && !funcInfo.Generic {
// plain function(not method, not generic)
return nil, funcInfo, funcPC, 0
}
}
- WithInterceptor(&Interceptor{
- Pre: func(ctx context.Context, f *core.FuncInfo, args, result core.Object) (data interface{}, err error) {
- trappingPC = GetTrappingPC()
- funcInfo = f
- if !maybeClosureGeneric {
- // closure cannot have receiver pointer
- recvPtr = args.GetFieldIndex(0).Ptr()
- }
- return nil, ErrAbort
- },
- }, func() {
+ key := uintptr(__xgo_link_getcurg())
+ ensureTrapInstall()
+ inspectingMap.Store(key, inspectingFunc(func(f *core.FuncInfo, recv interface{}, pc uintptr) {
+ trappingPC = pc
+ funcInfo = f
+ if needRecv {
+ // closure cannot have receiver pointer
+ recvPtr = recv
+ }
+ }))
+ defer inspectingMap.Delete(key)
+ callFn := func() {
fnType := fn.Type()
nargs := fnType.NumIn()
args := make([]reflect.Value, nargs)
@@ -94,8 +96,9 @@ func InspectPC(f interface{}) (recvPtr interface{}, funcInfo *core.FuncInfo, fun
} else {
fn.CallSlice(args)
}
- })
- if !maybeClosureGeneric && recvPtr == nil {
+ }
+ callFn()
+ if needRecv && recvPtr == nil {
if origFullName == "" {
origFullName = __xgo_link_get_pc_name(funcPC)
}
diff --git a/runtime/trap/interceptor.go b/runtime/trap/interceptor.go
index 65706281..db53e5bd 100644
--- a/runtime/trap/interceptor.go
+++ b/runtime/trap/interceptor.go
@@ -44,12 +44,13 @@ type Interceptor struct {
}
var interceptors []*Interceptor
-var localInterceptors sync.Map // goroutine ptr -> *interceptorList
+var localInterceptors sync.Map // goroutine ptr -> *interceptorGroup
func AddInterceptor(interceptor *Interceptor) func() {
ensureTrapInstall()
if __xgo_link_init_finished() {
- return addLocalInterceptor(interceptor)
+ dispose, _ := addLocalInterceptor(interceptor, false)
+ return dispose
}
interceptors = append(interceptors, interceptor)
return func() {
@@ -68,22 +69,42 @@ func AddInterceptor(interceptor *Interceptor) func() {
// even from init because it will be soon cleared
// without causing concurrent issues.
func WithInterceptor(interceptor *Interceptor, f func()) {
- dispose := addLocalInterceptor(interceptor)
+ dispose, _ := addLocalInterceptor(interceptor, false)
defer dispose()
f()
}
+// WithOverride override local and global interceptors
+// in current goroutine temporarily, it returns a function
+// that can be used to cancel the override.
+func WithOverride(interceptor *Interceptor, f func()) {
+ _, disposeGroup := addLocalInterceptor(interceptor, true)
+ defer disposeGroup()
+ f()
+}
+
func GetInterceptors() []*Interceptor {
return interceptors
}
func GetLocalInterceptors() []*Interceptor {
+ group := getLocalInterceptorGroup()
+ if group == nil {
+ return nil
+ }
+ gi := group.currentGroupInterceptors()
+ if gi == nil {
+ return nil
+ }
+ return gi.list
+}
+func getLocalInterceptorGroup() *interceptorGroup {
key := uintptr(__xgo_link_getcurg())
val, ok := localInterceptors.Load(key)
if !ok {
return nil
}
- return val.(*interceptorList).interceptors
+ return val.(*interceptorGroup)
}
func ClearLocalInterceptors() {
@@ -91,43 +112,70 @@ func ClearLocalInterceptors() {
}
func GetAllInterceptors() []*Interceptor {
- locals := GetLocalInterceptors()
+ res, _ := getAllInterceptors()
+ return res
+}
+
+func getAllInterceptors() ([]*Interceptor, int) {
+ group := getLocalInterceptorGroup()
+ var locals []*Interceptor
+ var g int
+ if group != nil {
+ gi := group.currentGroupInterceptors()
+ if gi != nil {
+ g = group.currentGroup()
+ if gi.override {
+ return gi.list, g
+ }
+ locals = gi.list
+ }
+ }
global := GetInterceptors()
if len(locals) == 0 {
- return global
+ return global, g
}
if len(global) == 0 {
- return locals
+ return locals, g
}
// run locals first(in reversed order)
- return append(global, locals...)
+ return append(global[:len(global):len(global)], locals...), g
}
// returns a function to dispose the key
// NOTE: if not called correctly,there might be memory leak
-func addLocalInterceptor(interceptor *Interceptor) func() {
+func addLocalInterceptor(interceptor *Interceptor, override bool) (removeInterceptor func(), removeGroup func()) {
ensureTrapInstall()
key := uintptr(__xgo_link_getcurg())
- list := &interceptorList{}
+ list := &interceptorGroup{}
val, loaded := localInterceptors.LoadOrStore(key, list)
if loaded {
- list = val.(*interceptorList)
+ list = val.(*interceptorGroup)
}
- list.interceptors = append(list.interceptors, interceptor)
+ // ensure at least one group
+ if override || list.groupsEmpty() {
+ list.enterNewGroup(override)
+ }
+ g := list.currentGroup()
+ list.appendToCurrentGroup(interceptor)
- removed := false
+ removedInterceptor := false
// used to remove the local interceptor
- return func() {
- if removed {
+ removeInterceptor = func() {
+ if removedInterceptor {
panic(fmt.Errorf("remove interceptor more than once"))
}
curKey := uintptr(__xgo_link_getcurg())
if key != curKey {
panic(fmt.Errorf("remove interceptor from another goroutine"))
}
- removed = true
+ curG := list.currentGroup()
+ if curG != g {
+ panic(fmt.Errorf("interceptor group changed: previous=%d, current=%d", g, curG))
+ }
+ interceptors := list.groups[g].list
+ removedInterceptor = true
var idx int = -1
- for i, intc := range list.interceptors {
+ for i, intc := range interceptors {
if intc == interceptor {
idx = i
break
@@ -136,20 +184,79 @@ func addLocalInterceptor(interceptor *Interceptor) func() {
if idx < 0 {
panic(fmt.Errorf("interceptor leaked"))
}
- n := len(list.interceptors)
+ n := len(interceptors)
for i := idx; i < n-1; i++ {
- list.interceptors[i] = list.interceptors[i+1]
+ interceptors[i] = interceptors[i+1]
}
- list.interceptors = list.interceptors[:n-1]
- if len(list.interceptors) == 0 {
- // remove the entry from map to prevent memory leak
- localInterceptors.Delete(key)
+ interceptors = interceptors[:n-1]
+ list.groups[g].list = interceptors
+ }
+
+ removedGroup := false
+ removeGroup = func() {
+ if removedGroup {
+ panic(fmt.Errorf("remove group more than once"))
}
+ curKey := uintptr(__xgo_link_getcurg())
+ if key != curKey {
+ panic(fmt.Errorf("remove group from another goroutine"))
+ }
+ curG := list.currentGroup()
+ if curG != g {
+ panic(fmt.Errorf("interceptor group changed: previous=%d, current=%d", g, curG))
+ }
+ list.exitGroup()
+ removedInterceptor = true
}
+
+ return removeInterceptor, removeGroup
+}
+
+type interceptorGroup struct {
+ groups []*interceptorList
}
type interceptorList struct {
- interceptors []*Interceptor
+ override bool
+ list []*Interceptor
+}
+
+func (c *interceptorList) append(interceptor *Interceptor) {
+ c.list = append(c.list, interceptor)
+}
+
+func (c *interceptorGroup) appendToCurrentGroup(interceptor *Interceptor) {
+ g := c.currentGroup()
+ c.groups[g].append(interceptor)
+}
+
+func (c *interceptorGroup) groupsEmpty() bool {
+ return len(c.groups) == 0
+}
+func (c *interceptorGroup) currentGroup() int {
+ n := len(c.groups)
+ return n - 1
+}
+func (c *interceptorGroup) currentGroupInterceptors() *interceptorList {
+ g := c.currentGroup()
+ if g < 0 {
+ return nil
+ }
+ return c.groups[g]
+}
+
+func (c *interceptorGroup) enterNewGroup(override bool) {
+ c.groups = append(c.groups, &interceptorList{
+ override: override,
+ list: make([]*Interceptor, 0, 1),
+ })
+}
+func (c *interceptorGroup) exitGroup() {
+ n := len(c.groups)
+ if n == 0 {
+ panic("exit no group")
+ }
+ c.groups = c.groups[:n-1]
}
func clearLocalInterceptorsAndMark() {
@@ -157,5 +264,5 @@ func clearLocalInterceptorsAndMark() {
localInterceptors.Delete(key)
bypassMapping.Delete(key)
- clearTrappingMark()
+ clearTrappingMarkAllGroup()
}
diff --git a/runtime/trap/trap.go b/runtime/trap/trap.go
index 81a5955f..131c3de0 100644
--- a/runtime/trap/trap.go
+++ b/runtime/trap/trap.go
@@ -15,10 +15,20 @@ var setupOnce sync.Once
func ensureTrapInstall() {
setupOnce.Do(func() {
- __xgo_link_set_trap(trapFunc)
- __xgo_link_set_trap_var(trapVar)
+ // do not capture trap before init finished
+ if __xgo_link_init_finished() {
+ __xgo_link_set_trap(trapFunc)
+ __xgo_link_set_trap_var(trapVar)
+ } else {
+ // deferred
+ __xgo_link_on_init_finished(func() {
+ __xgo_link_set_trap(trapFunc)
+ __xgo_link_set_trap_var(trapVar)
+ })
+ }
})
}
+
func init() {
__xgo_link_on_gonewproc(func(g uintptr) {
if isByPassing() {
@@ -31,9 +41,11 @@ func init() {
copyInterceptors := make([]*Interceptor, len(interceptors))
copy(copyInterceptors, interceptors)
- // inherit interceptors
- localInterceptors.Store(g, &interceptorList{
- interceptors: copyInterceptors,
+ // inherit interceptors of last group
+ localInterceptors.Store(g, &interceptorGroup{
+ groups: []*interceptorList{{
+ list: copyInterceptors,
+ }},
})
})
}
@@ -48,6 +60,9 @@ func __xgo_link_set_trap_var(trap func(pkgPath string, name string, tmpVarAddr i
func __xgo_link_on_gonewproc(f func(g uintptr)) {
fmt.Fprintln(os.Stderr, "WARNING: failed to link __xgo_link_on_gonewproc(requires xgo).")
}
+func __xgo_link_on_init_finished(f func()) {
+ fmt.Fprintln(os.Stderr, "WARNING: failed to link __xgo_link_on_init_finished(requires xgo).")
+}
// Skip serves as mark to tell xgo not insert
// trap instructions for the function that
@@ -60,17 +75,26 @@ func Skip() {}
var trappingMark sync.Map // -> struct{}{}
var trappingPC sync.Map // -> PC
+var inspectingMap sync.Map // -> interceptor
+
// link to runtime
// xgo:notrap
func trapFunc(pkgPath string, identityName string, generic bool, pc uintptr, recv interface{}, args []interface{}, results []interface{}) (func(), bool) {
- interceptors := GetAllInterceptors()
- n := len(interceptors)
- if n == 0 {
+ if isByPassing() {
return nil, false
}
- // setup context
- setTrappingPC(pc)
- defer clearTrappingPC()
+ inspectingFn, inspecting := inspectingMap.Load(uintptr(__xgo_link_getcurg()))
+ var interceptors []*Interceptor
+ var group int
+ var n int
+ if !inspecting {
+ // never trap any function from runtime
+ interceptors, group = getAllInterceptors()
+ n = len(interceptors)
+ if n == 0 {
+ return nil, false
+ }
+ }
// NOTE: this may return nil for generic template
var f *core.FuncInfo
@@ -90,11 +114,23 @@ func trapFunc(pkgPath string, identityName string, generic bool, pc uintptr, rec
// let go to the next interceptor
return nil, false
}
- return trap(f, interceptors, recv, args, results)
+ if inspecting {
+ inspectingFn.(inspectingFunc)(f, recv, pc)
+ // abort,never really call the target
+ return nil, true
+ }
+
+ // setup context
+ setTrappingPC(pc)
+ defer clearTrappingPC()
+ return trap(f, interceptors, group, recv, args, results)
}
func trapVar(pkgPath string, name string, tmpVarAddr interface{}, takeAddr bool) {
- interceptors := GetAllInterceptors()
+ if isByPassing() {
+ return
+ }
+ interceptors, group := getAllInterceptors()
n := len(interceptors)
if n == 0 {
return
@@ -111,18 +147,15 @@ func trapVar(pkgPath string, name string, tmpVarAddr interface{}, takeAddr bool)
return
}
// NOTE: stop always ignored because this is a simple get
- post, _ := trap(fnInfo, interceptors, nil, nil, []interface{}{tmpVarAddr})
+ post, _ := trap(fnInfo, interceptors, group, nil, nil, []interface{}{tmpVarAddr})
if post != nil {
// NOTE: must in defer, because in post we
// may capture panic
defer post()
}
}
-func trap(f *core.FuncInfo, interceptors []*Interceptor, recv interface{}, args []interface{}, results []interface{}) (func(), bool) {
- if isByPassing() {
- return nil, false
- }
- dispose := setTrappingMark()
+func trap(f *core.FuncInfo, interceptors []*Interceptor, group int, recv interface{}, args []interface{}, results []interface{}) (func(), bool) {
+ dispose := setTrappingMark(group)
if dispose == nil {
return nil, false
}
@@ -215,6 +248,8 @@ func trap(f *core.FuncInfo, interceptors []*Interceptor, recv interface{}, args
ctx = context.TODO()
}
+ var firstPreErr error
+
abortIdx := -1
n := len(interceptors)
dataList := make([]interface{}, n)
@@ -227,50 +262,31 @@ func trap(f *core.FuncInfo, interceptors []*Interceptor, recv interface{}, args
data, err := interceptor.Pre(ctx, f, req, resObject)
dataList[i] = data
if err != nil {
- if err == ErrAbort {
- abortIdx = i
- // aborted
- break
- }
- // handle error gracefully
- if perr != nil {
- *perr = err
- return nil, true
- } else {
- panic(err)
- }
- }
- }
- if abortIdx >= 0 {
- // run Post immediately
- for i := abortIdx; i < n; i++ {
- interceptor := interceptors[i]
- if interceptor.Post == nil {
- continue
- }
- err := interceptor.Post(ctx, f, req, resObject, dataList[i])
- if err != nil {
- if err == ErrAbort {
- return nil, true
- }
- if perr != nil {
- *perr = err
- return nil, true
- } else {
- panic(err)
- }
- }
+ // always break on error
+ firstPreErr = err
+ abortIdx = i
+ break
}
- return nil, true
}
+ // not really an error
+ if firstPreErr == ErrAbort {
+ firstPreErr = nil
+ }
+ // always run post in defer
return func() {
- dispose := setTrappingMark()
+ dispose := setTrappingMark(group)
if dispose == nil {
return
}
defer dispose()
- for i := 0; i < n; i++ {
+
+ var lastPostErr error = firstPreErr
+ idx := 0
+ if abortIdx != -1 {
+ idx = abortIdx
+ }
+ for i := idx; i < n; i++ {
interceptor := interceptors[i]
if interceptor.Post == nil {
continue
@@ -280,15 +296,17 @@ func trap(f *core.FuncInfo, interceptors []*Interceptor, recv interface{}, args
if err == ErrAbort {
return
}
- if perr != nil {
- *perr = err
- return
- } else {
- panic(err)
- }
+ lastPostErr = err
+ }
+ }
+ if lastPostErr != nil {
+ if perr != nil {
+ *perr = lastPostErr
+ } else {
+ panic(lastPostErr)
}
}
- }, false
+ }, abortIdx != -1
}
func GetTrappingPC() uintptr {
@@ -299,18 +317,30 @@ func GetTrappingPC() uintptr {
}
return val.(uintptr)
}
-func setTrappingMark() func() {
+
+type trappingGroup struct {
+ m map[int]bool
+}
+
+func setTrappingMark(group int) func() {
key := uintptr(__xgo_link_getcurg())
- _, trapping := trappingMark.LoadOrStore(key, struct{}{})
- if trapping {
- return nil
+ g := &trappingGroup{}
+ v, loaded := trappingMark.LoadOrStore(key, g)
+ if loaded {
+ g = v.(*trappingGroup)
+ if g.m[group] {
+ return nil
+ }
+ } else {
+ g.m = make(map[int]bool, 1)
}
+ g.m[group] = true
return func() {
- trappingMark.Delete(key)
+ g.m[group] = false
}
}
-func clearTrappingMark() {
+func clearTrappingMarkAllGroup() {
key := uintptr(__xgo_link_getcurg())
trappingMark.Delete(key)
}
diff --git a/script/build-release/fixup.go b/script/build-release/fixup.go
index dbe3f424..ad2770c7 100644
--- a/script/build-release/fixup.go
+++ b/script/build-release/fixup.go
@@ -1,6 +1,7 @@
package main
import (
+ "os"
"strings"
"github.com/xhd2015/xgo/script/build-release/revision"
@@ -8,22 +9,46 @@ import (
)
// fixup src dir to prepare for release build
-func fixupSrcDir(targetDir string, rev string) error {
- err := updateRevisions(targetDir, false, rev)
+func fixupSrcDir(targetDir string, rev string) (restore func() error, err error) {
+ restore, err = updateRevisions(targetDir, false, rev)
if err != nil {
- return err
+ return restore, err
}
- return nil
+ return restore, nil
}
-func updateRevisions(targetDir string, unlink bool, rev string) error {
+func stageFile(file string) (restore func() error, err error) {
+ content, err := os.ReadFile(file)
+ if err != nil {
+ return nil, err
+ }
+ return func() error {
+ return os.WriteFile(file, content, 0755)
+ }, nil
+}
+
+func updateRevisions(targetDir string, unlink bool, rev string) (restore func() error, err error) {
// unlink files because all files are symlink
files := revision.GetVersionFiles(targetDir)
+ var restoreFiles []func() error
+ for _, file := range files {
+ r, err := stageFile(file)
+ if err != nil {
+ return nil, err
+ }
+ restoreFiles = append(restoreFiles, r)
+ }
+ restore = func() error {
+ for _, r := range restoreFiles {
+ r()
+ }
+ return nil
+ }
if unlink {
for _, file := range files {
err := unlinkFile(file)
if err != nil {
- return err
+ return restore, err
}
}
}
@@ -31,10 +56,10 @@ func updateRevisions(targetDir string, unlink bool, rev string) error {
for _, file := range files {
err := revision.PatchVersionFile(file, rev, false)
if err != nil {
- return err
+ return restore, err
}
}
- return nil
+ return restore, nil
}
func gitListWorkingTreeChangedFiles(dir string) ([]string, error) {
diff --git a/script/build-release/main.go b/script/build-release/main.go
index a87144ed..1c61b811 100644
--- a/script/build-release/main.go
+++ b/script/build-release/main.go
@@ -173,7 +173,10 @@ func buildRelease(releaseDirName string, installLocal bool, localName string, de
rev += fmt.Sprintf("_DEV_%s", time.Now().UTC().Format("2006-01-02T15:04:05Z"))
}
- err = fixupSrcDir(tmpSrcDir, rev)
+ restore, err := fixupSrcDir(tmpSrcDir, rev)
+ if restore != nil {
+ defer restore()
+ }
if err != nil {
return err
}
diff --git a/script/run-test/main.go b/script/run-test/main.go
index 2a690208..e0aa8b43 100644
--- a/script/run-test/main.go
+++ b/script/run-test/main.go
@@ -15,9 +15,9 @@ import (
// usage:
//
// go run ./script/run-test/ --include go1.19.13
-// go run ./script/run-test/ --include go1.19.13 -count=1
+// go run ./script/run-test/ -count=1 --include go1.19.13
// go run ./script/run-test/ --include go1.19.13 -run TestHelloWorld -v
-// go run ./script/run-test/ --include go1.17.13 --include go1.18.10 --include go1.19.13 --include go1.20.14 --include go1.21.8 --include go1.22.1 -count=1
+// go run ./script/run-test/ -count=1 --include go1.17.13 --include go1.18.10 --include go1.19.13 --include go1.20.14 --include go1.21.8 --include go1.22.2
// go run ./script/run-test/ -cover -coverpkg github.com/xhd2015/xgo/runtime/... -coverprofile covers/cover.out --include go1.21.8
// runtime test:
@@ -34,10 +34,13 @@ import (
// TODO: remove duplicate test between xgo test and runtime test
var runtimeSubTests = []string{
- "trace_panic_peek",
"func_list",
- "trap_inspect_func",
"trap",
+ "trap_inspect_func",
+ "trap_args",
+ "trace",
+ "trace_marshal",
+ "trace_panic_peek",
"mock_func",
"mock_method",
"mock_by_name",
@@ -45,7 +48,6 @@ var runtimeSubTests = []string{
"mock_stdlib",
"mock_generic",
"mock_var",
- "trap_args",
"patch",
"patch_const",
}
@@ -397,7 +399,7 @@ func doRunTest(goroot string, kind testKind, args []string, tests []string) erro
)
}
case testKind_runtimeSubTest:
- testArgs = []string{"run", "./cmd/xgo", "test", "--project-dir", "runtime/test", "-tags", "dev"}
+ testArgs = []string{"run", "-tags", "dev", "./cmd/xgo", "test", "--project-dir", "runtime/test"}
testArgs = append(testArgs, args...)
if len(tests) > 0 {
testArgs = append(testArgs, tests...)
diff --git a/test/testdata/mock/mock.go b/test/testdata/mock/mock.go
index 1e03f7d6..fe1dc332 100644
--- a/test/testdata/mock/mock.go
+++ b/test/testdata/mock/mock.go
@@ -11,7 +11,7 @@ import (
func main() {
if os.Getenv("XGO_TEST_HAS_INSTRUMENT") != "false" {
- mock.AddFuncInterceptor(hello, func(ctx context.Context, fn *core.FuncInfo, args, results core.Object) error {
+ mock.Mock(hello, func(ctx context.Context, fn *core.FuncInfo, args, results core.Object) error {
a := args.GetField("a")
a.Set("mock:" + a.Value().(string))
return mock.ErrCallOld
diff --git a/test/testdata/mock_res/main.go b/test/testdata/mock_res/main.go
index 417e2c3e..ee0603e0 100644
--- a/test/testdata/mock_res/main.go
+++ b/test/testdata/mock_res/main.go
@@ -13,7 +13,7 @@ func main() {
before := add(5, 2)
fmt.Printf("before mock: add(5,2)=%d\n", before)
if os.Getenv("XGO_TEST_HAS_INSTRUMENT") != "false" {
- mock.AddFuncInterceptor(add, func(ctx context.Context, fn *core.FuncInfo, args core.Object, results core.Object) error {
+ mock.Mock(add, func(ctx context.Context, fn *core.FuncInfo, args core.Object, results core.Object) error {
a := args.GetField("a").Value().(int)
b := args.GetField("b").Value().(int)
res := a - b
diff --git a/test/trace_test.go b/test/trace_test.go
index 94476bc0..d48f932f 100644
--- a/test/trace_test.go
+++ b/test/trace_test.go
@@ -16,7 +16,7 @@ func TestTraceJSONOutput(t *testing.T) {
t.Fatal(getErrMsg(err))
}
- // t.Logf("%s", output)
+ t.Logf("%s", output)
expectLines := []string{
// output
"A\nB\nC\nC\n",
@@ -53,7 +53,7 @@ func TestTracePanicCapture(t *testing.T) {
t.Fatal(getErrMsg(err))
}
- // t.Logf("%s", output)
+ t.Logf("%s", output)
// output
expectOutputLines := []string{
diff --git a/test/trap_test.go b/test/trap_test.go
index 4dde9c3d..3cdfc674 100644
--- a/test/trap_test.go
+++ b/test/trap_test.go
@@ -24,7 +24,7 @@ func TestTrap(t *testing.T) {
// go test -run TestTrapNormalBuildShouldWarn -v ./test
func TestTrapNormalBuildShouldWarn(t *testing.T) {
t.Parallel()
- expectOrigStderr := "WARNING: failed to link __xgo_link_set_trap(requires xgo)."
+ expectOrigStderr := "WARNING: failed to link __xgo_link_init_finished(requires xgo)."
var origStderr bytes.Buffer
runAndCheckInstrumentOutput(t, "./testdata/trap", func(output string) error {