diff options
author | Jaden Weiss <[email protected]> | 2020-04-10 09:51:08 -0400 |
---|---|---|
committer | Ayke <[email protected]> | 2020-04-12 16:54:40 +0200 |
commit | bb5f7534e5451edfbeb2a2a0fa59ab5a3c77c672 (patch) | |
tree | 2d5a6b904fc633a0f8116ae46957e5e88680e1d5 /transform | |
parent | 3862d6e8a2fd8d2fc69ce59f43fb7eeb787e2538 (diff) | |
download | tinygo-bb5f7534e5451edfbeb2a2a0fa59ab5a3c77c672.tar.gz tinygo-bb5f7534e5451edfbeb2a2a0fa59ab5a3c77c672.zip |
transform (coroutines): remove map iteration from coroutine lowering pass
The coroutine lowering pass had issues where it iterated over maps, sometimes resulting in non-deterministic output.
This change removes many of the maps and ensures that the transformations are deterministic.
Diffstat (limited to 'transform')
-rw-r--r-- | transform/coroutines.go | 154 |
1 files changed, 101 insertions, 53 deletions
diff --git a/transform/coroutines.go b/transform/coroutines.go index 4e989f1ff..5700e64a4 100644 --- a/transform/coroutines.go +++ b/transform/coroutines.go @@ -115,13 +115,22 @@ type asyncFunc struct { // callers is a set of all functions which call this async function. callers map[llvm.Value]struct{} - // returns is a map of terminal basic blocks to their return kinds. - returns map[llvm.BasicBlock]returnKind + // returns is a list of returns in the function, along with metadata. + returns []asyncReturn - // calls is the set of all calls in the asyncFunc. - // normalCalls is the set of all intermideate suspending calls in the asyncFunc. - // tailCalls is the set of all tail calls in the asyncFunc. - calls, normalCalls, tailCalls map[llvm.Value]struct{} + // calls is a list of all calls in the asyncFunc. + // normalCalls is a list of all intermideate suspending calls in the asyncFunc. + // tailCalls is a list of all tail calls in the asyncFunc. + calls, normalCalls, tailCalls []llvm.Value +} + +// asyncReturn is a metadata container for a return from an asynchronous function. +type asyncReturn struct { + // block is the basic block terminated by the return. + block llvm.BasicBlock + + // kind is the kind of the return. + kind returnKind } // coroutineLoweringPass is a goroutine lowering pass which is used with the "coroutines" scheduler. @@ -135,6 +144,8 @@ type coroutineLoweringPass struct { // The map keys are function pointers. asyncFuncs map[llvm.Value]*asyncFunc + asyncFuncsOrdered []*asyncFunc + // calls is a slice of all of the async calls in the module. calls []llvm.Value @@ -159,14 +170,15 @@ type coroutineLoweringPass struct { // A function is considered asynchronous if it calls an asynchronous function or intrinsic. func (c *coroutineLoweringPass) findAsyncFuncs() { asyncs := map[llvm.Value]*asyncFunc{} + asyncsOrdered := []llvm.Value{} calls := []llvm.Value{} // Use a breadth-first search to find all async functions. worklist := []llvm.Value{c.pause} for len(worklist) > 0 { // Pop a function off the worklist. - fn := worklist[len(worklist)-1] - worklist = worklist[:len(worklist)-1] + fn := worklist[0] + worklist = worklist[1:] // Get task pointer argument. task := fn.LastParam() @@ -204,6 +216,7 @@ func (c *coroutineLoweringPass) findAsyncFuncs() { // Mark the caller as async. // Use nil as a temporary value. It will be replaced later. asyncs[caller] = nil + asyncsOrdered = append(asyncsOrdered, caller) // Put the caller on the worklist. worklist = append(worklist, caller) @@ -216,7 +229,19 @@ func (c *coroutineLoweringPass) findAsyncFuncs() { } } + // Flip the order of the async functions so that the top ones are lowered first. + for i := 0; i < len(asyncsOrdered)/2; i++ { + asyncsOrdered[i], asyncsOrdered[len(asyncsOrdered)-(i+1)] = asyncsOrdered[len(asyncsOrdered)-(i+1)], asyncsOrdered[i] + } + + // Map the elements of asyncsOrdered to *asyncFunc. + asyncFuncsOrdered := make([]*asyncFunc, len(asyncsOrdered)) + for i, v := range asyncsOrdered { + asyncFuncsOrdered[i] = asyncs[v] + } + c.asyncFuncs = asyncs + c.asyncFuncsOrdered = asyncFuncsOrdered c.calls = calls } @@ -386,7 +411,7 @@ func (c *coroutineLoweringPass) isAsyncCall(call llvm.Value) bool { // analyzeFuncReturns analyzes and classifies the returns of a function. func (c *coroutineLoweringPass) analyzeFuncReturns(fn *asyncFunc) { - returns := map[llvm.BasicBlock]returnKind{} + returns := []asyncReturn{} if fn.fn == c.pause { // Skip pause. fn.returns = returns @@ -410,28 +435,49 @@ func (c *coroutineLoweringPass) analyzeFuncReturns(fn *asyncFunc) { case !c.isAsyncCall(prev): // This is not any form of asynchronous tail call. if isVoid { - returns[bb] = returnVoid + returns = append(returns, asyncReturn{ + block: bb, + kind: returnVoid, + }) } else { - returns[bb] = returnNormal + returns = append(returns, asyncReturn{ + block: bb, + kind: returnNormal, + }) } case isVoid: if prev.CalledValue().Type().ElementType().ReturnType().TypeKind() == llvm.VoidTypeKind { // This is a tail call to a void-returning function from a function with a void return. - returns[bb] = returnVoidTail + returns = append(returns, asyncReturn{ + block: bb, + kind: returnVoidTail, + }) } else { // This is a tail call to a value-returning function from a function with a void return. // The returned value will be ditched. - returns[bb] = returnDitchedTail + returns = append(returns, asyncReturn{ + block: bb, + kind: returnDitchedTail, + }) } case last.Operand(0) == prev: // This is a regular tail call. The return of the callee is returned to the parent. - returns[bb] = returnTail + returns = append(returns, asyncReturn{ + block: bb, + kind: returnTail, + }) case prev.CalledValue().Type().ElementType().ReturnType().TypeKind() == llvm.VoidTypeKind: // This is a tail call that returns a previous value after waiting on a void function. - returns[bb] = returnDelayedValue + returns = append(returns, asyncReturn{ + block: bb, + kind: returnDelayedValue, + }) default: // This is a tail call that returns a value that is available before the function call. - returns[bb] = returnAlternateTail + returns = append(returns, asyncReturn{ + block: bb, + kind: returnAlternateTail, + }) } case llvm.Unreachable: prev := llvm.PrevInstruction(last) @@ -442,7 +488,10 @@ func (c *coroutineLoweringPass) analyzeFuncReturns(fn *asyncFunc) { } // This is an asyncnhronous tail call to function that does not return. - returns[bb] = returnDeadTail + returns = append(returns, asyncReturn{ + block: bb, + kind: returnDeadTail, + }) } } @@ -451,7 +500,7 @@ func (c *coroutineLoweringPass) analyzeFuncReturns(fn *asyncFunc) { // returnAnalysisPass runs an analysis pass which classifies the returns of all async functions. func (c *coroutineLoweringPass) returnAnalysisPass() { - for _, async := range c.asyncFuncs { + for _, async := range c.asyncFuncsOrdered { c.analyzeFuncReturns(async) } } @@ -459,38 +508,37 @@ func (c *coroutineLoweringPass) returnAnalysisPass() { // categorizeCalls categorizes all asynchronous calls into regular vs. async and matches them to their callers. func (c *coroutineLoweringPass) categorizeCalls() { // Sort calls into their respective callers. - for _, async := range c.asyncFuncs { - async.calls = map[llvm.Value]struct{}{} - } for _, call := range c.calls { - c.asyncFuncs[call.InstructionParent().Parent()].calls[call] = struct{}{} + caller := c.asyncFuncs[call.InstructionParent().Parent()] + caller.calls = append(caller.calls, call) } // Seperate regular and tail calls. - for _, async := range c.asyncFuncs { - // Find all tail calls (of any kind). + for _, async := range c.asyncFuncsOrdered { + // Search returns for tail calls. tails := map[llvm.Value]struct{}{} - for ret, kind := range async.returns { - switch kind { + for _, ret := range async.returns { + switch ret.kind { case returnVoidTail, returnTail, returnDeadTail, returnAlternateTail, returnDitchedTail, returnDelayedValue: // This is a tail return. The previous instruction is a tail call. - tails[llvm.PrevInstruction(ret.LastInstruction())] = struct{}{} + tails[llvm.PrevInstruction(ret.block.LastInstruction())] = struct{}{} } } - // Find all regular calls. - regulars := map[llvm.Value]struct{}{} - for call := range async.calls { + // Seperate tail calls and regular calls. + normalCalls, tailCalls := []llvm.Value{}, []llvm.Value{} + for _, call := range async.calls { if _, ok := tails[call]; ok { // This is a tail call. - continue + tailCalls = append(tailCalls, call) + } else { + // This is a regular call. + normalCalls = append(normalCalls, call) } - - regulars[call] = struct{}{} } - async.tailCalls = tails - async.normalCalls = regulars + async.normalCalls = normalCalls + async.tailCalls = tailCalls } } @@ -513,8 +561,8 @@ func (c *coroutineLoweringPass) lowerFuncsPass() { } func (async *asyncFunc) hasValueStoreReturn() bool { - for _, kind := range async.returns { - switch kind { + for _, ret := range async.returns { + switch ret.kind { case returnNormal, returnAlternateTail, returnDelayedValue: return true } @@ -550,18 +598,18 @@ func (c *coroutineLoweringPass) lowerFuncFast(fn *asyncFunc) { } // Lower returns. - for ret, kind := range fn.returns { + for _, ret := range fn.returns { // Get terminator. - terminator := ret.LastInstruction() + terminator := ret.block.LastInstruction() // Get tail call if applicable. var call llvm.Value - switch kind { + switch ret.kind { case returnVoidTail, returnTail, returnDeadTail, returnAlternateTail, returnDitchedTail, returnDelayedValue: call = llvm.PrevInstruction(terminator) } - switch kind { + switch ret.kind { case returnNormal: c.builder.SetInsertPointBefore(terminator) @@ -718,8 +766,8 @@ func (c *coroutineLoweringPass) lowerFuncCoro(fn *asyncFunc) { c.builder.CreateBr(suspend) // Restore old state before tail calls. - for call := range fn.tailCalls { - if fn.returns[call.InstructionParent()] == returnDeadTail { + for _, call := range fn.tailCalls { + if !llvm.NextInstruction(call).IsAUnreachableInst().IsNil() { // Callee never returns, so the state restore is ineffectual. continue } @@ -729,18 +777,18 @@ func (c *coroutineLoweringPass) lowerFuncCoro(fn *asyncFunc) { } // Lower returns. - for ret, kind := range fn.returns { + for _, ret := range fn.returns { // Get terminator instruction. - terminator := ret.LastInstruction() + terminator := ret.block.LastInstruction() // Get tail call if applicable. var call llvm.Value - switch kind { + switch ret.kind { case returnVoidTail, returnTail, returnDeadTail, returnAlternateTail, returnDitchedTail, returnDelayedValue: call = llvm.PrevInstruction(terminator) } - switch kind { + switch ret.kind { case returnNormal: c.builder.SetInsertPointBefore(terminator) @@ -760,7 +808,7 @@ func (c *coroutineLoweringPass) lowerFuncCoro(fn *asyncFunc) { c.builder.SetInsertPointBefore(call) // Store return value. - c.builder.CreateStore(ret.LastInstruction().Operand(0), retPtr) + c.builder.CreateStore(terminator.Operand(0), retPtr) // Heap-allocate a return buffer for the discarded return. alternateBuf := c.heapAlloc(call.Type(), "ret.alternate") @@ -775,7 +823,7 @@ func (c *coroutineLoweringPass) lowerFuncCoro(fn *asyncFunc) { c.builder.SetInsertPointBefore(call) // Store return value. - c.builder.CreateStore(ret.LastInstruction().Operand(0), retPtr) + c.builder.CreateStore(terminator.Operand(0), retPtr) } // Delete call if it is a pause, because it has already been lowered. @@ -785,12 +833,12 @@ func (c *coroutineLoweringPass) lowerFuncCoro(fn *asyncFunc) { // Replace terminator with branch to cleanup. terminator.EraseFromParentAsInstruction() - c.builder.SetInsertPointAtEnd(ret) + c.builder.SetInsertPointAtEnd(ret.block) c.builder.CreateBr(cleanup) } // Lower regular calls. - for call := range fn.normalCalls { + for _, call := range fn.normalCalls { // Lower return value of call. c.lowerCallReturn(fn, call) @@ -882,8 +930,8 @@ func (c *coroutineLoweringPass) lowerStart(start llvm.Value) { } else { // Check for any undead returns. var undead bool - for _, kind := range c.asyncFuncs[fn].returns { - if kind != returnDeadTail { + for _, ret := range c.asyncFuncs[fn].returns { + if ret.kind != returnDeadTail { // This return results in a value being eventually stored. undead = true break |