Skip to content

Commit

Permalink
feat: suport generics mocking (need go1.18+)
Browse files Browse the repository at this point in the history
Change-Id: I496c16ff2fb9338bc02d852b81986669fdcb5c4f
  • Loading branch information
Sychorius committed Jul 26, 2023
1 parent 2c5cfce commit a5850d0
Show file tree
Hide file tree
Showing 8 changed files with 205 additions and 9 deletions.
10 changes: 5 additions & 5 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
@@ -1,27 +1,27 @@
name: Tests

on: [ pull_request ]
on: [pull_request]

jobs:
unit-benchmark-test:
strategy:
matrix:
go: ["1.13", "1.14", "1.15", "1.16", "1.17", "1.18", "1.19", "1.20"]
os: [ linux ] # should be [ macOS, linux, windows ], but currently we don't have macOS and windows runners
arch: [ X64, ARM64 ]
os: [linux] # should be [ macOS, linux, windows ], but currently we don't have macOS and windows runners
arch: [X64, ARM64]
exclude:
- os: Linux
arch: ARM64
- os: Windows
arch: ARM64
runs-on: [ "${{ matrix.os }}", "${{ matrix.arch }}" ]
runs-on: ["${{ matrix.os }}", "${{ matrix.arch }}"]
steps:
- uses: actions/checkout@v3
- name: Set up Go
uses: actions/setup-go@v3
with:
go-version: ${{ matrix.go }}
- name: Unit Test
run: go test -gcflags="all=-l -N" -covermode=atomic -coverprofile=coverage.out ./...
run: MOCKEY_DEBUG=true go test -gcflags="all=-l -N" -covermode=atomic -coverprofile=coverage.out ./...
- name: Benchmark
run: go test -bench=. -benchmem -run=none ./...
8 changes: 7 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
module github.com/bytedance/mockey

go 1.13
go 1.18

require (
github.com/smartystreets/goconvey v1.6.4
golang.org/x/arch v0.0.0-20201008161808-52c3e6f60cff
)

require (
github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1 // indirect
github.com/jtolds/gls v4.20.0+incompatible // indirect
github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d // indirect
)
45 changes: 45 additions & 0 deletions internal/monkey/inst/disasm_amd64.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
package inst

import (
"unsafe"

"github.com/bytedance/mockey/internal/monkey/common"
"github.com/bytedance/mockey/internal/tool"
"golang.org/x/arch/x86/x86asm"
)
Expand All @@ -35,3 +38,45 @@ func Disassemble(code []byte, required int, checkLen bool) int {
}
return pos
}

func GetGenericJumpAddr(addr uintptr, maxScan uint64) uintptr {
code := common.BytesOf(addr, int(maxScan))
var pos uint64
var err error
var inst x86asm.Inst

for pos < maxScan {
inst, err = x86asm.Decode(code[pos:], 64)
tool.Assert(err == nil, err)
// if inst.Op == arm64asm.BL {
args := []interface{}{inst.Op}
for i := range inst.Args {
args = append(args, inst.Args[i])
}
tool.DebugPrintf("%v\t%v\t%v\t%v\t%v\t%v\n", args...)

if inst.Op == x86asm.CALL {
rel := int32(inst.Args[0].(x86asm.Rel))
tool.DebugPrintf("found: CALL, raw is: %x, rel: %v\n", inst.String(), rel)
return calcAddr(uintptr(unsafe.Pointer(&code[0]))+uintptr(pos+uint64(inst.Len)), rel)
}
tool.Assert(inst.Op != x86asm.RET, "!!!FOUND RET!!!")
pos += uint64(inst.Len)
}
tool.Assert(false, "CALL op not found")
return 0
}

func calcAddr(from uintptr, rel int32) uintptr {
tool.DebugPrintf("calc CALL addr, from: %x(%v) CALL: %x\n", from, from, rel)

var dest uintptr
if rel < 0 {
dest = from - uintptr(uint32(-rel))
} else {
dest = from + uintptr(rel)
}

tool.DebugPrintf("L->H:%v rel: %v from: %x(%v) dest: %x(%v), distance: %v\n", rel > 0, rel, from, from, dest, dest, from-dest)
return dest
}
52 changes: 52 additions & 0 deletions internal/monkey/inst/disasm_arm64.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,62 @@
package inst

import (
"unsafe"

"github.com/bytedance/mockey/internal/monkey/common"
"github.com/bytedance/mockey/internal/tool"
"golang.org/x/arch/arm64/arm64asm"
)

func Disassemble(code []byte, required int, checkLen bool) int {
tool.Assert(len(code) > required, "function is too short to patch")
return required
}

func GetGenericJumpAddr(addr uintptr, maxScan uint64) uintptr {
code := common.BytesOf(addr, int(maxScan))
var pos uint64
var err error
var inst arm64asm.Inst

for pos < maxScan {
inst, err = arm64asm.Decode(code[pos:])
tool.Assert(err == nil, err)
// if inst.Op == arm64asm.BL {
args := []interface{}{inst.Op}
for i := range inst.Args {
args = append(args, inst.Args[i])
}
tool.DebugPrintf("%v\t%v\t%v\t%v\t%v\t%v\n", args...)

if inst.Op == arm64asm.BL {
tool.DebugPrintf("found: BL, raw is: %x\n", inst.Enc)
return calcAddr(uintptr(unsafe.Pointer(&code[0]))+uintptr(pos), inst.Enc)
}
pos += uint64(unsafe.Sizeof(inst.Enc))
tool.Assert(inst.Op != arm64asm.RET, "!!!FOUND RET!!!")
}
tool.Assert(false, "BL op not found")
return 0
}

