diff options
author | Ayke van Laethem <[email protected]> | 2018-09-22 20:25:50 +0200 |
---|---|---|
committer | Ayke van Laethem <[email protected]> | 2018-09-22 20:32:07 +0200 |
commit | b75a02e66dd240685b281b45bc87d37788fb0a7c (patch) | |
tree | 056f039cbb6705537dfd718d6f769f6b4d51e0a4 /ir | |
parent | 473e71b5731f88edfd07c4c07da32204f74e8389 (diff) | |
download | tinygo-b75a02e66dd240685b281b45bc87d37788fb0a7c.tar.gz tinygo-b75a02e66dd240685b281b45bc87d37788fb0a7c.zip |
compiler: refactor IR parts into separate package
Diffstat (limited to 'ir')
-rw-r--r-- | ir/interpreter.go | 533 | ||||
-rw-r--r-- | ir/ir.go | 387 | ||||
-rw-r--r-- | ir/passes.go | 427 |
3 files changed, 1347 insertions, 0 deletions
diff --git a/ir/interpreter.go b/ir/interpreter.go new file mode 100644 index 000000000..078ad1ac3 --- /dev/null +++ b/ir/interpreter.go @@ -0,0 +1,533 @@ +package ir + +// This file provides functionality to interpret very basic Go SSA, for +// compile-time initialization of globals. + +import ( + "errors" + "fmt" + "go/constant" + "go/token" + "go/types" + "strings" + + "golang.org/x/tools/go/ssa" +) + +var ErrCGoWrapper = errors.New("tinygo internal: cgo wrapper") // a signal, not an error + +// Ignore these calls (replace with a zero return value) when encountered during +// interpretation. +var ignoreInitCalls = map[string]struct{}{ + "syscall.runtime_envs": struct{}{}, + "syscall/js.predefValue": struct{}{}, + "(syscall/js.Value).Get": struct{}{}, + "(syscall/js.Value).New": struct{}{}, + "(syscall/js.Value).Int": struct{}{}, + "os.init$1": struct{}{}, +} + +// Interpret instructions as far as possible, and drop those instructions from +// the basic block. +func (p *Program) Interpret(block *ssa.BasicBlock, dumpSSA bool) error { + if dumpSSA { + fmt.Printf("\ninterpret: %s\n", block.Parent().Pkg.Pkg.Path()) + } + for { + i, err := p.interpret(block.Instrs, nil, nil, nil, dumpSSA) + if err == ErrCGoWrapper { + // skip this instruction + block.Instrs = block.Instrs[i+1:] + continue + } + block.Instrs = block.Instrs[i:] + return err + } +} + +// Interpret instructions as far as possible, and return the index of the first +// unknown instruction. +func (p *Program) interpret(instrs []ssa.Instruction, paramKeys []*ssa.Parameter, paramValues []Value, results []Value, dumpSSA bool) (int, error) { + locals := map[ssa.Value]Value{} + for i, key := range paramKeys { + locals[key] = paramValues[i] + } + for i, instr := range instrs { + if _, ok := instr.(*ssa.DebugRef); ok { + continue + } + if dumpSSA { + if val, ok := instr.(ssa.Value); ok && val.Name() != "" { + fmt.Printf("\t%s: %s = %s\n", instr.Parent().RelString(nil), val.Name(), val.String()) + } else { + fmt.Printf("\t%s: %s\n", instr.Parent().RelString(nil), instr.String()) + } + } + switch instr := instr.(type) { + case *ssa.Alloc: + alloc, err := p.getZeroValue(instr.Type().Underlying().(*types.Pointer).Elem()) + if err != nil { + return i, err + } + locals[instr] = &PointerValue{nil, &alloc} + case *ssa.BinOp: + if typ, ok := instr.Type().(*types.Basic); ok && typ.Kind() == types.String { + // Concatenate two strings. + // This happens in the time package, for example. + x, err := p.getValue(instr.X, locals) + if err != nil { + return i, err + } + y, err := p.getValue(instr.Y, locals) + if err != nil { + return i, err + } + xstr := constant.StringVal(x.(*ConstValue).Expr.Value) + ystr := constant.StringVal(y.(*ConstValue).Expr.Value) + locals[instr] = &ConstValue{ssa.NewConst(constant.MakeString(xstr+ystr), types.Typ[types.String])} + } else { + return i, errors.New("init: unknown binop: " + instr.String()) + } + case *ssa.Call: + common := instr.Common() + callee := common.StaticCallee() + if callee == nil { + return i, nil // don't understand dynamic dispatch + } + if _, ok := ignoreInitCalls[callee.String()]; ok { + // These calls are not needed and can be ignored, for the time + // being. + results := make([]Value, callee.Signature.Results().Len()) + for i := range results { + var err error + results[i], err = p.getZeroValue(callee.Signature.Results().At(i).Type()) + if err != nil { + return i, err + } + } + if len(results) == 1 { + locals[instr] = results[0] + } else if len(results) > 1 { + locals[instr] = &StructValue{Fields: results} + } + continue + } + if callee.String() == "os.NewFile" { + // Emulate the creation of os.Stdin, os.Stdout and os.Stderr. + resultPtrType := callee.Signature.Results().At(0).Type().(*types.Pointer) + resultStructOuterType := resultPtrType.Elem().Underlying().(*types.Struct) + if resultStructOuterType.NumFields() != 1 { + panic("expected 1 field in os.File struct") + } + fileInnerPtrType := resultStructOuterType.Field(0).Type().(*types.Pointer) + fileInnerType := fileInnerPtrType.Elem().(*types.Named) + fileInnerStructType := fileInnerType.Underlying().(*types.Struct) + fileInner, err := p.getZeroValue(fileInnerType) // os.file + if err != nil { + return i, err + } + for fieldIndex := 0; fieldIndex < fileInnerStructType.NumFields(); fieldIndex++ { + field := fileInnerStructType.Field(fieldIndex) + if field.Name() == "name" { + // Set the 'name' field. + name, err := p.getValue(common.Args[1], locals) + if err != nil { + return i, err + } + fileInner.(*StructValue).Fields[fieldIndex] = name + } else if field.Type().String() == "internal/poll.FD" { + // Set the file descriptor field. + field := field.Type().Underlying().(*types.Struct) + for subfieldIndex := 0; subfieldIndex < field.NumFields(); subfieldIndex++ { + subfield := field.Field(subfieldIndex) + if subfield.Name() == "Sysfd" { + sysfd, err := p.getValue(common.Args[0], locals) + if err != nil { + return i, err + } + sysfd = &ConstValue{Expr: ssa.NewConst(sysfd.(*ConstValue).Expr.Value, subfield.Type())} + fileInner.(*StructValue).Fields[fieldIndex].(*StructValue).Fields[subfieldIndex] = sysfd + } + } + } + } + fileInnerPtr := &PointerValue{fileInnerPtrType, &fileInner} // *os.file + var fileOuter Value = &StructValue{Type: resultPtrType.Elem(), Fields: []Value{fileInnerPtr}} // os.File + result := &PointerValue{resultPtrType.Elem(), &fileOuter} // *os.File + locals[instr] = result + continue + } + if canInterpret(callee) { + params := make([]Value, len(common.Args)) + for i, arg := range common.Args { + val, err := p.getValue(arg, locals) + if err != nil { + return i, err + } + params[i] = val + } + results := make([]Value, callee.Signature.Results().Len()) + subi, err := p.interpret(callee.Blocks[0].Instrs, callee.Params, params, results, dumpSSA) + if err != nil { + return i, err + } + if subi != len(callee.Blocks[0].Instrs) { + return i, errors.New("init: could not interpret all instructions of subroutine") + } + if len(results) == 1 { + locals[instr] = results[0] + } else { + panic("unimplemented: not exactly 1 result") + } + continue + } + if callee.Object() == nil || callee.Object().Name() == "init" { + return i, nil // arrived at the init#num functions + } + return i, errors.New("todo: init call: " + callee.String()) + case *ssa.ChangeType: + x, err := p.getValue(instr.X, locals) + if err != nil { + return i, err + } + // The only case when we need to bitcast is when casting between named + // struct types, as those are actually different in LLVM. Let's just + // bitcast all struct types for ease of use. + if _, ok := instr.Type().Underlying().(*types.Struct); ok { + return i, errors.New("todo: init: " + instr.String()) + } + locals[instr] = x + case *ssa.Convert: + x, err := p.getValue(instr.X, locals) + if err != nil { + return i, err + } + typeFrom := instr.X.Type().Underlying() + switch typeTo := instr.Type().Underlying().(type) { + case *types.Basic: + if typeTo.Kind() == types.String { + return i, errors.New("todo: init: cannot convert string") + } + + if _, ok := typeFrom.(*types.Pointer); ok && typeTo.Kind() == types.UnsafePointer { + locals[instr] = &PointerBitCastValue{typeTo, x} + } else if typeFrom, ok := typeFrom.(*types.Basic); ok { + if typeFrom.Kind() == types.UnsafePointer && typeTo.Kind() == types.Uintptr { + locals[instr] = &PointerToUintptrValue{x} + } else if typeFrom.Info()&types.IsInteger != 0 && typeTo.Info()&types.IsInteger != 0 { + locals[instr] = &ConstValue{Expr: ssa.NewConst(x.(*ConstValue).Expr.Value, typeTo)} + } else { + return i, errors.New("todo: init: unknown basic-to-basic convert: " + instr.String()) + } + } else { + return i, errors.New("todo: init: unknown basic convert: " + instr.String()) + } + case *types.Pointer: + if typeFrom, ok := typeFrom.(*types.Basic); ok && typeFrom.Kind() == types.UnsafePointer { + locals[instr] = &PointerBitCastValue{typeTo, x} + } else { + panic("expected unsafe pointer conversion") + } + default: + return i, errors.New("todo: init: unknown convert: " + instr.String()) + } + case *ssa.DebugRef: + // ignore + case *ssa.Extract: + tuple, err := p.getValue(instr.Tuple, locals) + if err != nil { + return i, err + } + locals[instr] = tuple.(*StructValue).Fields[instr.Index] + case *ssa.FieldAddr: + x, err := p.getValue(instr.X, locals) + if err != nil { + return i, err + } + var structVal *StructValue + switch x := x.(type) { + case *GlobalValue: + structVal = x.Global.initializer.(*StructValue) + case *PointerValue: + structVal = (*x.Elem).(*StructValue) + default: + panic("expected a pointer") + } + locals[instr] = &PointerValue{nil, &structVal.Fields[instr.Field]} + case *ssa.IndexAddr: + x, err := p.getValue(instr.X, locals) + if err != nil { + return i, err + } + if cnst, ok := instr.Index.(*ssa.Const); ok { + index, _ := constant.Int64Val(cnst.Value) + switch xPtr := x.(type) { + case *GlobalValue: + x = xPtr.Global.initializer + case *PointerValue: + x = *xPtr.Elem + default: + panic("expected a pointer") + } + switch x := x.(type) { + case *ArrayValue: + locals[instr] = &PointerValue{nil, &x.Elems[index]} + default: + return i, errors.New("todo: init IndexAddr not on an array or struct") + } + } else { + return i, errors.New("todo: init IndexAddr index: " + instr.Index.String()) + } + case *ssa.MakeInterface: + locals[instr] = &InterfaceValue{instr.X.Type(), locals[instr.X]} + case *ssa.MakeMap: + locals[instr] = &MapValue{instr.Type().Underlying().(*types.Map), nil, nil} + case *ssa.MapUpdate: + // Assume no duplicate keys exist. This is most likely true for + // autogenerated code, but may not be true when trying to interpret + // user code. + key, err := p.getValue(instr.Key, locals) + if err != nil { + return i, err + } + value, err := p.getValue(instr.Value, locals) + if err != nil { + return i, err + } + x := locals[instr.Map].(*MapValue) + x.Keys = append(x.Keys, key) + x.Values = append(x.Values, value) + case *ssa.Return: + for i, r := range instr.Results { + val, err := p.getValue(r, locals) + if err != nil { + return i, err + } + results[i] = val + } + case *ssa.Slice: + // Turn a just-allocated array into a slice. + if instr.Low != nil || instr.High != nil || instr.Max != nil { + return i, errors.New("init: slice expression with bounds") + } + source, err := p.getValue(instr.X, locals) + if err != nil { + return i, err + } + switch source := source.(type) { + case *PointerValue: // pointer to array + array := (*source.Elem).(*ArrayValue) + locals[instr] = &SliceValue{instr.Type().Underlying().(*types.Slice), array} + default: + return i, errors.New("init: unknown slice type") + } + case *ssa.Store: + if addr, ok := instr.Addr.(*ssa.Global); ok { + if strings.HasPrefix(instr.Addr.Name(), "__cgofn__cgo_") || strings.HasPrefix(instr.Addr.Name(), "_cgo_") { + // Ignore CGo global variables which we don't use. + continue + } + value, err := p.getValue(instr.Val, locals) + if err != nil { + return i, err + } + p.GetGlobal(addr).initializer = value + } else if addr, ok := locals[instr.Addr]; ok { + value, err := p.getValue(instr.Val, locals) + if err != nil { + return i, err + } + if addr, ok := addr.(*PointerValue); ok { + *(addr.Elem) = value + } else { + panic("store to non-pointer") + } + } else { + return i, errors.New("todo: init Store: " + instr.String()) + } + case *ssa.UnOp: + if instr.Op != token.MUL || instr.CommaOk { + return i, errors.New("init: unknown unop: " + instr.String()) + } + valPtr, err := p.getValue(instr.X, locals) + if err != nil { + return i, err + } + switch valPtr := valPtr.(type) { + case *GlobalValue: + locals[instr] = valPtr.Global.initializer + case *PointerValue: + locals[instr] = *valPtr.Elem + default: + panic("expected a pointer") + } + default: + return i, nil + } + } + return len(instrs), nil +} + +// Check whether this function can be interpreted at compile time. For that, it +// needs to only contain relatively simple instructions (for example, no control +// flow). +func canInterpret(callee *ssa.Function) bool { + if len(callee.Blocks) != 1 || callee.Signature.Results().Len() != 1 { + // No control flow supported so only one basic block. + // Only exactly one return value supported right now so check that as + // well. + return false + } + for _, instr := range callee.Blocks[0].Instrs { + switch instr.(type) { + // Ignore all functions fully supported by Program.interpret() + // above. + case *ssa.Alloc: + case *ssa.ChangeType: + case *ssa.Convert: + case *ssa.DebugRef: + case *ssa.Extract: + case *ssa.FieldAddr: + case *ssa.IndexAddr: + case *ssa.MakeInterface: + case *ssa.MakeMap: + case *ssa.MapUpdate: + case *ssa.Return: + case *ssa.Slice: + case *ssa.Store: + case *ssa.UnOp: + default: + return false + } + } + return true +} + +func (p *Program) getValue(value ssa.Value, locals map[ssa.Value]Value) (Value, error) { + switch value := value.(type) { + case *ssa.Const: + return &ConstValue{value}, nil + case *ssa.Function: + return &FunctionValue{value.Type(), value}, nil + case *ssa.Global: + if strings.HasPrefix(value.Name(), "__cgofn__cgo_") || strings.HasPrefix(value.Name(), "_cgo_") { + // Ignore CGo global variables which we don't use. + return nil, ErrCGoWrapper + } + g := p.GetGlobal(value) + if g.initializer == nil { + value, err := p.getZeroValue(value.Type().Underlying().(*types.Pointer).Elem()) + if err != nil { + return nil, err + } + g.initializer = value + } + return &GlobalValue{g}, nil + default: + if local, ok := locals[value]; ok { + return local, nil + } else { + return nil, errors.New("todo: init: unknown value: " + value.String()) + } + } +} + +func (p *Program) getZeroValue(t types.Type) (Value, error) { + switch typ := t.Underlying().(type) { + case *types.Array: + elems := make([]Value, typ.Len()) + for i := range elems { + elem, err := p.getZeroValue(typ.Elem()) + if err != nil { + return nil, err + } + elems[i] = elem + } + return &ArrayValue{typ.Elem(), elems}, nil + case *types.Basic: + return &ZeroBasicValue{typ}, nil + case *types.Signature: + return &FunctionValue{typ, nil}, nil + case *types.Interface: + return &InterfaceValue{typ, nil}, nil + case *types.Map: + return &MapValue{typ, nil, nil}, nil + case *types.Pointer: + return &PointerValue{typ, nil}, nil + case *types.Struct: + elems := make([]Value, typ.NumFields()) + for i := range elems { + elem, err := p.getZeroValue(typ.Field(i).Type()) + if err != nil { + return nil, err + } + elems[i] = elem + } + return &StructValue{t, elems}, nil + case *types.Slice: + return &SliceValue{typ, nil}, nil + default: + return nil, errors.New("todo: init: unknown global type: " + typ.String()) + } +} + +// Boxed value for interpreter. +type Value interface { +} + +type ConstValue struct { + Expr *ssa.Const +} + +type ZeroBasicValue struct { + Type *types.Basic +} + +type PointerValue struct { + Type types.Type + Elem *Value +} + +type FunctionValue struct { + Type types.Type + Elem *ssa.Function +} + +type InterfaceValue struct { + Type types.Type + Elem Value +} + +type PointerBitCastValue struct { + Type types.Type + Elem Value +} + +type PointerToUintptrValue struct { + Elem Value +} + +type GlobalValue struct { + Global *Global +} + +type ArrayValue struct { + ElemType types.Type + Elems []Value +} + +type StructValue struct { + Type types.Type // types.Struct or types.Named + Fields []Value +} + +type SliceValue struct { + Type *types.Slice + Array *ArrayValue +} + +type MapValue struct { + Type *types.Map + Keys []Value + Values []Value +} diff --git a/ir/ir.go b/ir/ir.go new file mode 100644 index 000000000..4f098425d --- /dev/null +++ b/ir/ir.go @@ -0,0 +1,387 @@ +package ir + +import ( + "go/ast" + "go/token" + "go/types" + "sort" + "strings" + + "github.com/aykevl/llvm/bindings/go/llvm" + "golang.org/x/tools/go/loader" + "golang.org/x/tools/go/ssa" + "golang.org/x/tools/go/ssa/ssautil" +) + +// This file provides a wrapper around go/ssa values and adds extra +// functionality to them. + +// View on all functions, types, and globals in a program, with analysis +// results. +type Program struct { + Program *ssa.Program + mainPkg *ssa.Package + Functions []*Function + functionMap map[*ssa.Function]*Function + Globals []*Global + globalMap map[*ssa.Global]*Global + comments map[string]*ast.CommentGroup + NamedTypes []*NamedType + needsScheduler bool + goCalls []*ssa.Go + typesWithMethods map[string]*TypeWithMethods // see AnalyseInterfaceConversions + typesWithoutMethods map[string]int // see AnalyseInterfaceConversions + methodSignatureNames map[string]int // see MethodNum + interfaces map[string]*Interface // see AnalyseInterfaceConversions + fpWithContext map[string]struct{} // see AnalyseFunctionPointers +} + +// Function or method. +type Function struct { + *ssa.Function + LLVMFn llvm.Value + linkName string // go:linkname or go:export pragma + exported bool // go:export + nobounds bool // go:nobounds pragma + blocking bool // calculated by AnalyseBlockingRecursive + flag bool // used by dead code elimination + addressTaken bool // used as function pointer, calculated by AnalyseFunctionPointers + parents []*Function // calculated by AnalyseCallgraph + children []*Function // calculated by AnalyseCallgraph +} + +// Global variable, possibly constant. +type Global struct { + *ssa.Global + program *Program + LLVMGlobal llvm.Value + linkName string // go:extern + extern bool // go:extern + initializer Value +} + +// Type with a name and possibly methods. +type NamedType struct { + *ssa.Type + LLVMType llvm.Type +} + +// Type that is at some point put in an interface. +type TypeWithMethods struct { + t types.Type + Num int + Methods map[string]*types.Selection +} + +// Interface type that is at some point used in a type assert (to check whether +// it implements another interface). +type Interface struct { + Num int + Type *types.Interface +} + +// Create and intialize a new *Program from a *ssa.Program. +func NewProgram(lprogram *loader.Program, mainPath string) *Program { + comments := map[string]*ast.CommentGroup{} + for _, pkgInfo := range lprogram.AllPackages { + for _, file := range pkgInfo.Files { + for _, decl := range file.Decls { + switch decl := decl.(type) { + case *ast.GenDecl: + switch decl.Tok { + case token.VAR: + if len(decl.Specs) != 1 { + continue + } + for _, spec := range decl.Specs { + valueSpec := spec.(*ast.ValueSpec) + for _, valueName := range valueSpec.Names { + id := pkgInfo.Pkg.Path() + "." + valueName.Name + comments[id] = decl.Doc + } + } + } + } + } + } + } + + program := ssautil.CreateProgram(lprogram, ssa.SanityCheckFunctions|ssa.BareInits|ssa.GlobalDebug) + program.Build() + + return &Program{ + Program: program, + mainPkg: program.ImportedPackage(mainPath), + functionMap: make(map[*ssa.Function]*Function), + globalMap: make(map[*ssa.Global]*Global), + methodSignatureNames: make(map[string]int), + interfaces: make(map[string]*Interface), + comments: comments, + } +} + +// Add a package to this Program. All packages need to be added first before any +// analysis is done for correct results. +func (p *Program) AddPackage(pkg *ssa.Package) { + memberNames := make([]string, 0) + for name := range pkg.Members { + if isCGoInternal(name) { + continue + } + memberNames = append(memberNames, name) + } + sort.Strings(memberNames) + + for _, name := range memberNames { + member := pkg.Members[name] + switch member := member.(type) { + case *ssa.Function: + if isCGoInternal(member.Name()) { + continue + } + p.addFunction(member) + case *ssa.Type: + t := &NamedType{Type: member} + p.NamedTypes = append(p.NamedTypes, t) + methods := getAllMethods(pkg.Prog, member.Type()) + if !types.IsInterface(member.Type()) { + // named type + for _, method := range methods { + p.addFunction(pkg.Prog.MethodValue(method)) + } + } + case *ssa.Global: + g := &Global{program: p, Global: member} + doc := p.comments[g.RelString(nil)] + if doc != nil { + g.parsePragmas(doc) + } + p.Globals = append(p.Globals, g) + p.globalMap[member] = g + case *ssa.NamedConst: + // Ignore: these are already resolved. + default: + panic("unknown member type: " + member.String()) + } + } +} + +func (p *Program) addFunction(ssaFn *ssa.Function) { + f := &Function{Function: ssaFn} + f.parsePragmas() + p.Functions = append(p.Functions, f) + p.functionMap[ssaFn] = f + + for _, anon := range ssaFn.AnonFuncs { + p.addFunction(anon) + } +} + +// Return true if this package imports "unsafe", false otherwise. +func hasUnsafeImport(pkg *types.Package) bool { + for _, imp := range pkg.Imports() { + if imp == types.Unsafe { + return true + } + } + return false +} + +func (p *Program) GetFunction(ssaFn *ssa.Function) *Function { + return p.functionMap[ssaFn] +} + +func (p *Program) GetGlobal(ssaGlobal *ssa.Global) *Global { + return p.globalMap[ssaGlobal] +} + +// SortMethods sorts the list of methods by method ID. +func (p *Program) SortMethods(methods []*types.Selection) { + m := &methodList{methods: methods, program: p} + sort.Sort(m) +} + +// SortFuncs sorts the list of functions by method ID. +func (p *Program) SortFuncs(funcs []*types.Func) { + m := &funcList{funcs: funcs, program: p} + sort.Sort(m) +} + +func (p *Program) MainPkg() *ssa.Package { + return p.mainPkg +} + +// Parse compiler directives in the preceding comments. +func (f *Function) parsePragmas() { + if f.Syntax() == nil { + return + } + if decl, ok := f.Syntax().(*ast.FuncDecl); ok && decl.Doc != nil { + for _, comment := range decl.Doc.List { + if !strings.HasPrefix(comment.Text, "//go:") { + continue + } + parts := strings.Fields(comment.Text) + switch parts[0] { + case "//go:linkname": + if len(parts) != 3 || parts[1] != f.Name() { + continue + } + // Only enable go:linkname when the package imports "unsafe". + // This is a slightly looser requirement than what gc uses: gc + // requires the file to import "unsafe", not the package as a + // whole. + if hasUnsafeImport(f.Pkg.Pkg) { + f.linkName = parts[2] + } + case "//go:nobounds": + // Skip bounds checking in this function. Useful for some + // runtime functions. + // This is somewhat dangerous and thus only imported in packages + // that import unsafe. + if hasUnsafeImport(f.Pkg.Pkg) { + f.nobounds = true + } + case "//go:export": + if len(parts) != 2 { + continue + } + f.linkName = parts[1] + f.exported = true + } + } + } +} + +func (f *Function) IsNoBounds() bool { + return f.nobounds +} + +// Return true iff this function is externally visible. +func (f *Function) IsExported() bool { + return f.exported +} + +// Return the link name for this function. +func (f *Function) LinkName() string { + if f.linkName != "" { + return f.linkName + } + if f.Signature.Recv() != nil { + // Method on a defined type (which may be a pointer). + return f.RelString(nil) + } else { + // Bare function. + if name := f.CName(); name != "" { + // Name CGo functions directly. + return name + } else { + return f.RelString(nil) + } + } +} + +// Return the name of the C function if this is a CGo wrapper. Otherwise, return +// a zero-length string. +func (f *Function) CName() string { + name := f.Name() + if strings.HasPrefix(name, "_Cfunc_") { + return name[len("_Cfunc_"):] + } + return "" +} + +// Parse //go: pragma comments from the source. +func (g *Global) parsePragmas(doc *ast.CommentGroup) { + for _, comment := range doc.List { + if !strings.HasPrefix(comment.Text, "//go:") { + continue + } + parts := strings.Fields(comment.Text) + switch parts[0] { + case "//go:extern": + g.extern = true + if len(parts) == 2 { + g.linkName = parts[1] + } + } + } +} + +// Return the link name for this global. +func (g *Global) LinkName() string { + if g.linkName != "" { + return g.linkName + } + return g.RelString(nil) +} + +func (g *Global) IsExtern() bool { + return g.extern +} + +func (g *Global) Initializer() Value { + return g.initializer +} + +// Wrapper type to implement sort.Interface for []*types.Selection. +type methodList struct { + methods []*types.Selection + program *Program +} + +func (m *methodList) Len() int { + return len(m.methods) +} + +func (m *methodList) Less(i, j int) bool { + iid := m.program.MethodNum(m.methods[i].Obj().(*types.Func)) + jid := m.program.MethodNum(m.methods[j].Obj().(*types.Func)) + return iid < jid +} + +func (m *methodList) Swap(i, j int) { + m.methods[i], m.methods[j] = m.methods[j], m.methods[i] +} + +// Wrapper type to implement sort.Interface for []*types.Func. +type funcList struct { + funcs []*types.Func + program *Program +} + +func (fl *funcList) Len() int { + return len(fl.funcs) +} + +func (fl *funcList) Less(i, j int) bool { + iid := fl.program.MethodNum(fl.funcs[i]) + jid := fl.program.MethodNum(fl.funcs[j]) + return iid < jid +} + +func (fl *funcList) Swap(i, j int) { + fl.funcs[i], fl.funcs[j] = fl.funcs[j], fl.funcs[i] +} + +// Return true if this is a CGo-internal function that can be ignored. +func isCGoInternal(name string) bool { + if strings.HasPrefix(name, "_Cgo_") || strings.HasPrefix(name, "_cgo") { + // _Cgo_ptr, _Cgo_use, _cgoCheckResult, _cgo_runtime_cgocall + return true // CGo-internal functions + } + if strings.HasPrefix(name, "__cgofn__cgo_") { + return true // CGo function pointer in global scope + } + return false +} + +// Get all methods of a type. +func getAllMethods(prog *ssa.Program, typ types.Type) []*types.Selection { + ms := prog.MethodSets.MethodSet(typ) + methods := make([]*types.Selection, ms.Len()) + for i := 0; i < ms.Len(); i++ { + methods[i] = ms.At(i) + } + return methods +} diff --git a/ir/passes.go b/ir/passes.go new file mode 100644 index 000000000..87175faea --- /dev/null +++ b/ir/passes.go @@ -0,0 +1,427 @@ +package ir + +import ( + "go/types" + "sort" + "strings" + + "golang.org/x/tools/go/ssa" +) + +// This file implements several optimization passes (analysis + transform) to +// optimize code in SSA form before it is compiled to LLVM IR. It is based on +// the IR defined in ir.go. + +// Make a readable version of a method signature (including the function name, +// excluding the receiver name). This string is used internally to match +// interfaces and to call the correct method on an interface. Examples: +// +// String() string +// Read([]byte) (int, error) +func MethodSignature(method *types.Func) string { + return method.Name() + Signature(method.Type().(*types.Signature)) +} + +// Make a readable version of a function (pointer) signature. This string is +// used internally to match signatures (like in AnalyseFunctionPointers). +// Examples: +// +// () string +// (string, int) (int, error) +func Signature(sig *types.Signature) string { + s := "" + if sig.Params().Len() == 0 { + s += "()" + } else { + s += "(" + for i := 0; i < sig.Params().Len(); i++ { + if i > 0 { + s += ", " + } + s += sig.Params().At(i).Type().String() + } + s += ")" + } + if sig.Results().Len() == 0 { + // keep as-is + } else if sig.Results().Len() == 1 { + s += " " + sig.Results().At(0).Type().String() + } else { + s += " (" + for i := 0; i < sig.Results().Len(); i++ { + if i > 0 { + s += ", " + } + s += sig.Results().At(i).Type().String() + } + s += ")" + } + return s +} + +// Convert an interface type to a string of all method strings, separated by +// "; ". For example: "Read([]byte) (int, error); Close() error" +func InterfaceKey(itf *types.Interface) string { + methodStrings := []string{} + for i := 0; i < itf.NumMethods(); i++ { + method := itf.Method(i) + methodStrings = append(methodStrings, MethodSignature(method)) + } + sort.Strings(methodStrings) + return strings.Join(methodStrings, ";") +} + +// Fill in parents of all functions. +// +// All packages need to be added before this pass can run, or it will produce +// incorrect results. +func (p *Program) AnalyseCallgraph() { + for _, f := range p.Functions { + // Clear, if AnalyseCallgraph has been called before. + f.children = nil + f.parents = nil + + for _, block := range f.Blocks { + for _, instr := range block.Instrs { + switch instr := instr.(type) { + case *ssa.Call: + if instr.Common().IsInvoke() { + continue + } + switch call := instr.Call.Value.(type) { + case *ssa.Builtin: + // ignore + case *ssa.Function: + if isCGoInternal(call.Name()) { + continue + } + child := p.GetFunction(call) + if child.CName() != "" { + continue // assume non-blocking + } + if child.RelString(nil) == "time.Sleep" { + f.blocking = true + } + f.children = append(f.children, child) + } + } + } + } + } + for _, f := range p.Functions { + for _, child := range f.children { + child.parents = append(child.parents, f) + } + } +} + +// Find all types that are put in an interface. +func (p *Program) AnalyseInterfaceConversions() { + // Clear, if AnalyseTypes has been called before. + p.typesWithoutMethods = map[string]int{"nil": 0} + p.typesWithMethods = map[string]*TypeWithMethods{} + + for _, f := range p.Functions { + for _, block := range f.Blocks { + for _, instr := range block.Instrs { + switch instr := instr.(type) { + case *ssa.MakeInterface: + methods := getAllMethods(f.Prog, instr.X.Type()) + name := instr.X.Type().String() + if _, ok := p.typesWithMethods[name]; !ok && len(methods) > 0 { + t := &TypeWithMethods{ + t: instr.X.Type(), + Num: len(p.typesWithMethods), + Methods: make(map[string]*types.Selection), + } + for _, sel := range methods { + name := MethodSignature(sel.Obj().(*types.Func)) + t.Methods[name] = sel + } + p.typesWithMethods[name] = t + } else if _, ok := p.typesWithoutMethods[name]; !ok && len(methods) == 0 { + p.typesWithoutMethods[name] = len(p.typesWithoutMethods) + } + } + } + } + } +} + +// Analyse which function pointer signatures need a context parameter. +// This makes calling function pointers more efficient. +func (p *Program) AnalyseFunctionPointers() { + // Clear, if AnalyseFunctionPointers has been called before. + p.fpWithContext = map[string]struct{}{} + + for _, f := range p.Functions { + for _, block := range f.Blocks { + for _, instr := range block.Instrs { + switch instr := instr.(type) { + case ssa.CallInstruction: + for _, arg := range instr.Common().Args { + switch arg := arg.(type) { + case *ssa.Function: + f := p.GetFunction(arg) + f.addressTaken = true + } + } + case *ssa.DebugRef: + default: + // For anything that isn't a call... + for _, operand := range instr.Operands(nil) { + if operand == nil || *operand == nil || isCGoInternal((*operand).Name()) { + continue + } + switch operand := (*operand).(type) { + case *ssa.Function: + f := p.GetFunction(operand) + f.addressTaken = true + } + } + } + switch instr := instr.(type) { + case *ssa.MakeClosure: + fn := instr.Fn.(*ssa.Function) + sig := Signature(fn.Signature) + p.fpWithContext[sig] = struct{}{} + } + } + } + } +} + +// Analyse which functions are recursively blocking. +// +// Depends on AnalyseCallgraph. +func (p *Program) AnalyseBlockingRecursive() { + worklist := make([]*Function, 0) + + // Fill worklist with directly blocking functions. + for _, f := range p.Functions { + if f.blocking { + worklist = append(worklist, f) + } + } + + // Keep reducing this worklist by marking a function as recursively blocking + // from the worklist and pushing all its parents that are non-blocking. + // This is somewhat similar to a worklist in a mark-sweep garbage collector. + // The work items are then grey objects. + for len(worklist) != 0 { + // Pick the topmost. + f := worklist[len(worklist)-1] + worklist = worklist[:len(worklist)-1] + for _, parent := range f.parents { + if !parent.blocking { + parent.blocking = true + worklist = append(worklist, parent) + } + } + } +} + +// Check whether we need a scheduler. A scheduler is only necessary when there +// are go calls that start blocking functions (if they're not blocking, the go +// function can be turned into a regular function call). +// +// Depends on AnalyseBlockingRecursive. +func (p *Program) AnalyseGoCalls() { + p.goCalls = nil + for _, f := range p.Functions { + for _, block := range f.Blocks { + for _, instr := range block.Instrs { + switch instr := instr.(type) { + case *ssa.Go: + p.goCalls = append(p.goCalls, instr) + } + } + } + } + for _, instr := range p.goCalls { + switch instr := instr.Call.Value.(type) { + case *ssa.Builtin: + case *ssa.Function: + if p.functionMap[instr].blocking { + p.needsScheduler = true + } + default: + panic("unknown go call function type") + } + } +} + +// Simple pass that removes dead code. This pass makes later analysis passes +// more useful. +func (p *Program) SimpleDCE() { + // Unmark all functions. + for _, f := range p.Functions { + f.flag = false + } + + // Initial set of live functions. Include main.main, *.init and runtime.* + // functions. + main := p.mainPkg.Members["main"].(*ssa.Function) + runtimePkg := p.Program.ImportedPackage("runtime") + p.GetFunction(main).flag = true + worklist := []*ssa.Function{main} + for _, f := range p.Functions { + if f.Synthetic == "package initializer" || f.Pkg == runtimePkg { + if f.flag || isCGoInternal(f.Name()) { + continue + } + f.flag = true + worklist = append(worklist, f.Function) + } + } + + // Mark all called functions recursively. + for len(worklist) != 0 { + f := worklist[len(worklist)-1] + worklist = worklist[:len(worklist)-1] + for _, block := range f.Blocks { + for _, instr := range block.Instrs { + if instr, ok := instr.(*ssa.MakeInterface); ok { + for _, sel := range getAllMethods(p.Program, instr.X.Type()) { + fn := p.Program.MethodValue(sel) + callee := p.GetFunction(fn) + if callee == nil { + // TODO: why is this necessary? + p.addFunction(fn) + callee = p.GetFunction(fn) + } + if !callee.flag { + callee.flag = true + worklist = append(worklist, callee.Function) + } + } + } + for _, operand := range instr.Operands(nil) { + if operand == nil || *operand == nil || isCGoInternal((*operand).Name()) { + continue + } + switch operand := (*operand).(type) { + case *ssa.Function: + f := p.GetFunction(operand) + if f == nil { + // FIXME HACK: this function should have been + // discovered already. It is not for bound methods. + p.addFunction(operand) + f = p.GetFunction(operand) + } + if !f.flag { + f.flag = true + worklist = append(worklist, operand) + } + } + } + } + } + } + + // Remove unmarked functions. + livefunctions := []*Function{} + for _, f := range p.Functions { + if f.flag { + livefunctions = append(livefunctions, f) + } else { + delete(p.functionMap, f.Function) + } + } + p.Functions = livefunctions +} + +// Whether this function needs a scheduler. +// +// Depends on AnalyseGoCalls. +func (p *Program) NeedsScheduler() bool { + return p.needsScheduler +} + +// Whether this function blocks. Builtins are also accepted for convenience. +// They will always be non-blocking. +// +// Depends on AnalyseBlockingRecursive. +func (p *Program) IsBlocking(f *Function) bool { + if !p.needsScheduler { + return false + } + return f.blocking +} + +// Return the type number and whether this type is actually used. Used in +// interface conversions (type is always used) and type asserts (type may not be +// used, meaning assert is always false in this program). +// +// May only be used after all packages have been added to the analyser. +func (p *Program) TypeNum(typ types.Type) (int, bool) { + if n, ok := p.typesWithoutMethods[typ.String()]; ok { + return n, true + } else if meta, ok := p.typesWithMethods[typ.String()]; ok { + return len(p.typesWithoutMethods) + meta.Num, true + } else { + return -1, false // type is never put in an interface + } +} + +// InterfaceNum returns the numeric interface ID of this type, for use in type +// asserts. +func (p *Program) InterfaceNum(itfType *types.Interface) int { + key := InterfaceKey(itfType) + if itf, ok := p.interfaces[key]; !ok { + num := len(p.interfaces) + p.interfaces[key] = &Interface{Num: num, Type: itfType} + return num + } else { + return itf.Num + } +} + +// MethodNum returns the numeric ID of this method, to be used in method lookups +// on interfaces for example. +func (p *Program) MethodNum(method *types.Func) int { + name := MethodSignature(method) + if _, ok := p.methodSignatureNames[name]; !ok { + p.methodSignatureNames[name] = len(p.methodSignatureNames) + } + return p.methodSignatureNames[MethodSignature(method)] +} + +// The start index of the first dynamic type that has methods. +// Types without methods always have a lower ID and types with methods have this +// or a higher ID. +// +// May only be used after all packages have been added to the analyser. +func (p *Program) FirstDynamicType() int { + return len(p.typesWithoutMethods) +} + +// Return all types with methods, sorted by type ID. +func (p *Program) AllDynamicTypes() []*TypeWithMethods { + l := make([]*TypeWithMethods, len(p.typesWithMethods)) + for _, m := range p.typesWithMethods { + l[m.Num] = m + } + return l +} + +// Return all interface types, sorted by interface ID. +func (p *Program) AllInterfaces() []*Interface { + l := make([]*Interface, len(p.interfaces)) + for _, itf := range p.interfaces { + l[itf.Num] = itf + } + return l +} + +func (p *Program) FunctionNeedsContext(f *Function) bool { + if !f.addressTaken { + return false + } + return p.SignatureNeedsContext(f.Signature) +} + +func (p *Program) SignatureNeedsContext(sig *types.Signature) bool { + _, needsContext := p.fpWithContext[Signature(sig)] + return needsContext +} |