diff options
Diffstat (limited to 'transform/interface-lowering.go')
-rw-r--r-- | transform/interface-lowering.go | 118 |
1 files changed, 84 insertions, 34 deletions
diff --git a/transform/interface-lowering.go b/transform/interface-lowering.go index 55d1af39b..e57463715 100644 --- a/transform/interface-lowering.go +++ b/transform/interface-lowering.go @@ -13,8 +13,7 @@ package transform // // typeAssert: // Replaced with an icmp instruction so it can be directly used in a type -// switch. This is very easy to optimize for LLVM: it will often translate a -// type switch into a regular switch statement. +// switch. // // interface type assert: // These functions are defined by creating a big type switch over all the @@ -54,10 +53,11 @@ type methodInfo struct { // typeInfo describes a single concrete Go type, which can be a basic or a named // type. If it is a named type, it may have methods. type typeInfo struct { - name string - typecode llvm.Value - methodSet llvm.Value - methods []*methodInfo + name string + typecode llvm.Value + typecodeGEP llvm.Value + methodSet llvm.Value + methods []*methodInfo } // getMethod looks up the method on this type with the given signature and @@ -91,6 +91,8 @@ type lowerInterfacesPass struct { difiles map[string]llvm.Metadata ctx llvm.Context uintptrType llvm.Type + targetData llvm.TargetData + i8ptrType llvm.Type types map[string]*typeInfo signatures map[string]*signatureInfo interfaces map[string]*interfaceInfo @@ -101,14 +103,17 @@ type lowerInterfacesPass struct { // before LLVM can work on them. This is done so that a few cleanup passes can // run before assigning the final type codes. func LowerInterfaces(mod llvm.Module, config *compileopts.Config) error { + ctx := mod.Context() targetData := llvm.NewTargetData(mod.DataLayout()) defer targetData.Dispose() p := &lowerInterfacesPass{ mod: mod, config: config, - builder: mod.Context().NewBuilder(), - ctx: mod.Context(), + builder: ctx.NewBuilder(), + ctx: ctx, + targetData: targetData, uintptrType: mod.Context().IntType(targetData.PointerSize() * 8), + i8ptrType: llvm.PointerType(ctx.Int8Type(), 0), types: make(map[string]*typeInfo), signatures: make(map[string]*signatureInfo), interfaces: make(map[string]*interfaceInfo), @@ -151,11 +156,26 @@ func (p *lowerInterfacesPass) run() error { } p.types[name] = t initializer := global.Initializer() - if initializer.IsNil() { - continue + firstField := p.builder.CreateExtractValue(initializer, 0, "") + if firstField.Type() != p.ctx.Int8Type() { + // This type has a method set at index 0. Change the GEP to + // point to index 1 (the meta byte). + t.typecodeGEP = llvm.ConstGEP(global.GlobalValueType(), global, []llvm.Value{ + llvm.ConstInt(p.ctx.Int32Type(), 0, false), + llvm.ConstInt(p.ctx.Int32Type(), 1, false), + }) + methodSet := stripPointerCasts(firstField) + if !strings.HasSuffix(methodSet.Name(), "$methodset") { + panic("expected method set") + } + p.addTypeMethods(t, methodSet) + } else { + // This type has no method set. + t.typecodeGEP = llvm.ConstGEP(global.GlobalValueType(), global, []llvm.Value{ + llvm.ConstInt(p.ctx.Int32Type(), 0, false), + llvm.ConstInt(p.ctx.Int32Type(), 0, false), + }) } - methodSet := p.builder.CreateExtractValue(initializer, 2, "") - p.addTypeMethods(t, methodSet) } } } @@ -266,10 +286,10 @@ func (p *lowerInterfacesPass) run() error { actualType := use.Operand(0) name := strings.TrimPrefix(use.Operand(1).Name(), "reflect/types.typeid:") if t, ok := p.types[name]; ok { - // The type exists in the program, so lower to a regular integer + // The type exists in the program, so lower to a regular pointer // comparison. p.builder.SetInsertPointBefore(use) - commaOk := p.builder.CreateICmp(llvm.IntEQ, llvm.ConstPtrToInt(t.typecode, p.uintptrType), actualType, "typeassert.ok") + commaOk := p.builder.CreateICmp(llvm.IntEQ, t.typecodeGEP, actualType, "typeassert.ok") use.ReplaceAllUsesWith(commaOk) } else { // The type does not exist in the program, so lower to a constant @@ -283,15 +303,45 @@ func (p *lowerInterfacesPass) run() error { } // Remove all method sets, which are now unnecessary and inhibit later - // optimizations if they are left in place. Also remove references to the - // interface type assert functions just to be sure. - zeroUintptr := llvm.ConstNull(p.uintptrType) + // optimizations if they are left in place. + zero := llvm.ConstInt(p.ctx.Int32Type(), 0, false) for _, t := range p.types { - initializer := t.typecode.Initializer() - methodSet := p.builder.CreateExtractValue(initializer, 2, "") - initializer = p.builder.CreateInsertValue(initializer, llvm.ConstNull(methodSet.Type()), 2, "") - initializer = p.builder.CreateInsertValue(initializer, zeroUintptr, 4, "") - t.typecode.SetInitializer(initializer) + if !t.methodSet.IsNil() { + initializer := t.typecode.Initializer() + var newInitializerFields []llvm.Value + for i := 1; i < initializer.Type().StructElementTypesCount(); i++ { + newInitializerFields = append(newInitializerFields, p.builder.CreateExtractValue(initializer, i, "")) + } + newInitializer := p.ctx.ConstStruct(newInitializerFields, false) + typecodeName := t.typecode.Name() + newGlobal := llvm.AddGlobal(p.mod, newInitializer.Type(), typecodeName+".tmp") + newGlobal.SetInitializer(newInitializer) + newGlobal.SetLinkage(t.typecode.Linkage()) + newGlobal.SetGlobalConstant(true) + newGlobal.SetAlignment(t.typecode.Alignment()) + for _, use := range getUses(t.typecode) { + if !use.IsAConstantExpr().IsNil() { + opcode := use.Opcode() + if opcode == llvm.GetElementPtr && use.OperandsCount() == 3 { + if use.Operand(1).ZExtValue() == 0 && use.Operand(2).ZExtValue() == 1 { + gep := p.builder.CreateInBoundsGEP(newGlobal.GlobalValueType(), newGlobal, []llvm.Value{zero, zero}, "") + use.ReplaceAllUsesWith(gep) + } + } + } + } + // Fallback. + if hasUses(t.typecode) { + bitcast := llvm.ConstBitCast(newGlobal, p.i8ptrType) + negativeOffset := -int64(p.targetData.TypeAllocSize(p.i8ptrType)) + gep := p.builder.CreateInBoundsGEP(p.ctx.Int8Type(), bitcast, []llvm.Value{llvm.ConstInt(p.ctx.Int32Type(), uint64(negativeOffset), true)}, "") + bitcast2 := llvm.ConstBitCast(gep, t.typecode.Type()) + t.typecode.ReplaceAllUsesWith(bitcast2) + } + t.typecode.EraseFromParentAsGlobal() + newGlobal.SetName(typecodeName) + t.typecode = newGlobal + } } return nil @@ -301,22 +351,22 @@ func (p *lowerInterfacesPass) run() error { // retrieves the signatures and the references to the method functions // themselves for later type<->interface matching. func (p *lowerInterfacesPass) addTypeMethods(t *typeInfo, methodSet llvm.Value) { - if !t.methodSet.IsNil() || methodSet.IsNull() { + if !t.methodSet.IsNil() { // no methods or methods already read return } - if !methodSet.IsAConstantExpr().IsNil() && methodSet.Opcode() == llvm.GetElementPtr { - methodSet = methodSet.Operand(0) // get global from GEP, for LLVM 14 (non-opaque pointers) - } // This type has methods, collect all methods of this type. t.methodSet = methodSet set := methodSet.Initializer() // get value from global - for i := 0; i < set.Type().ArrayLength(); i++ { - methodData := p.builder.CreateExtractValue(set, i, "") - signatureGlobal := p.builder.CreateExtractValue(methodData, 0, "") + signatures := p.builder.CreateExtractValue(set, 1, "") + wrappers := p.builder.CreateExtractValue(set, 2, "") + numMethods := signatures.Type().ArrayLength() + for i := 0; i < numMethods; i++ { + signatureGlobal := p.builder.CreateExtractValue(signatures, i, "") + function := p.builder.CreateExtractValue(wrappers, i, "") + function = stripPointerCasts(function) // strip bitcasts signatureName := signatureGlobal.Name() - function := p.builder.CreateExtractValue(methodData, 1, "").Operand(0) signature := p.getSignature(signatureName) method := &methodInfo{ function: function, @@ -401,7 +451,7 @@ func (p *lowerInterfacesPass) defineInterfaceImplementsFunc(fn llvm.Value, itf * actualType := fn.Param(0) for _, typ := range itf.types { nextBlock := p.ctx.AddBasicBlock(fn, typ.name+".next") - cmp := p.builder.CreateICmp(llvm.IntEQ, actualType, llvm.ConstPtrToInt(typ.typecode, p.uintptrType), typ.name+".icmp") + cmp := p.builder.CreateICmp(llvm.IntEQ, actualType, typ.typecodeGEP, typ.name+".icmp") p.builder.CreateCondBr(cmp, thenBlock, nextBlock) p.builder.SetInsertPointAtEnd(nextBlock) } @@ -440,7 +490,7 @@ func (p *lowerInterfacesPass) defineInterfaceMethodFunc(fn llvm.Value, itf *inte params[i] = fn.Param(i + 1) } params = append(params, - llvm.Undef(llvm.PointerType(p.ctx.Int8Type(), 0)), + llvm.Undef(p.i8ptrType), ) // Start chain in the entry block. @@ -472,7 +522,7 @@ func (p *lowerInterfacesPass) defineInterfaceMethodFunc(fn llvm.Value, itf *inte // Create type check (if/else). bb := p.ctx.AddBasicBlock(fn, typ.name) next := p.ctx.AddBasicBlock(fn, typ.name+".next") - cmp := p.builder.CreateICmp(llvm.IntEQ, actualType, llvm.ConstPtrToInt(typ.typecode, p.uintptrType), typ.name+".icmp") + cmp := p.builder.CreateICmp(llvm.IntEQ, actualType, typ.typecodeGEP, typ.name+".icmp") p.builder.CreateCondBr(cmp, bb, next) // The function we will redirect to when the interface has this type. @@ -522,7 +572,7 @@ func (p *lowerInterfacesPass) defineInterfaceMethodFunc(fn llvm.Value, itf *inte // method on a nil interface. nilPanic := p.mod.NamedFunction("runtime.nilPanic") p.builder.CreateCall(nilPanic.GlobalValueType(), nilPanic, []llvm.Value{ - llvm.Undef(llvm.PointerType(p.ctx.Int8Type(), 0)), + llvm.Undef(p.i8ptrType), }, "") p.builder.CreateUnreachable() } |