diff options
Diffstat (limited to 'transform/coroutines.go')
-rw-r--r-- | transform/coroutines.go | 56 |
1 files changed, 53 insertions, 3 deletions
diff --git a/transform/coroutines.go b/transform/coroutines.go index 3756da257..2aeff43c9 100644 --- a/transform/coroutines.go +++ b/transform/coroutines.go @@ -5,6 +5,7 @@ package transform import ( "errors" + "go/token" "strconv" "github.com/tinygo-org/tinygo/compiler/llvmutil" @@ -104,6 +105,31 @@ func LowerCoroutines(mod llvm.Module, needStackSlots bool) error { return nil } +// CoroutinesError is an error returned when coroutine lowering failed, for +// example because an async function is exported. +type CoroutinesError struct { + Msg string + Pos token.Position + Traceback []CoroutinesErrorLine +} + +// CoroutinesErrorLine is a single line of a CoroutinesError traceback. +type CoroutinesErrorLine struct { + Name string // function name + Position token.Position // position in the function +} + +// Error implements the error interface by returning a simple error message +// without the stack. +func (err CoroutinesError) Error() string { + return err.Msg +} + +type asyncCallInfo struct { + fn llvm.Value + call llvm.Value +} + // asyncFunc is a metadata container for an asynchronous function. type asyncFunc struct { // fn is the underlying function pointer. @@ -168,10 +194,11 @@ type coroutineLoweringPass struct { // findAsyncFuncs finds all asynchronous functions. // A function is considered asynchronous if it calls an asynchronous function or intrinsic. -func (c *coroutineLoweringPass) findAsyncFuncs() { +func (c *coroutineLoweringPass) findAsyncFuncs() error { asyncs := map[llvm.Value]*asyncFunc{} asyncsOrdered := []llvm.Value{} calls := []llvm.Value{} + callsAsyncFunction := map[llvm.Value]asyncCallInfo{} // Use a breadth-first search to find all async functions. worklist := []llvm.Value{c.pause} @@ -183,7 +210,18 @@ func (c *coroutineLoweringPass) findAsyncFuncs() { // Get task pointer argument. task := fn.LastParam() if fn != c.pause && (task.IsNil() || task.Name() != "parentHandle") { - panic("trying to make exported function async: " + fn.Name()) + // Exported functions must not do async operations. + err := CoroutinesError{ + Msg: "blocking operation in exported function: " + fn.Name(), + Pos: getPosition(fn), + } + f := fn + for !f.IsNil() && f != c.pause { + data := callsAsyncFunction[f] + err.Traceback = append(err.Traceback, CoroutinesErrorLine{f.Name(), getPosition(data.call)}) + f = data.fn + } + return err } // Search all uses of the function while collecting callers. @@ -218,6 +256,13 @@ func (c *coroutineLoweringPass) findAsyncFuncs() { asyncs[caller] = nil asyncsOrdered = append(asyncsOrdered, caller) + // Track which calls caused this function to be marked async (for + // better diagnostics). + callsAsyncFunction[caller] = asyncCallInfo{ + fn: fn, + call: user, + } + // Put the caller on the worklist. worklist = append(worklist, caller) } @@ -243,6 +288,8 @@ func (c *coroutineLoweringPass) findAsyncFuncs() { c.asyncFuncs = asyncs c.asyncFuncsOrdered = asyncFuncsOrdered c.calls = calls + + return nil } func (c *coroutineLoweringPass) load() error { @@ -302,7 +349,10 @@ func (c *coroutineLoweringPass) load() error { } // Find async functions. - c.findAsyncFuncs() + err := c.findAsyncFuncs() + if err != nil { + return err + } // Get i8* type. c.i8ptr = llvm.PointerType(c.ctx.Int8Type(), 0) |