func calcAddr(from uintptr, bl uint32) uintptr {
tool.DebugPrintf("calc BL addr, from: %x(%v) bl: %x\n", from, from, bl)
offset := bl << 8 >> 8
flag := (offset << 9 >> 9) == offset // 是否小于0

var dest uintptr
if flag {
// L -> H
// (dest - cur) / 4 = offset
// dest = cur + offset * 4
dest = from + uintptr(offset*4)
} else {
// H -> L
// (cur - dest) / 4 = (0x00ffffff - offset + 1)
// dest = cur - (0x00ffffff - offset + 1) * 4
dest = from - uintptr((0x00ffffff-offset+1)*4)
}
tool.DebugPrintf("2th complement, L->H:%v offset: %x from: %x(%v) dest: %x(%v), distance: %v\n", flag, offset, from, from, dest, dest, from-dest)
return dest
}
8 changes: 6 additions & 2 deletions internal/monkey/patch.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,17 @@ func (p *Patch) Unpatch() {

// PatchValue replace the target function with a hook function, and stores the target function in the proxy function
// for future restore. Target and hook are values of function. Proxy is a value of proxy function pointer.
func PatchValue(target, hook, proxy reflect.Value, unsafe bool) *Patch {
func PatchValue(target, hook, proxy reflect.Value, unsafe, generic bool) *Patch {
tool.Assert(hook.Kind() == reflect.Func, "'%s' is not a function", hook.Kind())
tool.Assert(proxy.Kind() == reflect.Ptr, "'%v' is not a function pointer", proxy.Kind())
tool.Assert(hook.Type() == target.Type(), "'%v' and '%s' mismatch", hook.Type(), target.Type())
tool.Assert(proxy.Elem().Type() == target.Type(), "'*%v' and '%s' mismatch", proxy.Elem().Type(), target.Type())

targetAddr := target.Pointer()
if generic {
// we assume that generic call/bl op is located in first 200 bytes of codes from targetAddr
targetAddr = inst.GetGenericJumpAddr(targetAddr, 200)
}
// The first few bytes of the target function code
const bufSize = 64
targetCodeBuf := common.BytesOf(targetAddr, bufSize)
Expand Down Expand Up @@ -76,5 +80,5 @@ func PatchValue(target, hook, proxy reflect.Value, unsafe bool) *Patch {
func PatchFunc(fn, hook, proxy interface{}, unsafe bool) *Patch {
vv := reflect.ValueOf(fn)
tool.Assert(vv.Kind() == reflect.Func, "'%v' is not a function", fn)
return PatchValue(vv, reflect.ValueOf(hook), reflect.ValueOf(proxy), unsafe)
return PatchValue(vv, reflect.ValueOf(hook), reflect.ValueOf(proxy), unsafe, false)
}
3 changes: 2 additions & 1 deletion mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ type MockBuilder struct {
filterGoroutine FilterGoroutineType
gId int64
unsafe bool
generic bool
}

func Mock(target interface{}) *MockBuilder {
Expand Down Expand Up @@ -299,7 +300,7 @@ func (mocker *Mocker) Patch() *Mocker {
if mocker.isPatched {
return mocker
}
mocker.patch = monkey.PatchValue(mocker.target, mocker.hook, reflect.ValueOf(mocker.proxy), mocker.builder.unsafe)
mocker.patch = monkey.PatchValue(mocker.target, mocker.hook, reflect.ValueOf(mocker.proxy), mocker.builder.unsafe, mocker.builder.generic)
mocker.isPatched = true
addToGlobal(mocker)

Expand Down
23 changes: 23 additions & 0 deletions mock_generics.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Copyright 2022 ByteDance Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package mockey

func MockGeneric(target interface{}) *MockBuilder {
builder := Mock(target)
builder.generic = true
return builder
}
65 changes: 65 additions & 0 deletions mock_generics_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
//go:build go1.18
// +build go1.18

/*
* Copyright 2022 ByteDance Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package mockey

import (
"testing"

"github.com/smartystreets/goconvey/convey"
)

func sum[T int | float64](l, r T) T {
return l + r
}

type generic[T int | string] struct {
a T
}

func (g generic[T]) Value() T {
return g.a
}

func (g generic[T]) Value2() T {
return g.a + g.a
}

func TestGeneric(t *testing.T) {
PatchConvey("generic", t, func() {
PatchConvey("func", func() {
MockGeneric(sum[int]).To(func(a, b int) int {
return 999
}).Build()
MockGeneric(sum[float64]).Return(888).Build()
convey.So(sum[int](1, 2), convey.ShouldEqual, 999)
convey.So(sum[float64](1, 2), convey.ShouldEqual, 888)
})
PatchConvey("type", func() {
MockGeneric((generic[int]).Value).Return(999).Build()
MockGeneric(GetMethod(generic[string]{}, "Value2")).To(func() string {
return "mock"
}).Build()
gi := generic[int]{a: 123}
gs := generic[string]{a: "abc"}
convey.So(gi.Value(), convey.ShouldEqual, 999)
convey.So(gs.Value2(), convey.ShouldEqual, "mock")
})
})
}

0 comments on commit a5850d0

Please sign in to comment.