aboutsummaryrefslogtreecommitdiffhomepage
path: root/ir
diff options
context:
space:
mode:
authorAyke van Laethem <[email protected]>2018-09-22 20:25:50 +0200
committerAyke van Laethem <[email protected]>2018-09-22 20:32:07 +0200
commitb75a02e66dd240685b281b45bc87d37788fb0a7c (patch)
tree056f039cbb6705537dfd718d6f769f6b4d51e0a4 /ir
parent473e71b5731f88edfd07c4c07da32204f74e8389 (diff)
downloadtinygo-b75a02e66dd240685b281b45bc87d37788fb0a7c.tar.gz
tinygo-b75a02e66dd240685b281b45bc87d37788fb0a7c.zip
compiler: refactor IR parts into separate package
Diffstat (limited to 'ir')
-rw-r--r--ir/interpreter.go533
-rw-r--r--ir/ir.go387
-rw-r--r--ir/passes.go427
3 files changed, 1347 insertions, 0 deletions
diff --git a/ir/interpreter.go b/ir/interpreter.go
new file mode 100644
index 000000000..078ad1ac3
--- /dev/null
+++ b/ir/interpreter.go
@@ -0,0 +1,533 @@
+package ir
+
+// This file provides functionality to interpret very basic Go SSA, for
+// compile-time initialization of globals.
+
+import (
+ "errors"
+ "fmt"
+ "go/constant"
+ "go/token"
+ "go/types"
+ "strings"
+
+ "golang.org/x/tools/go/ssa"
+)
+
+var ErrCGoWrapper = errors.New("tinygo internal: cgo wrapper") // a signal, not an error
+
+// Ignore these calls (replace with a zero return value) when encountered during
+// interpretation.
+var ignoreInitCalls = map[string]struct{}{
+ "syscall.runtime_envs": struct{}{},
+ "syscall/js.predefValue": struct{}{},
+ "(syscall/js.Value).Get": struct{}{},
+ "(syscall/js.Value).New": struct{}{},
+ "(syscall/js.Value).Int": struct{}{},
+ "os.init$1": struct{}{},
+}
+
+// Interpret instructions as far as possible, and drop those instructions from
+// the basic block.
+func (p *Program) Interpret(block *ssa.BasicBlock, dumpSSA bool) error {
+ if dumpSSA {
+ fmt.Printf("\ninterpret: %s\n", block.Parent().Pkg.Pkg.Path())
+ }
+ for {
+ i, err := p.interpret(block.Instrs, nil, nil, nil, dumpSSA)
+ if err == ErrCGoWrapper {
+ // skip this instruction
+ block.Instrs = block.Instrs[i+1:]
+ continue
+ }
+ block.Instrs = block.Instrs[i:]
+ return err
+ }
+}
+
+// Interpret instructions as far as possible, and return the index of the first
+// unknown instruction.
+func (p *Program) interpret(instrs []ssa.Instruction, paramKeys []*ssa.Parameter, paramValues []Value, results []Value, dumpSSA bool) (int, error) {
+ locals := map[ssa.Value]Value{}
+ for i, key := range paramKeys {
+ locals[key] = paramValues[i]
+ }
+ for i, instr := range instrs {
+ if _, ok := instr.(*ssa.DebugRef); ok {
+ continue
+ }
+ if dumpSSA {
+ if val, ok := instr.(ssa.Value); ok && val.Name() != "" {
+ fmt.Printf("\t%s: %s = %s\n", instr.Parent().RelString(nil), val.Name(), val.String())
+ } else {
+ fmt.Printf("\t%s: %s\n", instr.Parent().RelString(nil), instr.String())
+ }
+ }
+ switch instr := instr.(type) {
+ case *ssa.Alloc:
+ alloc, err := p.getZeroValue(instr.Type().Underlying().(*types.Pointer).Elem())
+ if err != nil {
+ return i, err
+ }
+ locals[instr] = &PointerValue{nil, &alloc}
+ case *ssa.BinOp:
+ if typ, ok := instr.Type().(*types.Basic); ok && typ.Kind() == types.String {
+ // Concatenate two strings.
+ // This happens in the time package, for example.
+ x, err := p.getValue(instr.X, locals)
+ if err != nil {
+ return i, err
+ }
+ y, err := p.getValue(instr.Y, locals)
+ if err != nil {
+ return i, err
+ }
+ xstr := constant.StringVal(x.(*ConstValue).Expr.Value)
+ ystr := constant.StringVal(y.(*ConstValue).Expr.Value)
+ locals[instr] = &ConstValue{ssa.NewConst(constant.MakeString(xstr+ystr), types.Typ[types.String])}
+ } else {
+ return i, errors.New("init: unknown binop: " + instr.String())
+ }
+ case *ssa.Call:
+ common := instr.Common()
+ callee := common.StaticCallee()
+ if callee == nil {
+ return i, nil // don't understand dynamic dispatch
+ }
+ if _, ok := ignoreInitCalls[callee.String()]; ok {
+ // These calls are not needed and can be ignored, for the time
+ // being.
+ results := make([]Value, callee.Signature.Results().Len())
+ for i := range results {
+ var err error
+ results[i], err = p.getZeroValue(callee.Signature.Results().At(i).Type())
+ if err != nil {
+ return i, err
+ }
+ }
+ if len(results) == 1 {
+ locals[instr] = results[0]
+ } else if len(results) > 1 {
+ locals[instr] = &StructValue{Fields: results}
+ }
+ continue
+ }
+ if callee.String() == "os.NewFile" {
+ // Emulate the creation of os.Stdin, os.Stdout and os.Stderr.
+ resultPtrType := callee.Signature.Results().At(0).Type().(*types.Pointer)
+ resultStructOuterType := resultPtrType.Elem().Underlying().(*types.Struct)
+ if resultStructOuterType.NumFields() != 1 {
+ panic("expected 1 field in os.File struct")
+ }
+ fileInnerPtrType := resultStructOuterType.Field(0).Type().(*types.Pointer)
+ fileInnerType := fileInnerPtrType.Elem().(*types.Named)
+ fileInnerStructType := fileInnerType.Underlying().(*types.Struct)
+ fileInner, err := p.getZeroValue(fileInnerType) // os.file
+ if err != nil {
+ return i, err
+ }
+ for fieldIndex := 0; fieldIndex < fileInnerStructType.NumFields(); fieldIndex++ {
+ field := fileInnerStructType.Field(fieldIndex)
+ if field.Name() == "name" {
+ // Set the 'name' field.
+ name, err := p.getValue(common.Args[1], locals)
+ if err != nil {
+ return i, err
+ }
+ fileInner.(*StructValue).Fields[fieldIndex] = name
+ } else if field.Type().String() == "internal/poll.FD" {
+ // Set the file descriptor field.
+ field := field.Type().Underlying().(*types.Struct)
+ for subfieldIndex := 0; subfieldIndex < field.NumFields(); subfieldIndex++ {
+ subfield := field.Field(subfieldIndex)
+ if subfield.Name() == "Sysfd" {
+ sysfd, err := p.getValue(common.Args[0], locals)
+ if err != nil {
+ return i, err
+ }
+ sysfd = &ConstValue{Expr: ssa.NewConst(sysfd.(*ConstValue).Expr.Value, subfield.Type())}
+ fileInner.(*StructValue).Fields[fieldIndex].(*StructValue).Fields[subfieldIndex] = sysfd
+ }
+ }
+ }
+ }
+ fileInnerPtr := &PointerValue{fileInnerPtrType, &fileInner} // *os.file
+ var fileOuter Value = &StructValue{Type: resultPtrType.Elem(), Fields: []Value{fileInnerPtr}} // os.File
+ result := &PointerValue{resultPtrType.Elem(), &fileOuter} // *os.File
+ locals[instr] = result
+ continue
+ }
+ if canInterpret(callee) {
+ params := make([]Value, len(common.Args))
+ for i, arg := range common.Args {
+ val, err := p.getValue(arg, locals)
+ if err != nil {
+ return i, err
+ }
+ params[i] = val
+ }
+ results := make([]Value, callee.Signature.Results().Len())
+ subi, err := p.interpret(callee.Blocks[0].Instrs, callee.Params, params, results, dumpSSA)
+ if err != nil {
+ return i, err
+ }
+ if subi != len(callee.Blocks[0].Instrs) {
+ return i, errors.New("init: could not interpret all instructions of subroutine")
+ }
+ if len(results) == 1 {
+ locals[instr] = results[0]
+ } else {
+ panic("unimplemented: not exactly 1 result")
+ }
+ continue
+ }
+ if callee.Object() == nil || callee.Object().Name() == "init" {
+ return i, nil // arrived at the init#num functions
+ }
+ return i, errors.New("todo: init call: " + callee.String())
+ case *ssa.ChangeType:
+ x, err := p.getValue(instr.X, locals)
+ if err != nil {
+ return i, err
+ }
+ // The only case when we need to bitcast is when casting between named
+ // struct types, as those are actually different in LLVM. Let's just
+ // bitcast all struct types for ease of use.
+ if _, ok := instr.Type().Underlying().(*types.Struct); ok {
+ return i, errors.New("todo: init: " + instr.String())
+ }
+ locals[instr] = x
+ case *ssa.Convert:
+ x, err := p.getValue(instr.X, locals)
+ if err != nil {
+ return i, err
+ }
+ typeFrom := instr.X.Type().Underlying()
+ switch typeTo := instr.Type().Underlying().(type) {
+ case *types.Basic:
+ if typeTo.Kind() == types.String {
+ return i, errors.New("todo: init: cannot convert string")
+ }
+
+ if _, ok := typeFrom.(*types.Pointer); ok && typeTo.Kind() == types.UnsafePointer {
+ locals[instr] = &PointerBitCastValue{typeTo, x}
+ } else if typeFrom, ok := typeFrom.(*types.Basic); ok {
+ if typeFrom.Kind() == types.UnsafePointer && typeTo.Kind() == types.Uintptr {
+ locals[instr] = &PointerToUintptrValue{x}
+ } else if typeFrom.Info()&types.IsInteger != 0 && typeTo.Info()&types.IsInteger != 0 {
+ locals[instr] = &ConstValue{Expr: ssa.NewConst(x.(*ConstValue).Expr.Value, typeTo)}
+ } else {
+ return i, errors.New("todo: init: unknown basic-to-basic convert: " + instr.String())
+ }
+ } else {
+ return i, errors.New("todo: init: unknown basic convert: " + instr.String())
+ }
+ case *types.Pointer:
+ if typeFrom, ok := typeFrom.(*types.Basic); ok && typeFrom.Kind() == types.UnsafePointer {
+ locals[instr] = &PointerBitCastValue{typeTo, x}
+ } else {
+ panic("expected unsafe pointer conversion")
+ }
+ default:
+ return i, errors.New("todo: init: unknown convert: " + instr.String())
+ }
+ case *ssa.DebugRef:
+ // ignore
+ case *ssa.Extract:
+ tuple, err := p.getValue(instr.Tuple, locals)
+ if err != nil {
+ return i, err
+ }
+ locals[instr] = tuple.(*StructValue).Fields[instr.Index]
+ case *ssa.FieldAddr:
+ x, err := p.getValue(instr.X, locals)
+ if err != nil {
+ return i, err
+ }
+ var structVal *StructValue
+ switch x := x.(type) {
+ case *GlobalValue:
+ structVal = x.Global.initializer.(*StructValue)
+ case *PointerValue:
+ structVal = (*x.Elem).(*StructValue)
+ default:
+ panic("expected a pointer")
+ }
+ locals[instr] = &PointerValue{nil, &structVal.Fields[instr.Field]}
+ case *ssa.IndexAddr:
+ x, err := p.getValue(instr.X, locals)
+ if err != nil {
+ return i, err
+ }
+ if cnst, ok := instr.Index.(*ssa.Const); ok {
+ index, _ := constant.Int64Val(cnst.Value)
+ switch xPtr := x.(type) {
+ case *GlobalValue:
+ x = xPtr.Global.initializer
+ case *PointerValue:
+ x = *xPtr.Elem
+ default:
+ panic("expected a pointer")
+ }
+ switch x := x.(type) {
+ case *ArrayValue:
+ locals[instr] = &PointerValue{nil, &x.Elems[index]}
+ default:
+ return i, errors.New("todo: init IndexAddr not on an array or struct")
+ }
+ } else {
+ return i, errors.New("todo: init IndexAddr index: " + instr.Index.String())
+ }
+ case *ssa.MakeInterface:
+ locals[instr] = &InterfaceValue{instr.X.Type(), locals[instr.X]}
+ case *ssa.MakeMap:
+ locals[instr] = &MapValue{instr.Type().Underlying().(*types.Map), nil, nil}
+ case *ssa.MapUpdate:
+ // Assume no duplicate keys exist. This is most likely true for
+ // autogenerated code, but may not be true when trying to interpret
+ // user code.
+ key, err := p.getValue(instr.Key, locals)
+ if err != nil {
+ return i, err
+ }
+ value, err := p.getValue(instr.Value, locals)
+ if err != nil {
+ return i, err
+ }
+ x := locals[instr.Map].(*MapValue)
+ x.Keys = append(x.Keys, key)
+ x.Values = append(x.Values, value)
+ case *ssa.Return:
+ for i, r := range instr.Results {
+ val, err := p.getValue(r, locals)
+ if err != nil {
+ return i, err
+ }
+ results[i] = val
+ }
+ case *ssa.Slice:
+ // Turn a just-allocated array into a slice.
+ if instr.Low != nil || instr.High != nil || instr.Max != nil {
+ return i, errors.New("init: slice expression with bounds")
+ }
+ source, err := p.getValue(instr.X, locals)
+ if err != nil {
+ return i, err
+ }
+ switch source := source.(type) {
+ case *PointerValue: // pointer to array
+ array := (*source.Elem).(*ArrayValue)
+ locals[instr] = &SliceValue{instr.Type().Underlying().(*types.Slice), array}
+ default:
+ return i, errors.New("init: unknown slice type")
+ }
+ case *ssa.Store:
+ if addr, ok := instr.Addr.(*ssa.Global); ok {
+ if strings.HasPrefix(instr.Addr.Name(), "__cgofn__cgo_") || strings.HasPrefix(instr.Addr.Name(), "_cgo_") {
+ // Ignore CGo global variables which we don't use.
+ continue
+ }
+ value, err := p.getValue(instr.Val, locals)
+ if err != nil {
+ return i, err
+ }
+ p.GetGlobal(addr).initializer = value
+ } else if addr, ok := locals[instr.Addr]; ok {
+ value, err := p.getValue(instr.Val, locals)
+ if err != nil {
+ return i, err
+ }
+ if addr, ok := addr.(*PointerValue); ok {
+ *(addr.Elem) = value
+ } else {
+ panic("store to non-pointer")
+ }
+ } else {
+ return i, errors.New("todo: init Store: " + instr.String())
+ }
+ case *ssa.UnOp:
+ if instr.Op != token.MUL || instr.CommaOk {
+ return i, errors.New("init: unknown unop: " + instr.String())
+ }
+ valPtr, err := p.getValue(instr.X, locals)
+ if err != nil {
+ return i, err
+ }
+ switch valPtr := valPtr.(type) {
+ case *GlobalValue:
+ locals[instr] = valPtr.Global.initializer
+ case *PointerValue:
+ locals[instr] = *valPtr.Elem
+ default:
+ panic("expected a pointer")
+ }
+ default:
+ return i, nil
+ }
+ }
+ return len(instrs), nil
+}
+
+// Check whether this function can be interpreted at compile time. For that, it
+// needs to only contain relatively simple instructions (for example, no control
+// flow).
+func canInterpret(callee *ssa.Function) bool {
+ if len(callee.Blocks) != 1 || callee.Signature.Results().Len() != 1 {
+ // No control flow supported so only one basic block.
+ // Only exactly one return value supported right now so check that as
+ // well.
+ return false
+ }
+ for _, instr := range callee.Blocks[0].Instrs {
+ switch instr.(type) {
+ // Ignore all functions fully supported by Program.interpret()
+ // above.
+ case *ssa.Alloc:
+ case *ssa.ChangeType:
+ case *ssa.Convert:
+ case *ssa.DebugRef:
+ case *ssa.Extract:
+ case *ssa.FieldAddr:
+ case *ssa.IndexAddr:
+ case *ssa.MakeInterface:
+ case *ssa.MakeMap:
+ case *ssa.MapUpdate:
+ case *ssa.Return:
+ case *ssa.Slice:
+ case *ssa.Store:
+ case *ssa.UnOp:
+ default:
+ return false
+ }
+ }
+ return true
+}
+
+func (p *Program) getValue(value ssa.Value, locals map[ssa.Value]Value) (Value, error) {
+ switch value := value.(type) {
+ case *ssa.Const:
+ return &ConstValue{value}, nil
+ case *ssa.Function:
+ return &FunctionValue{value.Type(), value}, nil
+ case *ssa.Global:
+ if strings.HasPrefix(value.Name(), "__cgofn__cgo_") || strings.HasPrefix(value.Name(), "_cgo_") {
+ // Ignore CGo global variables which we don't use.
+ return nil, ErrCGoWrapper
+ }
+ g := p.GetGlobal(value)
+ if g.initializer == nil {
+ value, err := p.getZeroValue(value.Type().Underlying().(*types.Pointer).Elem())
+ if err != nil {
+ return nil, err
+ }
+ g.initializer = value
+ }
+ return &GlobalValue{g}, nil
+ default:
+ if local, ok := locals[value]; ok {
+ return local, nil
+ } else {
+ return nil, errors.New("todo: init: unknown value: " + value.String())
+ }
+ }
+}
+
+func (p *Program) getZeroValue(t types.Type) (Value, error) {
+ switch typ := t.Underlying().(type) {
+ case *types.Array:
+ elems := make([]Value, typ.Len())
+ for i := range elems {
+ elem, err := p.getZeroValue(typ.Elem())
+ if err != nil {
+ return nil, err
+ }
+ elems[i] = elem
+ }
+ return &ArrayValue{typ.Elem(), elems}, nil
+ case *types.Basic:
+ return &ZeroBasicValue{typ}, nil
+ case *types.Signature:
+ return &FunctionValue{typ, nil}, nil
+ case *types.Interface:
+ return &InterfaceValue{typ, nil}, nil
+ case *types.Map:
+ return &MapValue{typ, nil, nil}, nil
+ case *types.Pointer:
+ return &PointerValue{typ, nil}, nil
+ case *types.Struct:
+ elems := make([]Value, typ.NumFields())
+ for i := range elems {
+ elem, err := p.getZeroValue(typ.Field(i).Type())
+ if err != nil {
+ return nil, err
+ }
+ elems[i] = elem
+ }
+ return &StructValue{t, elems}, nil
+ case *types.Slice:
+ return &SliceValue{typ, nil}, nil
+ default:
+ return nil, errors.New("todo: init: unknown global type: " + typ.String())
+ }
+}
+
+// Boxed value for interpreter.
+type Value interface {
+}
+
+type ConstValue struct {
+ Expr *ssa.Const
+}
+
+type ZeroBasicValue struct {
+ Type *types.Basic
+}
+
+type PointerValue struct {
+ Type types.Type
+ Elem *Value
+}
+
+type FunctionValue struct {
+ Type types.Type
+ Elem *ssa.Function
+}
+
+type InterfaceValue struct {
+ Type types.Type
+ Elem Value
+}
+
+type PointerBitCastValue struct {
+ Type types.Type
+ Elem Value
+}
+
+type PointerToUintptrValue struct {
+ Elem Value
+}
+
+type GlobalValue struct {
+ Global *Global
+}
+
+type ArrayValue struct {
+ ElemType types.Type
+ Elems []Value
+}
+
+type StructValue struct {
+ Type types.Type // types.Struct or types.Named
+ Fields []Value
+}
+
+type SliceValue struct {
+ Type *types.Slice
+ Array *ArrayValue
+}
+
+type MapValue struct {
+ Type *types.Map
+ Keys []Value
+ Values []Value
+}
diff --git a/ir/ir.go b/ir/ir.go
new file mode 100644
index 000000000..4f098425d
--- /dev/null
+++ b/ir/ir.go
@@ -0,0 +1,387 @@
+package ir
+
+import (
+ "go/ast"
+ "go/token"
+ "go/types"
+ "sort"
+ "strings"
+
+ "github.com/aykevl/llvm/bindings/go/llvm"
+ "golang.org/x/tools/go/loader"
+ "golang.org/x/tools/go/ssa"
+ "golang.org/x/tools/go/ssa/ssautil"
+)
+
+// This file provides a wrapper around go/ssa values and adds extra
+// functionality to them.
+
+// View on all functions, types, and globals in a program, with analysis
+// results.
+type Program struct {
+ Program *ssa.Program
+ mainPkg *ssa.Package
+ Functions []*Function
+ functionMap map[*ssa.Function]*Function
+ Globals []*Global
+ globalMap map[*ssa.Global]*Global
+ comments map[string]*ast.CommentGroup
+ NamedTypes []*NamedType
+ needsScheduler bool
+ goCalls []*ssa.Go
+ typesWithMethods map[string]*TypeWithMethods // see AnalyseInterfaceConversions
+ typesWithoutMethods map[string]int // see AnalyseInterfaceConversions
+ methodSignatureNames map[string]int // see MethodNum
+ interfaces map[string]*Interface // see AnalyseInterfaceConversions
+ fpWithContext map[string]struct{} // see AnalyseFunctionPointers
+}
+
+// Function or method.
+type Function struct {
+ *ssa.Function
+ LLVMFn llvm.Value
+ linkName string // go:linkname or go:export pragma
+ exported bool // go:export
+ nobounds bool // go:nobounds pragma
+ blocking bool // calculated by AnalyseBlockingRecursive
+ flag bool // used by dead code elimination
+ addressTaken bool // used as function pointer, calculated by AnalyseFunctionPointers
+ parents []*Function // calculated by AnalyseCallgraph
+ children []*Function // calculated by AnalyseCallgraph
+}
+
+// Global variable, possibly constant.
+type Global struct {
+ *ssa.Global
+ program *Program
+ LLVMGlobal llvm.Value
+ linkName string // go:extern
+ extern bool // go:extern
+ initializer Value
+}
+
+// Type with a name and possibly methods.
+type NamedType struct {
+ *ssa.Type
+ LLVMType llvm.Type
+}
+
+// Type that is at some point put in an interface.
+type TypeWithMethods struct {
+ t types.Type
+ Num int
+ Methods map[string]*types.Selection
+}
+
+// Interface type that is at some point used in a type assert (to check whether
+// it implements another interface).
+type Interface struct {
+ Num int
+ Type *types.Interface
+}
+
+// Create and intialize a new *Program from a *ssa.Program.
+func NewProgram(lprogram *loader.Program, mainPath string) *Program {
+ comments := map[string]*ast.CommentGroup{}
+ for _, pkgInfo := range lprogram.AllPackages {
+ for _, file := range pkgInfo.Files {
+ for _, decl := range file.Decls {
+ switch decl := decl.(type) {
+ case *ast.GenDecl:
+ switch decl.Tok {
+ case token.VAR:
+ if len(decl.Specs) != 1 {
+ continue
+ }
+ for _, spec := range decl.Specs {
+ valueSpec := spec.(*ast.ValueSpec)
+ for _, valueName := range valueSpec.Names {
+ id := pkgInfo.Pkg.Path() + "." + valueName.Name
+ comments[id] = decl.Doc
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+
+ program := ssautil.CreateProgram(lprogram, ssa.SanityCheckFunctions|ssa.BareInits|ssa.GlobalDebug)
+ program.Build()
+
+ return &Program{
+ Program: program,
+ mainPkg: program.ImportedPackage(mainPath),
+ functionMap: make(map[*ssa.Function]*Function),
+ globalMap: make(map[*ssa.Global]*Global),
+ methodSignatureNames: make(map[string]int),
+ interfaces: make(map[string]*Interface),
+ comments: comments,
+ }
+}
+
+// Add a package to this Program. All packages need to be added first before any
+// analysis is done for correct results.
+func (p *Program) AddPackage(pkg *ssa.Package) {
+ memberNames := make([]string, 0)
+ for name := range pkg.Members {
+ if isCGoInternal(name) {
+ continue
+ }
+ memberNames = append(memberNames, name)
+ }
+ sort.Strings(memberNames)
+
+ for _, name := range memberNames {
+ member := pkg.Members[name]
+ switch member := member.(type) {
+ case *ssa.Function:
+ if isCGoInternal(member.Name()) {
+ continue
+ }
+ p.addFunction(member)
+ case *ssa.Type:
+ t := &NamedType{Type: member}
+ p.NamedTypes = append(p.NamedTypes, t)
+ methods := getAllMethods(pkg.Prog, member.Type())
+ if !types.IsInterface(member.Type()) {
+ // named type
+ for _, method := range methods {
+ p.addFunction(pkg.Prog.MethodValue(method))
+ }
+ }
+ case *ssa.Global:
+ g := &Global{program: p, Global: member}
+ doc := p.comments[g.RelString(nil)]
+ if doc != nil {
+ g.parsePragmas(doc)
+ }
+ p.Globals = append(p.Globals, g)
+ p.globalMap[member] = g
+ case *ssa.NamedConst:
+ // Ignore: these are already resolved.
+ default:
+ panic("unknown member type: " + member.String())
+ }
+ }
+}
+
+func (p *Program) addFunction(ssaFn *ssa.Function) {
+ f := &Function{Function: ssaFn}
+ f.parsePragmas()
+ p.Functions = append(p.Functions, f)
+ p.functionMap[ssaFn] = f
+
+ for _, anon := range ssaFn.AnonFuncs {
+ p.addFunction(anon)
+ }
+}
+
+// Return true if this package imports "unsafe", false otherwise.
+func hasUnsafeImport(pkg *types.Package) bool {
+ for _, imp := range pkg.Imports() {
+ if imp == types.Unsafe {
+ return true
+ }
+ }
+ return false
+}
+
+func (p *Program) GetFunction(ssaFn *ssa.Function) *Function {
+ return p.functionMap[ssaFn]
+}
+
+func (p *Program) GetGlobal(ssaGlobal *ssa.Global) *Global {
+ return p.globalMap[ssaGlobal]
+}
+
+// SortMethods sorts the list of methods by method ID.
+func (p *Program) SortMethods(methods []*types.Selection) {
+ m := &methodList{methods: methods, program: p}
+ sort.Sort(m)
+}
+
+// SortFuncs sorts the list of functions by method ID.
+func (p *Program) SortFuncs(funcs []*types.Func) {
+ m := &funcList{funcs: funcs, program: p}
+ sort.Sort(m)
+}
+
+func (p *Program) MainPkg() *ssa.Package {
+ return p.mainPkg
+}
+
+// Parse compiler directives in the preceding comments.
+func (f *Function) parsePragmas() {
+ if f.Syntax() == nil {
+ return
+ }
+ if decl, ok := f.Syntax().(*ast.FuncDecl); ok && decl.Doc != nil {
+ for _, comment := range decl.Doc.List {
+ if !strings.HasPrefix(comment.Text, "//go:") {
+ continue
+ }
+ parts := strings.Fields(comment.Text)
+ switch parts[0] {
+ case "//go:linkname":
+ if len(parts) != 3 || parts[1] != f.Name() {
+ continue
+ }
+ // Only enable go:linkname when the package imports "unsafe".
+ // This is a slightly looser requirement than what gc uses: gc
+ // requires the file to import "unsafe", not the package as a
+ // whole.
+ if hasUnsafeImport(f.Pkg.Pkg) {
+ f.linkName = parts[2]
+ }
+ case "//go:nobounds":
+ // Skip bounds checking in this function. Useful for some
+ // runtime functions.
+ // This is somewhat dangerous and thus only imported in packages
+ // that import unsafe.
+ if hasUnsafeImport(f.Pkg.Pkg) {
+ f.nobounds = true
+ }
+ case "//go:export":
+ if len(parts) != 2 {
+ continue
+ }
+ f.linkName = parts[1]
+ f.exported = true
+ }
+ }
+ }
+}
+
+func (f *Function) IsNoBounds() bool {
+ return f.nobounds
+}
+
+// Return true iff this function is externally visible.
+func (f *Function) IsExported() bool {
+ return f.exported
+}
+
+// Return the link name for this function.
+func (f *Function) LinkName() string {
+ if f.linkName != "" {
+ return f.linkName
+ }
+ if f.Signature.Recv() != nil {
+ // Method on a defined type (which may be a pointer).
+ return f.RelString(nil)
+ } else {
+ // Bare function.
+ if name := f.CName(); name != "" {
+ // Name CGo functions directly.
+ return name
+ } else {
+ return f.RelString(nil)
+ }
+ }
+}
+
+// Return the name of the C function if this is a CGo wrapper. Otherwise, return
+// a zero-length string.
+func (f *Function) CName() string {
+ name := f.Name()
+ if strings.HasPrefix(name, "_Cfunc_") {
+ return name[len("_Cfunc_"):]
+ }
+ return ""
+}
+
+// Parse //go: pragma comments from the source.
+func (g *Global) parsePragmas(doc *ast.CommentGroup) {
+ for _, comment := range doc.List {
+ if !strings.HasPrefix(comment.Text, "//go:") {
+ continue
+ }
+ parts := strings.Fields(comment.Text)
+ switch parts[0] {
+ case "//go:extern":
+ g.extern = true
+ if len(parts) == 2 {
+ g.linkName = parts[1]
+ }
+ }
+ }
+}
+
+// Return the link name for this global.
+func (g *Global) LinkName() string {
+ if g.linkName != "" {
+ return g.linkName
+ }
+ return g.RelString(nil)
+}
+
+func (g *Global) IsExtern() bool {
+ return g.extern
+}
+
+func (g *Global) Initializer() Value {
+ return g.initializer
+}
+
+// Wrapper type to implement sort.Interface for []*types.Selection.
+type methodList struct {
+ methods []*types.Selection
+ program *Program
+}
+
+func (m *methodList) Len() int {
+ return len(m.methods)
+}
+
+func (m *methodList) Less(i, j int) bool {
+ iid := m.program.MethodNum(m.methods[i].Obj().(*types.Func))
+ jid := m.program.MethodNum(m.methods[j].Obj().(*types.Func))
+ return iid < jid
+}
+
+func (m *methodList) Swap(i, j int) {
+ m.methods[i], m.methods[j] = m.methods[j], m.methods[i]
+}
+
+// Wrapper type to implement sort.Interface for []*types.Func.
+type funcList struct {
+ funcs []*types.Func
+ program *Program
+}
+
+func (fl *funcList) Len() int {
+ return len(fl.funcs)
+}
+
+func (fl *funcList) Less(i, j int) bool {
+ iid := fl.program.MethodNum(fl.funcs[i])
+ jid := fl.program.MethodNum(fl.funcs[j])
+ return iid < jid
+}
+
+func (fl *funcList) Swap(i, j int) {
+ fl.funcs[i], fl.funcs[j] = fl.funcs[j], fl.funcs[i]
+}
+
+// Return true if this is a CGo-internal function that can be ignored.
+func isCGoInternal(name string) bool {
+ if strings.HasPrefix(name, "_Cgo_") || strings.HasPrefix(name, "_cgo") {
+ // _Cgo_ptr, _Cgo_use, _cgoCheckResult, _cgo_runtime_cgocall
+ return true // CGo-internal functions
+ }
+ if strings.HasPrefix(name, "__cgofn__cgo_") {
+ return true // CGo function pointer in global scope
+ }
+ return false
+}
+
+// Get all methods of a type.
+func getAllMethods(prog *ssa.Program, typ types.Type) []*types.Selection {
+ ms := prog.MethodSets.MethodSet(typ)
+ methods := make([]*types.Selection, ms.Len())
+ for i := 0; i < ms.Len(); i++ {
+ methods[i] = ms.At(i)
+ }
+ return methods
+}
diff --git a/ir/passes.go b/ir/passes.go
new file mode 100644
index 000000000..87175faea
--- /dev/null
+++ b/ir/passes.go
@@ -0,0 +1,427 @@
+package ir
+
+import (
+ "go/types"
+ "sort"
+ "strings"
+
+ "golang.org/x/tools/go/ssa"
+)
+
+// This file implements several optimization passes (analysis + transform) to
+// optimize code in SSA form before it is compiled to LLVM IR. It is based on
+// the IR defined in ir.go.
+
+// Make a readable version of a method signature (including the function name,
+// excluding the receiver name). This string is used internally to match
+// interfaces and to call the correct method on an interface. Examples:
+//
+// String() string
+// Read([]byte) (int, error)
+func MethodSignature(method *types.Func) string {
+ return method.Name() + Signature(method.Type().(*types.Signature))
+}
+
+// Make a readable version of a function (pointer) signature. This string is
+// used internally to match signatures (like in AnalyseFunctionPointers).
+// Examples:
+//
+// () string
+// (string, int) (int, error)
+func Signature(sig *types.Signature) string {
+ s := ""
+ if sig.Params().Len() == 0 {
+ s += "()"
+ } else {
+ s += "("
+ for i := 0; i < sig.Params().Len(); i++ {
+ if i > 0 {
+ s += ", "
+ }
+ s += sig.Params().At(i).Type().String()
+ }
+ s += ")"
+ }
+ if sig.Results().Len() == 0 {
+ // keep as-is
+ } else if sig.Results().Len() == 1 {
+ s += " " + sig.Results().At(0).Type().String()
+ } else {
+ s += " ("
+ for i := 0; i < sig.Results().Len(); i++ {
+ if i > 0 {
+ s += ", "
+ }
+ s += sig.Results().At(i).Type().String()
+ }
+ s += ")"
+ }
+ return s
+}
+
+// Convert an interface type to a string of all method strings, separated by
+// "; ". For example: "Read([]byte) (int, error); Close() error"
+func InterfaceKey(itf *types.Interface) string {
+ methodStrings := []string{}
+ for i := 0; i < itf.NumMethods(); i++ {
+ method := itf.Method(i)
+ methodStrings = append(methodStrings, MethodSignature(method))
+ }
+ sort.Strings(methodStrings)
+ return strings.Join(methodStrings, ";")
+}
+
+// Fill in parents of all functions.
+//
+// All packages need to be added before this pass can run, or it will produce
+// incorrect results.
+func (p *Program) AnalyseCallgraph() {
+ for _, f := range p.Functions {
+ // Clear, if AnalyseCallgraph has been called before.
+ f.children = nil
+ f.parents = nil
+
+ for _, block := range f.Blocks {
+ for _, instr := range block.Instrs {
+ switch instr := instr.(type) {
+ case *ssa.Call:
+ if instr.Common().IsInvoke() {
+ continue
+ }
+ switch call := instr.Call.Value.(type) {
+ case *ssa.Builtin:
+ // ignore
+ case *ssa.Function:
+ if isCGoInternal(call.Name()) {
+ continue
+ }
+ child := p.GetFunction(call)
+ if child.CName() != "" {
+ continue // assume non-blocking
+ }
+ if child.RelString(nil) == "time.Sleep" {
+ f.blocking = true
+ }
+ f.children = append(f.children, child)
+ }
+ }
+ }
+ }
+ }
+ for _, f := range p.Functions {
+ for _, child := range f.children {
+ child.parents = append(child.parents, f)
+ }
+ }
+}
+
+// Find all types that are put in an interface.
+func (p *Program) AnalyseInterfaceConversions() {
+ // Clear, if AnalyseTypes has been called before.
+ p.typesWithoutMethods = map[string]int{"nil": 0}
+ p.typesWithMethods = map[string]*TypeWithMethods{}
+
+ for _, f := range p.Functions {
+ for _, block := range f.Blocks {
+ for _, instr := range block.Instrs {
+ switch instr := instr.(type) {
+ case *ssa.MakeInterface:
+ methods := getAllMethods(f.Prog, instr.X.Type())
+ name := instr.X.Type().String()
+ if _, ok := p.typesWithMethods[name]; !ok && len(methods) > 0 {
+ t := &TypeWithMethods{
+ t: instr.X.Type(),
+ Num: len(p.typesWithMethods),
+ Methods: make(map[string]*types.Selection),
+ }
+ for _, sel := range methods {
+ name := MethodSignature(sel.Obj().(*types.Func))
+ t.Methods[name] = sel
+ }
+ p.typesWithMethods[name] = t
+ } else if _, ok := p.typesWithoutMethods[name]; !ok && len(methods) == 0 {
+ p.typesWithoutMethods[name] = len(p.typesWithoutMethods)
+ }
+ }
+ }
+ }
+ }
+}
+
+// Analyse which function pointer signatures need a context parameter.
+// This makes calling function pointers more efficient.
+func (p *Program) AnalyseFunctionPointers() {
+ // Clear, if AnalyseFunctionPointers has been called before.
+ p.fpWithContext = map[string]struct{}{}
+
+ for _, f := range p.Functions {
+ for _, block := range f.Blocks {
+ for _, instr := range block.Instrs {
+ switch instr := instr.(type) {
+ case ssa.CallInstruction:
+ for _, arg := range instr.Common().Args {
+ switch arg := arg.(type) {
+ case *ssa.Function:
+ f := p.GetFunction(arg)
+ f.addressTaken = true
+ }
+ }
+ case *ssa.DebugRef:
+ default:
+ // For anything that isn't a call...
+ for _, operand := range instr.Operands(nil) {
+ if operand == nil || *operand == nil || isCGoInternal((*operand).Name()) {
+ continue
+ }
+ switch operand := (*operand).(type) {
+ case *ssa.Function:
+ f := p.GetFunction(operand)
+ f.addressTaken = true
+ }
+ }
+ }
+ switch instr := instr.(type) {
+ case *ssa.MakeClosure:
+ fn := instr.Fn.(*ssa.Function)
+ sig := Signature(fn.Signature)
+ p.fpWithContext[sig] = struct{}{}
+ }
+ }
+ }
+ }
+}
+
+// Analyse which functions are recursively blocking.
+//
+// Depends on AnalyseCallgraph.
+func (p *Program) AnalyseBlockingRecursive() {
+ worklist := make([]*Function, 0)
+
+ // Fill worklist with directly blocking functions.
+ for _, f := range p.Functions {
+ if f.blocking {
+ worklist = append(worklist, f)
+ }
+ }
+
+ // Keep reducing this worklist by marking a function as recursively blocking
+ // from the worklist and pushing all its parents that are non-blocking.
+ // This is somewhat similar to a worklist in a mark-sweep garbage collector.
+ // The work items are then grey objects.
+ for len(worklist) != 0 {
+ // Pick the topmost.
+ f := worklist[len(worklist)-1]
+ worklist = worklist[:len(worklist)-1]
+ for _, parent := range f.parents {
+ if !parent.blocking {
+ parent.blocking = true
+ worklist = append(worklist, parent)
+ }
+ }
+ }
+}
+
+// Check whether we need a scheduler. A scheduler is only necessary when there
+// are go calls that start blocking functions (if they're not blocking, the go
+// function can be turned into a regular function call).
+//
+// Depends on AnalyseBlockingRecursive.
+func (p *Program) AnalyseGoCalls() {
+ p.goCalls = nil
+ for _, f := range p.Functions {
+ for _, block := range f.Blocks {
+ for _, instr := range block.Instrs {
+ switch instr := instr.(type) {
+ case *ssa.Go:
+ p.goCalls = append(p.goCalls, instr)
+ }
+ }
+ }
+ }
+ for _, instr := range p.goCalls {
+ switch instr := instr.Call.Value.(type) {
+ case *ssa.Builtin:
+ case *ssa.Function:
+ if p.functionMap[instr].blocking {
+ p.needsScheduler = true
+ }
+ default:
+ panic("unknown go call function type")
+ }
+ }
+}
+
+// Simple pass that removes dead code. This pass makes later analysis passes
+// more useful.
+func (p *Program) SimpleDCE() {
+ // Unmark all functions.
+ for _, f := range p.Functions {
+ f.flag = false
+ }
+
+ // Initial set of live functions. Include main.main, *.init and runtime.*
+ // functions.
+ main := p.mainPkg.Members["main"].(*ssa.Function)
+ runtimePkg := p.Program.ImportedPackage("runtime")
+ p.GetFunction(main).flag = true
+ worklist := []*ssa.Function{main}
+ for _, f := range p.Functions {
+ if f.Synthetic == "package initializer" || f.Pkg == runtimePkg {
+ if f.flag || isCGoInternal(f.Name()) {
+ continue
+ }
+ f.flag = true
+ worklist = append(worklist, f.Function)
+ }
+ }
+
+ // Mark all called functions recursively.
+ for len(worklist) != 0 {
+ f := worklist[len(worklist)-1]
+ worklist = worklist[:len(worklist)-1]
+ for _, block := range f.Blocks {
+ for _, instr := range block.Instrs {
+ if instr, ok := instr.(*ssa.MakeInterface); ok {
+ for _, sel := range getAllMethods(p.Program, instr.X.Type()) {
+ fn := p.Program.MethodValue(sel)
+ callee := p.GetFunction(fn)
+ if callee == nil {
+ // TODO: why is this necessary?
+ p.addFunction(fn)
+ callee = p.GetFunction(fn)
+ }
+ if !callee.flag {
+ callee.flag = true
+ worklist = append(worklist, callee.Function)
+ }
+ }
+ }
+ for _, operand := range instr.Operands(nil) {
+ if operand == nil || *operand == nil || isCGoInternal((*operand).Name()) {
+ continue
+ }
+ switch operand := (*operand).(type) {
+ case *ssa.Function:
+ f := p.GetFunction(operand)
+ if f == nil {
+ // FIXME HACK: this function should have been
+ // discovered already. It is not for bound methods.
+ p.addFunction(operand)
+ f = p.GetFunction(operand)
+ }
+ if !f.flag {
+ f.flag = true
+ worklist = append(worklist, operand)
+ }
+ }
+ }
+ }
+ }
+ }
+
+ // Remove unmarked functions.
+ livefunctions := []*Function{}
+ for _, f := range p.Functions {
+ if f.flag {
+ livefunctions = append(livefunctions, f)
+ } else {
+ delete(p.functionMap, f.Function)
+ }
+ }
+ p.Functions = livefunctions
+}
+
+// Whether this function needs a scheduler.
+//
+// Depends on AnalyseGoCalls.
+func (p *Program) NeedsScheduler() bool {
+ return p.needsScheduler
+}
+
+// Whether this function blocks. Builtins are also accepted for convenience.
+// They will always be non-blocking.
+//
+// Depends on AnalyseBlockingRecursive.
+func (p *Program) IsBlocking(f *Function) bool {
+ if !p.needsScheduler {
+ return false
+ }
+ return f.blocking
+}
+
+// Return the type number and whether this type is actually used. Used in
+// interface conversions (type is always used) and type asserts (type may not be
+// used, meaning assert is always false in this program).
+//
+// May only be used after all packages have been added to the analyser.
+func (p *Program) TypeNum(typ types.Type) (int, bool) {
+ if n, ok := p.typesWithoutMethods[typ.String()]; ok {
+ return n, true
+ } else if meta, ok := p.typesWithMethods[typ.String()]; ok {
+ return len(p.typesWithoutMethods) + meta.Num, true
+ } else {
+ return -1, false // type is never put in an interface
+ }
+}
+
+// InterfaceNum returns the numeric interface ID of this type, for use in type
+// asserts.
+func (p *Program) InterfaceNum(itfType *types.Interface) int {
+ key := InterfaceKey(itfType)
+ if itf, ok := p.interfaces[key]; !ok {
+ num := len(p.interfaces)
+ p.interfaces[key] = &Interface{Num: num, Type: itfType}
+ return num
+ } else {
+ return itf.Num
+ }
+}
+
+// MethodNum returns the numeric ID of this method, to be used in method lookups
+// on interfaces for example.
+func (p *Program) MethodNum(method *types.Func) int {
+ name := MethodSignature(method)
+ if _, ok := p.methodSignatureNames[name]; !ok {
+ p.methodSignatureNames[name] = len(p.methodSignatureNames)
+ }
+ return p.methodSignatureNames[MethodSignature(method)]
+}
+
+// The start index of the first dynamic type that has methods.
+// Types without methods always have a lower ID and types with methods have this
+// or a higher ID.
+//
+// May only be used after all packages have been added to the analyser.
+func (p *Program) FirstDynamicType() int {
+ return len(p.typesWithoutMethods)
+}
+
+// Return all types with methods, sorted by type ID.
+func (p *Program) AllDynamicTypes() []*TypeWithMethods {
+ l := make([]*TypeWithMethods, len(p.typesWithMethods))
+ for _, m := range p.typesWithMethods {
+ l[m.Num] = m
+ }
+ return l
+}
+
+// Return all interface types, sorted by interface ID.
+func (p *Program) AllInterfaces() []*Interface {
+ l := make([]*Interface, len(p.interfaces))
+ for _, itf := range p.interfaces {
+ l[itf.Num] = itf
+ }
+ return l
+}
+
+func (p *Program) FunctionNeedsContext(f *Function) bool {
+ if !f.addressTaken {
+ return false
+ }
+ return p.SignatureNeedsContext(f.Signature)
+}
+
+func (p *Program) SignatureNeedsContext(sig *types.Signature) bool {
+ _, needsContext := p.fpWithContext[Signature(sig)]
+ return needsContext
+}