diff options
author | Ayke van Laethem <[email protected]> | 2023-03-16 15:06:01 +0100 |
---|---|---|
committer | Ron Evans <[email protected]> | 2023-03-21 22:22:03 +0100 |
commit | 523c6c0e3b09b5dc613be3eca7d604093cbd14b7 (patch) | |
tree | 5c5a4b27d8d38aed3973c799f0a8378036f14ea3 | |
parent | 17f5fb1071afa78ca41220a197098506377b2396 (diff) | |
download | tinygo-523c6c0e3b09b5dc613be3eca7d604093cbd14b7.tar.gz tinygo-523c6c0e3b09b5dc613be3eca7d604093cbd14b7.zip |
compiler: correctly generate code for local named types
It is possible to create function-local named types:
func foo() any {
type named int
return named(0)
}
This patch makes sure they don't alias with named types declared at the
package scope.
Bug originally found by Damian Gryski while working on reflect support.
-rw-r--r-- | compiler/compiler.go | 1 | ||||
-rw-r--r-- | compiler/func.go | 3 | ||||
-rw-r--r-- | compiler/interface.go | 101 | ||||
-rw-r--r-- | testdata/interface.go | 35 |
4 files changed, 116 insertions, 24 deletions
diff --git a/compiler/compiler.go b/compiler/compiler.go index 9e05ca524..53576fc31 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -71,6 +71,7 @@ type compilerContext struct { difiles map[string]llvm.Metadata ditypes map[types.Type]llvm.Metadata llvmTypes typeutil.Map + interfaceTypes typeutil.Map machine llvm.TargetMachine targetData llvm.TargetData intType llvm.Type diff --git a/compiler/func.go b/compiler/func.go index 743a4f083..3ac42e749 100644 --- a/compiler/func.go +++ b/compiler/func.go @@ -32,7 +32,8 @@ func (c *compilerContext) createFuncValue(builder llvm.Builder, funcPtr, context // global reference is not real, it is only used during func lowering to assign // signature types to functions and will then be removed. func (c *compilerContext) getFuncSignatureID(sig *types.Signature) llvm.Value { - sigGlobalName := "reflect/types.funcid:" + getTypeCodeName(sig) + s, _ := getTypeCodeName(sig) + sigGlobalName := "reflect/types.funcid:" + s sigGlobal := c.mod.NamedGlobal(sigGlobalName) if sigGlobal.IsNil() { sigGlobal = llvm.AddGlobal(c.mod, c.ctx.Int8Type(), sigGlobalName) diff --git a/compiler/interface.go b/compiler/interface.go index 2d6c4a7f7..edd10138b 100644 --- a/compiler/interface.go +++ b/compiler/interface.go @@ -118,8 +118,23 @@ func (c *compilerContext) getTypeCode(typ types.Type) llvm.Value { if _, ok := typ.Underlying().(*types.Interface); ok { hasMethodSet = false } - globalName := "reflect/types.type:" + getTypeCodeName(typ) - global := c.mod.NamedGlobal(globalName) + typeCodeName, isLocal := getTypeCodeName(typ) + globalName := "reflect/types.type:" + typeCodeName + var global llvm.Value + if isLocal { + // This type is a named type inside a function, like this: + // + // func foo() any { + // type named int + // return named(0) + // } + if obj := c.interfaceTypes.At(typ); obj != nil { + global = obj.(llvm.Value) + } + } else { + // Regular type (named or otherwise). + global = c.mod.NamedGlobal(globalName) + } if global.IsNil() { var typeFields []llvm.Value // Define the type fields. These must match the structs in @@ -203,6 +218,9 @@ func (c *compilerContext) getTypeCode(typ types.Type) llvm.Value { } globalType := types.NewStruct(typeFieldTypes, nil) global = llvm.AddGlobal(c.mod, c.getLLVMType(globalType), globalName) + if isLocal { + c.interfaceTypes.Set(typ, global) + } metabyte := getTypeKind(typ) switch typ := typ.(type) { case *types.Basic: @@ -330,7 +348,11 @@ func (c *compilerContext) getTypeCode(typ types.Type) llvm.Value { alignment := c.targetData.TypeAllocSize(c.i8ptrType) globalValue := c.ctx.ConstStruct(typeFields, false) global.SetInitializer(globalValue) - global.SetLinkage(llvm.LinkOnceODRLinkage) + if isLocal { + global.SetLinkage(llvm.InternalLinkage) + } else { + global.SetLinkage(llvm.LinkOnceODRLinkage) + } global.SetGlobalConstant(true) global.SetAlignment(int(alignment)) if c.Debug { @@ -411,57 +433,84 @@ var basicTypeNames = [...]string{ // getTypeCodeName returns a name for this type that can be used in the // interface lowering pass to assign type codes as expected by the reflect // package. See getTypeCodeNum. -func getTypeCodeName(t types.Type) string { +func getTypeCodeName(t types.Type) (string, bool) { switch t := t.(type) { case *types.Named: - return "named:" + t.String() + // Note: check for `t.Obj().Pkg() != nil` for Go 1.18 only. + if t.Obj().Pkg() != nil && t.Obj().Parent() != t.Obj().Pkg().Scope() { + return "named:" + t.String() + "$local", true + } + return "named:" + t.String(), false case *types.Array: - return "array:" + strconv.FormatInt(t.Len(), 10) + ":" + getTypeCodeName(t.Elem()) + s, isLocal := getTypeCodeName(t.Elem()) + return "array:" + strconv.FormatInt(t.Len(), 10) + ":" + s, isLocal case *types.Basic: - return "basic:" + basicTypeNames[t.Kind()] + return "basic:" + basicTypeNames[t.Kind()], false case *types.Chan: - return "chan:" + getTypeCodeName(t.Elem()) + s, isLocal := getTypeCodeName(t.Elem()) + return "chan:" + s, isLocal case *types.Interface: + isLocal := false methods := make([]string, t.NumMethods()) for i := 0; i < t.NumMethods(); i++ { name := t.Method(i).Name() if !token.IsExported(name) { name = t.Method(i).Pkg().Path() + "." + name } - methods[i] = name + ":" + getTypeCodeName(t.Method(i).Type()) + s, local := getTypeCodeName(t.Method(i).Type()) + if local { + isLocal = true + } + methods[i] = name + ":" + s } - return "interface:" + "{" + strings.Join(methods, ",") + "}" + return "interface:" + "{" + strings.Join(methods, ",") + "}", isLocal case *types.Map: - keyType := getTypeCodeName(t.Key()) - elemType := getTypeCodeName(t.Elem()) - return "map:" + "{" + keyType + "," + elemType + "}" + keyType, keyLocal := getTypeCodeName(t.Key()) + elemType, elemLocal := getTypeCodeName(t.Elem()) + return "map:" + "{" + keyType + "," + elemType + "}", keyLocal || elemLocal case *types.Pointer: - return "pointer:" + getTypeCodeName(t.Elem()) + s, isLocal := getTypeCodeName(t.Elem()) + return "pointer:" + s, isLocal case *types.Signature: + isLocal := false params := make([]string, t.Params().Len()) for i := 0; i < t.Params().Len(); i++ { - params[i] = getTypeCodeName(t.Params().At(i).Type()) + s, local := getTypeCodeName(t.Params().At(i).Type()) + if local { + isLocal = true + } + params[i] = s } results := make([]string, t.Results().Len()) for i := 0; i < t.Results().Len(); i++ { - results[i] = getTypeCodeName(t.Results().At(i).Type()) + s, local := getTypeCodeName(t.Results().At(i).Type()) + if local { + isLocal = true + } + results[i] = s } - return "func:" + "{" + strings.Join(params, ",") + "}{" + strings.Join(results, ",") + "}" + return "func:" + "{" + strings.Join(params, ",") + "}{" + strings.Join(results, ",") + "}", isLocal case *types.Slice: - return "slice:" + getTypeCodeName(t.Elem()) + s, isLocal := getTypeCodeName(t.Elem()) + return "slice:" + s, isLocal case *types.Struct: elems := make([]string, t.NumFields()) + isLocal := false for i := 0; i < t.NumFields(); i++ { embedded := "" if t.Field(i).Embedded() { embedded = "#" } - elems[i] = embedded + t.Field(i).Name() + ":" + getTypeCodeName(t.Field(i).Type()) + s, local := getTypeCodeName(t.Field(i).Type()) + if local { + isLocal = true + } + elems[i] = embedded + t.Field(i).Name() + ":" + s if t.Tag(i) != "" { elems[i] += "`" + t.Tag(i) + "`" } } - return "struct:" + "{" + strings.Join(elems, ",") + "}" + return "struct:" + "{" + strings.Join(elems, ",") + "}", isLocal default: panic("unknown type: " + t.String()) } @@ -564,7 +613,11 @@ func (b *builder) createTypeAssert(expr *ssa.TypeAssert) llvm.Value { commaOk = b.CreateCall(fn.GlobalValueType(), fn, []llvm.Value{actualTypeNum}, "") } else { - globalName := "reflect/types.typeid:" + getTypeCodeName(expr.AssertedType) + assertedTypeGlobal := b.getTypeCode(expr.AssertedType) + if !assertedTypeGlobal.IsAConstantExpr().IsNil() { + assertedTypeGlobal = assertedTypeGlobal.Operand(0) // resolve the GEP operation + } + globalName := "reflect/types.typeid:" + strings.TrimPrefix(assertedTypeGlobal.Name(), "reflect/types.type:") assertedTypeCodeGlobal := b.mod.NamedGlobal(globalName) if assertedTypeCodeGlobal.IsNil() { // Create a new typecode global. @@ -640,7 +693,8 @@ func (c *compilerContext) getMethodsString(itf *types.Interface) string { // getInterfaceImplementsfunc returns a declared function that works as a type // switch. The interface lowering pass will define this function. func (c *compilerContext) getInterfaceImplementsFunc(assertedType types.Type) llvm.Value { - fnName := getTypeCodeName(assertedType.Underlying()) + ".$typeassert" + s, _ := getTypeCodeName(assertedType.Underlying()) + fnName := s + ".$typeassert" llvmFn := c.mod.NamedFunction(fnName) if llvmFn.IsNil() { llvmFnType := llvm.FunctionType(c.ctx.Int1Type(), []llvm.Type{c.i8ptrType}, false) @@ -656,7 +710,8 @@ func (c *compilerContext) getInterfaceImplementsFunc(assertedType types.Type) ll // thunk is declared, not defined: it will be defined by the interface lowering // pass. func (c *compilerContext) getInvokeFunction(instr *ssa.CallCommon) llvm.Value { - fnName := getTypeCodeName(instr.Value.Type().Underlying()) + "." + instr.Method.Name() + "$invoke" + s, _ := getTypeCodeName(instr.Value.Type().Underlying()) + fnName := s + "." + instr.Method.Name() + "$invoke" llvmFn := c.mod.NamedFunction(fnName) if llvmFn.IsNil() { sig := instr.Method.Type().(*types.Signature) diff --git a/testdata/interface.go b/testdata/interface.go index d13399f36..7820538a4 100644 --- a/testdata/interface.go +++ b/testdata/interface.go @@ -93,6 +93,12 @@ func main() { a int b int }{3, 6}}, + {true, named1(), named1()}, + {true, named2(), named2()}, + {false, named1(), named2()}, + {false, named2(), named3()}, + {true, namedptr1(), namedptr1()}, + {false, namedptr1(), namedptr2()}, } for i, tc := range interfaceEqualTests { if (tc.lhs == tc.rhs) != tc.equal { @@ -277,3 +283,32 @@ func (f FooByte) Byte() byte { return byte(f) } type Byter interface { Byte() uint8 } + +// Make sure that named types inside functions do not alias with any other named +// functions. + +type named int + +func named1() any { + return named(0) +} + +func named2() any { + type named int + return named(0) +} + +func named3() any { + type named int + return named(0) +} + +func namedptr1() interface{} { + type Test int + return (*Test)(nil) +} + +func namedptr2() interface{} { + type Test byte + return (*Test)(nil) +} |