Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: wrong parameter value when using To/When on generic functions #40

Merged
merged 1 commit into from
Oct 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion internal/monkey/fn/copy_darwin.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func Copy(targetPtr, oriFn interface{}) {
targetType := targetVal.Type().Elem()
tool.Assert(targetType.Kind() == reflect.Func, "'%v' is not a function pointer", targetPtr)
oriVal := reflect.ValueOf(oriFn)
tool.Assert(tool.CheckFuncArgs(targetType, oriVal.Type(), 0), "target and ori not match")
tool.Assert(tool.CheckFuncArgs(targetType, oriVal.Type(), 0, 0), "target and ori not match")

oriAddr := oriVal.Pointer()
tool.DebugPrintf("Copy: copy start for %v\n", runtime.FuncForPC(oriAddr).Name())
Expand Down
2 changes: 0 additions & 2 deletions internal/monkey/patch.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,6 @@ func (p *Patch) Unpatch() {
func PatchValue(target, hook, proxy reflect.Value, unsafe, generic bool) *Patch {
tool.Assert(hook.Kind() == reflect.Func, "'%s' is not a function", hook.Kind())
tool.Assert(proxy.Kind() == reflect.Ptr, "'%v' is not a function pointer", proxy.Kind())
tool.Assert(hook.Type() == target.Type(), "'%v' and '%s' mismatch", hook.Type(), target.Type())
tool.Assert(proxy.Elem().Type() == target.Type(), "'*%v' and '%s' mismatch", proxy.Elem().Type(), target.Type())

targetAddr := target.Pointer()
if generic {
Expand Down
7 changes: 0 additions & 7 deletions internal/tool/call.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,6 @@ import (
"reflect"
)

func ReflectCallWithShiftOne(f reflect.Value, args []reflect.Value, shift bool) []reflect.Value {
if shift {
return ReflectCall(f, args[1:])
}
return ReflectCall(f, args)
}

func ReflectCall(f reflect.Value, args []reflect.Value) []reflect.Value {
if f.Type().IsVariadic() {
newArgs := make([]reflect.Value, 0)
Expand Down
8 changes: 4 additions & 4 deletions internal/tool/check.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ func CheckReturnType(fn interface{}, results ...interface{}) {
}
}

func CheckFuncArgs(a, b reflect.Type, shift int) bool {
if a.NumIn() == b.NumIn()+shift {
for i := shift; i < a.NumIn(); i++ {
if a.In(i) != b.In(i-shift) {
func CheckFuncArgs(a, b reflect.Type, shiftA, shiftB int) bool {
if a.NumIn()-shiftA == b.NumIn()-shiftB {
for indexA, indexB := shiftA, shiftB; indexA < a.NumIn(); indexA, indexB = indexA+1, indexB+1 {
if a.In(indexA) != b.In(indexB) {
return false
}
}
Expand Down
105 changes: 65 additions & 40 deletions mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,31 +34,35 @@ const (
)

type Mocker struct {
target reflect.Value // 目标函数
hook reflect.Value // mock函数
proxy interface{} // mock之后,原函数地址
target reflect.Value // mock target value
hook reflect.Value // mock hook
proxy interface{} // proxy function to origin
times int64
mockTimes int64
patch *monkey.Patch
lock sync.Mutex
isPatched bool
builder *MockBuilder

outerCaller tool.CallerInfo // Mocker 的外部调用位置
outerCaller tool.CallerInfo
}

type MockBuilder struct {
target interface{} // 目标函数
// hook interface{} // mock函数
proxyCaller interface{} // mock之后,原函数地址
// when interface{} // 条件函数
conditions []*mockCondition // 条件转移
target interface{} // mock target
proxyCaller interface{} // origin function caller hook
conditions []*mockCondition // mock conditions
filterGoroutine FilterGoroutineType
gId int64
unsafe bool
generic bool
}

// Mock mocks target function
//
// If target is a generic method or method of generic types, you need add a genericOpt, like this:
Sychorius marked this conversation as resolved.
Show resolved Hide resolved
//
// func f[int, float64](x int, y T1) T2
// Mock(f[int, float64], OptGeneric)
func Mock(target interface{}, opt ...optionFn) *MockBuilder {
tool.AssertFunc(target)

Expand All @@ -79,11 +83,38 @@ func MockUnsafe(target interface{}) *MockBuilder {
return Mock(target, OptUnsafe)
}

func (builder *MockBuilder) hookType() reflect.Type {
targetType := reflect.TypeOf(builder.target)
if builder.generic {
targetIn := []reflect.Type{genericInfoType}
for i := 0; i < targetType.NumIn(); i++ {
targetIn = append(targetIn, targetType.In(i))
}
targetOut := []reflect.Type{}
for i := 0; i < targetType.NumOut(); i++ {
targetOut = append(targetOut, targetType.Out(i))
}
return reflect.FuncOf(targetIn, targetOut, targetType.IsVariadic())
}
return targetType
}

func (builder *MockBuilder) resetCondition() *MockBuilder {
builder.conditions = []*mockCondition{builder.newCondition()} // at least 1 condition is needed
return builder
}

// Origin add an origin hook which can be used to call un-mocked origin function
//
// For example:
//
// origin := Fun // only need the same type
// mock := func(p string) string {
// return origin(p + "mocked")
// }
// mock2 := Mock(Fun).To(mock).Origin(&origin).Build()
//
// Origin only works when call origin hook directly, target will still be mocked in recursive call
func (builder *MockBuilder) Origin(funcPtr interface{}) *MockBuilder {
tool.Assert(builder.proxyCaller == nil, "re-set builder origin")
return builder.origin(funcPtr)
Expand Down Expand Up @@ -187,15 +218,15 @@ func (builder *MockBuilder) Build() *Mocker {
return &mocker
}

func (mocker *Mocker) checkReceiver(target reflect.Type, hook interface{}) bool {
func (mocker *Mocker) missReceiver(target reflect.Type, hook interface{}) bool {
hType := reflect.TypeOf(hook)
tool.Assert(hType.Kind() == reflect.Func, "Param(%v) a is not a func", hType.Kind())
tool.Assert(target.IsVariadic() == hType.IsVariadic(), "target:%v, hook:%v args not match", target, hook)
// has receiver
if tool.CheckFuncArgs(target, hType, 0) {
if tool.CheckFuncArgs(target, hType, 0, 0) {
return false
}
if tool.CheckFuncArgs(target, hType, 1) {
if tool.CheckFuncArgs(target, hType, 1, 0) {
return true
}
tool.Assert(false, "target:%v, hook:%v args not match", target, hook)
Expand All @@ -205,40 +236,36 @@ func (mocker *Mocker) checkReceiver(target reflect.Type, hook interface{}) bool
func (mocker *Mocker) buildHook() {
proxySetter := mocker.buildProxy()

origin := reflect.ValueOf(mocker.proxy).Elem()
originExec := func(args []reflect.Value) []reflect.Value {
return tool.ReflectCall(origin, args)
return tool.ReflectCall(reflect.ValueOf(mocker.proxy).Elem(), args)
}

match := []func(args []reflect.Value) bool{}
exec := []func(args []reflect.Value) []reflect.Value{}

for _, condition := range mocker.builder.conditions {
when := condition.when
hook := condition.hook

if when == nil {
for i := range mocker.builder.conditions {
condition := mocker.builder.conditions[i]
if condition.when == nil {
// when condition is not set, just go into hook exec
match = append(match, func(args []reflect.Value) bool { return true })
} else {
missWhenReceiver := mocker.checkReceiver(mocker.target.Type(), when)
match = append(match, func(args []reflect.Value) bool {
return tool.ReflectCallWithShiftOne(reflect.ValueOf(when), args, missWhenReceiver)[0].Bool()
return tool.ReflectCall(reflect.ValueOf(condition.when), args)[0].Bool()
})
}

if hook == nil {
if condition.hook == nil {
// hook condition is not set, just go into original exec
exec = append(exec, originExec)
} else {
missHookReceiver := mocker.checkReceiver(mocker.target.Type(), hook)
exec = append(exec, func(args []reflect.Value) []reflect.Value {
mocker.mock()
return tool.ReflectCallWithShiftOne(reflect.ValueOf(hook), args, missHookReceiver)
return tool.ReflectCall(reflect.ValueOf(condition.hook), args)
})
}
}

mockerHook := reflect.MakeFunc(mocker.target.Type(), func(args []reflect.Value) []reflect.Value {
mockerHook := reflect.MakeFunc(mocker.builder.hookType(), func(args []reflect.Value) []reflect.Value {
proxySetter(args) // 设置origin调用proxy

mocker.access()
Expand Down Expand Up @@ -267,29 +294,27 @@ func (mocker *Mocker) buildHook() {
mocker.hook = mockerHook
}

// buildProx create a proxyCaller which could call origin directly
func (mocker *Mocker) buildProxy() func(args []reflect.Value) {
proxy := reflect.New(mocker.target.Type())
proxy := reflect.New(mocker.builder.hookType())

proxyCallerSetter := func(args []reflect.Value) {}
missProxyReceiver := false
if mocker.builder.proxyCaller != nil {
pVal := reflect.ValueOf(mocker.builder.proxyCaller)
tool.Assert(pVal.Kind() == reflect.Ptr && pVal.Elem().Kind() == reflect.Func, "origin receiver must be a function pointer")
pElem := pVal.Elem()
missProxyReceiver = mocker.checkReceiver(mocker.target.Type(), pElem.Interface())

if missProxyReceiver {
proxyCallerSetter = func(args []reflect.Value) {
pElem.Set(reflect.MakeFunc(pElem.Type(), func(innerArgs []reflect.Value) (results []reflect.Value) {
return tool.ReflectCall(proxy.Elem(), append(args[0:1], innerArgs...))
}))
}
} else {
proxyCallerSetter = func(args []reflect.Value) {
pElem.Set(reflect.MakeFunc(pElem.Type(), func(innerArgs []reflect.Value) (results []reflect.Value) {
return tool.ReflectCall(proxy.Elem(), innerArgs)
}))
}
shift := 0
if mocker.builder.generic {
shift += 1
}
if mocker.missReceiver(mocker.target.Type(), pElem.Interface()) {
shift += 1
}
proxyCallerSetter = func(args []reflect.Value) {
pElem.Set(reflect.MakeFunc(pElem.Type(), func(innerArgs []reflect.Value) (results []reflect.Value) {
return tool.ReflectCall(proxy.Elem(), append(args[0:shift], innerArgs...))
}))
}
}
mocker.proxy = proxy.Interface()
Expand Down
85 changes: 68 additions & 17 deletions mock_condition.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,19 @@ func (m *mockCondition) SetWhenForce(when interface{}) {
tool.Assert(wVal.Type().NumOut() == 1, "when func ret value not bool")
out1 := wVal.Type().Out(0)
tool.Assert(out1.Kind() == reflect.Bool, "when func ret value not bool")
checkReceiver(reflect.TypeOf(m.builder.target), when) // inputs must be in same or has an extra self receiver
m.when = when

hookType := m.builder.hookType()
inTypes := []reflect.Type{}
for i := 0; i < hookType.NumIn(); i++ {
inTypes = append(inTypes, hookType.In(i))
}

hasGeneric, hasReceiver := m.checkGenericAndReceiver(wVal.Type())
whenType := reflect.FuncOf(inTypes, []reflect.Type{out1}, hookType.IsVariadic())
m.when = reflect.MakeFunc(whenType, func(args []reflect.Value) (results []reflect.Value) {
results = tool.ReflectCall(wVal, m.adaptArgsForReflectCall(args, hasGeneric, hasReceiver))
return
}).Interface()
}

func (m *mockCondition) SetReturn(results ...interface{}) {
Expand All @@ -61,15 +72,15 @@ func (m *mockCondition) SetReturnForce(results ...interface{}) {
}
}

targetType := reflect.TypeOf(m.builder.target)
m.hook = reflect.MakeFunc(targetType, func(args []reflect.Value) []reflect.Value {
hookType := m.builder.hookType()
m.hook = reflect.MakeFunc(hookType, func(_ []reflect.Value) []reflect.Value {
results := getResult()
tool.CheckReturnType(m.builder.target, results...)
valueResults := make([]reflect.Value, 0)
for i, result := range results {
rValue := reflect.Zero(targetType.Out(i))
rValue := reflect.Zero(hookType.Out(i))
if result != nil {
rValue = reflect.ValueOf(result).Convert(targetType.Out(i))
rValue = reflect.ValueOf(result).Convert(hookType.Out(i))
}
valueResults = append(valueResults, rValue)
}
Expand All @@ -85,20 +96,60 @@ func (m *mockCondition) SetTo(to interface{}) {
func (m *mockCondition) SetToForce(to interface{}) {
hType := reflect.TypeOf(to)
tool.Assert(hType.Kind() == reflect.Func, "to a is not a func")
m.hook = to
hasGeneric, hasReceiver := m.checkGenericAndReceiver(hType)
tool.Assert(m.builder.generic || !hasGeneric, "non-generic function should not have 'GenericInfo' as first argument")
m.hook = reflect.MakeFunc(m.builder.hookType(), func(args []reflect.Value) (results []reflect.Value) {
results = tool.ReflectCall(reflect.ValueOf(to), m.adaptArgsForReflectCall(args, hasGeneric, hasReceiver))
return
}).Interface()
}

func checkReceiver(target reflect.Type, hook interface{}) bool {
hType := reflect.TypeOf(hook)
tool.Assert(hType.Kind() == reflect.Func, "Param(%v) a is not a func", hType.Kind())
tool.Assert(target.IsVariadic() == hType.IsVariadic(), "target:%v, hook:%v args not match", target, hook)
// checkGenericAndReceiver check if typ has GenericsInfo and selfReceiver as argument
//
// The hook function will looks like func(_ GenericInfo, self *struct, arg0 int ...)
// When we use 'When' or 'To', our input hook function will looks like:
// 1. func(arg0 int ...)
// 2. func(info GenericInfo, arg0 int ...)
// 3. func(self *struct, arg0 int ...)
// 4. func(info GenericInfo, self *struct, arg0 int ...)
//
// All above input hooks are legal, but we need to make an adaptation when calling then
func (m *mockCondition) checkGenericAndReceiver(typ reflect.Type) (bool, bool) {
targetType := reflect.TypeOf(m.builder.target)
tool.Assert(typ.Kind() == reflect.Func, "Param(%v) a is not a func", typ.Kind())
tool.Assert(targetType.IsVariadic() == typ.IsVariadic(), "target:%v, hook:%v args not match", targetType, typ)

shiftTyp := 0
if typ.NumIn() > 0 && typ.In(0) == genericInfoType {
shiftTyp = 1
}

// has receiver
if tool.CheckFuncArgs(target, hType, 0) {
return false
if tool.CheckFuncArgs(targetType, typ, 0, shiftTyp) {
return shiftTyp == 1, true
}

if tool.CheckFuncArgs(targetType, typ, 1, shiftTyp) {
return shiftTyp == 1, false
}
tool.Assert(false, "target:%v, hook:%v args not match", targetType, typ)
return false, false
}

// adaptArgsForReflectCall makes an adaption for reflect call
//
// see (*mockCondition).checkGenericAndReceiver for more info
func (m *mockCondition) adaptArgsForReflectCall(args []reflect.Value, hasGeneric, hasReceiver bool) []reflect.Value {
adaption := []reflect.Value{}
if m.builder.generic {
if hasGeneric {
adaption = append(adaption, args[0])
}
args = args[1:]
}
if tool.CheckFuncArgs(target, hType, 1) {
return true
if !hasReceiver {
args = args[1:]
}
tool.Assert(false, "target:%v, hook:%v args not match", target, hook)
return false
adaption = append(adaption, args...)
return adaption
}
33 changes: 33 additions & 0 deletions mock_generics.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,39 @@

package mockey

import (
"reflect"
"unsafe"
)

// MockGeneric mocks generic function
//
// Target must be generic method or method of generic types
Sychorius marked this conversation as resolved.
Show resolved Hide resolved
func MockGeneric(target interface{}) *MockBuilder {
return Mock(target, OptGeneric)
}

type GenericInfo uintptr
Sychorius marked this conversation as resolved.
Show resolved Hide resolved

var genericInfoType = reflect.TypeOf(GenericInfo(0))

func (g GenericInfo) Equal(other GenericInfo) bool {
return g == other
}

// UsedParamType get the type of used parameter in generic function/struct
//
// For example: assume we have generic function "f[int, float64](x int, y T1) T2" and derived type f[int, float64]:
//
// UsedParamType(0) == reflect.TypeOf(int(0))
// UsedParamType(1) == reflect.TypeOf(float64(0))
//
// If index n is out of range, or the derived types have more complex structure(for example: define an generic struct
// in a generic function using generic types, unused parameterized type etc.), this function may return unexpected value
// or cause unrecoverable runtime error . So it is NOT RECOMMENDED to use this function unless you actually knows what
// you are doing.
func (g GenericInfo) UsedParamType(n uintptr) reflect.Type {
var vt interface{}
*(*uintptr)(unsafe.Pointer(&vt)) = *(*uintptr)(unsafe.Pointer(uintptr(g) + 8*n))
return reflect.TypeOf(vt)
}
Loading