diff --git a/README.md b/README.md index a70bd164..715c9f16 100644 --- a/README.md +++ b/README.md @@ -240,6 +240,8 @@ func main(){ } ``` +Trap also have a helper function called `Direct(fn)`, which can be used to bypass any trap and mock interceptors, calling directly into the original function. + ## Mock Mock simplifies the process of setting up Trap interceptors. @@ -271,7 +273,7 @@ mock.Mock(v.Method, interceptor) mock.Mock(closure, interceptor) ``` -Arguments: +Parameters: - If `fn` is a simple function(i.e. a package level function, or a function owned by a type, or a closure(yes, we do support mocking closures)),then all call to that function will be intercepted, - If `fn` is a method(i.e. `file.Read`),then only call to the instance will be intercepted, other instances will not be affected @@ -323,6 +325,47 @@ func TestMethodMock(t *testing.T){ **Notice for mocking stdlib**: due to performance and security impact, only a few packages and functions of stdlib can be mocked, the list can be found at [runtime/mock/stdlib.md](./runtime/mock/stdlib.md). If you want to mock additional stdlib functions, please discussion in [Issue#6](https://github.com/xhd2015/xgo/issues/6). +## Patch +The `runtime/mock` package also provides another api: +- `Patch(fn,replacer) func()` + +Parameters: +- `fn` same as described in [Mock](#mock) section +- `replacer` another function that will replace `fn` + +NOTE: `fn` and `replacer` should have the same signature. + +Return: +- a `func()` can be used to remove the replacer earlier before current goroutine exits + +Patch replaces the given `fn` with `replacer` in current goroutine. It will remove the replacer once current goroutine exits. + +Example: +```go +package patch_test + +import ( + "testing" + + "github.com/xhd2015/xgo/runtime/mock" +) + +func greet(s string) string { + return "hello " + s +} + +func TestPatchFunc(t *testing.T) { + mock.Patch(greet, func(s string) string { + return "mock " + s + }) + + res := greet("world") + if res != "mock world" { + t.Fatalf("expect patched result to be %q, actual: %q", "mock world", res) + } +} +``` + ## Trace It is painful when debugging with a deep call stack. diff --git a/README_zh_cn.md b/README_zh_cn.md index 64b76d7f..d0882b28 100644 --- a/README_zh_cn.md +++ b/README_zh_cn.md @@ -234,6 +234,8 @@ func main(){ } ``` +Trap还提供了一个`Direct(fn)`的函数, 用于跳过拦截器, 直接调用到原始的函数。 + ## Mock Mock简化了设置拦截器的步骤, 并允许仅对特定的函数进行拦截。 @@ -317,6 +319,46 @@ func TestMethodMock(t *testing.T){ **关于标准库Mock的注意事项**: 出于性能和安全考虑, 标准库中只有一部分包和函数能被Mock, 这个List可以在[runtime/mock/stdlib.md](./runtime/mock/stdlib.md)找到. 如果你需要Mock的标准库函数不在列表中, 可以在[Issue#6](https://github.com/xhd2015/xgo/issues/6)中进行评论。 +## Patch +`runtime/mock`还提供了另一个API: +- `Patch(fn,replacer) func()` + +参数: +- `fn` 与[Mock](#mock)中的第一个参数相同 +- `replacer`一个用来替换`fn`的函数 + +注意: `replacer`应当和`fn`具有同样的签名。 + +返回值: +- 一个`func()`, 用来提前移除`replacer` + +Patch将`fn`替换为`replacer`,这个替换仅对当前goroutine生效.在当前Goroutine退出后, `replacer`被自动移除。 + +例子: +```go +package patch_test + +import ( + "testing" + + "github.com/xhd2015/xgo/runtime/mock" +) + +func greet(s string) string { + return "hello " + s +} + +func TestPatchFunc(t *testing.T) { + mock.Patch(greet, func(s string) string { + return "mock " + s + }) + + res := greet("world") + if res != "mock world" { + t.Fatalf("expect patched result to be %q, actual: %q", "mock world", res) + } +} +``` ## Trace 在调试一个非常深的调用栈时, 通常会感觉非常痛苦, 并且效率低下。 diff --git a/cmd/xgo/patch_runtime.go b/cmd/xgo/patch_runtime.go index fd64901a..05f56bb1 100644 --- a/cmd/xgo/patch_runtime.go +++ b/cmd/xgo/patch_runtime.go @@ -5,6 +5,7 @@ import ( "fmt" "os" "path/filepath" + "strings" "github.com/xhd2015/xgo/cmd/xgo/patch" "github.com/xhd2015/xgo/support/filecopy" @@ -20,6 +21,10 @@ func patchRuntimeAndTesting(goroot string) error { if err != nil { return err } + err = patchRuntimeTime(goroot) + if err != nil { + return err + } return nil } @@ -80,6 +85,7 @@ func addRuntimeFunctions(goroot string, goVersion *goinfo.GoVersion, xgoSrc stri } func patchRuntimeProc(goroot string) error { + procFile := filepath.Join(goroot, "src", "runtime", "proc.go") anchors := []string{ "func main() {", "doInit(", "runtime_inittask", ")", // first doInit for runtime @@ -87,8 +93,7 @@ func patchRuntimeProc(goroot string) error { "close(main_init_done)", "\n", } - procGo := filepath.Join(goroot, "src", "runtime", "proc.go") - err := editFile(procGo, func(content string) (string, error) { + err := editFile(procFile, func(content string) (string, error) { content = addContentAfter(content, "/**/", "/**/", anchors, patch.RuntimeProcPatch) // goexit1() is called for every exited goroutine @@ -132,3 +137,42 @@ func patchRuntimeTesting(goroot string) error { return content, nil }) } + +// only required if need to mock time.Sleep +func patchRuntimeTime(goroot string) error { + runtimeTimeFile := filepath.Join(goroot, "src", "runtime", "time.go") + timeSleepFile := filepath.Join(goroot, "src", "time", "sleep.go") + + err := editFile(runtimeTimeFile, func(content string) (string, error) { + content = replaceContentAfter(content, + "/**/", "/**/", + []string{}, + "//go:linkname timeSleep time.Sleep\nfunc timeSleep(ns int64) {", + "//go:linkname timeSleep time.runtimeSleep\nfunc timeSleep(ns int64) {", + ) + return content, nil + }) + if err != nil { + return err + } + + err = editFile(timeSleepFile, func(content string) (string, error) { + content = replaceContentAfter(content, + "/**/", "/**/", + []string{}, + "func Sleep(d Duration)", + strings.Join([]string{ + "func runtimeSleep(d Duration)", + "func Sleep(d Duration){", + " runtimeSleep(d)", + "}", + }, "\n"), + ) + return content, nil + }) + if err != nil { + return err + } + + return nil +} diff --git a/cmd/xgo/version.go b/cmd/xgo/version.go index 2094c7d4..7e49a239 100644 --- a/cmd/xgo/version.go +++ b/cmd/xgo/version.go @@ -3,8 +3,8 @@ package main import "fmt" const VERSION = "1.0.15" -const REVISION = "05e21215595deccfc08d49cb2d502e0d48b3cf4b+1" -const NUMBER = 151 +const REVISION = "2861a46387df90bcadae7651dc6e0d2db8ab0148+1" +const NUMBER = 152 func getRevision() string { return fmt.Sprintf("%s %s BUILD_%d", VERSION, REVISION, NUMBER) diff --git a/patch/ctxt/ctx.go b/patch/ctxt/ctx.go index ee98a77a..f9ed3ea4 100644 --- a/patch/ctxt/ctx.go +++ b/patch/ctxt/ctx.go @@ -25,7 +25,7 @@ func SkipPackageTrap() bool { // allow http pkgPath := GetPkgPath() - if pkgPath == "net/http" || pkgPath == "net" || pkgPath == "time" || pkgPath == "os" || pkgPath == "os/exec" { + if _, ok := stdWhitelist[pkgPath]; ok { return false } return true @@ -45,11 +45,21 @@ func SkipPackageTrap() bool { } var stdWhitelist = map[string]map[string]bool{ + // "runtime": map[string]bool{ + // "timeSleep": true, + // }, "os": map[string]bool{ // starts with Get + "OpenFile": true, }, "time": map[string]bool{ - "Now": true, + "Now": true, + // time.Sleep is special: + // if trapped like normal functions + // runtime/time.go:178:6: ns escapes to heap, not allowed in runtime + // there are special handling of this, see cmd/xgo/patch_runtime patchRuntimeTime + "Sleep": true, // NOTE: time.Sleep links to runtime.timeSleep + "NewTicker": true, "Time.Format": true, }, "os/exec": map[string]bool{ diff --git a/patch/syntax/rewrite.go b/patch/syntax/rewrite.go index 503efa0c..e73653b6 100644 --- a/patch/syntax/rewrite.go +++ b/patch/syntax/rewrite.go @@ -41,6 +41,10 @@ func rewriteStdAndGenericFuncs(funcDecls []*DeclInfo, pkgPath string) { if fn.Closure { continue } + if fn.FuncDecl.Body == nil { + // no body, may be linked + continue + } // stdlib and generic if !base.Flag.Std { @@ -48,7 +52,6 @@ func rewriteStdAndGenericFuncs(funcDecls []*DeclInfo, pkgPath string) { continue } } - fnDecl := fn.FuncDecl pos := fn.FuncDecl.Pos() diff --git a/patch/syntax/syntax.go b/patch/syntax/syntax.go index 7117cf64..d14a2923 100644 --- a/patch/syntax/syntax.go +++ b/patch/syntax/syntax.go @@ -255,7 +255,7 @@ func shouldTrap() bool { } pkgPath := xgo_ctxt.GetPkgPath() - if pkgPath == "" || pkgPath == "runtime" || strings.HasPrefix(pkgPath, "runtime/") || strings.HasPrefix(pkgPath, "internal/") || isSkippableSpecialPkg() { + if pkgPath == "" || strings.HasPrefix(pkgPath, "runtime/") || strings.HasPrefix(pkgPath, "internal/") || isSkippableSpecialPkg() { // runtime/internal should not be rewritten // internal/api has problem with the function register return false diff --git a/patch/trap.go b/patch/trap.go index 3290da6d..a9d6e521 100644 --- a/patch/trap.go +++ b/patch/trap.go @@ -301,6 +301,11 @@ func CanInsertTrapOrLink(fn *ir.Func) (string, bool) { return linkName, false // ir.Dump("after:", fn) } + // disable all stdlib IR rewrite + if base.Flag.Std { + // NOTE: stdlib are rewritten by source + return "", false + } if xgo_ctxt.SkipPackageTrap() { return "", false } @@ -339,12 +344,6 @@ func CanInsertTrapOrLink(fn *ir.Func) (string, bool) { return "", false } - // disable part of stdlibs - if base.Flag.Std { - // NOTE: stdlib are rewritten by source - return "", false - } - // func marked nosplit will skip trap because // inserting traps when -gcflags=-N -l enabled // would cause stack overflow 792 bytes diff --git a/runtime/core/version.go b/runtime/core/version.go index b4139837..0334f3f2 100644 --- a/runtime/core/version.go +++ b/runtime/core/version.go @@ -7,8 +7,8 @@ import ( ) const VERSION = "1.0.15" -const REVISION = "05e21215595deccfc08d49cb2d502e0d48b3cf4b+1" -const NUMBER = 151 +const REVISION = "2861a46387df90bcadae7651dc6e0d2db8ab0148+1" +const NUMBER = 152 // these fields will be filled by compiler const XGO_VERSION = "" diff --git a/runtime/mock/stdlib.md b/runtime/mock/stdlib.md index f93edda1..29651a9e 100644 --- a/runtime/mock/stdlib.md +++ b/runtime/mock/stdlib.md @@ -9,9 +9,12 @@ So only a limited list of stdlib functions can be mocked. However, if there lack ## `os` - `Getenv` - `Getwd` +- `OpenFile` ## `time` - `Now` +- `Sleep` +- `NewTicker` - `Time.Format` ## `os/exec` @@ -33,4 +36,57 @@ So only a limited list of stdlib functions can be mocked. However, if there lack - `DialIP` - `DialUDP` - `DialUnix` -- `DialTimeout` \ No newline at end of file +- `DialTimeout` + + +# Examples +> Check [../test/mock_stdlib/mock_stdlib_test.go](../test/mock_stdlib/mock_stdlib_test.go) for more details. +```go +package mock_stdlib + +import ( + "context" + "testing" + "time" + + "github.com/xhd2015/xgo/runtime/core" + "github.com/xhd2015/xgo/runtime/mock" +) + +func TestMockTimeSleep(t *testing.T) { + begin := time.Now() + sleepDur := 1 * time.Second + var haveCalledMock bool + var mockArg time.Duration + mock.Mock(time.Sleep, func(ctx context.Context, fn *core.FuncInfo, args, results core.Object) error { + haveCalledMock = true + mockArg = args.GetFieldIndex(0).Value().(time.Duration) + return nil + }) + time.Sleep(sleepDur) + + // 37.275µs + cost := time.Since(begin) + + if !haveCalledMock { + t.Fatalf("expect haveCalledMock, actually not") + } + if mockArg != sleepDur { + t.Fatalf("expect mockArg to be %v, actual: %v", sleepDur, mockArg) + } + if cost > 100*time.Millisecond { + t.Fatalf("expect time.Sleep mocked, actual cost: %v", cost) + } +} +``` + +Run:`xgo test -v ./` +Output: +```sh +=== RUN TestMockTimeSleep +--- PASS: TestMockTimeSleep (0.00s) +PASS +ok github.com/xhd2015/xgo/runtime/test/mock_stdlib 0.725s +``` + +Note we call `time.Sleep` with `1s`, but it returns within few micro-seonds. \ No newline at end of file diff --git a/runtime/test/debug/debug_test.go b/runtime/test/debug/debug_test.go index 2941a03d..94aa7caf 100644 --- a/runtime/test/debug/debug_test.go +++ b/runtime/test/debug/debug_test.go @@ -7,37 +7,34 @@ package debug import ( "context" - "fmt" "testing" + "time" "github.com/xhd2015/xgo/runtime/core" - "github.com/xhd2015/xgo/runtime/trap" + "github.com/xhd2015/xgo/runtime/mock" ) -func TestNewGoroutineShouldInheritInterceptor(t *testing.T) { - trap.AddInterceptor(&trap.Interceptor{ - Pre: func(ctx context.Context, f *core.FuncInfo, args, result core.Object) (data interface{}, err error) { - if f.IdentityName == "greet" { - result.GetFieldIndex(0).Set("mock " + args.GetFieldIndex(0).Value().(string)) - return nil, trap.ErrAbort - } - return - }, +func TestMockTimeSleep(t *testing.T) { + begin := time.Now() + sleepDur := 1 * time.Second + var haveCalledMock bool + var mockArg time.Duration + mock.Mock(time.Sleep, func(ctx context.Context, fn *core.FuncInfo, args, results core.Object) error { + haveCalledMock = true + mockArg = args.GetFieldIndex(0).Value().(time.Duration) + return nil }) - - done := make(chan struct{}) - go func() { - s := greet("world") - if s != "mock world" { - panic(fmt.Errorf("expect greet returns %q, actual: %q", "mock world", s)) - } - - close(done) - }() - - <-done -} - -func greet(s string) string { - return "hello " + s + time.Sleep(sleepDur) + + cost := time.Since(begin) + + if !haveCalledMock { + t.Fatalf("expect haveCalledMock, actually not") + } + if mockArg != sleepDur { + t.Fatalf("expect mockArg to be %v, actual: %v", sleepDur, mockArg) + } + if cost > 100*time.Millisecond { + t.Fatalf("expect time.Sleep mocked, actual cost: %v", cost) + } } diff --git a/runtime/test/func_list/func_list_stdlib_test.go b/runtime/test/func_list/func_list_stdlib_test.go index 5dfe9d5b..0fadfde5 100644 --- a/runtime/test/func_list/func_list_stdlib_test.go +++ b/runtime/test/func_list/func_list_stdlib_test.go @@ -21,11 +21,14 @@ func TestListStdlib(t *testing.T) { stdPkgs := map[string]bool{ // os - "os.Getenv": true, - "os.Getwd": true, + "os.Getenv": true, + "os.Getwd": true, + "os.OpenFile": true, // time "time.Now": true, + "time.Sleep": true, + "time.NewTicker": true, "time.Time.Format": true, // exec diff --git a/runtime/test/mock_stdlib/mock_stdlib_test.go b/runtime/test/mock_stdlib/mock_stdlib_test.go index 61d560a6..38f7c441 100644 --- a/runtime/test/mock_stdlib/mock_stdlib_test.go +++ b/runtime/test/mock_stdlib/mock_stdlib_test.go @@ -10,10 +10,6 @@ import ( "github.com/xhd2015/xgo/runtime/mock" ) -// func TEST() { -// panic("debug") -// } - // go run ./cmd/xgo test --project-dir runtime -run TestMockTimeNow -v ./test/mock_stdlib // go run ./script/run-test/ --include go1.22.1 --xgo-runtime-test-only -run TestMockTimeNow -v ./test/mock_stdlib func TestMockTimeNow(t *testing.T) { @@ -54,3 +50,36 @@ func TestMockHTTP(t *testing.T) { t.Fatalf("expect http.DefaultClient.Do to have been mocked, actually not mocked") } } + +// execution log(NOTE the cost is not 1s): +// +// === RUN TestMockTimeSleep +// --- PASS: TestMockTimeSleep (0.00s) +// PASS +// ok github.com/xhd2015/xgo/runtime/test/mock_stdlib 0.725s +func TestMockTimeSleep(t *testing.T) { + begin := time.Now() + sleepDur := 1 * time.Second + var haveCalledMock bool + var mockArg time.Duration + mock.Mock(time.Sleep, func(ctx context.Context, fn *core.FuncInfo, args, results core.Object) error { + haveCalledMock = true + mockArg = args.GetFieldIndex(0).Value().(time.Duration) + return nil + }) + time.Sleep(sleepDur) + + // 37.275µs + cost := time.Since(begin) + + if !haveCalledMock { + t.Fatalf("expect haveCalledMock, actually not") + } + if mockArg != sleepDur { + t.Fatalf("expect mockArg to be %v, actual: %v", sleepDur, mockArg) + } + // t.Logf("cost: %v", cost) + if cost > 100*time.Millisecond { + t.Fatalf("expect time.Sleep mocked, actual cost: %v", cost) + } +} diff --git a/runtime/test/trap/trap_direct_test.go b/runtime/test/trap/trap_direct_test.go new file mode 100644 index 00000000..0db8d94b --- /dev/null +++ b/runtime/test/trap/trap_direct_test.go @@ -0,0 +1,43 @@ +package trap + +import ( + "context" + "testing" + + "github.com/xhd2015/xgo/runtime/core" + "github.com/xhd2015/xgo/runtime/trap" +) + +func TestDirectShouldByPassTrap(t *testing.T) { + trap.AddInterceptor(&trap.Interceptor{ + Pre: func(ctx context.Context, f *core.FuncInfo, args, result core.Object) (data interface{}, err error) { + if f.IdentityName == "direct" { + panic("direct should be bypassed") + } + if f.IdentityName == "nonByPass" { + result.GetFieldIndex(0).Set("mock nonByPass") + return nil, trap.ErrAbort + } + return nil, nil + }, + }) + var resDirect string + trap.Direct(func() { + resDirect = direct() + }) + if resDirect != "direct" { + t.Fatalf("expect direct() to be %q, actual: %q", "direct", resDirect) + } + + nonRes := nonByPass() + if nonRes != "mock nonByPass" { + t.Fatalf("expect nonByPass() to be %q, actual: %q", "mock nonByPass", nonRes) + } +} + +func direct() string { + return "direct" +} +func nonByPass() string { + return "nonByPass" +} diff --git a/runtime/trap/direct.go b/runtime/trap/direct.go new file mode 100644 index 00000000..9b44bdfc --- /dev/null +++ b/runtime/trap/direct.go @@ -0,0 +1,20 @@ +package trap + +import "sync" + +var bypassMapping sync.Map // -> struct{}{} + +// Direct make a call to fn, without +// any trap and mock interceptors +func Direct(fn func()) { + key := uintptr(__xgo_link_getcurg()) + bypassMapping.Store(key, struct{}{}) + defer bypassMapping.Delete(key) + fn() +} + +func isByPassing() bool { + key := uintptr(__xgo_link_getcurg()) + _, ok := bypassMapping.Load(key) + return ok +} diff --git a/runtime/trap/inspect.go b/runtime/trap/inspect.go index 5567158f..d6d3b815 100644 --- a/runtime/trap/inspect.go +++ b/runtime/trap/inspect.go @@ -12,7 +12,7 @@ import ( const methodSuffix = "-fm" -// Inspect make a call to f to capture its receiver pointer if is +// Inspect make a call to f to capture its receiver pointer if it // is bound method // It can be used to get the unwrapped innermost function of a method // wrapper. diff --git a/runtime/trap/interceptor.go b/runtime/trap/interceptor.go index c28a598a..65706281 100644 --- a/runtime/trap/interceptor.go +++ b/runtime/trap/interceptor.go @@ -155,6 +155,7 @@ type interceptorList struct { func clearLocalInterceptorsAndMark() { key := uintptr(__xgo_link_getcurg()) localInterceptors.Delete(key) + bypassMapping.Delete(key) clearTrappingMark() } diff --git a/runtime/trap/trap.go b/runtime/trap/trap.go index 2af55f9b..8d8d0a95 100644 --- a/runtime/trap/trap.go +++ b/runtime/trap/trap.go @@ -20,6 +20,9 @@ func ensureTrapInstall() { } func init() { __xgo_link_on_gonewproc(func(g uintptr) { + if isByPassing() { + return + } interceptors := GetLocalInterceptors() if len(interceptors) == 0 { return @@ -55,6 +58,9 @@ var trappingPC sync.Map // -> PC // link to runtime // xgo:notrap func trapImpl(pkgPath string, identityName string, generic bool, pc uintptr, recv interface{}, args []interface{}, results []interface{}) (func(), bool) { + if isByPassing() { + return nil, false + } dispose := setTrappingMark() if dispose == nil { return nil, false