aboutsummaryrefslogtreecommitdiffhomepage
path: root/transform/interface-lowering.go
diff options
context:
space:
mode:
Diffstat (limited to 'transform/interface-lowering.go')
-rw-r--r--transform/interface-lowering.go118
1 files changed, 84 insertions, 34 deletions
diff --git a/transform/interface-lowering.go b/transform/interface-lowering.go
index 55d1af39b..e57463715 100644
--- a/transform/interface-lowering.go
+++ b/transform/interface-lowering.go
@@ -13,8 +13,7 @@ package transform
//
// typeAssert:
// Replaced with an icmp instruction so it can be directly used in a type
-// switch. This is very easy to optimize for LLVM: it will often translate a
-// type switch into a regular switch statement.
+// switch.
//
// interface type assert:
// These functions are defined by creating a big type switch over all the
@@ -54,10 +53,11 @@ type methodInfo struct {
// typeInfo describes a single concrete Go type, which can be a basic or a named
// type. If it is a named type, it may have methods.
type typeInfo struct {
- name string
- typecode llvm.Value
- methodSet llvm.Value
- methods []*methodInfo
+ name string
+ typecode llvm.Value
+ typecodeGEP llvm.Value
+ methodSet llvm.Value
+ methods []*methodInfo
}
// getMethod looks up the method on this type with the given signature and
@@ -91,6 +91,8 @@ type lowerInterfacesPass struct {
difiles map[string]llvm.Metadata
ctx llvm.Context
uintptrType llvm.Type
+ targetData llvm.TargetData
+ i8ptrType llvm.Type
types map[string]*typeInfo
signatures map[string]*signatureInfo
interfaces map[string]*interfaceInfo
@@ -101,14 +103,17 @@ type lowerInterfacesPass struct {
// before LLVM can work on them. This is done so that a few cleanup passes can
// run before assigning the final type codes.
func LowerInterfaces(mod llvm.Module, config *compileopts.Config) error {
+ ctx := mod.Context()
targetData := llvm.NewTargetData(mod.DataLayout())
defer targetData.Dispose()
p := &lowerInterfacesPass{
mod: mod,
config: config,
- builder: mod.Context().NewBuilder(),
- ctx: mod.Context(),
+ builder: ctx.NewBuilder(),
+ ctx: ctx,
+ targetData: targetData,
uintptrType: mod.Context().IntType(targetData.PointerSize() * 8),
+ i8ptrType: llvm.PointerType(ctx.Int8Type(), 0),
types: make(map[string]*typeInfo),
signatures: make(map[string]*signatureInfo),
interfaces: make(map[string]*interfaceInfo),
@@ -151,11 +156,26 @@ func (p *lowerInterfacesPass) run() error {
}
p.types[name] = t
initializer := global.Initializer()
- if initializer.IsNil() {
- continue
+ firstField := p.builder.CreateExtractValue(initializer, 0, "")
+ if firstField.Type() != p.ctx.Int8Type() {
+ // This type has a method set at index 0. Change the GEP to
+ // point to index 1 (the meta byte).
+ t.typecodeGEP = llvm.ConstGEP(global.GlobalValueType(), global, []llvm.Value{
+ llvm.ConstInt(p.ctx.Int32Type(), 0, false),
+ llvm.ConstInt(p.ctx.Int32Type(), 1, false),
+ })
+ methodSet := stripPointerCasts(firstField)
+ if !strings.HasSuffix(methodSet.Name(), "$methodset") {
+ panic("expected method set")
+ }
+ p.addTypeMethods(t, methodSet)
+ } else {
+ // This type has no method set.
+ t.typecodeGEP = llvm.ConstGEP(global.GlobalValueType(), global, []llvm.Value{
+ llvm.ConstInt(p.ctx.Int32Type(), 0, false),
+ llvm.ConstInt(p.ctx.Int32Type(), 0, false),
+ })
}
- methodSet := p.builder.CreateExtractValue(initializer, 2, "")
- p.addTypeMethods(t, methodSet)
}
}
}
@@ -266,10 +286,10 @@ func (p *lowerInterfacesPass) run() error {
actualType := use.Operand(0)
name := strings.TrimPrefix(use.Operand(1).Name(), "reflect/types.typeid:")
if t, ok := p.types[name]; ok {
- // The type exists in the program, so lower to a regular integer
+ // The type exists in the program, so lower to a regular pointer
// comparison.
p.builder.SetInsertPointBefore(use)
- commaOk := p.builder.CreateICmp(llvm.IntEQ, llvm.ConstPtrToInt(t.typecode, p.uintptrType), actualType, "typeassert.ok")
+ commaOk := p.builder.CreateICmp(llvm.IntEQ, t.typecodeGEP, actualType, "typeassert.ok")
use.ReplaceAllUsesWith(commaOk)
} else {
// The type does not exist in the program, so lower to a constant
@@ -283,15 +303,45 @@ func (p *lowerInterfacesPass) run() error {
}
// Remove all method sets, which are now unnecessary and inhibit later
- // optimizations if they are left in place. Also remove references to the
- // interface type assert functions just to be sure.
- zeroUintptr := llvm.ConstNull(p.uintptrType)
+ // optimizations if they are left in place.
+ zero := llvm.ConstInt(p.ctx.Int32Type(), 0, false)
for _, t := range p.types {
- initializer := t.typecode.Initializer()
- methodSet := p.builder.CreateExtractValue(initializer, 2, "")
- initializer = p.builder.CreateInsertValue(initializer, llvm.ConstNull(methodSet.Type()), 2, "")
- initializer = p.builder.CreateInsertValue(initializer, zeroUintptr, 4, "")
- t.typecode.SetInitializer(initializer)
+ if !t.methodSet.IsNil() {
+ initializer := t.typecode.Initializer()
+ var newInitializerFields []llvm.Value
+ for i := 1; i < initializer.Type().StructElementTypesCount(); i++ {
+ newInitializerFields = append(newInitializerFields, p.builder.CreateExtractValue(initializer, i, ""))
+ }
+ newInitializer := p.ctx.ConstStruct(newInitializerFields, false)
+ typecodeName := t.typecode.Name()
+ newGlobal := llvm.AddGlobal(p.mod, newInitializer.Type(), typecodeName+".tmp")
+ newGlobal.SetInitializer(newInitializer)
+ newGlobal.SetLinkage(t.typecode.Linkage())
+ newGlobal.SetGlobalConstant(true)
+ newGlobal.SetAlignment(t.typecode.Alignment())
+ for _, use := range getUses(t.typecode) {
+ if !use.IsAConstantExpr().IsNil() {
+ opcode := use.Opcode()
+ if opcode == llvm.GetElementPtr && use.OperandsCount() == 3 {
+ if use.Operand(1).ZExtValue() == 0 && use.Operand(2).ZExtValue() == 1 {
+ gep := p.builder.CreateInBoundsGEP(newGlobal.GlobalValueType(), newGlobal, []llvm.Value{zero, zero}, "")
+ use.ReplaceAllUsesWith(gep)
+ }
+ }
+ }
+ }
+ // Fallback.
+ if hasUses(t.typecode) {
+ bitcast := llvm.ConstBitCast(newGlobal, p.i8ptrType)
+ negativeOffset := -int64(p.targetData.TypeAllocSize(p.i8ptrType))
+ gep := p.builder.CreateInBoundsGEP(p.ctx.Int8Type(), bitcast, []llvm.Value{llvm.ConstInt(p.ctx.Int32Type(), uint64(negativeOffset), true)}, "")
+ bitcast2 := llvm.ConstBitCast(gep, t.typecode.Type())
+ t.typecode.ReplaceAllUsesWith(bitcast2)
+ }
+ t.typecode.EraseFromParentAsGlobal()
+ newGlobal.SetName(typecodeName)
+ t.typecode = newGlobal
+ }
}
return nil
@@ -301,22 +351,22 @@ func (p *lowerInterfacesPass) run() error {
// retrieves the signatures and the references to the method functions
// themselves for later type<->interface matching.
func (p *lowerInterfacesPass) addTypeMethods(t *typeInfo, methodSet llvm.Value) {
- if !t.methodSet.IsNil() || methodSet.IsNull() {
+ if !t.methodSet.IsNil() {
// no methods or methods already read
return
}
- if !methodSet.IsAConstantExpr().IsNil() && methodSet.Opcode() == llvm.GetElementPtr {
- methodSet = methodSet.Operand(0) // get global from GEP, for LLVM 14 (non-opaque pointers)
- }
// This type has methods, collect all methods of this type.
t.methodSet = methodSet
set := methodSet.Initializer() // get value from global
- for i := 0; i < set.Type().ArrayLength(); i++ {
- methodData := p.builder.CreateExtractValue(set, i, "")
- signatureGlobal := p.builder.CreateExtractValue(methodData, 0, "")
+ signatures := p.builder.CreateExtractValue(set, 1, "")
+ wrappers := p.builder.CreateExtractValue(set, 2, "")
+ numMethods := signatures.Type().ArrayLength()
+ for i := 0; i < numMethods; i++ {
+ signatureGlobal := p.builder.CreateExtractValue(signatures, i, "")
+ function := p.builder.CreateExtractValue(wrappers, i, "")
+ function = stripPointerCasts(function) // strip bitcasts
signatureName := signatureGlobal.Name()
- function := p.builder.CreateExtractValue(methodData, 1, "").Operand(0)
signature := p.getSignature(signatureName)
method := &methodInfo{
function: function,
@@ -401,7 +451,7 @@ func (p *lowerInterfacesPass) defineInterfaceImplementsFunc(fn llvm.Value, itf *
actualType := fn.Param(0)
for _, typ := range itf.types {
nextBlock := p.ctx.AddBasicBlock(fn, typ.name+".next")
- cmp := p.builder.CreateICmp(llvm.IntEQ, actualType, llvm.ConstPtrToInt(typ.typecode, p.uintptrType), typ.name+".icmp")
+ cmp := p.builder.CreateICmp(llvm.IntEQ, actualType, typ.typecodeGEP, typ.name+".icmp")
p.builder.CreateCondBr(cmp, thenBlock, nextBlock)
p.builder.SetInsertPointAtEnd(nextBlock)
}
@@ -440,7 +490,7 @@ func (p *lowerInterfacesPass) defineInterfaceMethodFunc(fn llvm.Value, itf *inte
params[i] = fn.Param(i + 1)
}
params = append(params,
- llvm.Undef(llvm.PointerType(p.ctx.Int8Type(), 0)),
+ llvm.Undef(p.i8ptrType),
)
// Start chain in the entry block.
@@ -472,7 +522,7 @@ func (p *lowerInterfacesPass) defineInterfaceMethodFunc(fn llvm.Value, itf *inte
// Create type check (if/else).
bb := p.ctx.AddBasicBlock(fn, typ.name)
next := p.ctx.AddBasicBlock(fn, typ.name+".next")
- cmp := p.builder.CreateICmp(llvm.IntEQ, actualType, llvm.ConstPtrToInt(typ.typecode, p.uintptrType), typ.name+".icmp")
+ cmp := p.builder.CreateICmp(llvm.IntEQ, actualType, typ.typecodeGEP, typ.name+".icmp")
p.builder.CreateCondBr(cmp, bb, next)
// The function we will redirect to when the interface has this type.
@@ -522,7 +572,7 @@ func (p *lowerInterfacesPass) defineInterfaceMethodFunc(fn llvm.Value, itf *inte
// method on a nil interface.
nilPanic := p.mod.NamedFunction("runtime.nilPanic")
p.builder.CreateCall(nilPanic.GlobalValueType(), nilPanic, []llvm.Value{
- llvm.Undef(llvm.PointerType(p.ctx.Int8Type(), 0)),
+ llvm.Undef(p.i8ptrType),
}, "")
p.builder.CreateUnreachable()
}