Skip to content

Commit

Permalink
reflect: implement NumMethod and Implements
Browse files Browse the repository at this point in the history
These two functions are needed by the encoding/json package.
  • Loading branch information
aykevl committed Feb 19, 2020
1 parent 3f74e3c commit dd5298f
Show file tree
Hide file tree
Showing 12 changed files with 184 additions and 31 deletions.
17 changes: 15 additions & 2 deletions compiler/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,14 +72,27 @@ func (c *Compiler) getTypeCode(typ types.Type) llvm.Value {
structGlobal := c.makeStructTypeFields(typ)
references = llvm.ConstBitCast(structGlobal, global.Type())
}
if !references.IsNil() {
methodSet := types.NewMethodSet(typ)
numMethods := 0
for i := 0; i < methodSet.Len(); i++ {
if methodSet.At(i).Obj().Exported() {
numMethods++
}
}
if !references.IsNil() || numMethods != 0 {
// Set the 'references' field of the runtime.typecodeID struct.
globalValue := llvm.ConstNull(global.Type().ElementType())
globalValue = llvm.ConstInsertValue(globalValue, references, []uint32{0})
if !references.IsNil() {
globalValue = llvm.ConstInsertValue(globalValue, references, []uint32{0})
}
if length != 0 {
lengthValue := llvm.ConstInt(c.uintptrType, uint64(length), false)
globalValue = llvm.ConstInsertValue(globalValue, lengthValue, []uint32{1})
}
if numMethods != 0 {
numMethodsValue := llvm.ConstInt(c.uintptrType, uint64(numMethods), false)
globalValue = llvm.ConstInsertValue(globalValue, numMethodsValue, []uint32{2})
}
global.SetInitializer(globalValue)
global.SetLinkage(llvm.PrivateLinkage)
}
Expand Down
12 changes: 11 additions & 1 deletion src/reflect/sidetables.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,22 @@ import (
"unsafe"
)

// This stores the number of methods (for Type.NumMethod()) for each named basic
// type. It is indexed by the named type number.
//go:extern reflect.namedBasicNumMethodSidetable
var namedBasicNumMethodSidetable byte

// This stores a varint for each named type. Named types are identified by their
// name instead of by their type. The named types stored in this struct are
// non-basic types: pointer, struct, and channel.
// non-basic types: pointer, struct, channel, and interface.
//go:extern reflect.namedNonBasicTypesSidetable
var namedNonBasicTypesSidetable uintptr

// This stores the number of methods (for Type.NumMethods()) for each named
// non-basic type. It is indexed by the named type number.
//go:extern reflect.namedNonBasicNumMethodSidetable
var namedNonBasicNumMethodSidetable byte

//go:extern reflect.structTypesSidetable
var structTypesSidetable byte

Expand Down
70 changes: 65 additions & 5 deletions src/reflect/type.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,20 @@ func (t Type) Kind() Kind {
}
}

// isBasic returns true if (and only if) this is a basic type.
func (t Type) isBasic() bool {
return t%2 == 0
}

// isNamed returns true if (and only if) this is a named type.
func (t Type) isNamed() bool {
if t.isBasic() {
return t>>6 != 0
} else {
return (t>>4)%2 != 0
}
}

