aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorJaden Weiss <[email protected]>2020-04-10 09:51:08 -0400
committerAyke <[email protected]>2020-04-12 16:54:40 +0200
commitbb5f7534e5451edfbeb2a2a0fa59ab5a3c77c672 (patch)
tree2d5a6b904fc633a0f8116ae46957e5e88680e1d5
parent3862d6e8a2fd8d2fc69ce59f43fb7eeb787e2538 (diff)
downloadtinygo-bb5f7534e5451edfbeb2a2a0fa59ab5a3c77c672.tar.gz
tinygo-bb5f7534e5451edfbeb2a2a0fa59ab5a3c77c672.zip
transform (coroutines): remove map iteration from coroutine lowering pass
The coroutine lowering pass had issues where it iterated over maps, sometimes resulting in non-deterministic output. This change removes many of the maps and ensures that the transformations are deterministic.
-rw-r--r--transform/coroutines.go154
1 files changed, 101 insertions, 53 deletions
diff --git a/transform/coroutines.go b/transform/coroutines.go
index 4e989f1ff..5700e64a4 100644
--- a/transform/coroutines.go
+++ b/transform/coroutines.go
@@ -115,13 +115,22 @@ type asyncFunc struct {
// callers is a set of all functions which call this async function.
callers map[llvm.Value]struct{}
- // returns is a map of terminal basic blocks to their return kinds.
- returns map[llvm.BasicBlock]returnKind
+ // returns is a list of returns in the function, along with metadata.
+ returns []asyncReturn
- // calls is the set of all calls in the asyncFunc.
- // normalCalls is the set of all intermideate suspending calls in the asyncFunc.
- // tailCalls is the set of all tail calls in the asyncFunc.
- calls, normalCalls, tailCalls map[llvm.Value]struct{}
+ // calls is a list of all calls in the asyncFunc.
+ // normalCalls is a list of all intermideate suspending calls in the asyncFunc.
+ // tailCalls is a list of all tail calls in the asyncFunc.
+ calls, normalCalls, tailCalls []llvm.Value
+}
+
+// asyncReturn is a metadata container for a return from an asynchronous function.
+type asyncReturn struct {
+ // block is the basic block terminated by the return.
+ block llvm.BasicBlock
+
+ // kind is the kind of the return.
+ kind returnKind
}
// coroutineLoweringPass is a goroutine lowering pass which is used with the "coroutines" scheduler.
@@ -135,6 +144,8 @@ type coroutineLoweringPass struct {
// The map keys are function pointers.
asyncFuncs map[llvm.Value]*asyncFunc
+ asyncFuncsOrdered []*asyncFunc
+
// calls is a slice of all of the async calls in the module.
calls []llvm.Value
@@ -159,14 +170,15 @@ type coroutineLoweringPass struct {
// A function is considered asynchronous if it calls an asynchronous function or intrinsic.
func (c *coroutineLoweringPass) findAsyncFuncs() {
asyncs := map[llvm.Value]*asyncFunc{}
+ asyncsOrdered := []llvm.Value{}
calls := []llvm.Value{}
// Use a breadth-first search to find all async functions.
worklist := []llvm.Value{c.pause}
for len(worklist) > 0 {
// Pop a function off the worklist.
- fn := worklist[len(worklist)-1]
- worklist = worklist[:len(worklist)-1]
+ fn := worklist[0]
+ worklist = worklist[1:]
// Get task pointer argument.
task := fn.LastParam()
@@ -204,6 +216,7 @@ func (c *coroutineLoweringPass) findAsyncFuncs() {
// Mark the caller as async.
// Use nil as a temporary value. It will be replaced later.
asyncs[caller] = nil
+ asyncsOrdered = append(asyncsOrdered, caller)
// Put the caller on the worklist.
worklist = append(worklist, caller)
@@ -216,7 +229,19 @@ func (c *coroutineLoweringPass) findAsyncFuncs() {
}
}
+ // Flip the order of the async functions so that the top ones are lowered first.
+ for i := 0; i < len(asyncsOrdered)/2; i++ {
+ asyncsOrdered[i], asyncsOrdered[len(asyncsOrdered)-(i+1)] = asyncsOrdered[len(asyncsOrdered)-(i+1)], asyncsOrdered[i]
+ }
+
+ // Map the elements of asyncsOrdered to *asyncFunc.
+ asyncFuncsOrdered := make([]*asyncFunc, len(asyncsOrdered))
+ for i, v := range asyncsOrdered {
+ asyncFuncsOrdered[i] = asyncs[v]
+ }
+
c.asyncFuncs = asyncs
+ c.asyncFuncsOrdered = asyncFuncsOrdered
c.calls = calls
}
@@ -386,7 +411,7 @@ func (c *coroutineLoweringPass) isAsyncCall(call llvm.Value) bool {
// analyzeFuncReturns analyzes and classifies the returns of a function.
func (c *coroutineLoweringPass) analyzeFuncReturns(fn *asyncFunc) {
- returns := map[llvm.BasicBlock]returnKind{}
+ returns := []asyncReturn{}
if fn.fn == c.pause {
// Skip pause.
fn.returns = returns
@@ -410,28 +435,49 @@ func (c *coroutineLoweringPass) analyzeFuncReturns(fn *asyncFunc) {
case !c.isAsyncCall(prev):
// This is not any form of asynchronous tail call.
if isVoid {
- returns[bb] = returnVoid
+ returns = append(returns, asyncReturn{
+ block: bb,
+ kind: returnVoid,
+ })
} else {
- returns[bb] = returnNormal
+ returns = append(returns, asyncReturn{
+ block: bb,
+ kind: returnNormal,
+ })
}
case isVoid:
if prev.CalledValue().Type().ElementType().ReturnType().TypeKind() == llvm.VoidTypeKind {
// This is a tail call to a void-returning function from a function with a void return.
- returns[bb] = returnVoidTail
+ returns = append(returns, asyncReturn{
+ block: bb,
+ kind: returnVoidTail,
+ })
} else {
// This is a tail call to a value-returning function from a function with a void return.
// The returned value will be ditched.
- returns[bb] = returnDitchedTail
+ returns = append(returns, asyncReturn{
+ block: bb,
+ kind: returnDitchedTail,
+ })
}
case last.Operand(0) == prev:
// This is a regular tail call. The return of the callee is returned to the parent.
- returns[bb] = returnTail
+ returns = append(returns, asyncReturn{
+ block: bb,
+ kind: returnTail,
+ })
case prev.CalledValue().Type().ElementType().ReturnType().TypeKind() == llvm.VoidTypeKind:
// This is a tail call that returns a previous value after waiting on a void function.
- returns[bb] = returnDelayedValue
+ returns = append(returns, asyncReturn{
+ block: bb,
+ kind: returnDelayedValue,
+ })
default:
// This is a tail call that returns a value that is available before the function call.
- returns[bb] = returnAlternateTail
+ returns = append(returns, asyncReturn{
+ block: bb,
+ kind: returnAlternateTail,
+ })
}
case llvm.Unreachable:
prev := llvm.PrevInstruction(last)
@@ -442,7 +488,10 @@ func (c *coroutineLoweringPass) analyzeFuncReturns(fn *asyncFunc) {
}
// This is an asyncnhronous tail call to function that does not return.
- returns[bb] = returnDeadTail
+ returns = append(returns, asyncReturn{
+ block: bb,
+ kind: returnDeadTail,
+ })
}
}
@@ -451,7 +500,7 @@ func (c *coroutineLoweringPass) analyzeFuncReturns(fn *asyncFunc) {
// returnAnalysisPass runs an analysis pass which classifies the returns of all async functions.
func (c *coroutineLoweringPass) returnAnalysisPass() {
- for _, async := range c.asyncFuncs {
+ for _, async := range c.asyncFuncsOrdered {
c.analyzeFuncReturns(async)
}
}
@@ -459,38 +508,37 @@ func (c *coroutineLoweringPass) returnAnalysisPass() {
// categorizeCalls categorizes all asynchronous calls into regular vs. async and matches them to their callers.
func (c *coroutineLoweringPass) categorizeCalls() {
// Sort calls into their respective callers.
- for _, async := range c.asyncFuncs {
- async.calls = map[llvm.Value]struct{}{}
- }
for _, call := range c.calls {
- c.asyncFuncs[call.InstructionParent().Parent()].calls[call] = struct{}{}
+ caller := c.asyncFuncs[call.InstructionParent().Parent()]
+ caller.calls = append(caller.calls, call)
}
// Seperate regular and tail calls.
- for _, async := range c.asyncFuncs {
- // Find all tail calls (of any kind).
+ for _, async := range c.asyncFuncsOrdered {
+ // Search returns for tail calls.
tails := map[llvm.Value]struct{}{}
- for ret, kind := range async.returns {
- switch kind {
+ for _, ret := range async.returns {
+ switch ret.kind {
case returnVoidTail, returnTail, returnDeadTail, returnAlternateTail, returnDitchedTail, returnDelayedValue:
// This is a tail return. The previous instruction is a tail call.
- tails[llvm.PrevInstruction(ret.LastInstruction())] = struct{}{}
+ tails[llvm.PrevInstruction(ret.block.LastInstruction())] = struct{}{}
}
}
- // Find all regular calls.
- regulars := map[llvm.Value]struct{}{}
- for call := range async.calls {
+ // Seperate tail calls and regular calls.
+ normalCalls, tailCalls := []llvm.Value{}, []llvm.Value{}
+ for _, call := range async.calls {
if _, ok := tails[call]; ok {
// This is a tail call.
- continue
+ tailCalls = append(tailCalls, call)
+ } else {
+ // This is a regular call.
+ normalCalls = append(normalCalls, call)
}
-
- regulars[call] = struct{}{}
}
- async.tailCalls = tails
- async.normalCalls = regulars
+ async.normalCalls = normalCalls
+ async.tailCalls = tailCalls
}
}
@@ -513,8 +561,8 @@ func (c *coroutineLoweringPass) lowerFuncsPass() {
}
func (async *asyncFunc) hasValueStoreReturn() bool {
- for _, kind := range async.returns {
- switch kind {
+ for _, ret := range async.returns {
+ switch ret.kind {
case returnNormal, returnAlternateTail, returnDelayedValue:
return true
}
@@ -550,18 +598,18 @@ func (c *coroutineLoweringPass) lowerFuncFast(fn *asyncFunc) {
}
// Lower returns.
- for ret, kind := range fn.returns {
+ for _, ret := range fn.returns {
// Get terminator.
- terminator := ret.LastInstruction()
+ terminator := ret.block.LastInstruction()
// Get tail call if applicable.
var call llvm.Value
- switch kind {
+ switch ret.kind {
case returnVoidTail, returnTail, returnDeadTail, returnAlternateTail, returnDitchedTail, returnDelayedValue:
call = llvm.PrevInstruction(terminator)
}
- switch kind {
+ switch ret.kind {
case returnNormal:
c.builder.SetInsertPointBefore(terminator)
@@ -718,8 +766,8 @@ func (c *coroutineLoweringPass) lowerFuncCoro(fn *asyncFunc) {
c.builder.CreateBr(suspend)
// Restore old state before tail calls.
- for call := range fn.tailCalls {
- if fn.returns[call.InstructionParent()] == returnDeadTail {
+ for _, call := range fn.tailCalls {
+ if !llvm.NextInstruction(call).IsAUnreachableInst().IsNil() {
// Callee never returns, so the state restore is ineffectual.
continue
}
@@ -729,18 +777,18 @@ func (c *coroutineLoweringPass) lowerFuncCoro(fn *asyncFunc) {
}
// Lower returns.
- for ret, kind := range fn.returns {
+ for _, ret := range fn.returns {
// Get terminator instruction.
- terminator := ret.LastInstruction()
+ terminator := ret.block.LastInstruction()
// Get tail call if applicable.
var call llvm.Value
- switch kind {
+ switch ret.kind {
case returnVoidTail, returnTail, returnDeadTail, returnAlternateTail, returnDitchedTail, returnDelayedValue:
call = llvm.PrevInstruction(terminator)
}
- switch kind {
+ switch ret.kind {
case returnNormal:
c.builder.SetInsertPointBefore(terminator)
@@ -760,7 +808,7 @@ func (c *coroutineLoweringPass) lowerFuncCoro(fn *asyncFunc) {
c.builder.SetInsertPointBefore(call)
// Store return value.
- c.builder.CreateStore(ret.LastInstruction().Operand(0), retPtr)
+ c.builder.CreateStore(terminator.Operand(0), retPtr)
// Heap-allocate a return buffer for the discarded return.
alternateBuf := c.heapAlloc(call.Type(), "ret.alternate")
@@ -775,7 +823,7 @@ func (c *coroutineLoweringPass) lowerFuncCoro(fn *asyncFunc) {
c.builder.SetInsertPointBefore(call)
// Store return value.
- c.builder.CreateStore(ret.LastInstruction().Operand(0), retPtr)
+ c.builder.CreateStore(terminator.Operand(0), retPtr)
}
// Delete call if it is a pause, because it has already been lowered.
@@ -785,12 +833,12 @@ func (c *coroutineLoweringPass) lowerFuncCoro(fn *asyncFunc) {
// Replace terminator with branch to cleanup.
terminator.EraseFromParentAsInstruction()
- c.builder.SetInsertPointAtEnd(ret)
+ c.builder.SetInsertPointAtEnd(ret.block)
c.builder.CreateBr(cleanup)
}
// Lower regular calls.
- for call := range fn.normalCalls {
+ for _, call := range fn.normalCalls {
// Lower return value of call.
c.lowerCallReturn(fn, call)
@@ -882,8 +930,8 @@ func (c *coroutineLoweringPass) lowerStart(start llvm.Value) {
} else {
// Check for any undead returns.
var undead bool
- for _, kind := range c.asyncFuncs[fn].returns {
- if kind != returnDeadTail {
+ for _, ret := range c.asyncFuncs[fn].returns {
+ if ret.kind != returnDeadTail {
// This return results in a value being eventually stored.
undead = true
break