Skip to content

Commit

Permalink
add Patch,PatchByName and PatchMethodByName
Browse files Browse the repository at this point in the history
  • Loading branch information
xhd2015 committed Apr 1, 2024
1 parent 4375601 commit bd1a2bd
Show file tree
Hide file tree
Showing 16 changed files with 473 additions and 34 deletions.
4 changes: 2 additions & 2 deletions cmd/xgo/version.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ package main
import "fmt"

const VERSION = "1.0.11"
const REVISION = "b1b21571259df7c1632d86cef35ba681727a0cde+1"
const NUMBER = 143
const REVISION = "43756010e13cabfae008c1de9d72f98b946b0a09+1"
const NUMBER = 144

func getRevision() string {
return fmt.Sprintf("%s %s BUILD_%d", VERSION, REVISION, NUMBER)
Expand Down
6 changes: 4 additions & 2 deletions patch/syntax/helper_code.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@ type __xgo_local_func_stub struct {
ArgNames []string
ResNames []string

// can be retrieved at runtime
// Deprecated
// these two fields can be retrieved at runtime
FirstArgCtx bool // first argument is context.Context or sub type?
LastResErr bool // last res is error or sub type?
// Deprecated
LastResErr bool // last res is error or sub type?

File string
Line int
Expand Down
13 changes: 8 additions & 5 deletions patch/syntax/syntax.go
Original file line number Diff line number Diff line change
Expand Up @@ -476,11 +476,14 @@ func getFuncDeclInfo(fileIndex int, f *syntax.File, file string, fn *syntax.Func
}
var firstArgCtx bool
var lastResErr bool
if len(fn.Type.ParamList) > 0 && hasQualifiedName(fn.Type.ParamList[0].Type, "context", "Context") {
firstArgCtx = true
}
if len(fn.Type.ResultList) > 0 && isName(fn.Type.ResultList[len(fn.Type.ResultList)-1].Type, "error") {
lastResErr = true
if false {
// NOTE: these fields will be retrieved at runtime dynamically
if len(fn.Type.ParamList) > 0 && hasQualifiedName(fn.Type.ParamList[0].Type, "context", "Context") {
firstArgCtx = true
}
if len(fn.Type.ResultList) > 0 && isName(fn.Type.ResultList[len(fn.Type.ResultList)-1].Type, "error") {
lastResErr = true
}
}

return &DeclInfo{
Expand Down
4 changes: 2 additions & 2 deletions runtime/core/version.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ import (
)

const VERSION = "1.0.11"
const REVISION = "b1b21571259df7c1632d86cef35ba681727a0cde+1"
const NUMBER = 143
const REVISION = "43756010e13cabfae008c1de9d72f98b946b0a09+1"
const NUMBER = 144

// these fields will be filled by compiler
const XGO_VERSION = ""
Expand Down
25 changes: 14 additions & 11 deletions runtime/functab/functab.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,21 +166,24 @@ func ensureMapping() {
generic := rv.FieldByName("Generic").Bool()
f := rv.FieldByName("Fn").Interface()

firstArgCtx := rv.FieldByName("FirstArgCtx").Bool()
lastResErr := rv.FieldByName("LastResErr").Bool()
var firstArgCtx bool
var lastResErr bool
var pc uintptr
var fullName string
if !generic && !interface_ {
if f != nil {
if closure {
// TODO: move all ctx, err check logic here
ft := reflect.TypeOf(f)
if ft.NumIn() > 0 && ft.In(0).Implements(ctxType) {
firstArgCtx = true
}
if ft.NumOut() > 0 && ft.Out(ft.NumOut()-1).Implements(errType) {
lastResErr = true
}
// TODO: move all ctx, err check logic here
ft := reflect.TypeOf(f)
off := 0
if recvTypeName != "" {
off = 1
}
if ft.NumIn() > off && ft.In(off).Implements(ctxType) {
firstArgCtx = true
}
// NOTE: use == instead of implements
if ft.NumOut() > 0 && ft.Out(ft.NumOut()-1) == errType {
lastResErr = true
}
pc = getFuncPC(f)
fullName = __xgo_link_get_pc_name(pc)
Expand Down
71 changes: 71 additions & 0 deletions runtime/mock/patch.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package mock

import (
"context"
"fmt"
"reflect"

"github.com/xhd2015/xgo/runtime/core"
)

func PatchByName(pkgPath string, funcName string, replacer interface{}) func() {
return MockByName(pkgPath, funcName, buildInterceptorFromPatch(replacer))
}

func PatchMethodByName(instance interface{}, method string, replacer interface{}) func() {
return MockMethodByName(instance, method, buildInterceptorFromPatch(replacer))
}

func buildInterceptorFromPatch(replacer interface{}) func(ctx context.Context, fn *core.FuncInfo, args, results core.Object) error {
v := reflect.ValueOf(replacer)
t := v.Type()
if t.Kind() != reflect.Func {
panic(fmt.Errorf("requires func, given %T", replacer))
}
if v.IsNil() {
panic("replacer is nil")
}
nIn := t.NumIn()
return func(ctx context.Context, fn *core.FuncInfo, args, results core.Object) error {
// assemble arguments
callArgs := make([]reflect.Value, nIn)
src := 0
dst := 0
if fn.RecvType != "" {
src++
}
if fn.FirstArgCtx {
callArgs[0] = reflect.ValueOf(ctx)
dst++
}
for i := 0; i < nIn-dst; i++ {
callArgs[dst+i] = reflect.ValueOf(args.GetFieldIndex(src + i).Value())
}

// call the function
var res []reflect.Value
if !t.IsVariadic() {
res = v.Call(callArgs)
} else {
res = v.CallSlice(callArgs)
}

// assign result
nOut := len(res)
for i := 0; i < nOut; i++ {
results.GetFieldIndex(i).Set(res[i].Interface())
}

// check error
if nOut > 0 {
last := res[nOut-1].Interface()
if last != nil {
if err, ok := last.(error); ok {
return err
}
}
}

return nil
}
}
8 changes: 8 additions & 0 deletions runtime/mock/patch_go1.17.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
//go:build !go1.18
// +build !go1.18

package mock

func Patch(fn interface{}, replacer interface{}) func() {
return Mock(fn, buildInterceptorFromPatch(replacer))
}
10 changes: 10 additions & 0 deletions runtime/mock/patch_go1.18.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
//go:build go1.18
// +build go1.18

package mock

// TODO: what if `fn` is a Type function
// instead of an instance method?
func Patch[T any](fn T, replacer T) func() {
return Mock(fn, buildInterceptorFromPatch(replacer))
}
60 changes: 60 additions & 0 deletions runtime/test/patch/patch_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package patch

import (
"strings"
"testing"

"github.com/xhd2015/xgo/runtime/mock"
)

func greet(s string) string {
return "hello " + s
}

func greetVaradic(s ...string) string {
return "hello " + strings.Join(s, ",")
}

func TestPatchSimpleFunc(t *testing.T) {
mock.Patch(greet, func(s string) string {
return "mock " + s
})

res := greet("world")
if res != "mock world" {
t.Fatalf("expect patched result to be %q, actual: %q", "mock world", res)
}
}

func TestPatchVaradicFunc(t *testing.T) {
mock.Patch(greetVaradic, func(s ...string) string {
return "mock " + strings.Join(s, ",")
})

res := greetVaradic("earth", "moon")
if res != "mock earth,moon" {
t.Fatalf("expect patched result to be %q, actual: %q", "mock earth,moon", res)
}
}

type struct_ struct {
s string
}

func (c *struct_) greet() string {
return "hello " + c.s
}

func TestPatchMethod(t *testing.T) {
ins := &struct_{
s: "world",
}
mock.Patch(ins.greet, func() string {
return "mock " + ins.s
})

res := ins.greet()
if res != "mock world" {
t.Fatalf("expect patched result to be %q, actual: %q", "mock world", res)
}
}
31 changes: 31 additions & 0 deletions runtime/test/trap_args/closure_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package trap_args

import (
"context"
"testing"

"github.com/xhd2015/xgo/runtime/core"
)

var gc = func(ctx context.Context) {
panic("gc should be trapped")
}

func TestClosureShouldRetrieveCtxInfoAtTrapTime(t *testing.T) {
ctx := context.Background()
ctx = context.WithValue(ctx, "test", "mock")
callAndCheck(func() {
gc(ctx)
}, func(trapCtx context.Context, f *core.FuncInfo, args, result core.Object) error {
if !f.FirstArgCtx {
t.Fatalf("expect closure also mark firstArgCtx, actually not marked")
}
if trapCtx == nil {
t.Fatalf("expect trapCtx to be non nil, atcual nil")
}
if trapCtx != ctx {
t.Fatalf("expect trapCtx to be the same with ctx, actully different")
}
return nil
})
}
51 changes: 51 additions & 0 deletions runtime/test/trap_args/ctx_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package trap_args

import (
"context"
"testing"

"github.com/xhd2015/xgo/runtime/core"
)

func TestPlainCtxArgCanBeRecognized(t *testing.T) {
ctx := context.Background()
ctx = context.WithValue(ctx, "test", "mock")
callAndCheck(func() {
f2(ctx)
}, func(trapCtx context.Context, f *core.FuncInfo, args, result core.Object) error {
if !f.FirstArgCtx {
t.Fatalf("expect first arg to be context")
}
if trapCtx != ctx {
t.Fatalf("expect context passed unchanged, actually different")
}
ctxVal := trapCtx.Value("test").(string)
if ctxVal != "mock" {
t.Fatalf("expect context value to be %q, actual: %q", "mock", ctxVal)
}
return nil
})
}

func TestCtxVariantCanBeRecognized(t *testing.T) {
ctx := context.Background()
ctx = context.WithValue(ctx, "test", "mock")

myCtx := &MyContext{Context: ctx}

callAndCheck(func() {
f3(myCtx)
}, func(trapCtx context.Context, f *core.FuncInfo, args, result core.Object) error {
if !f.FirstArgCtx {
t.Fatalf("expect first arg to be context")
}
if trapCtx != myCtx {
t.Fatalf("expect context passed unchanged, actually different")
}
ctxVal := trapCtx.Value("test").(string)
if ctxVal != "mock" {
t.Fatalf("expect context value to be %q, actual: %q", "mock", ctxVal)
}
return nil
})
}
69 changes: 69 additions & 0 deletions runtime/test/trap_args/err_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package trap_args

import (
"context"
"errors"
"testing"

"github.com/xhd2015/xgo/runtime/core"
"github.com/xhd2015/xgo/runtime/trap"
)

func plainErr() error {
panic("plainErr should be mocked")
}

func subErr() *Error {
return &Error{"sub error"}
}

func TestPlainErrShouldSetErrRes(t *testing.T) {
mockErr := errors.New("mock err")
var err error
callAndCheck(func() {
err = plainErr()
}, func(trapCtx context.Context, f *core.FuncInfo, args, result core.Object) error {
if !f.LastResultErr {
t.Fatalf("expect f.LastResultErr to be true, actual: false")
}
return mockErr
})

if err != mockErr {
t.Fatalf("expect return err %v, actual %v", mockErr, err)
}
}

func TestSubErrShouldNotSetErrRes(t *testing.T) {
mockErr := errors.New("mock err")
var err *Error
var recoverErr interface{}
func() {
defer func() {
// NOTE: this may have impact
trap.Skip()
recoverErr = recover()
}()
callAndCheck(func() {
err = subErr()
}, func(trapCtx context.Context, f *core.FuncInfo, args, result core.Object) error {
if f.LastResultErr {
t.Fatalf("expect f.LastResultErr to be false, actual: true")
}
// even not pl should fail
return mockErr
})
}()

if recoverErr == nil {
t.Fatalf("expect error via panic, actually no panic")
}

if err != nil {
t.Fatalf("expect return error not set, actual: %v", err)
}

if recoverErr != mockErr {
t.Fatalf("expect panic err to be %v, actual: %v", mockErr, recoverErr)
}
}
Loading

0 comments on commit bd1a2bd

Please sign in to comment.