// Elem returns the element type for channel, slice and array types, the
// pointed-to value for pointer types, and the key type for map types.
func (t Type) Elem() Type {
Expand All @@ -166,7 +180,7 @@ func (t Type) Elem() Type {
func (t Type) stripPrefix() Type {
// Look at the 'n' bit in the type code (see the top of this file) to see
// whether this is a named type.
if (t>>4)%2 != 0 {
if t.isNamed() {
// This is a named type. The data is stored in a sidetable.
namedTypeNum := t >> 5
n := *(*uintptr)(unsafe.Pointer(uintptr(unsafe.Pointer(&namedNonBasicTypesSidetable)) + uintptr(namedTypeNum)*unsafe.Sizeof(uintptr(0))))
Expand All @@ -184,7 +198,11 @@ func (t Type) Field(i int) StructField {
}
structIdentifier := t.stripPrefix()

numField, p := readVarint(unsafe.Pointer(uintptr(unsafe.Pointer(&structTypesSidetable)) + uintptr(structIdentifier)))
// Skip past the NumMethod field.
_, p := readVarint(unsafe.Pointer(uintptr(unsafe.Pointer(&structTypesSidetable)) + uintptr(structIdentifier)))

// Read the NumField field.
numField, p := readVarint(p)
if uint(i) >= uint(numField) {
panic("reflect: field index out of range")
}
Expand Down Expand Up @@ -283,7 +301,10 @@ func (t Type) NumField() int {
panic(&TypeError{"NumField"})
}
structIdentifier := t.stripPrefix()
n, _ := readVarint(unsafe.Pointer(uintptr(unsafe.Pointer(&structTypesSidetable)) + uintptr(structIdentifier)))
// Skip past the NumMethod field.
_, p := readVarint(unsafe.Pointer(uintptr(unsafe.Pointer(&structTypesSidetable)) + uintptr(structIdentifier)))
// Read the NumField field.
n, _ := readVarint(p)
return int(n)
}

Expand Down Expand Up @@ -401,10 +422,49 @@ func (t Type) AssignableTo(u Type) bool {
}

func (t Type) Implements(u Type) bool {
if t.Kind() != Interface {
if u.Kind() != Interface {
panic("reflect: non-interface type passed to Type.Implements")
}
return u.AssignableTo(t)
if u.NumMethod() == 0 {
// Empty interface, therefore implements any type.
return true
}
// Type u has methods.
if t.NumMethod() == 0 {
// Type t has no methods, while u has. Therefore, t cannot implement u.
return false
}
// Now we'd have to compare individual methods.
// This has not yet been implemented.
panic("unimplemented: (reflect.Type).Implements()")
}

func (t Type) NumMethod() int {
if t.isBasic() {
if !t.isNamed() {
// Not a named type, so can't have methods.
return 0
}
// Named type methods are stored in a sidetable.
namedTypeNum := t >> 6
return int(*(*byte)(unsafe.Pointer(uintptr(unsafe.Pointer(&namedBasicNumMethodSidetable)) + uintptr(namedTypeNum) - 1)))
} else {
if !t.isNamed() {
// Most non-named types cannot have methods. However, structs can
// because they can embed named values.
if t.Kind() == Struct {
structIdentifier := t.stripPrefix()
numMethod, _ := readVarint(unsafe.Pointer(uintptr(unsafe.Pointer(&structTypesSidetable)) + uintptr(structIdentifier)))
return int(numMethod)
}
// Not a named type so can't have methods on that, and also not a
// struct type that may have embedded fields.
return 0
}
// Named non-basic type, so read the number of methods from a sidetable.
namedTypeNum := t >> 5
return int(*(*byte)(unsafe.Pointer(uintptr(unsafe.Pointer(&namedNonBasicNumMethodSidetable)) + uintptr(namedTypeNum))))
}
}

// Comparable returns whether values of this type can be compared to each other.
Expand Down
4 changes: 4 additions & 0 deletions src/reflect/value.go
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,10 @@ func maskAndShift(value, offset, size uintptr) uintptr {
return (uintptr(value) >> (offset * 8)) & mask
}

func (v Value) NumMethod() int {
return v.Type().NumMethod()
}

func (v Value) MapKeys() []Value {
panic("unimplemented: (reflect.Value).MapKeys()")
}
Expand Down
3 changes: 3 additions & 0 deletions src/runtime/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,9 @@ type typecodeID struct {

// The array length, for array types.
length uintptr

// Number of methods on this type. Always 0 for non-named types.
numMethods uintptr
}

// structField is used by the compiler to pass information to the interface
Expand Down
12 changes: 12 additions & 0 deletions testdata/reflect.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ type (
}
)

// To test Type.NumMethod.
func (n myint) foo() {}
func (n myint) bar() {}
func (n myint) Baz() {}

func main() {
println("matching types")
println(reflect.TypeOf(int(3)) == reflect.TypeOf(int(5)))
Expand Down Expand Up @@ -111,6 +116,10 @@ func main() {
&linkedList{
foo: 42,
},
// interfaces
struct {
x interface{}
}{},
} {
showValue(reflect.ValueOf(v), "")
}
Expand Down Expand Up @@ -281,6 +290,9 @@ func showValue(rv reflect.Value, indent string) {
if !rt.Comparable() {
print(" comparable=false")
}
if rt.NumMethod() != 0 {
print(" methods=", rt.NumMethod())
}
println()
switch rt.Kind() {
case reflect.Bool:
Expand Down
16 changes: 12 additions & 4 deletions testdata/reflect.txt
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ reflect type: complex64
complex: (+1.200000e+000+3.000000e-001i)
reflect type: complex128
complex: (+1.300000e+000+4.000000e-001i)
reflect type: int
reflect type: int methods=1
int: 32
reflect type: string
string: foo 3
Expand All @@ -77,7 +77,7 @@ reflect type: ptr
reflect type: ptr
pointer: true interface
nil: false
reflect type: interface settable=true
reflect type: interface settable=true methods=1
interface
nil: true
reflect type: ptr
Expand Down Expand Up @@ -230,12 +230,12 @@ reflect type: map comparable=false
nil: false
reflect type: struct
struct: 0
reflect type: struct
reflect type: struct methods=1
struct: 1
field: 0 error
tag:
embedded: true
reflect type: interface
reflect type: interface methods=1
interface
nil: true
reflect type: struct
Expand Down Expand Up @@ -321,6 +321,14 @@ reflect type: ptr
embedded: false
reflect type: int
int: 42
reflect type: struct
struct: 1
field: 0 x
tag:
embedded: false
reflect type: interface
interface
nil: true

sizes:
int8 1 8
Expand Down
69 changes: 56 additions & 13 deletions transform/reflect.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,11 @@ type typeCodeAssignmentState struct {
// all. If it is false, namedNonBasicTypesSidetable will contain simple
// monotonically increasing numbers.
needsNamedNonBasicTypesSidetable bool

// These two slices contain the number of methods on named types, and are
// stored in the output binary with similar names.
namedBasicNumMethodSidetable []byte
namedNonBasicNumMethodSidetable []byte
}

// assignTypeCodes is used to assign a type code to each type in the program
Expand All @@ -138,7 +143,7 @@ func assignTypeCodes(mod llvm.Module, typeSlice typeInfoSlice) {
arrayTypes: make(map[string]int),
structTypes: make(map[string]int),
structNames: make(map[string]int),
needsNamedNonBasicTypesSidetable: len(getUses(mod.NamedGlobal("reflect.namedNonBasicTypesSidetable"))) != 0,
needsNamedNonBasicTypesSidetable: len(getUses(mod.NamedGlobal("reflect.namedNonBasicTypesSidetable"))) != 0 || len(getUses(mod.NamedGlobal("reflect.namedNonBasicNumMethodSidetable"))) != 0,
needsStructTypesSidetable: len(getUses(mod.NamedGlobal("reflect.structTypesSidetable"))) != 0,
needsStructNamesSidetable: len(getUses(mod.NamedGlobal("reflect.structNamesSidetable"))) != 0,
needsArrayTypesSidetable: len(getUses(mod.NamedGlobal("reflect.arrayTypesSidetable"))) != 0,
Expand All @@ -157,7 +162,21 @@ func assignTypeCodes(mod llvm.Module, typeSlice typeInfoSlice) {

// Only create this sidetable when it is necessary.
if state.needsNamedNonBasicTypesSidetable {
global := replaceGlobalIntWithArray(mod, "reflect.namedNonBasicTypesSidetable", state.namedNonBasicTypesSidetable)
if len(getUses(mod.NamedGlobal("reflect.namedNonBasicTypesSidetable"))) != 0 {
global := replaceGlobalIntWithArray(mod, "reflect.namedNonBasicTypesSidetable", state.namedNonBasicTypesSidetable)
global.SetLinkage(llvm.InternalLinkage)
global.SetUnnamedAddr(true)
global.SetGlobalConstant(true)
}
if len(getUses(mod.NamedGlobal("reflect.namedNonBasicNumMethodSidetable"))) != 0 {
global := replaceGlobalIntWithArray(mod, "reflect.namedNonBasicNumMethodSidetable", state.namedNonBasicNumMethodSidetable)
global.SetLinkage(llvm.InternalLinkage)
global.SetUnnamedAddr(true)
global.SetGlobalConstant(true)
}
}
if len(getUses(mod.NamedGlobal("reflect.namedBasicNumMethodSidetable"))) != 0 {
global := replaceGlobalIntWithArray(mod, "reflect.namedBasicNumMethodSidetable", state.namedBasicNumMethodSidetable)
global.SetLinkage(llvm.InternalLinkage)
global.SetUnnamedAddr(true)
global.SetGlobalConstant(true)
Expand Down Expand Up @@ -189,9 +208,12 @@ func (state *typeCodeAssignmentState) getTypeCodeNum(typecode llvm.Value) *big.I
// Note: see src/reflect/type.go for bit allocations.
class, value := getClassAndValueFromTypeCode(typecode)
name := ""
var namedNumMethods uint64 // number of methods for a named type
if class == "named" {
name = value
typecode = llvm.ConstExtractValue(typecode.Initializer(), []uint32{0})
initializer := typecode.Initializer()
typecode = llvm.ConstExtractValue(initializer, []uint32{0})
namedNumMethods = llvm.ConstExtractValue(initializer, []uint32{2}).ZExtValue()
class, value = getClassAndValueFromTypeCode(typecode)
}
if class == "basic" {
Expand All @@ -205,7 +227,7 @@ func (state *typeCodeAssignmentState) getTypeCodeNum(typecode llvm.Value) *big.I
}
if name != "" {
// This type is named, set the upper bits to the name ID.
num |= int64(state.getBasicNamedTypeNum(name)) << 5
num |= int64(state.getBasicNamedTypeNum(name, namedNumMethods)) << 5
}
return big.NewInt(num << 1)
} else {
Expand Down Expand Up @@ -248,6 +270,14 @@ func (state *typeCodeAssignmentState) getTypeCodeNum(typecode llvm.Value) *big.I
index := len(state.namedNonBasicTypesSidetable)
state.namedNonBasicTypesSidetable = append(state.namedNonBasicTypesSidetable, 0)
state.namedNonBasicTypes[name] = index
// Also store the number of methods.
if index != len(state.namedNonBasicNumMethodSidetable) {
panic("unexpected side table length")
}
if uint64(byte(namedNumMethods)) != namedNumMethods {
panic("too many methods for type " + name)
}
state.namedNonBasicNumMethodSidetable = append(state.namedNonBasicNumMethodSidetable, byte(namedNumMethods))
// Get the typecode of the underlying type (which could be the
// element type in the case of pointers, for example).
num = state.getNonBasicTypeCode(class, typecode)
Expand Down Expand Up @@ -316,12 +346,19 @@ func getClassAndValueFromTypeCode(typecode llvm.Value) (class, value string) {
// getBasicNamedTypeNum returns an appropriate (unique) number for the given
// named type. If the name already has a number that number is returned, else a
// new number is returned. The number is always non-zero.
func (state *typeCodeAssignmentState) getBasicNamedTypeNum(name string) int {
func (state *typeCodeAssignmentState) getBasicNamedTypeNum(name string, numMethods uint64) int {
if num, ok := state.namedBasicTypes[name]; ok {
return num
}
num := len(state.namedBasicTypes) + 1
state.namedBasicTypes[name] = num
if uint64(byte(numMethods)) != numMethods {
panic("too many methods for type " + name)
}
if len(state.namedBasicNumMethodSidetable) != num-1 {
panic("unexpected side table length")
}
state.namedBasicNumMethodSidetable = append(state.namedBasicNumMethodSidetable, byte(numMethods))
return num
}

Expand Down Expand Up @@ -381,15 +418,21 @@ func (state *typeCodeAssignmentState) getStructTypeNum(typecode llvm.Value) int
}

// Get the fields this struct type contains.
// The struct number will be the start index of
structTypeGlobal := llvm.ConstExtractValue(typecode.Initializer(), []uint32{0}).Operand(0).Initializer()
// The struct number will be the start index into
// reflect.structTypesSidetable.
typecodeID := typecode.Initializer()
structTypeGlobal := llvm.ConstExtractValue(typecodeID, []uint32{0}).Operand(0).Initializer()
numFields := structTypeGlobal.Type().ArrayLength()

// The first data that is stored in the struct sidetable is the number of
// fields this struct contains. This is usually just a single byte because
// most structs don't contain that many fields, but make it a varint just
// to be sure.
buf := makeVarint(uint64(numFields))
numMethods := llvm.ConstExtractValue(typecodeID, []uint32{2}).ZExtValue()

// The first element that is stored in the struct sidetable is the number
// of methods this struct has. It is used by Type.NumMethod().
buf := makeVarint(numMethods)
// The second element that is stored in the struct sidetable is the number
// of fields this struct contains. This is usually just a single byte
// because most structs don't contain that many fields, but make it a varint
// just to be sure.
buf = append(buf, makeVarint(uint64(numFields))...)

// Iterate over every field in the struct.
// Every field is stored sequentially in the struct sidetable. Fields can
Expand Down
Loading

0 comments on commit dd5298f

Please sign in to comment.