diff --git a/README.md b/README.md index 1566ee54..6c8f9610 100644 --- a/README.md +++ b/README.md @@ -36,6 +36,8 @@ xgo version # 1.0.x ``` +If `xgo` is not found, you may need to check if `$GOPATH/bin` is added to your `PATH` variable. + There are other options,see [doc/INSTALLATION.md](./doc/INSTALLATION.md). # Requirement @@ -60,7 +62,6 @@ xgo version # output # 1.0.x ``` -If `xgo` is not found, you may need to add `~/.xgo/bin` to your `PATH` variable. 2. Init a go project: ```sh @@ -70,7 +71,7 @@ go mod init demo ``` 3. Add `demo_test.go` with following code: ```go -package demo +package demo_test import ( "context" @@ -83,6 +84,7 @@ import ( func MyFunc() string { return "my func" } + func TestFuncMock(t *testing.T) { mock.Mock(MyFunc, func(ctx context.Context, fn *core.FuncInfo, args core.Object, results core.Object) error { results.GetFieldIndex(0).Set("mock func") @@ -303,7 +305,7 @@ func TestMethodMock(t *testing.T){ } ``` -**Notice for mocking stdlib**: due to performance and security impact, only a few packages and functions of stdlib can be mocked, the list can be found at [runtime/mock/stdlib.md](./runtime/mock/stdlib.md). If you want to mock additional stdlib functions, please discussion in [Issue#6](https://github.com/xhd2015/xgo/issues/6). +**Notice for mocking stdlib**: due to performance and security impact, only a few packages and functions of stdlib can be mocked, the list can be found at [runtime/mock/stdlib.md](./runtime/mock/stdlib.md). If you want to mock additional stdlib functions, please file a discussion in [Issue#6](https://github.com/xhd2015/xgo/issues/6). ## Patch The `runtime/mock` package also provides another api: diff --git a/README_zh_cn.md b/README_zh_cn.md index 9e35bdd4..fbb74cd9 100644 --- a/README_zh_cn.md +++ b/README_zh_cn.md @@ -33,6 +33,7 @@ xgo version # 输出: # 1.0.x ``` +如果未找到`xgo`, 你可能需要查看`$GOPATH/bin`是否已经添加到你的`PATH`变量中。 更多安装方式, 参考[doc/INSTALLATION.md](./doc/INSTALLATION.md). @@ -57,7 +58,6 @@ xgo version # 输出 # 1.0.x ``` -如果未找到`xgo`, 你可能需要将`~/.xgo/bin`添加到你的环境变量中。 2. 创建一个demo工程: ```sh @@ -67,7 +67,7 @@ go mod init demo ``` 3. 将下面的内容添加到文件`demo_test.go`中: ```go -package demo +package demo_test import ( "context" @@ -80,6 +80,7 @@ import ( func MyFunc() string { return "my func" } + func TestFuncMock(t *testing.T) { mock.Mock(MyFunc, func(ctx context.Context, fn *core.FuncInfo, args core.Object, results core.Object) error { results.GetFieldIndex(0).Set("mock func") diff --git a/cmd/xgo/version.go b/cmd/xgo/version.go index 1b8d467a..29a880ca 100644 --- a/cmd/xgo/version.go +++ b/cmd/xgo/version.go @@ -3,8 +3,8 @@ package main import "fmt" const VERSION = "1.0.22" -const REVISION = "f34bf4ca38af8b5adcf36b67de5ea6fb853e4823+1" -const NUMBER = 174 +const REVISION = "f174b18f76bff4bd8acddffd03835b760dad03d8+1" +const NUMBER = 175 func getRevision() string { revSuffix := "" diff --git a/doc/INSTALLATION.md b/doc/INSTALLATION.md index e6f3b738..f841334e 100644 --- a/doc/INSTALLATION.md +++ b/doc/INSTALLATION.md @@ -1,5 +1,10 @@ +# Standard installation +```sh +go install github.com/xhd2015/xgo/cmd/xgo@latest +``` + # Install Prebuilt Binaries -Install prebuilt: +For environments like CI, xgo can be installed from pre-built binaries: ```sh # macOS and Linux (and WSL) curl -fsSL https://github.com/xhd2015/xgo/raw/master/install.sh | bash @@ -8,6 +13,8 @@ curl -fsSL https://github.com/xhd2015/xgo/raw/master/install.sh | bash powershell -c "irm github.com/xhd2015/xgo/raw/master/install.ps1|iex" ``` +After installation, `~/.xgo/bin/xgo` will be available. + # Upgrade if you've already installed If you've already installed `xgo`, you can upgrade it with: @@ -22,4 +29,7 @@ If you want to build from source, run with: git clone https://github.com/xhd2015/xgo cd xgo go run ./script/build-release --local + +# check build version +~/.xgo/bin/xgo version ``` \ No newline at end of file diff --git a/patch/README.md b/patch/README.md new file mode 100644 index 00000000..0eb295a0 --- /dev/null +++ b/patch/README.md @@ -0,0 +1,8 @@ +# patch +This directory contains code to patch the go runtime and compiler: +- [./](./) + - patch the compiler IR +- [syntax](syntax) + - patch the compiler AST +- [trap_runtime](trap_runtime) + - patch go runtime \ No newline at end of file diff --git a/patch/ir.go b/patch/ir.go index d62256fa..72ef5ad2 100644 --- a/patch/ir.go +++ b/patch/ir.go @@ -143,8 +143,11 @@ func convToEFace(pos src.XPos, x ir.Node, t *types.Type, ptr bool) *ir.ConvExpr func isFirstStmtSkipTrap(nodes ir.Nodes) bool { // NOTE: for performance reason, only check the first - if len(nodes) > 0 && isCallTo(nodes[0], xgoRuntimeTrapPkg, "Skip") { - return true + if len(nodes) > 0 { + firstNode := nodes[0] + if isCallTo(firstNode, xgoRuntimeTrapPkg, "Skip") || isCallToName(firstNode, "__xgo_trap_skip") { + return true + } } if false { for _, node := range nodes { @@ -157,6 +160,12 @@ func isFirstStmtSkipTrap(nodes ir.Nodes) bool { } func isCallTo(node ir.Node, pkgPath string, name string) bool { + return checkCall(node, pkgPath, name, true) +} +func isCallToName(node ir.Node, name string) bool { + return checkCall(node, "", name, false) +} +func checkCall(node ir.Node, pkgPath string, name string, checkPkg bool) bool { callNode, ok := node.(*ir.CallExpr) if !ok { return false @@ -169,7 +178,16 @@ func isCallTo(node ir.Node, pkgPath string, name string) bool { if sym == nil { return false } - return sym.Pkg != nil && sym.Name == name && sym.Pkg.Path == pkgPath + if sym.Name != name { + return false + } + if !checkPkg { + return true + } + if sym.Pkg != nil && sym.Pkg.Path == pkgPath { + return true + } + return false } func newNilInterface(pos src.XPos) ir.Expr { diff --git a/patch/syntax/helper_code.go b/patch/syntax/helper_code.go index d0e39d8b..42070170 100644 --- a/patch/syntax/helper_code.go +++ b/patch/syntax/helper_code.go @@ -33,9 +33,16 @@ type __xgo_local_func_stub struct { Line int } -func init() { +// ensure early init +var _ = func() bool { + __xgo_trap_skip() __xgo_link_generate_init_regs_body() -} + return true +}() + +// func init() { +// __xgo_link_generate_init_regs_body() +// } // TODO: ensure safety for this func __xgo_link_generate_init_regs_body() { @@ -63,6 +70,8 @@ func __xgo_local_register_func(pkgPath string, identityName string, fn interface __xgo_link_generated_register_func(__xgo_local_func_stub{PkgPath: pkgPath, IdentityName: identityName, Fn: fn, Closure: closure, RecvName: recvName, ArgNames: argNames, ResNames: resNames, File: file, Line: line}) } +func __xgo_trap_skip() {} + // not used // func __xgo_local_register_interface(pkgPath string, interfaceName string, file string, line int) { // __xgo_link_generated_register_func(__xgo_local_func_stub{PkgPath: pkgPath, Interface: true, File: file, Line: line}) diff --git a/patch/syntax/helper_code_gen.go b/patch/syntax/helper_code_gen.go index 04161a4e..3fdd7430 100755 --- a/patch/syntax/helper_code_gen.go +++ b/patch/syntax/helper_code_gen.go @@ -62,9 +62,16 @@ type __xgo_local_func_stub struct { Line int } -func init() { +// ensure early init +var _ = func() bool { + __xgo_trap_skip() __xgo_link_generate_init_regs_body() -} + return true +}() + +// func init() { +// __xgo_link_generate_init_regs_body() +// } // TODO: ensure safety for this func __xgo_link_generate_init_regs_body() { @@ -92,6 +99,8 @@ func __xgo_local_register_func(pkgPath string, identityName string, fn interface __xgo_link_generated_register_func(__xgo_local_func_stub{PkgPath: pkgPath, IdentityName: identityName, Fn: fn, Closure: closure, RecvName: recvName, ArgNames: argNames, ResNames: resNames, File: file, Line: line}) } +func __xgo_trap_skip() {} + // not used // func __xgo_local_register_interface(pkgPath string, interfaceName string, file string, line int) { // __xgo_link_generated_register_func(__xgo_local_func_stub{PkgPath: pkgPath, Interface: true, File: file, Line: line}) diff --git a/patch/trap_runtime/xgo_trap.go b/patch/trap_runtime/xgo_trap.go index 45c27a98..05d912a1 100644 --- a/patch/trap_runtime/xgo_trap.go +++ b/patch/trap_runtime/xgo_trap.go @@ -78,20 +78,29 @@ func __xgo_set_trap_var(trap func(pkgPath string, name string, tmpVarAddr interf // NOTE: runtime has problem when using slice var __xgo_registered_func_infos []interface{} +var __xgo_register_func_callback func(info interface{}) // a function cannot have too many params, so use a struct to wrap them // the client should use reflect to retrieve these fields respectively func __xgo_register_func(info interface{}) { + if __xgo_register_func_callback != nil { + __xgo_register_func_callback(info) + return + } __xgo_registered_func_infos = append(__xgo_registered_func_infos, info) } func __xgo_retrieve_all_funcs_and_clear(f func(info interface{})) { - for _, fn := range __xgo_registered_func_infos { + if __xgo_register_func_callback != nil { + panic("__xgo_register_func_callback already set") + } + __xgo_register_func_callback = f + funcInfos := __xgo_registered_func_infos + __xgo_registered_func_infos = nil // clear + for _, fn := range funcInfos { f(fn) } - // clear - __xgo_registered_func_infos = nil } var __xgo_is_init_finished bool diff --git a/runtime/core/version.go b/runtime/core/version.go index f8c9b730..1091b43d 100644 --- a/runtime/core/version.go +++ b/runtime/core/version.go @@ -7,8 +7,8 @@ import ( ) const VERSION = "1.0.22" -const REVISION = "f34bf4ca38af8b5adcf36b67de5ea6fb853e4823+1" -const NUMBER = 174 +const REVISION = "f174b18f76bff4bd8acddffd03835b760dad03d8+1" +const NUMBER = 175 // these fields will be filled by compiler const XGO_VERSION = "" diff --git a/runtime/functab/functab.go b/runtime/functab/functab.go index 43b544d2..835ed7b5 100644 --- a/runtime/functab/functab.go +++ b/runtime/functab/functab.go @@ -12,9 +12,25 @@ import ( "github.com/xhd2015/xgo/runtime/core" ) +// all func infos +var funcInfos []*core.FuncInfo +var funcInfoMapping map[string]map[string]*core.FuncInfo // pkg -> identifyName -> FuncInfo +var funcPCMapping map[uintptr]*core.FuncInfo // pc->FuncInfo +var varAddrMapping map[uintptr]*core.FuncInfo // addr->FuncInfo +var funcFullNameMapping map[string]*core.FuncInfo // fullName -> FuncInfo +var interfaceMapping map[string]map[string]*core.FuncInfo // pkg -> interfaceName -> FuncInfo +var typeMethodMapping map[reflect.Type]map[string]*core.FuncInfo // reflect.Type -> interfaceName -> FuncInfo + func init() { - __xgo_link_on_init_finished(ensureMapping) - __xgo_link_on_init_finished(ensureTypeMapping) + funcPCMapping = make(map[uintptr]*core.FuncInfo) + funcInfoMapping = make(map[string]map[string]*core.FuncInfo) + funcFullNameMapping = make(map[string]*core.FuncInfo) + interfaceMapping = make(map[string]map[string]*core.FuncInfo) + varAddrMapping = make(map[uintptr]*core.FuncInfo) + + // this will consume all staged func infos in runtime, + // and set registerFuncInfo for later registering + __xgo_link_retrieve_all_funcs_and_clear(registerFuncInfo) } // rewrite at compile time by compiler, the body will be replaced with @@ -24,28 +40,11 @@ func __xgo_link_retrieve_all_funcs_and_clear(f func(fn interface{})) { fmt.Fprintln(os.Stderr, "WARNING: failed to link __xgo_link_retrieve_all_funcs_and_clear(requires xgo).") } -func __xgo_link_on_init_finished(f func()) { - fmt.Fprintln(os.Stderr, "WARNING: failed to link __xgo_link_on_init_finished(requires xgo).") -} - -func __xgo_link_init_finished() bool { - fmt.Fprintln(os.Stderr, "WARNING: failed to link __xgo_link_init_finished(requires xgo).") - return false -} - func __xgo_link_get_pc_name(pc uintptr) string { fmt.Fprintf(os.Stderr, "WARNING: failed to link __xgo_link_get_pc_name(requires xgo).\n") return "" } -var funcInfos []*core.FuncInfo -var funcInfoMapping map[string]map[string]*core.FuncInfo // pkg -> identifyName -> FuncInfo -var funcPCMapping map[uintptr]*core.FuncInfo // pc->FuncInfo -var varAddrMapping map[uintptr]*core.FuncInfo // addr->FuncInfo -var funcFullNameMapping map[string]*core.FuncInfo // fullName -> FuncInfo -var interfaceMapping map[string]map[string]*core.FuncInfo // pkg -> interfaceName -> FuncInfo -var typeMethodMapping map[reflect.Type]map[string]*core.FuncInfo // reflect.Type -> interfaceName -> FuncInfo - func GetFuncs() []*core.FuncInfo { return funcInfos } @@ -59,6 +58,7 @@ func InfoFunc(fn interface{}) *core.FuncInfo { pc := v.Pointer() return funcPCMapping[pc] } + func InfoVar(addr interface{}) *core.FuncInfo { v := reflect.ValueOf(addr) if v.Kind() != reflect.Ptr { @@ -75,7 +75,6 @@ func InfoPC(pc uintptr) *core.FuncInfo { // maybe rename to FuncForGeneric func Info(pkg string, identityName string) *core.FuncInfo { - ensureMapping() return funcInfoMapping[pkg][identityName] } @@ -85,7 +84,6 @@ func Info(pkg string, identityName string) *core.FuncInfo { // pkg.Recv.Func // pkg.(*Recv).Func func GetFuncByPkg(pkg string, name string) *core.FuncInfo { - ensureMapping() pkgMapping := funcInfoMapping[pkg] if pkgMapping == nil { return nil @@ -119,7 +117,7 @@ func GetFuncByFullName(fullName string) *core.FuncInfo { } func GetTypeMethods(typ reflect.Type) map[string]*core.FuncInfo { - return typeMethodMapping[typ] + return getTypeMethodMapping()[typ] } func getInterfaceOrGenericByFullName(fullName string) *core.FuncInfo { @@ -142,169 +140,168 @@ func getInterfaceOrGenericByFullName(fullName string) *core.FuncInfo { return nil } -var mappingOnce sync.Once - var errType = reflect.TypeOf((*error)(nil)).Elem() var ctxType = reflect.TypeOf((*context.Context)(nil)).Elem() -func ensureMapping() { - mappingOnce.Do(func() { - funcPCMapping = make(map[uintptr]*core.FuncInfo) - funcInfoMapping = make(map[string]map[string]*core.FuncInfo) - funcFullNameMapping = make(map[string]*core.FuncInfo) - interfaceMapping = make(map[string]map[string]*core.FuncInfo) - varAddrMapping = make(map[uintptr]*core.FuncInfo) - __xgo_link_retrieve_all_funcs_and_clear(func(fnInfo interface{}) { - rv := reflect.ValueOf(fnInfo) - if rv.Kind() != reflect.Struct { - panic(fmt.Errorf("expect struct, actual: %s", rv.Kind().String())) - } - closure := rv.FieldByName("Closure").Bool() - identityName := rv.FieldByName("IdentityName").String() - if identityName == "" { - if !closure { - return - } - // fmt.Fprintf(os.Stderr, "empty name\n",pkgPath) - } - var fnKind core.Kind - fnKindV := rv.FieldByName("Kind") - if fnKindV.IsValid() { - fnKind = core.Kind(fnKindV.Int()) - } - varField := rv.FieldByName("Var") - pkgPath := rv.FieldByName("PkgPath").String() - recvTypeName := rv.FieldByName("RecvTypeName").String() - recvPtr := rv.FieldByName("RecvPtr").Bool() - name := rv.FieldByName("Name").String() - interface_ := rv.FieldByName("Interface").Bool() - generic := rv.FieldByName("Generic").Bool() - f := rv.FieldByName("Fn").Interface() - - var firstArgCtx bool - var lastResErr bool - var pc uintptr - var fullName string - if !generic && !interface_ { - if f != nil { - // TODO: move all ctx, err check logic here - ft := reflect.TypeOf(f) - off := 0 - if recvTypeName != "" { - off = 1 - } - if ft.NumIn() > off && ft.In(off).Implements(ctxType) { - firstArgCtx = true - } - // NOTE: use == instead of implements - if ft.NumOut() > 0 && ft.Out(ft.NumOut()-1) == errType { - lastResErr = true - } - pc = getFuncPC(f) - fullName = __xgo_link_get_pc_name(pc) - } else { - if (closure || fnKind == core.Kind_Var || fnKind == core.Kind_VarPtr || fnKind == core.Kind_Const) && identityName != "" { - fullName = pkgPath + "." + identityName - } - } - } - recvName := rv.FieldByName("RecvName").String() - argNames := rv.FieldByName("ArgNames").Interface().([]string) - resNames := rv.FieldByName("ResNames").Interface().([]string) - file := rv.FieldByName("File").String() - line := int(rv.FieldByName("Line").Int()) - - // debug - // fmt.Printf("reg: %s\n", fullName) - // if pkgPath == "main" { - // fmt.Fprintf(os.Stderr, "reg: funcName=%s,pc=%x,generic=%v,genericname=%s\n", funcName, pc, generic, genericName) - // } - // _, recvTypeName, recvPtr, name := core.ParseFuncName(identityName, false) - info := &core.FuncInfo{ - Kind: fnKind, - FullName: fullName, - Pkg: pkgPath, - IdentityName: identityName, - Name: name, - RecvType: recvTypeName, - RecvPtr: recvPtr, - - Interface: interface_, - Generic: generic, - Closure: closure, - - File: file, - Line: line, - - // runtime info - PC: pc, // nil for generic - Func: f, // nil for geneirc - - RecvName: recvName, - ArgNames: argNames, - ResNames: resNames, - - // brief info - FirstArgCtx: firstArgCtx, - LastResultErr: lastResErr, - } - if varField.IsValid() { - info.Var = varField.Interface() - } - funcInfos = append(funcInfos, info) - if !generic && info.PC != 0 { - funcPCMapping[info.PC] = info - } - if identityName != "" { - pkgMapping := funcInfoMapping[pkgPath] - if pkgMapping == nil { - pkgMapping = make(map[string]*core.FuncInfo, 1) - funcInfoMapping[pkgPath] = pkgMapping - } - pkgMapping[identityName] = info +func registerFuncInfo(fnInfo interface{}) { + rv := reflect.ValueOf(fnInfo) + if rv.Kind() != reflect.Struct { + panic(fmt.Errorf("expect struct, actual: %s", rv.Kind().String())) + } + closure := rv.FieldByName("Closure").Bool() + identityName := rv.FieldByName("IdentityName").String() + if identityName == "" { + if !closure { + return + } + // fmt.Fprintf(os.Stderr, "empty name\n",pkgPath) + } + var fnKind core.Kind + fnKindV := rv.FieldByName("Kind") + if fnKindV.IsValid() { + fnKind = core.Kind(fnKindV.Int()) + } + varField := rv.FieldByName("Var") + pkgPath := rv.FieldByName("PkgPath").String() + recvTypeName := rv.FieldByName("RecvTypeName").String() + recvPtr := rv.FieldByName("RecvPtr").Bool() + name := rv.FieldByName("Name").String() + interface_ := rv.FieldByName("Interface").Bool() + generic := rv.FieldByName("Generic").Bool() + f := rv.FieldByName("Fn").Interface() + + var firstArgCtx bool + var lastResErr bool + var pc uintptr + var fullName string + if !generic && !interface_ { + if f != nil { + // TODO: move all ctx, err check logic here + ft := reflect.TypeOf(f) + off := 0 + if recvTypeName != "" { + off = 1 } - if interface_ && recvTypeName != "" { - pkgMapping := interfaceMapping[pkgPath] - if pkgMapping == nil { - pkgMapping = make(map[string]*core.FuncInfo, 1) - interfaceMapping[pkgPath] = pkgMapping - } - pkgMapping[recvTypeName] = info + if ft.NumIn() > off && ft.In(off).Implements(ctxType) { + firstArgCtx = true } - if fnKind == core.Kind_Var { - if varField.IsValid() { - varAddr := varField.Elem().Pointer() - varAddrMapping[varAddr] = info - } + // NOTE: use == instead of implements + if ft.NumOut() > 0 && ft.Out(ft.NumOut()-1) == errType { + lastResErr = true } - if fullName != "" { - funcFullNameMapping[fullName] = info + pc = getFuncPC(f) + fullName = __xgo_link_get_pc_name(pc) + } else { + if (closure || fnKind == core.Kind_Var || fnKind == core.Kind_VarPtr || fnKind == core.Kind_Const) && identityName != "" { + fullName = pkgPath + "." + identityName } - }) - }) + } + } + recvName := rv.FieldByName("RecvName").String() + argNames := rv.FieldByName("ArgNames").Interface().([]string) + resNames := rv.FieldByName("ResNames").Interface().([]string) + file := rv.FieldByName("File").String() + line := int(rv.FieldByName("Line").Int()) + + // debug + // fmt.Printf("reg: %s\n", fullName) + // if pkgPath == "main" { + // fmt.Fprintf(os.Stderr, "reg: funcName=%s,pc=%x,generic=%v,genericname=%s\n", funcName, pc, generic, genericName) + // } + // _, recvTypeName, recvPtr, name := core.ParseFuncName(identityName, false) + info := &core.FuncInfo{ + Kind: fnKind, + FullName: fullName, + Pkg: pkgPath, + IdentityName: identityName, + Name: name, + RecvType: recvTypeName, + RecvPtr: recvPtr, + + Interface: interface_, + Generic: generic, + Closure: closure, + + File: file, + Line: line, + + // runtime info + PC: pc, // nil for generic + Func: f, // nil for geneirc + + RecvName: recvName, + ArgNames: argNames, + ResNames: resNames, + + // brief info + FirstArgCtx: firstArgCtx, + LastResultErr: lastResErr, + } + if varField.IsValid() { + info.Var = varField.Interface() + } + funcInfos = append(funcInfos, info) + if !generic && info.PC != 0 { + funcPCMapping[info.PC] = info + } + if identityName != "" { + pkgMapping := funcInfoMapping[pkgPath] + if pkgMapping == nil { + pkgMapping = make(map[string]*core.FuncInfo, 1) + funcInfoMapping[pkgPath] = pkgMapping + } + pkgMapping[identityName] = info + } + if interface_ && recvTypeName != "" { + pkgMapping := interfaceMapping[pkgPath] + if pkgMapping == nil { + pkgMapping = make(map[string]*core.FuncInfo, 1) + interfaceMapping[pkgPath] = pkgMapping + } + pkgMapping[recvTypeName] = info + } + if fnKind == core.Kind_Var { + if varField.IsValid() { + varAddr := varField.Elem().Pointer() + varAddrMapping[varAddr] = info + } + } + if fullName != "" { + funcFullNameMapping[fullName] = info + } } var mappingTypeOnce sync.Once -func ensureTypeMapping() { - mappingTypeOnce.Do(func() { - typeMethodMapping = make(map[reflect.Type]map[string]*core.FuncInfo) - for _, funcInfo := range funcInfos { - if funcInfo.Generic || funcInfo.Interface || funcInfo.RecvType == "" { - continue - } - if funcInfo.Func == nil || funcInfo.Name == "" { - continue - } - recvType := reflect.TypeOf(funcInfo.Func).In(0) - methodMapping := typeMethodMapping[recvType] - if methodMapping == nil { - methodMapping = make(map[string]*core.FuncInfo, 1) - typeMethodMapping[recvType] = methodMapping - } - methodMapping[funcInfo.Name] = funcInfo - } - }) +func getTypeMethodMapping() map[reflect.Type]map[string]*core.FuncInfo { + mappingTypeOnce.Do(initTypeMethodMapping) + return typeMethodMapping +} + +func initTypeMethodMapping() { + typeMethodMapping = make(map[reflect.Type]map[string]*core.FuncInfo) + for _, funcInfo := range funcInfos { + registerTypeMethod(funcInfo) + } +} + +func registerTypeMethod(funcInfo *core.FuncInfo) { + if funcInfo.Kind != core.Kind_Func { + return + } + if funcInfo.Generic || funcInfo.Interface || funcInfo.RecvType == "" { + return + } + if funcInfo.Func == nil || funcInfo.Name == "" { + return + } + recvType := reflect.TypeOf(funcInfo.Func).In(0) + methodMapping := typeMethodMapping[recvType] + if methodMapping == nil { + methodMapping = make(map[string]*core.FuncInfo, 1) + typeMethodMapping[recvType] = methodMapping + } + methodMapping[funcInfo.Name] = funcInfo } func getFuncPC(fn interface{}) uintptr { diff --git a/runtime/mock/mock.go b/runtime/mock/mock.go index efc4d2e2..902c0b28 100644 --- a/runtime/mock/mock.go +++ b/runtime/mock/mock.go @@ -93,7 +93,7 @@ func getMethodByName(instance interface{}, method string) (recvPtr interface{}, // - if mockRecvPtr has a value, then only call to that instance will be mocked // - if mockRecvPtr is nil, then all call to the function will be mocked func mock(mockRecvPtr interface{}, mockFnInfo *core.FuncInfo, funcPC uintptr, trappingPC uintptr, interceptor Interceptor) func() { - return trap.AddInterceptor(&trap.Interceptor{ + return trap.AddFuncInfoInterceptor(mockFnInfo, &trap.Interceptor{ Pre: func(ctx context.Context, f *core.FuncInfo, args, result core.Object) (data interface{}, err error) { if f.Kind == core.Kind_Func && f.PC == 0 { if !f.Generic { diff --git a/runtime/mock/patch.go b/runtime/mock/patch.go index 1a376f78..8e1623ed 100644 --- a/runtime/mock/patch.go +++ b/runtime/mock/patch.go @@ -8,6 +8,7 @@ import ( "github.com/xhd2015/xgo/runtime/core" "github.com/xhd2015/xgo/runtime/functab" + "github.com/xhd2015/xgo/runtime/trap" ) // Patch replaces `fn` with `replacer` in current goroutine, @@ -142,6 +143,10 @@ func buildInterceptorFromPatch(recvPtr interface{}, replacer interface{}) func(c } nIn := t.NumIn() + // replacer is usually a closure, + // we can bypass it + trap.Ignore(replacer) + // first arg ctx: true => [recv,args[1:]...] // first arg ctx: false => [recv, args[0:]...] return func(ctx context.Context, fn *core.FuncInfo, args, results core.Object) error { diff --git a/runtime/test/debug/debug_test.go b/runtime/test/debug/debug_test.go index 365b2b3b..c4aad11b 100644 --- a/runtime/test/debug/debug_test.go +++ b/runtime/test/debug/debug_test.go @@ -6,18 +6,35 @@ package debug import ( - "fmt" + "context" + "os" "testing" + "time" + "github.com/xhd2015/xgo/runtime/core" "github.com/xhd2015/xgo/runtime/trap" ) -func ToString[T any](v T) string { - return fmt.Sprint(v) +func TestTimeNowNestedLevel2AllowNested(t *testing.T) { + i := 0 + trap.AddInterceptor(&trap.Interceptor{ + Pre: func(ctx context.Context, f *core.FuncInfo, args, result core.Object) (data interface{}, err error) { + // t.Logf("%s.%s", f.Pkg, f.IdentityName) + i++ + if i > 20 { + os.Exit(1) + } + getTime2() + return + }, + }) + time.Now() } -func TestNakedTrapShouldAvoidRecursive(t *testing.T) { - trap.InspectPC(ToString[int]) - // _, fnInfo, funcPC, trappingPC := trap.InspectPC(ToString[int]) - // _, fnInfoStr, funcPCStr, trappingPCStr := trap.InspectPC(ToString[string]) +func getTime2() time.Time { + return getTime3() +} + +func getTime3() time.Time { + return time.Now() } diff --git a/runtime/test/mock_func/mock_func_test.go b/runtime/test/mock_func/mock_func_test.go index 328f81d5..9a47f570 100644 --- a/runtime/test/mock_func/mock_func_test.go +++ b/runtime/test/mock_func/mock_func_test.go @@ -64,3 +64,32 @@ func TestMockFuncErr(t *testing.T) { t.Fatalf("expect mocked neverErr() to be %v, actual: %v", mockErr, err) } } + +func TestNestedMock(t *testing.T) { + // before mock + beforeMock := A() + beforeMockExpect := "A B" + if beforeMock != beforeMockExpect { + t.Fatalf("expect before mock: %q, actual: %q", beforeMockExpect, beforeMock) + } + mock.Mock(B, func(ctx context.Context, fn *core.FuncInfo, args, results core.Object) error { + results.GetFieldIndex(0).Set("b") + return nil + }) + mock.Mock(A, func(ctx context.Context, fn *core.FuncInfo, args, results core.Object) error { + results.GetFieldIndex(0).Set("a " + B()) + return nil + }) + afterMock := A() + afterMockExpect := "a b" + if afterMock != afterMockExpect { + t.Fatalf("expect after mock: %q, actual: %q", afterMockExpect, afterMock) + } +} + +func A() string { + return "A " + B() +} +func B() string { + return "B" +} diff --git a/runtime/test/stack_trace/update_test.go b/runtime/test/stack_trace/update_test.go index 5328caf0..8b75e2e5 100644 --- a/runtime/test/stack_trace/update_test.go +++ b/runtime/test/stack_trace/update_test.go @@ -7,6 +7,9 @@ import ( ) func init() { + if true { + panic("should not be run") + } trace.Enable() } diff --git a/runtime/test/trace_arg/main.go b/runtime/test/trace_arg/main.go deleted file mode 100644 index 92d30e0a..00000000 --- a/runtime/test/trace_arg/main.go +++ /dev/null @@ -1,34 +0,0 @@ -package main - -import ( - "github.com/xhd2015/xgo/runtime/trace" -) - -func init() { - trace.Enable() -} - -func main() { - ReadAtLeast("hello ") - ReadAtLeast("trap") -} - -func ReadAtLeast(a string) { - print(a) -} - -func ReadAtLeast_trap(a string) { - after, stop := __xgo_trap(nil, []interface{}{&a}, []interface{}{}) - if stop { - } else { - if after != nil { - defer after() - } - print(a) - } -} - -//go:noinline -func __xgo_trap(recv interface{}, args []interface{}, results []interface{}) (func(), bool) { - return nil, false -} diff --git a/runtime/test/trap/trap_avoid_recursive_test.go b/runtime/test/trap/trap_avoid_recursive_test.go index f2e4a13d..e7240d77 100644 --- a/runtime/test/trap/trap_avoid_recursive_test.go +++ b/runtime/test/trap/trap_avoid_recursive_test.go @@ -13,7 +13,7 @@ import ( // prints: pre->call_f->post // no repeation -func TestNakedTrapShouldAvoidRecursive(t *testing.T) { +func TestNakedTrapShouldAvoidRecursiveInterceptor(t *testing.T) { var recurseBuf bytes.Buffer trap.AddInterceptor(&trap.Interceptor{ Pre: func(ctx context.Context, f *core.FuncInfo, args, result core.Object) (data interface{}, err error) { diff --git a/runtime/test/trap/trap_nested_interceptor_test.go b/runtime/test/trap/trap_nested_interceptor_test.go new file mode 100644 index 00000000..8725831c --- /dev/null +++ b/runtime/test/trap/trap_nested_interceptor_test.go @@ -0,0 +1,156 @@ +package trap + +import ( + "bytes" + "context" + "fmt" + "testing" + "time" + + "github.com/xhd2015/xgo/runtime/core" + "github.com/xhd2015/xgo/runtime/functab" + "github.com/xhd2015/xgo/runtime/trap" +) + +func TestNestedTrapShouldBeAllowedBySpecifyingMapping(t *testing.T) { + var traceRecords bytes.Buffer + + // list the names that can be nested + trap.AddInterceptor(&trap.Interceptor{ + Pre: func(ctx context.Context, f *core.FuncInfo, args, result core.Object) (data interface{}, err error) { + traceRecords.WriteString(fmt.Sprintf("call %s\n", f.IdentityName)) + if f.IdentityName == "A0" { + // call A1 inside the interceptor + result.GetFieldIndex(0).Set(A1()) + return nil, trap.ErrAbort + } + return nil, nil + }, + }) + trap.AddFuncInterceptor(A1, &trap.Interceptor{ + Pre: func(ctx context.Context, f *core.FuncInfo, args, result core.Object) (data interface{}, err error) { + traceRecords.WriteString(fmt.Sprintf("call %s\n", f.IdentityName)) + return + }, + }) + trap.AddFuncInterceptor(A2, &trap.Interceptor{ + Pre: func(ctx context.Context, f *core.FuncInfo, args, result core.Object) (data interface{}, err error) { + traceRecords.WriteString(fmt.Sprintf("call %s\n", f.IdentityName)) + return + }, + }) + + result := A0() + + trace := traceRecords.String() + expect := "call A0\ncall A1\ncall A2\n" + if trace != expect { + t.Fatalf("expect trace: %q, actual: %q", expect, trace) + } + expectResult := "A1 A2" + if result != expectResult { + t.Fatalf("expect result: %q, actual: %q", expectResult, result) + } +} +func A0() string { + return "A0" +} + +func A1() string { + return "A1 " + A2() +} + +func A2() string { + return "A2" +} + +func TestNestedTrapPartialAllowShouldTakeEffect(t *testing.T) { + var traceRecords bytes.Buffer + trap.AddInterceptor(&trap.Interceptor{ + Pre: func(ctx context.Context, f *core.FuncInfo, args, result core.Object) (data interface{}, err error) { + traceRecords.WriteString(fmt.Sprintf("call %s\n", f.IdentityName)) + if f.IdentityName == "B0" { + // call A1 inside the interceptor + result.GetFieldIndex(0).Set(B1()) + return nil, trap.ErrAbort + } + return nil, nil + }, + }) + + trap.AddFuncInterceptor(B2, &trap.Interceptor{ + Pre: func(ctx context.Context, f *core.FuncInfo, args, result core.Object) (data interface{}, err error) { + traceRecords.WriteString(fmt.Sprintf("call %s\n", f.IdentityName)) + return + }, + }) + result := B0() + + trace := traceRecords.String() + expect := "call B0\ncall B2\n" + if trace != expect { + t.Fatalf("expect trace: %q, actual: %q", expect, trace) + } + expectResult := "B1 B2" + if result != expectResult { + t.Fatalf("expect result: %q, actual: %q", expectResult, result) + } +} +func B0() string { + return "B0" +} + +func B1() string { + return "B1 " + B2() +} + +// B2 is ignored +func B2() string { + return "B2" +} + +func TestTimeNowNestedLevel1Normal(t *testing.T) { + var r time.Time + trap.AddInterceptor(&trap.Interceptor{ + Pre: func(ctx context.Context, f *core.FuncInfo, args, result core.Object) (data interface{}, err error) { + if r.IsZero() && f == functab.InfoFunc(time.Now) { + r = time.Now() + } + return + }, + }) + now := time.Now() + + diff := now.Sub(r) + if diff < 0 { + t.Fatalf("pre should happen before call, diff: %v", diff) + } + if diff > 1*time.Millisecond { + t.Fatalf("interval too large:%v", diff) + } +} + +func TestTimeNowNestedLevel2Normal(t *testing.T) { + var r time.Time + trap.AddInterceptor(&trap.Interceptor{ + Pre: func(ctx context.Context, f *core.FuncInfo, args, result core.Object) (data interface{}, err error) { + if r.IsZero() && f == functab.InfoFunc(time.Now) { + r = getTime() + } + return + }, + }) + now := time.Now() + + diff := now.Sub(r) + if diff < 0 { + t.Fatalf("pre should happen before call, diff: %v", diff) + } + if diff > 1*time.Millisecond { + t.Fatalf("interval too large:%v", diff) + } +} + +func getTime() time.Time { + return time.Now() +} diff --git a/runtime/test/trap/trap_test.go b/runtime/test/trap/trap_test.go index 1dd0f595..86c3727c 100644 --- a/runtime/test/trap/trap_test.go +++ b/runtime/test/trap/trap_test.go @@ -2,7 +2,6 @@ package trap import ( "context" - "fmt" "testing" "github.com/xhd2015/xgo/runtime/core" @@ -52,10 +51,8 @@ func run() { func A(ctx context.Context) { hasCalledA = true - fmt.Printf("A\n") } func B() { hasCalledB = true - fmt.Printf("B\n") } diff --git a/runtime/trace/marshal.go b/runtime/trace/marshal.go index bc699493..fe8f82d0 100644 --- a/runtime/trace/marshal.go +++ b/runtime/trace/marshal.go @@ -16,27 +16,27 @@ import ( // to JSON, when it encounters unmarshalable values // like func, chan, it will bypass these values. func MarshalAnyJSON(v interface{}) ([]byte, error) { - fn := functab.Info("encoding/json", "newTypeEncoder") - if fn == nil { + funcInfo := functab.Info("encoding/json", "newTypeEncoder") + if funcInfo == nil { fmt.Fprintf(os.Stderr, "WARNING: encoding/json.newTypeEncoder not trapped(requires xgo).\n") return json.Marshal(v) } // get the unmarshalable function - unmarshalable := getMarshaler(fn.Func, reflect.TypeOf(unmarshalableFunc)) + unmarshalable := getMarshaler(funcInfo.Func, reflect.TypeOf(unmarshalableFunc)) var data []byte var err error // mock the encoding json - trap.WithOverride(&trap.Interceptor{ + trap.WithFuncOverride(funcInfo, &trap.Interceptor{ Post: func(ctx context.Context, f *core.FuncInfo, args, result core.Object, data interface{}) error { - if f != fn { + if f != funcInfo { return nil } resField := result.GetFieldIndex(0) // if unmarshalable, replace with an empty struct if reflect.ValueOf(resField.Value()).Pointer() == reflect.ValueOf(unmarshalable).Pointer() { - resField.Set(getMarshaler(fn.Func, reflect.TypeOf(struct{}{}))) + resField.Set(getMarshaler(funcInfo.Func, reflect.TypeOf(struct{}{}))) } return nil }, diff --git a/runtime/trace/trace.go b/runtime/trace/trace.go index 4b69464a..e03bd5b2 100644 --- a/runtime/trace/trace.go +++ b/runtime/trace/trace.go @@ -131,6 +131,10 @@ func setupInterceptor() { // collect trace trap.AddInterceptorHead(&trap.Interceptor{ Pre: func(ctx context.Context, f *core.FuncInfo, args core.Object, results core.Object) (interface{}, error) { + if !__xgo_link_init_finished() { + // do not collect trace while init + return nil, trap.ErrSkip + } key := uintptr(__xgo_link_getcurg()) localOptStack, ok := collectingMap.Load(key) var localOpts *collectOpts @@ -140,7 +144,7 @@ func setupInterceptor() { localOpts = l.list[len(l.list)-1] } } else if !enabledGlobally { - return nil, nil + return nil, trap.ErrSkip } stack := &Stack{ FuncInfo: f, @@ -166,7 +170,7 @@ func setupInterceptor() { // initial stack root := &Root{ Top: stack, - Begin: time.Now(), + Begin: getNow(), Children: []*Stack{ stack, }, @@ -282,7 +286,7 @@ func enableLocal(collOpts *collectOpts) func() { if collOpts.root == nil { collOpts.root = &Root{ Top: &Stack{}, - Begin: time.Now(), + Begin: getNow(), } } top := collOpts.root.Top @@ -312,8 +316,10 @@ func enableLocal(collOpts *collectOpts) func() { } } +var traceOutput = os.Getenv("XGO_TRACE_OUTPUT") + func getTraceOutput() string { - return os.Getenv("XGO_TRACE_OUTPUT") + return traceOutput } var marshalStack func(root *Root) ([]byte, error) @@ -343,6 +349,19 @@ func emitTraceNoErr(name string, root *Root) { emitTrace(name, root) } +func getNow() (now time.Time) { + trap.Direct(func() { + now = time.Now() + }) + return +} +func formatTime(t time.Time, layout string) (output string) { + trap.Direct(func() { + output = t.Format(layout) + }) + return +} + // this should also be marked as trap.Skip() // TODO: may add callback for this func emitTrace(name string, root *Root) error { @@ -357,7 +376,7 @@ func emitTrace(name string, root *Root) error { ghex := fmt.Sprintf("g_%x", __xgo_link_getcurg()) traceID := "t_" + strconv.FormatInt(traceIDNum, 10) if xgoTraceOutput == "" { - traceDir := time.Now().Format("trace_20060102_150405") + traceDir := formatTime(getNow(), "trace_20060102_150405") subName = filepath.Join(traceDir, ghex, traceID) } else if useStdout { subName = fmt.Sprintf("%s/%s", ghex, traceID) @@ -388,5 +407,8 @@ func emitTrace(name string, root *Root) error { if err != nil { return err } - return WriteFile(subFile, traceOut, 0755) + trap.Direct(func() { + err = WriteFile(subFile, traceOut, 0755) + }) + return err } diff --git a/runtime/trap/ignore.go b/runtime/trap/ignore.go new file mode 100644 index 00000000..439720f4 --- /dev/null +++ b/runtime/trap/ignore.go @@ -0,0 +1,28 @@ +package trap + +import ( + "sync" + + "github.com/xhd2015/xgo/runtime/core" +) + +var ignoreMap sync.Map // funcinfo -> bool + +// mark functions that should skip trap +func Ignore(f interface{}) { + if f == nil { + return + } + _, funcInfo := Inspect(f) + if funcInfo == nil { + return + } + ignoreMap.Store(funcInfo, true) +} + +// assume f is not nil + +func funcIgnored(f *core.FuncInfo) bool { + _, ok := ignoreMap.Load(f) + return ok +} diff --git a/runtime/trap/interceptor.go b/runtime/trap/interceptor.go index 172488ff..18a30862 100644 --- a/runtime/trap/interceptor.go +++ b/runtime/trap/interceptor.go @@ -5,15 +5,15 @@ import ( "errors" "fmt" "os" + "runtime" "sync" "unsafe" "github.com/xhd2015/xgo/runtime/core" ) -const __XGO_SKIP_TRAP = true - var ErrAbort error = errors.New("abort trap interceptor") +var ErrSkip error = errors.New("skip trap interceptor") // link by compiler func __xgo_link_getcurg() unsafe.Pointer { @@ -44,8 +44,9 @@ type Interceptor struct { } type interceptorManager struct { - head []*Interceptor - tail []*Interceptor + head []*Interceptor // always executed first + tail []*Interceptor + funcMapping map[*core.FuncInfo][]*Interceptor // nested mapping } func (c *interceptorManager) copy() *interceptorManager { @@ -56,9 +57,21 @@ func (c *interceptorManager) copy() *interceptorManager { tail := make([]*Interceptor, len(c.tail)) copy(head, c.head) copy(tail, c.tail) + + var funcMapping map[*core.FuncInfo][]*Interceptor + if c.funcMapping != nil { + funcMapping = make(map[*core.FuncInfo][]*Interceptor, len(c.funcMapping)) + for f, list := range c.funcMapping { + cpList := make([]*Interceptor, len(list)) + copy(cpList, list) + funcMapping[f] = cpList + } + } + return &interceptorManager{ - head: head, - tail: tail, + head: head, + tail: tail, + funcMapping: funcMapping, } } @@ -66,18 +79,36 @@ var interceptors = &interceptorManager{} var localInterceptors sync.Map // goroutine ptr -> *interceptorGroup func AddInterceptor(interceptor *Interceptor) func() { - return addInterceptor(interceptor, false) + return addInterceptor(nil, interceptor, false) +} + +func AddFuncInterceptor(f interface{}, interceptor *Interceptor) func() { + _, fnInfo, pc, _ := InspectPC(f) + if fnInfo == nil { + panic(fmt.Errorf("failed to add func interceptor: %s", runtime.FuncForPC(pc).Name())) + } + return addInterceptor(fnInfo, interceptor, false) +} + +func AddFuncInfoInterceptor(f *core.FuncInfo, interceptor *Interceptor) func() { + if f == nil { + panic(fmt.Errorf("func cannot be nil")) + } + return addInterceptor(f, interceptor, false) } func AddInterceptorHead(interceptor *Interceptor) func() { - return addInterceptor(interceptor, true) + return addInterceptor(nil, interceptor, true) } -func addInterceptor(interceptor *Interceptor, head bool) func() { + +func addInterceptor(f *core.FuncInfo, interceptor *Interceptor, head bool) func() { ensureTrapInstall() if __xgo_link_init_finished() { - dispose, _ := addLocalInterceptor(interceptor, false, head) + dispose, _ := addLocalInterceptor(f, interceptor, false, head) return dispose } + Ignore(interceptor.Pre) + Ignore(interceptor.Post) if head { interceptors.head = append(interceptors.head, interceptor) } else { @@ -99,7 +130,7 @@ func addInterceptor(interceptor *Interceptor, head bool) func() { // even from init because it will be soon cleared // without causing concurrent issues. func WithInterceptor(interceptor *Interceptor, f func()) { - dispose, _ := addLocalInterceptor(interceptor, false, false) + dispose, _ := addLocalInterceptor(nil, interceptor, false, false) defer dispose() f() } @@ -108,17 +139,14 @@ func WithInterceptor(interceptor *Interceptor, f func()) { // in current goroutine temporarily, it returns a function // that can be used to cancel the override. func WithOverride(interceptor *Interceptor, f func()) { - _, disposeGroup := addLocalInterceptor(interceptor, true, false) + _, disposeGroup := addLocalInterceptor(nil, interceptor, true, false) defer disposeGroup() f() } - -func GetInterceptors() []*Interceptor { - return interceptors.getInterceptors() -} - -func (c *interceptorManager) getInterceptors() []*Interceptor { - return mergeInterceptors(c.tail, c.head) +func WithFuncOverride(funcInfo *core.FuncInfo, interceptor *Interceptor, f func()) { + _, disposeGroup := addLocalInterceptor(funcInfo, interceptor, true, false) + defer disposeGroup() + f() } func mergeInterceptors(groups ...[]*Interceptor) []*Interceptor { @@ -133,13 +161,6 @@ func mergeInterceptors(groups ...[]*Interceptor) []*Interceptor { return list } -func GetLocalInterceptors() []*Interceptor { - g := getLocalInterceptorList() - if g == nil { - return nil - } - return g.getInterceptors() -} func getLocalInterceptorList() *interceptorManager { group := getLocalInterceptorGroup() if group == nil { @@ -160,42 +181,50 @@ func getLocalInterceptorGroup() *interceptorGroup { return val.(*interceptorGroup) } -func ClearLocalInterceptors() { - clearLocalInterceptorsAndMark() -} +// f must not be nil +// if `noCommon` is set, only get f's mapping interceptors +// TODO: may allow trace when set `noLocalCommon` +func getAllInterceptors(f *core.FuncInfo, needCommon bool) ([]*Interceptor, int) { + group := getLocalInterceptorGroup() -func GetAllInterceptors() []*Interceptor { - res, _ := getAllInterceptors() - return res -} + var globalHead []*Interceptor + var globalTail []*Interceptor -func getAllInterceptors() ([]*Interceptor, int) { - group := getLocalInterceptorGroup() var localHead []*Interceptor var localTail []*Interceptor + var localFunc []*Interceptor + + var override bool var g int if group != nil { gi := group.currentGroupInterceptors() if gi != nil { g = group.currentGroup() - if gi.override { - return gi.list.getInterceptors(), g + override = gi.override + if needCommon { + localHead = gi.list.head + localTail = gi.list.tail } - localHead = gi.list.head - localTail = gi.list.tail + localFunc = gi.list.funcMapping[f] } } - globalHead := interceptors.head - globalTail := interceptors.tail + + if needCommon && !override { + globalHead = interceptors.head + globalTail = interceptors.tail + } // run locals first(in reversed order) - return mergeInterceptors(globalTail, localTail, globalHead, localHead), g + return mergeInterceptors(globalTail, localFunc, localTail, globalHead, localHead), g } // returns a function to dispose the key // NOTE: if not called correctly,there might be memory leak -func addLocalInterceptor(interceptor *Interceptor, override bool, head bool) (removeInterceptor func(), removeGroup func()) { +func addLocalInterceptor(f *core.FuncInfo, interceptor *Interceptor, override bool, head bool) (removeInterceptor func(), removeGroup func()) { ensureTrapInstall() + Ignore(interceptor.Pre) + Ignore(interceptor.Post) + key := uintptr(__xgo_link_getcurg()) list := &interceptorGroup{} val, loaded := localInterceptors.LoadOrStore(key, list) @@ -207,7 +236,7 @@ func addLocalInterceptor(interceptor *Interceptor, override bool, head bool) (re list.enterNewGroup(override) } g := list.currentGroup() - list.appendToCurrentGroup(interceptor, head) + list.appendToCurrentGroup(f, interceptor, head) removedInterceptor := false // used to remove the local interceptor @@ -225,30 +254,34 @@ func addLocalInterceptor(interceptor *Interceptor, override bool, head bool) (re } manager := list.groups[g].list - interceptors := manager.tail - if head { - interceptors = manager.head - } - removedInterceptor = true - var idx int = -1 - for i, intc := range interceptors { - if intc == interceptor { - idx = i - break + if f == nil { + interceptors := manager.tail + if head { + interceptors = manager.head + } + removedInterceptor = true + var idx int = -1 + for i, intc := range interceptors { + if intc == interceptor { + idx = i + break + } + } + if idx < 0 { + panic(fmt.Errorf("interceptor leaked")) + } + n := len(interceptors) + for i := idx; i < n-1; i++ { + interceptors[i] = interceptors[i+1] + } + interceptors = interceptors[:n-1] + if head { + manager.head = interceptors + } else { + manager.tail = interceptors } - } - if idx < 0 { - panic(fmt.Errorf("interceptor leaked")) - } - n := len(interceptors) - for i := idx; i < n-1; i++ { - interceptors[i] = interceptors[i+1] - } - interceptors = interceptors[:n-1] - if head { - manager.head = interceptors } else { - manager.tail = interceptors + delete(manager.funcMapping, f) } } @@ -281,7 +314,14 @@ type interceptorList struct { list *interceptorManager } -func (c *interceptorList) append(interceptor *Interceptor, head bool) { +func (c *interceptorList) append(f *core.FuncInfo, interceptor *Interceptor, head bool) { + if f != nil { + if c.list.funcMapping == nil { + c.list.funcMapping = make(map[*core.FuncInfo][]*Interceptor, 1) + } + c.list.funcMapping[f] = append(c.list.funcMapping[f], interceptor) + return + } if head { c.list.head = append(c.list.head, interceptor) } else { @@ -289,9 +329,9 @@ func (c *interceptorList) append(interceptor *Interceptor, head bool) { } } -func (c *interceptorGroup) appendToCurrentGroup(interceptor *Interceptor, head bool) { +func (c *interceptorGroup) appendToCurrentGroup(f *core.FuncInfo, interceptor *Interceptor, head bool) { g := c.currentGroup() - c.groups[g].append(interceptor, head) + c.groups[g].append(f, interceptor, head) } func (c *interceptorGroup) groupsEmpty() bool { @@ -322,11 +362,3 @@ func (c *interceptorGroup) exitGroup() { } c.groups = c.groups[:n-1] } - -func clearLocalInterceptorsAndMark() { - key := uintptr(__xgo_link_getcurg()) - localInterceptors.Delete(key) - bypassMapping.Delete(key) - - clearTrappingMarkAllGroup() -} diff --git a/runtime/trap/trap.go b/runtime/trap/trap.go index 11887f82..85b973e6 100644 --- a/runtime/trap/trap.go +++ b/runtime/trap/trap.go @@ -15,17 +15,20 @@ var setupOnce sync.Once func ensureTrapInstall() { setupOnce.Do(func() { - // do not capture trap before init finished - if __xgo_link_init_finished() { - __xgo_link_set_trap(trapFunc) - __xgo_link_set_trap_var(trapVar) - } else { - // deferred - __xgo_link_on_init_finished(func() { - __xgo_link_set_trap(trapFunc) - __xgo_link_set_trap_var(trapVar) - }) - } + // set trap once needed, no matter it + // is inside init or not + __xgo_link_set_trap(trapFunc) + __xgo_link_set_trap_var(trapVar) + + // // do not capture trap before init finished + // if __xgo_link_init_finished() { + // } else { + // // deferred + // __xgo_link_on_init_finished(func() { + // __xgo_link_set_trap(trapFunc) + // __xgo_link_set_trap_var(trapVar) + // }) + // } }) } @@ -69,11 +72,30 @@ func __xgo_link_on_init_finished(f func()) { // sense at compile time. func Skip() {} -var trappingMark sync.Map // -> struct{}{} -var trappingPC sync.Map // -> PC +var stackMapping sync.Map // -> root var inspectingMap sync.Map // -> interceptor +type root struct { + top *stack + intercepting bool // is executing intercepting? to avoid re-entrance +} + +type stack struct { + parent *stack + funcInfo *core.FuncInfo + stage stage + pc uintptr // the actual pc +} + +type stage int + +const ( + stage_pre stage = 0 + stage_execute stage = 1 + stage_post stage = 2 +) + // link to runtime // xgo:notrap func trapFunc(pkgPath string, identityName string, generic bool, pc uintptr, recv interface{}, args []interface{}, results []interface{}) (func(), bool) { @@ -81,17 +103,6 @@ func trapFunc(pkgPath string, identityName string, generic bool, pc uintptr, rec return nil, false } inspectingFn, inspecting := inspectingMap.Load(uintptr(__xgo_link_getcurg())) - var interceptors []*Interceptor - var group int - var n int - if !inspecting { - // never trap any function from runtime - interceptors, group = getAllInterceptors() - n = len(interceptors) - if n == 0 { - return nil, false - } - } // NOTE: this may return nil for generic template var f *core.FuncInfo @@ -116,22 +127,18 @@ func trapFunc(pkgPath string, identityName string, generic bool, pc uintptr, rec // abort,never really call the target return nil, true } + // f is an interceptor or ignored by user + if funcIgnored(f) { + return nil, false + } - // setup context - setTrappingPC(pc) - defer clearTrappingPC() - return trap(f, interceptors, group, recv, args, results) + return trap(f, pc, recv, args, results) } func trapVar(pkgPath string, name string, tmpVarAddr interface{}, takeAddr bool) { if isByPassing() { return } - interceptors, group := getAllInterceptors() - n := len(interceptors) - if n == 0 { - return - } identityName := name if takeAddr { identityName = "*" + name @@ -144,19 +151,50 @@ func trapVar(pkgPath string, name string, tmpVarAddr interface{}, takeAddr bool) return } // NOTE: stop always ignored because this is a simple get - post, _ := trap(fnInfo, interceptors, group, nil, nil, []interface{}{tmpVarAddr}) + post, _ := trap(fnInfo, 0, nil, nil, []interface{}{tmpVarAddr}) if post != nil { // NOTE: must in defer, because in post we // may capture panic defer post() } } -func trap(f *core.FuncInfo, interceptors []*Interceptor, group int, recv interface{}, args []interface{}, results []interface{}) (func(), bool) { - dispose := setTrappingMark(group) - if dispose == nil { +func trap(f *core.FuncInfo, pc uintptr, recv interface{}, args []interface{}, results []interface{}) (func(), bool) { + // never trap any function from runtime + key := uintptr(__xgo_link_getcurg()) + r := &root{} + rv, loaded := stackMapping.LoadOrStore(key, r) + if loaded { + r = rv.(*root) + } + // fmt.Printf("trap: %s.%s intercepting=%v\n", f.Pkg, f.IdentityName, r.intercepting) + interceptors, _ := getAllInterceptors(f, !r.intercepting) + n := len(interceptors) + if n == 0 { return nil, false } - defer dispose() + + var resetFlag bool + parent := r.top + if !r.intercepting { + resetFlag = true + r.intercepting = true + defer func() { + r.intercepting = false + }() + } + stack := &stack{ + parent: parent, + funcInfo: f, + stage: stage_pre, + pc: pc, + } + r.top = stack + + // dispose := setTrappingMark(group, f) + // if dispose == nil { + // return nil, false + // } + // defer dispose() // retrieve context var ctx context.Context @@ -248,8 +286,8 @@ func trap(f *core.FuncInfo, interceptors []*Interceptor, group int, recv interfa var firstPreErr error abortIdx := -1 - n := len(interceptors) dataList := make([]interface{}, n) + skipIndex := make([]bool, n) for i := n - 1; i >= 0; i-- { interceptor := interceptors[i] if interceptor.Pre == nil { @@ -259,6 +297,10 @@ func trap(f *core.FuncInfo, interceptors []*Interceptor, group int, recv interfa data, err := interceptor.Pre(ctx, f, req, resObject) dataList[i] = data if err != nil { + if err == ErrSkip { + skipIndex[i] = true + continue + } // always break on error firstPreErr = err abortIdx = i @@ -270,13 +312,20 @@ func trap(f *core.FuncInfo, interceptors []*Interceptor, group int, recv interfa if firstPreErr == ErrAbort { firstPreErr = nil } + stack.stage = stage_execute // always run post in defer return func() { - dispose := setTrappingMark(group) - if dispose == nil { - return + stack.stage = stage_post + defer func() { + r.top = parent + }() + + if resetFlag { + r.intercepting = true + defer func() { + r.intercepting = false + }() } - defer dispose() var lastPostErr error = firstPreErr idx := 0 @@ -288,6 +337,9 @@ func trap(f *core.FuncInfo, interceptors []*Interceptor, group int, recv interfa if interceptor.Post == nil { continue } + if skipIndex[i] { + continue + } err := interceptor.Post(ctx, f, req, resObject, dataList[i]) if err != nil { if err == ErrAbort { @@ -308,46 +360,21 @@ func trap(f *core.FuncInfo, interceptors []*Interceptor, group int, recv interfa func GetTrappingPC() uintptr { key := uintptr(__xgo_link_getcurg()) - val, ok := trappingPC.Load(key) + val, ok := stackMapping.Load(key) if !ok { return 0 } - return val.(uintptr) -} - -type trappingGroup struct { - m map[int]bool -} - -func setTrappingMark(group int) func() { - key := uintptr(__xgo_link_getcurg()) - g := &trappingGroup{} - v, loaded := trappingMark.LoadOrStore(key, g) - if loaded { - g = v.(*trappingGroup) - if g.m[group] { - return nil - } - } else { - g.m = make(map[int]bool, 1) - } - g.m[group] = true - return func() { - g.m[group] = false + top := val.(*root).top + if top == nil { + return 0 } + return top.pc } -func clearTrappingMarkAllGroup() { - key := uintptr(__xgo_link_getcurg()) - trappingMark.Delete(key) -} - -func setTrappingPC(pc uintptr) { +func clearLocalInterceptorsAndMark() { key := uintptr(__xgo_link_getcurg()) - trappingPC.Store(key, pc) -} + localInterceptors.Delete(key) + bypassMapping.Delete(key) -func clearTrappingPC() { - key := uintptr(__xgo_link_getcurg()) - trappingPC.Delete(key) + stackMapping.Delete(key) }