From 75b44bd9be623a06a6763e29f314bfe964c8d941 Mon Sep 17 00:00:00 2001 From: xhd2015 Date: Sun, 14 Apr 2024 21:20:46 +0800 Subject: [PATCH] allow trap and cancel while init, see https://github.com/xhd2015/xgo/issues/55#issuecomment-2054059666 --- cmd/xgo/patch/runtime_def.go | 24 ++ cmd/xgo/version.go | 4 +- patch/syntax/helper_code.go | 11 +- patch/syntax/helper_code_gen.go | 11 +- runtime/core/version.go | 4 +- runtime/functab/functab.go | 3 +- runtime/test/trap/trap_init/trap_init_test.go | 55 +++++ runtime/trap/interceptor.go | 214 ++++++++++-------- 8 files changed, 203 insertions(+), 123 deletions(-) create mode 100644 runtime/test/trap/trap_init/trap_init_test.go diff --git a/cmd/xgo/patch/runtime_def.go b/cmd/xgo/patch/runtime_def.go index 180bb410..37985879 100644 --- a/cmd/xgo/patch/runtime_def.go +++ b/cmd/xgo/patch/runtime_def.go @@ -92,6 +92,14 @@ if os.Getenv("XGO_COMPILER_ENABLE")=="true" { } p.file = file noders = append(noders, p) + + // move to head + n := len(noders) + for i:=n-1;i>0;i--{ + noders[i]=noders[i-1] + } + noders[0]=p + return file }) } @@ -114,6 +122,14 @@ if os.Getenv("XGO_COMPILER_ENABLE")=="true" { } p.file = file noders = append(noders, p) + + // move to head + n := len(noders) + for i:=n-1;i>0;i--{ + noders[i]=noders[i-1] + } + noders[0]=p + return file }) } @@ -136,6 +152,14 @@ if os.Getenv("XGO_COMPILER_ENABLE")=="true" { } p.file = file noders = append(noders, p) + + // move to head + n := len(noders) + for i:=n-1;i>0;i--{ + noders[i]=noders[i-1] + } + noders[0]=p + return file }) } diff --git a/cmd/xgo/version.go b/cmd/xgo/version.go index d6762236..0ff967ea 100644 --- a/cmd/xgo/version.go +++ b/cmd/xgo/version.go @@ -3,8 +3,8 @@ package main import "fmt" const VERSION = "1.0.24" -const REVISION = "b0dd1873bc3e5e5464f55c7b01a077712dc4c818+1" -const NUMBER = 180 +const REVISION = "a9cbbd937997b1473a5f70f0131927adcdbf3b79+1" +const NUMBER = 181 func getRevision() string { revSuffix := "" diff --git a/patch/syntax/helper_code.go b/patch/syntax/helper_code.go index 42070170..2ccc65c5 100644 --- a/patch/syntax/helper_code.go +++ b/patch/syntax/helper_code.go @@ -33,16 +33,9 @@ type __xgo_local_func_stub struct { Line int } -// ensure early init -var _ = func() bool { - __xgo_trap_skip() +func init() { __xgo_link_generate_init_regs_body() - return true -}() - -// func init() { -// __xgo_link_generate_init_regs_body() -// } +} // TODO: ensure safety for this func __xgo_link_generate_init_regs_body() { diff --git a/patch/syntax/helper_code_gen.go b/patch/syntax/helper_code_gen.go index 3fdd7430..3844a7a3 100755 --- a/patch/syntax/helper_code_gen.go +++ b/patch/syntax/helper_code_gen.go @@ -62,16 +62,9 @@ type __xgo_local_func_stub struct { Line int } -// ensure early init -var _ = func() bool { - __xgo_trap_skip() +func init() { __xgo_link_generate_init_regs_body() - return true -}() - -// func init() { -// __xgo_link_generate_init_regs_body() -// } +} // TODO: ensure safety for this func __xgo_link_generate_init_regs_body() { diff --git a/runtime/core/version.go b/runtime/core/version.go index 058a343d..55d37bd0 100644 --- a/runtime/core/version.go +++ b/runtime/core/version.go @@ -7,8 +7,8 @@ import ( ) const VERSION = "1.0.24" -const REVISION = "b0dd1873bc3e5e5464f55c7b01a077712dc4c818+1" -const NUMBER = 180 +const REVISION = "a9cbbd937997b1473a5f70f0131927adcdbf3b79+1" +const NUMBER = 181 // these fields will be filled by compiler const XGO_VERSION = "" diff --git a/runtime/functab/functab.go b/runtime/functab/functab.go index 835ed7b5..a268ada1 100644 --- a/runtime/functab/functab.go +++ b/runtime/functab/functab.go @@ -156,13 +156,14 @@ func registerFuncInfo(fnInfo interface{}) { } // fmt.Fprintf(os.Stderr, "empty name\n",pkgPath) } + pkgPath := rv.FieldByName("PkgPath").String() + // fmt.Printf("register: %s %s\n", pkgPath, identityName) var fnKind core.Kind fnKindV := rv.FieldByName("Kind") if fnKindV.IsValid() { fnKind = core.Kind(fnKindV.Int()) } varField := rv.FieldByName("Var") - pkgPath := rv.FieldByName("PkgPath").String() recvTypeName := rv.FieldByName("RecvTypeName").String() recvPtr := rv.FieldByName("RecvPtr").Bool() name := rv.FieldByName("Name").String() diff --git a/runtime/test/trap/trap_init/trap_init_test.go b/runtime/test/trap/trap_init/trap_init_test.go new file mode 100644 index 00000000..d82412b1 --- /dev/null +++ b/runtime/test/trap/trap_init/trap_init_test.go @@ -0,0 +1,55 @@ +package trap_init + +import ( + "bytes" + "context" + "fmt" + "testing" + + "github.com/xhd2015/xgo/runtime/core" + "github.com/xhd2015/xgo/runtime/trap" +) + +var trapBuf bytes.Buffer +var initA string + +func init() { + cancel := trap.AddInterceptor(&trap.Interceptor{ + Pre: func(ctx context.Context, f *core.FuncInfo, args, result core.Object) (data interface{}, err error) { + trapBuf.WriteString(fmt.Sprintf("call %s\n", f.IdentityName)) + if f.IdentityName == "B" { + result.GetFieldIndex(0).Set("mock B") + return nil, trap.ErrAbort + } + return + }, + }) + defer cancel() + initA = A() +} + +func TestTrapInsideInit(t *testing.T) { + str := trapBuf.String() + expectTrapBuf := "call A\ncall B\n" + if str != expectTrapBuf { + t.Fatalf("expect trap buf: %q, actual: %q", expectTrapBuf, str) + } + expectInitA := "A:mock B" + if initA != expectInitA { + t.Fatalf("expect initA: %q, actual: %q", expectInitA, initA) + } + + // check if the interceptor is cancelled + a := A() + expectA := "A:B" + if a != expectA { + t.Fatalf("expect a: %q, actual: %q", expectA, a) + } +} + +func A() string { + return "A:" + B() +} +func B() string { + return "B" +} diff --git a/runtime/trap/interceptor.go b/runtime/trap/interceptor.go index 18a30862..4342d45e 100644 --- a/runtime/trap/interceptor.go +++ b/runtime/trap/interceptor.go @@ -43,45 +43,15 @@ type Interceptor struct { Post func(ctx context.Context, f *core.FuncInfo, args core.Object, result core.Object, data interface{}) error } -type interceptorManager struct { - head []*Interceptor // always executed first - tail []*Interceptor - funcMapping map[*core.FuncInfo][]*Interceptor // nested mapping -} - -func (c *interceptorManager) copy() *interceptorManager { - if c == nil { - return nil - } - head := make([]*Interceptor, len(c.head)) - tail := make([]*Interceptor, len(c.tail)) - copy(head, c.head) - copy(tail, c.tail) - - var funcMapping map[*core.FuncInfo][]*Interceptor - if c.funcMapping != nil { - funcMapping = make(map[*core.FuncInfo][]*Interceptor, len(c.funcMapping)) - for f, list := range c.funcMapping { - cpList := make([]*Interceptor, len(list)) - copy(cpList, list) - funcMapping[f] = cpList - } - } - - return &interceptorManager{ - head: head, - tail: tail, - funcMapping: funcMapping, - } -} - -var interceptors = &interceptorManager{} +var globalInterceptors = &interceptorManager{} var localInterceptors sync.Map // goroutine ptr -> *interceptorGroup +// AddInterceptor add a general interceptor, disallowing re-entrant func AddInterceptor(interceptor *Interceptor) func() { return addInterceptor(nil, interceptor, false) } +// AddFuncInterceptor add func interceptor, allowing f to be re-entrant func AddFuncInterceptor(f interface{}, interceptor *Interceptor) func() { _, fnInfo, pc, _ := InspectPC(f) if fnInfo == nil { @@ -101,24 +71,6 @@ func AddInterceptorHead(interceptor *Interceptor) func() { return addInterceptor(nil, interceptor, true) } -func addInterceptor(f *core.FuncInfo, interceptor *Interceptor, head bool) func() { - ensureTrapInstall() - if __xgo_link_init_finished() { - dispose, _ := addLocalInterceptor(f, interceptor, false, head) - return dispose - } - Ignore(interceptor.Pre) - Ignore(interceptor.Post) - if head { - interceptors.head = append(interceptors.head, interceptor) - } else { - interceptors.tail = append(interceptors.tail, interceptor) - } - return func() { - panic("global interceptor cannot be cancelled, if you want to cancel a global interceptor, use WithInterceptor") - } -} - // WithInterceptor executes given f with interceptor // setup. It can be used from init phase safely. // it clears the interceptor after f finishes. @@ -149,6 +101,89 @@ func WithFuncOverride(funcInfo *core.FuncInfo, interceptor *Interceptor, f func( f() } +func addInterceptor(f *core.FuncInfo, interceptor *Interceptor, head bool) func() { + ensureTrapInstall() + if __xgo_link_init_finished() { + dispose, _ := addLocalInterceptor(f, interceptor, false, head) + return dispose + } + Ignore(interceptor.Pre) + Ignore(interceptor.Post) + + globalInterceptors.append(f, interceptor, false) + return func() { + if __xgo_link_init_finished() { + // to ensure lock free + panic("global interceptor cannot be cancelled after init, if you want to cancel a global interceptor, use WithInterceptor") + } + globalInterceptors.removeInterceptor(f, interceptor, false) + } +} + +type interceptorManager struct { + head []*Interceptor // always executed first + tail []*Interceptor + funcMapping map[*core.FuncInfo][]*Interceptor // nested mapping +} + +func (c *interceptorManager) copy() *interceptorManager { + if c == nil { + return nil + } + head := make([]*Interceptor, len(c.head)) + tail := make([]*Interceptor, len(c.tail)) + copy(head, c.head) + copy(tail, c.tail) + + var funcMapping map[*core.FuncInfo][]*Interceptor + if c.funcMapping != nil { + funcMapping = make(map[*core.FuncInfo][]*Interceptor, len(c.funcMapping)) + for f, list := range c.funcMapping { + cpList := make([]*Interceptor, len(list)) + copy(cpList, list) + funcMapping[f] = cpList + } + } + + return &interceptorManager{ + head: head, + tail: tail, + funcMapping: funcMapping, + } +} + +func (c *interceptorManager) append(f *core.FuncInfo, interceptor *Interceptor, head bool) { + if f != nil { + if c.funcMapping == nil { + c.funcMapping = make(map[*core.FuncInfo][]*Interceptor, 1) + } + c.funcMapping[f] = append(c.funcMapping[f], interceptor) + return + } + if head { + c.head = append(c.head, interceptor) + } else { + c.tail = append(c.tail, interceptor) + } +} + +func (c *interceptorManager) removeInterceptor(f *core.FuncInfo, interceptor *Interceptor, head bool) { + if f == nil { + if head { + c.head = dropInterceptor(c.head, interceptor) + } else { + c.tail = dropInterceptor(c.tail, interceptor) + } + } else { + newInterceptor := dropInterceptor(c.funcMapping[f], interceptor) + if newInterceptor == nil { + delete(c.funcMapping, f) + } else { + c.funcMapping[f] = newInterceptor + } + } +} + func mergeInterceptors(groups ...[]*Interceptor) []*Interceptor { n := 0 for _, g := range groups { @@ -210,8 +245,8 @@ func getAllInterceptors(f *core.FuncInfo, needCommon bool) ([]*Interceptor, int) } if needCommon && !override { - globalHead = interceptors.head - globalTail = interceptors.tail + globalHead = globalInterceptors.head + globalTail = globalInterceptors.tail } // run locals first(in reversed order) @@ -252,37 +287,8 @@ func addLocalInterceptor(f *core.FuncInfo, interceptor *Interceptor, override bo if curG != g { panic(fmt.Errorf("interceptor group changed: previous=%d, current=%d", g, curG)) } - manager := list.groups[g].list - - if f == nil { - interceptors := manager.tail - if head { - interceptors = manager.head - } - removedInterceptor = true - var idx int = -1 - for i, intc := range interceptors { - if intc == interceptor { - idx = i - break - } - } - if idx < 0 { - panic(fmt.Errorf("interceptor leaked")) - } - n := len(interceptors) - for i := idx; i < n-1; i++ { - interceptors[i] = interceptors[i+1] - } - interceptors = interceptors[:n-1] - if head { - manager.head = interceptors - } else { - manager.tail = interceptors - } - } else { - delete(manager.funcMapping, f) - } + removedInterceptor = true + list.groups[g].list.removeInterceptor(f, interceptor, head) } removedGroup := false @@ -305,6 +311,29 @@ func addLocalInterceptor(f *core.FuncInfo, interceptor *Interceptor, override bo return removeInterceptor, removeGroup } +func dropInterceptor(interceptors []*Interceptor, interceptor *Interceptor) []*Interceptor { + n := len(interceptors) + idx := -1 + for i := 0; i < n; i++ { + if interceptors[i] == interceptor { + idx = i + break + } + } + if idx < 0 { + panic("interceptor not found before removed") + } + + for i := idx + 1; i < n; i++ { + interceptors[i-1] = interceptors[i] + } + interceptors = interceptors[:n-1] + if len(interceptors) == 0 { + return nil + } + return interceptors +} + type interceptorGroup struct { groups []*interceptorList } @@ -314,24 +343,9 @@ type interceptorList struct { list *interceptorManager } -func (c *interceptorList) append(f *core.FuncInfo, interceptor *Interceptor, head bool) { - if f != nil { - if c.list.funcMapping == nil { - c.list.funcMapping = make(map[*core.FuncInfo][]*Interceptor, 1) - } - c.list.funcMapping[f] = append(c.list.funcMapping[f], interceptor) - return - } - if head { - c.list.head = append(c.list.head, interceptor) - } else { - c.list.tail = append(c.list.tail, interceptor) - } -} - func (c *interceptorGroup) appendToCurrentGroup(f *core.FuncInfo, interceptor *Interceptor, head bool) { g := c.currentGroup() - c.groups[g].append(f, interceptor, head) + c.groups[g].list.append(f, interceptor, head) } func (c *interceptorGroup) groupsEmpty() bool {