aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorAyke van Laethem <[email protected]>2023-03-16 15:06:01 +0100
committerRon Evans <[email protected]>2023-03-21 22:22:03 +0100
commit523c6c0e3b09b5dc613be3eca7d604093cbd14b7 (patch)
tree5c5a4b27d8d38aed3973c799f0a8378036f14ea3
parent17f5fb1071afa78ca41220a197098506377b2396 (diff)
downloadtinygo-523c6c0e3b09b5dc613be3eca7d604093cbd14b7.tar.gz
tinygo-523c6c0e3b09b5dc613be3eca7d604093cbd14b7.zip
compiler: correctly generate code for local named types
It is possible to create function-local named types: func foo() any { type named int return named(0) } This patch makes sure they don't alias with named types declared at the package scope. Bug originally found by Damian Gryski while working on reflect support.
-rw-r--r--compiler/compiler.go1
-rw-r--r--compiler/func.go3
-rw-r--r--compiler/interface.go101
-rw-r--r--testdata/interface.go35
4 files changed, 116 insertions, 24 deletions
diff --git a/compiler/compiler.go b/compiler/compiler.go
index 9e05ca524..53576fc31 100644
--- a/compiler/compiler.go
+++ b/compiler/compiler.go
@@ -71,6 +71,7 @@ type compilerContext struct {
difiles map[string]llvm.Metadata
ditypes map[types.Type]llvm.Metadata
llvmTypes typeutil.Map
+ interfaceTypes typeutil.Map
machine llvm.TargetMachine
targetData llvm.TargetData
intType llvm.Type
diff --git a/compiler/func.go b/compiler/func.go
index 743a4f083..3ac42e749 100644
--- a/compiler/func.go
+++ b/compiler/func.go
@@ -32,7 +32,8 @@ func (c *compilerContext) createFuncValue(builder llvm.Builder, funcPtr, context
// global reference is not real, it is only used during func lowering to assign
// signature types to functions and will then be removed.
func (c *compilerContext) getFuncSignatureID(sig *types.Signature) llvm.Value {
- sigGlobalName := "reflect/types.funcid:" + getTypeCodeName(sig)
+ s, _ := getTypeCodeName(sig)
+ sigGlobalName := "reflect/types.funcid:" + s
sigGlobal := c.mod.NamedGlobal(sigGlobalName)
if sigGlobal.IsNil() {
sigGlobal = llvm.AddGlobal(c.mod, c.ctx.Int8Type(), sigGlobalName)
diff --git a/compiler/interface.go b/compiler/interface.go
index 2d6c4a7f7..edd10138b 100644
--- a/compiler/interface.go
+++ b/compiler/interface.go
@@ -118,8 +118,23 @@ func (c *compilerContext) getTypeCode(typ types.Type) llvm.Value {
if _, ok := typ.Underlying().(*types.Interface); ok {
hasMethodSet = false
}
- globalName := "reflect/types.type:" + getTypeCodeName(typ)
- global := c.mod.NamedGlobal(globalName)
+ typeCodeName, isLocal := getTypeCodeName(typ)
+ globalName := "reflect/types.type:" + typeCodeName
+ var global llvm.Value
+ if isLocal {
+ // This type is a named type inside a function, like this:
+ //
+ // func foo() any {
+ // type named int
+ // return named(0)
+ // }
+ if obj := c.interfaceTypes.At(typ); obj != nil {
+ global = obj.(llvm.Value)
+ }
+ } else {
+ // Regular type (named or otherwise).
+ global = c.mod.NamedGlobal(globalName)
+ }
if global.IsNil() {
var typeFields []llvm.Value
// Define the type fields. These must match the structs in
@@ -203,6 +218,9 @@ func (c *compilerContext) getTypeCode(typ types.Type) llvm.Value {
}
globalType := types.NewStruct(typeFieldTypes, nil)
global = llvm.AddGlobal(c.mod, c.getLLVMType(globalType), globalName)
+ if isLocal {
+ c.interfaceTypes.Set(typ, global)
+ }
metabyte := getTypeKind(typ)
switch typ := typ.(type) {
case *types.Basic:
@@ -330,7 +348,11 @@ func (c *compilerContext) getTypeCode(typ types.Type) llvm.Value {
alignment := c.targetData.TypeAllocSize(c.i8ptrType)
globalValue := c.ctx.ConstStruct(typeFields, false)
global.SetInitializer(globalValue)
- global.SetLinkage(llvm.LinkOnceODRLinkage)
+ if isLocal {
+ global.SetLinkage(llvm.InternalLinkage)
+ } else {
+ global.SetLinkage(llvm.LinkOnceODRLinkage)
+ }
global.SetGlobalConstant(true)
global.SetAlignment(int(alignment))
if c.Debug {
@@ -411,57 +433,84 @@ var basicTypeNames = [...]string{
// getTypeCodeName returns a name for this type that can be used in the
// interface lowering pass to assign type codes as expected by the reflect
// package. See getTypeCodeNum.
-func getTypeCodeName(t types.Type) string {
+func getTypeCodeName(t types.Type) (string, bool) {
switch t := t.(type) {
case *types.Named:
- return "named:" + t.String()
+ // Note: check for `t.Obj().Pkg() != nil` for Go 1.18 only.
+ if t.Obj().Pkg() != nil && t.Obj().Parent() != t.Obj().Pkg().Scope() {
+ return "named:" + t.String() + "$local", true
+ }
+ return "named:" + t.String(), false
case *types.Array:
- return "array:" + strconv.FormatInt(t.Len(), 10) + ":" + getTypeCodeName(t.Elem())
+ s, isLocal := getTypeCodeName(t.Elem())
+ return "array:" + strconv.FormatInt(t.Len(), 10) + ":" + s, isLocal
case *types.Basic:
- return "basic:" + basicTypeNames[t.Kind()]
+ return "basic:" + basicTypeNames[t.Kind()], false
case *types.Chan:
- return "chan:" + getTypeCodeName(t.Elem())
+ s, isLocal := getTypeCodeName(t.Elem())
+ return "chan:" + s, isLocal
case *types.Interface:
+ isLocal := false
methods := make([]string, t.NumMethods())
for i := 0; i < t.NumMethods(); i++ {
name := t.Method(i).Name()
if !token.IsExported(name) {
name = t.Method(i).Pkg().Path() + "." + name
}
- methods[i] = name + ":" + getTypeCodeName(t.Method(i).Type())
+ s, local := getTypeCodeName(t.Method(i).Type())
+ if local {
+ isLocal = true
+ }
+ methods[i] = name + ":" + s
}
- return "interface:" + "{" + strings.Join(methods, ",") + "}"
+ return "interface:" + "{" + strings.Join(methods, ",") + "}", isLocal
case *types.Map:
- keyType := getTypeCodeName(t.Key())
- elemType := getTypeCodeName(t.Elem())
- return "map:" + "{" + keyType + "," + elemType + "}"
+ keyType, keyLocal := getTypeCodeName(t.Key())
+ elemType, elemLocal := getTypeCodeName(t.Elem())
+ return "map:" + "{" + keyType + "," + elemType + "}", keyLocal || elemLocal
case *types.Pointer:
- return "pointer:" + getTypeCodeName(t.Elem())
+ s, isLocal := getTypeCodeName(t.Elem())
+ return "pointer:" + s, isLocal
case *types.Signature:
+ isLocal := false
params := make([]string, t.Params().Len())
for i := 0; i < t.Params().Len(); i++ {
- params[i] = getTypeCodeName(t.Params().At(i).Type())
+ s, local := getTypeCodeName(t.Params().At(i).Type())
+ if local {
+ isLocal = true
+ }
+ params[i] = s
}
results := make([]string, t.Results().Len())
for i := 0; i < t.Results().Len(); i++ {
- results[i] = getTypeCodeName(t.Results().At(i).Type())
+ s, local := getTypeCodeName(t.Results().At(i).Type())
+ if local {
+ isLocal = true
+ }
+ results[i] = s
}
- return "func:" + "{" + strings.Join(params, ",") + "}{" + strings.Join(results, ",") + "}"
+ return "func:" + "{" + strings.Join(params, ",") + "}{" + strings.Join(results, ",") + "}", isLocal
case *types.Slice:
- return "slice:" + getTypeCodeName(t.Elem())
+ s, isLocal := getTypeCodeName(t.Elem())
+ return "slice:" + s, isLocal
case *types.Struct:
elems := make([]string, t.NumFields())
+ isLocal := false
for i := 0; i < t.NumFields(); i++ {
embedded := ""
if t.Field(i).Embedded() {
embedded = "#"
}
- elems[i] = embedded + t.Field(i).Name() + ":" + getTypeCodeName(t.Field(i).Type())
+ s, local := getTypeCodeName(t.Field(i).Type())
+ if local {
+ isLocal = true
+ }
+ elems[i] = embedded + t.Field(i).Name() + ":" + s
if t.Tag(i) != "" {
elems[i] += "`" + t.Tag(i) + "`"
}
}
- return "struct:" + "{" + strings.Join(elems, ",") + "}"
+ return "struct:" + "{" + strings.Join(elems, ",") + "}", isLocal
default:
panic("unknown type: " + t.String())
}
@@ -564,7 +613,11 @@ func (b *builder) createTypeAssert(expr *ssa.TypeAssert) llvm.Value {
commaOk = b.CreateCall(fn.GlobalValueType(), fn, []llvm.Value{actualTypeNum}, "")
} else {
- globalName := "reflect/types.typeid:" + getTypeCodeName(expr.AssertedType)
+ assertedTypeGlobal := b.getTypeCode(expr.AssertedType)
+ if !assertedTypeGlobal.IsAConstantExpr().IsNil() {
+ assertedTypeGlobal = assertedTypeGlobal.Operand(0) // resolve the GEP operation
+ }
+ globalName := "reflect/types.typeid:" + strings.TrimPrefix(assertedTypeGlobal.Name(), "reflect/types.type:")
assertedTypeCodeGlobal := b.mod.NamedGlobal(globalName)
if assertedTypeCodeGlobal.IsNil() {
// Create a new typecode global.
@@ -640,7 +693,8 @@ func (c *compilerContext) getMethodsString(itf *types.Interface) string {
// getInterfaceImplementsfunc returns a declared function that works as a type
// switch. The interface lowering pass will define this function.
func (c *compilerContext) getInterfaceImplementsFunc(assertedType types.Type) llvm.Value {
- fnName := getTypeCodeName(assertedType.Underlying()) + ".$typeassert"
+ s, _ := getTypeCodeName(assertedType.Underlying())
+ fnName := s + ".$typeassert"
llvmFn := c.mod.NamedFunction(fnName)
if llvmFn.IsNil() {
llvmFnType := llvm.FunctionType(c.ctx.Int1Type(), []llvm.Type{c.i8ptrType}, false)
@@ -656,7 +710,8 @@ func (c *compilerContext) getInterfaceImplementsFunc(assertedType types.Type) ll
// thunk is declared, not defined: it will be defined by the interface lowering
// pass.
func (c *compilerContext) getInvokeFunction(instr *ssa.CallCommon) llvm.Value {
- fnName := getTypeCodeName(instr.Value.Type().Underlying()) + "." + instr.Method.Name() + "$invoke"
+ s, _ := getTypeCodeName(instr.Value.Type().Underlying())
+ fnName := s + "." + instr.Method.Name() + "$invoke"
llvmFn := c.mod.NamedFunction(fnName)
if llvmFn.IsNil() {
sig := instr.Method.Type().(*types.Signature)
diff --git a/testdata/interface.go b/testdata/interface.go
index d13399f36..7820538a4 100644
--- a/testdata/interface.go
+++ b/testdata/interface.go
@@ -93,6 +93,12 @@ func main() {
a int
b int
}{3, 6}},
+ {true, named1(), named1()},
+ {true, named2(), named2()},
+ {false, named1(), named2()},
+ {false, named2(), named3()},
+ {true, namedptr1(), namedptr1()},
+ {false, namedptr1(), namedptr2()},
}
for i, tc := range interfaceEqualTests {
if (tc.lhs == tc.rhs) != tc.equal {
@@ -277,3 +283,32 @@ func (f FooByte) Byte() byte { return byte(f) }
type Byter interface {
Byte() uint8
}
+
+// Make sure that named types inside functions do not alias with any other named
+// functions.
+
+type named int
+
+func named1() any {
+ return named(0)
+}
+
+func named2() any {
+ type named int
+ return named(0)
+}
+
+func named3() any {
+ type named int
+ return named(0)
+}
+
+func namedptr1() interface{} {
+ type Test int
+ return (*Test)(nil)
+}
+
+func namedptr2() interface{} {
+ type Test byte
+ return (*Test)(nil)
+}