aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorAyke van Laethem <[email protected]>2019-01-13 17:05:00 +0100
committerAyke van Laethem <[email protected]>2019-01-21 22:09:37 +0100
commit2e4dd09bbccf5dd8d3908b98887bbd2320346a85 (patch)
tree00eb010c13b1da80021dbeb3f94bdaf517c4cf9f
parent602c2647490fc2352ff283da59a6b537c8e8cacc (diff)
downloadtinygo-2e4dd09bbccf5dd8d3908b98887bbd2320346a85.tar.gz
tinygo-2e4dd09bbccf5dd8d3908b98887bbd2320346a85.zip
compiler: add support for channel operations
Support for channels is not complete. The following pieces are missing: * Channels with values bigger than int. An int in TinyGo can always contain at least a pointer, so pointers are okay to send. * Buffered channels. * The select statement.
-rw-r--r--README.md3
-rw-r--r--compiler/channel.go97
-rw-r--r--compiler/compiler.go20
-rw-r--r--compiler/goroutine-lowering.go123
-rw-r--r--src/runtime/chan.go147
-rw-r--r--src/runtime/scheduler.go7
-rw-r--r--testdata/channel.go95
-rw-r--r--testdata/channel.txt21
8 files changed, 502 insertions, 11 deletions
diff --git a/README.md b/README.md
index e3d7016ec..447f6a335 100644
--- a/README.md
+++ b/README.md
@@ -56,13 +56,14 @@ Currently supported features:
* closures
* bound methods
* complex numbers (except for arithmetic)
+ * channels (with some limitations)
Not yet supported:
+ * select
* complex arithmetic
* garbage collection
* recover
- * channels
* introspection (if it ever gets implemented)
* ...
diff --git a/compiler/channel.go b/compiler/channel.go
new file mode 100644
index 000000000..63fa351b1
--- /dev/null
+++ b/compiler/channel.go
@@ -0,0 +1,97 @@
+package compiler
+
+// This file lowers channel operations (make/send/recv/close) to runtime calls
+// or pseudo-operations that are lowered during goroutine lowering.
+
+import (
+ "go/types"
+
+ "github.com/aykevl/go-llvm"
+ "golang.org/x/tools/go/ssa"
+)
+
+// emitMakeChan returns a new channel value for the given channel type.
+func (c *Compiler) emitMakeChan(expr *ssa.MakeChan) (llvm.Value, error) {
+ valueType, err := c.getLLVMType(expr.Type().(*types.Chan).Elem())
+ if err != nil {
+ return llvm.Value{}, err
+ }
+ if c.targetData.TypeAllocSize(valueType) > c.targetData.TypeAllocSize(c.intType) {
+ // Values bigger than int overflow the data part of the coroutine.
+ // TODO: make the coroutine data part big enough to hold these bigger
+ // values.
+ return llvm.Value{}, c.makeError(expr.Pos(), "todo: channel with values bigger than int")
+ }
+ chanType := c.mod.GetTypeByName("runtime.channel")
+ size := c.targetData.TypeAllocSize(chanType)
+ sizeValue := llvm.ConstInt(c.uintptrType, size, false)
+ ptr := c.createRuntimeCall("alloc", []llvm.Value{sizeValue}, "chan.alloc")
+ ptr = c.builder.CreateBitCast(ptr, llvm.PointerType(chanType, 0), "chan")
+ return ptr, nil
+}
+
+// emitChanSend emits a pseudo chan send operation. It is lowered to the actual
+// channel send operation during goroutine lowering.
+func (c *Compiler) emitChanSend(frame *Frame, instr *ssa.Send) error {
+ valueType, err := c.getLLVMType(instr.Chan.Type().(*types.Chan).Elem())
+ if err != nil {
+ return err
+ }
+ ch, err := c.parseExpr(frame, instr.Chan)
+ if err != nil {
+ return err
+ }
+ chanValue, err := c.parseExpr(frame, instr.X)
+ if err != nil {
+ return err
+ }
+ valueSize := llvm.ConstInt(c.uintptrType, c.targetData.TypeAllocSize(chanValue.Type()), false)
+ valueAlloca := c.builder.CreateAlloca(valueType, "chan.value")
+ c.builder.CreateStore(chanValue, valueAlloca)
+ valueAllocaCast := c.builder.CreateBitCast(valueAlloca, c.i8ptrType, "chan.value.i8ptr")
+ c.createRuntimeCall("chanSendStub", []llvm.Value{llvm.Undef(c.i8ptrType), ch, valueAllocaCast, valueSize}, "")
+ return nil
+}
+
+// emitChanRecv emits a pseudo chan receive operation. It is lowered to the
+// actual channel receive operation during goroutine lowering.
+func (c *Compiler) emitChanRecv(frame *Frame, unop *ssa.UnOp) (llvm.Value, error) {
+ valueType, err := c.getLLVMType(unop.X.Type().(*types.Chan).Elem())
+ if err != nil {
+ return llvm.Value{}, err
+ }
+ valueSize := llvm.ConstInt(c.uintptrType, c.targetData.TypeAllocSize(valueType), false)
+ ch, err := c.parseExpr(frame, unop.X)
+ if err != nil {
+ return llvm.Value{}, err
+ }
+ valueAlloca := c.builder.CreateAlloca(valueType, "chan.value")
+ valueAllocaCast := c.builder.CreateBitCast(valueAlloca, c.i8ptrType, "chan.value.i8ptr")
+ valueOk := c.builder.CreateAlloca(c.ctx.Int1Type(), "chan.comma-ok.alloca")
+ c.createRuntimeCall("chanRecvStub", []llvm.Value{llvm.Undef(c.i8ptrType), ch, valueAllocaCast, valueOk, valueSize}, "")
+ received := c.builder.CreateLoad(valueAlloca, "chan.received")
+ if unop.CommaOk {
+ commaOk := c.builder.CreateLoad(valueOk, "chan.comma-ok")
+ tuple := llvm.Undef(c.ctx.StructType([]llvm.Type{valueType, c.ctx.Int1Type()}, false))
+ tuple = c.builder.CreateInsertValue(tuple, received, 0, "")
+ tuple = c.builder.CreateInsertValue(tuple, commaOk, 1, "")
+ return tuple, nil
+ } else {
+ return received, nil
+ }
+}
+
+// emitChanClose closes the given channel.
+func (c *Compiler) emitChanClose(frame *Frame, param ssa.Value) error {
+ valueType, err := c.getLLVMType(param.Type().(*types.Chan).Elem())
+ valueSize := llvm.ConstInt(c.uintptrType, c.targetData.TypeAllocSize(valueType), false)
+ if err != nil {
+ return err
+ }
+ ch, err := c.parseExpr(frame, param)
+ if err != nil {
+ return err
+ }
+ c.createRuntimeCall("chanClose", []llvm.Value{ch, valueSize}, "")
+ return nil
+}
diff --git a/compiler/compiler.go b/compiler/compiler.go
index b2477bcfd..a7aa6d433 100644
--- a/compiler/compiler.go
+++ b/compiler/compiler.go
@@ -366,6 +366,8 @@ func (c *Compiler) Compile(mainPath string) error {
realMain.SetLinkage(llvm.ExternalLinkage) // keep alive until goroutine lowering
c.mod.NamedFunction("runtime.alloc").SetLinkage(llvm.ExternalLinkage)
c.mod.NamedFunction("runtime.free").SetLinkage(llvm.ExternalLinkage)
+ c.mod.NamedFunction("runtime.chanSend").SetLinkage(llvm.ExternalLinkage)
+ c.mod.NamedFunction("runtime.chanRecv").SetLinkage(llvm.ExternalLinkage)
c.mod.NamedFunction("runtime.sleepTask").SetLinkage(llvm.ExternalLinkage)
c.mod.NamedFunction("runtime.activateTask").SetLinkage(llvm.ExternalLinkage)
c.mod.NamedFunction("runtime.scheduler").SetLinkage(llvm.ExternalLinkage)
@@ -1297,6 +1299,8 @@ func (c *Compiler) parseInstr(frame *Frame, instr ssa.Instruction) error {
}
case *ssa.RunDefers:
return c.emitRunDefers(frame)
+ case *ssa.Send:
+ return c.emitChanSend(frame, instr)
case *ssa.Store:
llvmAddr, err := c.parseExpr(frame, instr.Addr)
if err == ir.ErrCGoWrapper {
@@ -1362,11 +1366,17 @@ func (c *Compiler) parseBuiltin(frame *Frame, args []ssa.Value, callName string,
return llvm.Value{}, err
}
switch args[0].Type().(type) {
+ case *types.Chan:
+ // Channel. Buffered channels haven't been implemented yet so always
+ // return 0.
+ return llvm.ConstInt(c.intType, 0, false), nil
case *types.Slice:
return c.builder.CreateExtractValue(value, 2, "cap"), nil
default:
return llvm.Value{}, c.makeError(pos, "todo: cap: unknown type")
}
+ case "close":
+ return llvm.Value{}, c.emitChanClose(frame, args[0])
case "complex":
r, err := c.parseExpr(frame, args[0])
if err != nil {
@@ -1434,6 +1444,10 @@ func (c *Compiler) parseBuiltin(frame *Frame, args []ssa.Value, callName string,
case *types.Basic, *types.Slice:
// string or slice
llvmLen = c.builder.CreateExtractValue(value, 1, "len")
+ case *types.Chan:
+ // Channel. Buffered channels haven't been implemented yet so always
+ // return 0.
+ llvmLen = llvm.ConstInt(c.intType, 0, false)
case *types.Map:
llvmLen = c.createRuntimeCall("hashmapLen", []llvm.Value{value}, "len")
default:
@@ -2000,12 +2014,12 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) {
default:
panic("unknown lookup type: " + expr.String())
}
-
+ case *ssa.MakeChan:
+ return c.emitMakeChan(expr)
case *ssa.MakeClosure:
// A closure returns a function pointer with context:
// {context, fp}
return c.parseMakeClosure(frame, expr)
-
case *ssa.MakeInterface:
val, err := c.parseExpr(frame, expr.X)
if err != nil {
@@ -2941,6 +2955,8 @@ func (c *Compiler) parseUnOp(frame *Frame, unop *ssa.UnOp) (llvm.Value, error) {
}
case token.XOR: // ^x, toggle all bits in integer
return c.builder.CreateXor(x, llvm.ConstInt(x.Type(), ^uint64(0), false), ""), nil
+ case token.ARROW: // <-x, receive from channel
+ return c.emitChanRecv(frame, unop)
default:
return llvm.Value{}, c.makeError(unop.Pos(), "todo: unknown unop")
}
diff --git a/compiler/goroutine-lowering.go b/compiler/goroutine-lowering.go
index c5abaf1ea..46a191053 100644
--- a/compiler/goroutine-lowering.go
+++ b/compiler/goroutine-lowering.go
@@ -142,6 +142,8 @@ func (c *Compiler) LowerGoroutines() error {
realMain.SetLinkage(llvm.InternalLinkage)
c.mod.NamedFunction("runtime.alloc").SetLinkage(llvm.InternalLinkage)
c.mod.NamedFunction("runtime.free").SetLinkage(llvm.InternalLinkage)
+ c.mod.NamedFunction("runtime.chanSend").SetLinkage(llvm.InternalLinkage)
+ c.mod.NamedFunction("runtime.chanRecv").SetLinkage(llvm.InternalLinkage)
c.mod.NamedFunction("runtime.sleepTask").SetLinkage(llvm.InternalLinkage)
c.mod.NamedFunction("runtime.activateTask").SetLinkage(llvm.InternalLinkage)
c.mod.NamedFunction("runtime.scheduler").SetLinkage(llvm.InternalLinkage)
@@ -162,8 +164,22 @@ func (c *Compiler) LowerGoroutines() error {
// * Set up the coroutine frames for async functions.
// * Transform blocking calls into their async equivalents.
func (c *Compiler) markAsyncFunctions() (needsScheduler bool, err error) {
+ var worklist []llvm.Value
+
sleep := c.mod.NamedFunction("time.Sleep")
- if sleep.IsNil() {
+ if !sleep.IsNil() {
+ worklist = append(worklist, sleep)
+ }
+ chanSendStub := c.mod.NamedFunction("runtime.chanSendStub")
+ if !chanSendStub.IsNil() {
+ worklist = append(worklist, chanSendStub)
+ }
+ chanRecvStub := c.mod.NamedFunction("runtime.chanRecvStub")
+ if !chanRecvStub.IsNil() {
+ worklist = append(worklist, chanRecvStub)
+ }
+
+ if len(worklist) == 0 {
// There are no blocking operations, so no need to transform anything.
return false, c.lowerMakeGoroutineCalls()
}
@@ -173,7 +189,6 @@ func (c *Compiler) markAsyncFunctions() (needsScheduler bool, err error) {
// from the worklist and pushing all its parents that are non-async.
// This is somewhat similar to a worklist in a mark-sweep garbage collector:
// the work items are then grey objects.
- worklist := []llvm.Value{sleep}
asyncFuncs := make(map[llvm.Value]*asyncFunc)
asyncList := make([]llvm.Value, 0, 4)
for len(worklist) != 0 {
@@ -259,6 +274,9 @@ func (c *Compiler) markAsyncFunctions() (needsScheduler bool, err error) {
coroBeginType := llvm.FunctionType(c.i8ptrType, []llvm.Type{c.ctx.TokenType(), c.i8ptrType}, false)
coroBeginFunc := llvm.AddFunction(c.mod, "llvm.coro.begin", coroBeginType)
+ coroPromiseType := llvm.FunctionType(c.i8ptrType, []llvm.Type{c.i8ptrType, c.ctx.Int32Type(), c.ctx.Int1Type()}, false)
+ coroPromiseFunc := llvm.AddFunction(c.mod, "llvm.coro.promise", coroPromiseType)
+
coroSuspendType := llvm.FunctionType(c.ctx.Int8Type(), []llvm.Type{c.ctx.TokenType(), c.ctx.Int1Type()}, false)
coroSuspendFunc := llvm.AddFunction(c.mod, "llvm.coro.suspend", coroSuspendType)
@@ -270,7 +288,7 @@ func (c *Compiler) markAsyncFunctions() (needsScheduler bool, err error) {
// Transform all async functions into coroutines.
for _, f := range asyncList {
- if f == sleep {
+ if f == sleep || f == chanSendStub || f == chanRecvStub {
continue
}
@@ -287,7 +305,7 @@ func (c *Compiler) markAsyncFunctions() (needsScheduler bool, err error) {
for inst := bb.FirstInstruction(); !inst.IsNil(); inst = llvm.NextInstruction(inst) {
if !inst.IsACallInst().IsNil() {
callee := inst.CalledValue()
- if _, ok := asyncFuncs[callee]; !ok || callee == sleep {
+ if _, ok := asyncFuncs[callee]; !ok || callee == sleep || callee == chanSendStub || callee == chanRecvStub {
continue
}
asyncCalls = append(asyncCalls, inst)
@@ -421,6 +439,103 @@ func (c *Compiler) markAsyncFunctions() (needsScheduler bool, err error) {
sleepCall.EraseFromParentAsInstruction()
}
+ // Transform calls to runtime.chanSendStub into channel send operations.
+ for _, sendOp := range getUses(chanSendStub) {
+ // sendOp must be a call instruction.
+ frame := asyncFuncs[sendOp.InstructionParent().Parent()]
+
+ // Send the value over the channel, or block.
+ sendOp.SetOperand(0, frame.taskHandle)
+ sendOp.SetOperand(sendOp.OperandsCount()-1, c.mod.NamedFunction("runtime.chanSend"))
+
+ // Use taskState.data to store the value to send:
+ // *(*valueType)(&coroutine.promise().data) = valueToSend
+ // runtime.chanSend(coroutine, ch)
+ bitcast := sendOp.Operand(2)
+ valueAlloca := bitcast.Operand(0)
+ c.builder.SetInsertPointBefore(valueAlloca)
+ promiseType := c.mod.GetTypeByName("runtime.taskState")
+ promiseRaw := c.builder.CreateCall(coroPromiseFunc, []llvm.Value{
+ frame.taskHandle,
+ llvm.ConstInt(c.ctx.Int32Type(), uint64(c.targetData.PrefTypeAlignment(promiseType)), false),
+ llvm.ConstInt(c.ctx.Int1Type(), 0, false),
+ }, "task.promise.raw")
+ promise := c.builder.CreateBitCast(promiseRaw, llvm.PointerType(promiseType, 0), "task.promise")
+ dataPtr := c.builder.CreateGEP(promise, []llvm.Value{
+ llvm.ConstInt(c.ctx.Int32Type(), 0, false),
+ llvm.ConstInt(c.ctx.Int32Type(), 2, false),
+ }, "task.promise.data")
+ sendOp.SetOperand(2, llvm.Undef(c.i8ptrType))
+ valueAlloca.ReplaceAllUsesWith(c.builder.CreateBitCast(dataPtr, valueAlloca.Type(), ""))
+ bitcast.EraseFromParentAsInstruction()
+ valueAlloca.EraseFromParentAsInstruction()
+
+ // Yield to scheduler.
+ c.builder.SetInsertPointBefore(llvm.NextInstruction(sendOp))
+ continuePoint := c.builder.CreateCall(coroSuspendFunc, []llvm.Value{
+ llvm.ConstNull(c.ctx.TokenType()),
+ llvm.ConstInt(c.ctx.Int1Type(), 0, false),
+ }, "")
+ sw := c.builder.CreateSwitch(continuePoint, frame.suspendBlock, 2)
+ wakeup := c.splitBasicBlock(sw, llvm.NextBasicBlock(c.builder.GetInsertBlock()), "task.sent")
+ sw.AddCase(llvm.ConstInt(c.ctx.Int8Type(), 0, false), wakeup)
+ sw.AddCase(llvm.ConstInt(c.ctx.Int8Type(), 1, false), frame.cleanupBlock)
+ }
+
+ // Transform calls to runtime.chanRecvStub into channel receive operations.
+ for _, recvOp := range getUses(chanRecvStub) {
+ // recvOp must be a call instruction.
+ frame := asyncFuncs[recvOp.InstructionParent().Parent()]
+
+ bitcast := recvOp.Operand(2)
+ commaOk := recvOp.Operand(3)
+ valueAlloca := bitcast.Operand(0)
+
+ // Receive the value over the channel, or block.
+ recvOp.SetOperand(0, frame.taskHandle)
+ recvOp.SetOperand(recvOp.OperandsCount()-1, c.mod.NamedFunction("runtime.chanRecv"))
+ recvOp.SetOperand(2, llvm.Undef(c.i8ptrType))
+ bitcast.EraseFromParentAsInstruction()
+
+ // Yield to scheduler.
+ c.builder.SetInsertPointBefore(llvm.NextInstruction(recvOp))
+ continuePoint := c.builder.CreateCall(coroSuspendFunc, []llvm.Value{
+ llvm.ConstNull(c.ctx.TokenType()),
+ llvm.ConstInt(c.ctx.Int1Type(), 0, false),
+ }, "")
+ sw := c.builder.CreateSwitch(continuePoint, frame.suspendBlock, 2)
+ wakeup := c.splitBasicBlock(sw, llvm.NextBasicBlock(c.builder.GetInsertBlock()), "task.received")
+ c.builder.SetInsertPointAtEnd(recvOp.InstructionParent())
+ sw.AddCase(llvm.ConstInt(c.ctx.Int8Type(), 0, false), wakeup)
+ sw.AddCase(llvm.ConstInt(c.ctx.Int8Type(), 1, false), frame.cleanupBlock)
+
+ // The value to receive is stored in taskState.data:
+ // runtime.chanRecv(coroutine, ch)
+ // promise := coroutine.promise()
+ // valueReceived := *(*valueType)(&promise.data)
+ // ok := promise.commaOk
+ c.builder.SetInsertPointBefore(wakeup.FirstInstruction())
+ promiseType := c.mod.GetTypeByName("runtime.taskState")
+ promiseRaw := c.builder.CreateCall(coroPromiseFunc, []llvm.Value{
+ frame.taskHandle,
+ llvm.ConstInt(c.ctx.Int32Type(), uint64(c.targetData.PrefTypeAlignment(promiseType)), false),
+ llvm.ConstInt(c.ctx.Int1Type(), 0, false),
+ }, "task.promise.raw")
+ promise := c.builder.CreateBitCast(promiseRaw, llvm.PointerType(promiseType, 0), "task.promise")
+ dataPtr := c.builder.CreateGEP(promise, []llvm.Value{
+ llvm.ConstInt(c.ctx.Int32Type(), 0, false),
+ llvm.ConstInt(c.ctx.Int32Type(), 2, false),
+ }, "task.promise.data")
+ valueAlloca.ReplaceAllUsesWith(c.builder.CreateBitCast(dataPtr, valueAlloca.Type(), ""))
+ valueAlloca.EraseFromParentAsInstruction()
+ commaOkPtr := c.builder.CreateGEP(promise, []llvm.Value{
+ llvm.ConstInt(c.ctx.Int32Type(), 0, false),
+ llvm.ConstInt(c.ctx.Int32Type(), 1, false),
+ }, "task.promise.comma-ok")
+ commaOk.ReplaceAllUsesWith(commaOkPtr)
+ recvOp.SetOperand(3, llvm.Undef(commaOk.Type()))
+ }
+
return true, c.lowerMakeGoroutineCalls()
}
diff --git a/src/runtime/chan.go b/src/runtime/chan.go
index abd0b1178..940861fbb 100644
--- a/src/runtime/chan.go
+++ b/src/runtime/chan.go
@@ -2,7 +2,152 @@ package runtime
// This file implements the 'chan' type and send/receive/select operations.
-// dummy
+// A channel can be in one of the following states:
+// empty:
+// No goroutine is waiting on a send or receive operation. The 'blocked'
+// member is nil.
+// recv:
+// A goroutine tries to receive from the channel. This goroutine is stored
+// in the 'blocked' member.
+// send:
+// The reverse of send. A goroutine tries to send to the channel. This
+// goroutine is stored in the 'blocked' member.
+// closed:
+// The channel is closed. Sends will panic, receives will get a zero value
+// plus optionally the indication that the channel is zero (with the
+// commao-ok value in the coroutine).
+//
+// A send/recv transmission is completed by copying from the data element of the
+// sending coroutine to the data element of the receiving coroutine, and setting
+// the 'comma-ok' value to true.
+// A receive operation on a closed channel is completed by zeroing the data
+// element of the receiving coroutine and setting the 'comma-ok' value to false.
+
+import (
+ "unsafe"
+)
type channel struct {
+ state uint8
+ blocked *coroutine
+}
+
+const (
+ chanStateEmpty = iota
+ chanStateRecv
+ chanStateSend
+ chanStateClosed
+)
+
+func chanSendStub(caller *coroutine, ch *channel, _ unsafe.Pointer, size uintptr)
+func chanRecvStub(caller *coroutine, ch *channel, _ unsafe.Pointer, _ *bool, size uintptr)
+
+// chanSend sends a single value over the channel. If this operation can
+// complete immediately (there is a goroutine waiting for a value), it sends the
+// value and re-activates both goroutines. If not, it sets itself as waiting on
+// a value.
+//
+// The unsafe.Pointer value is used during lowering. During IR generation, it
+// points to the to-be-received value. During coroutine lowering, this value is
+// replaced with a read from the coroutine promise.
+func chanSend(sender *coroutine, ch *channel, _ unsafe.Pointer, size uintptr) {
+ if ch == nil {
+ // A nil channel blocks forever. Do not scheduler this goroutine again.
+ return
+ }
+ switch ch.state {
+ case chanStateEmpty:
+ ch.state = chanStateSend
+ ch.blocked = sender
+ case chanStateRecv:
+ receiver := ch.blocked
+ receiverPromise := receiver.promise()
+ senderPromise := sender.promise()
+ memcpy(unsafe.Pointer(&receiverPromise.data), unsafe.Pointer(&senderPromise.data), size)
+ receiverPromise.commaOk = true
+ ch.blocked = receiverPromise.next
+ receiverPromise.next = nil
+ activateTask(receiver)
+ activateTask(sender)
+ if ch.blocked == nil {
+ ch.state = chanStateEmpty
+ }
+ case chanStateClosed:
+ runtimePanic("send on closed channel")
+ case chanStateSend:
+ sender.promise().next = ch.blocked
+ ch.blocked = sender
+ }
+}
+
+// chanRecv receives a single value over a channel. If there is an available
+// sender, it receives the value immediately and re-activates both coroutines.
+// If not, it sets itself as available for receiving. If the channel is closed,
+// it immediately activates itself with a zero value as the result.
+//
+// The two unnamed values exist to help during lowering. The unsafe.Pointer
+// points to the value, and the *bool points to the comma-ok value. Both are
+// replaced by reads from the coroutine promise.
+func chanRecv(receiver *coroutine, ch *channel, _ unsafe.Pointer, _ *bool, size uintptr) {
+ if ch == nil {
+ // A nil channel blocks forever. Do not scheduler this goroutine again.
+ return
+ }
+ switch ch.state {
+ case chanStateSend:
+ sender := ch.blocked
+ receiverPromise := receiver.promise()
+ senderPromise := sender.promise()
+ memcpy(unsafe.Pointer(&receiverPromise.data), unsafe.Pointer(&senderPromise.data), size)
+ receiverPromise.commaOk = true
+ ch.blocked = senderPromise.next
+ senderPromise.next = nil
+ activateTask(receiver)
+ activateTask(sender)
+ if ch.blocked == nil {
+ ch.state = chanStateEmpty
+ }
+ case chanStateEmpty:
+ ch.state = chanStateRecv
+ ch.blocked = receiver
+ case chanStateClosed:
+ receiverPromise := receiver.promise()
+ memzero(unsafe.Pointer(&receiverPromise.data), size)
+ receiverPromise.commaOk = false
+ activateTask(receiver)
+ case chanStateRecv:
+ receiver.promise().next = ch.blocked
+ ch.blocked = receiver
+ }
+}
+
+// chanClose closes the given channel. If this channel has a receiver or is
+// empty, it closes the channel. Else, it panics.
+func chanClose(ch *channel, size uintptr) {
+ if ch == nil {
+ // Not allowed by the language spec.
+ runtimePanic("close of nil channel")
+ }
+ switch ch.state {
+ case chanStateClosed:
+ // Not allowed by the language spec.
+ runtimePanic("close of closed channel")
+ case chanStateSend:
+ // This panic should ideally on the sending side, not in this goroutine.
+ // But when a goroutine tries to send while the channel is being closed,
+ // that is clearly invalid: the send should have been completed already
+ // before the close.
+ runtimePanic("close channel during send")
+ case chanStateRecv:
+ // The receiver must be re-activated with a zero value.
+ receiverPromise := ch.blocked.promise()
+ memzero(unsafe.Pointer(&receiverPromise.data), size)
+ receiverPromise.commaOk = false
+ activateTask(ch.blocked)
+ ch.state = chanStateClosed
+ ch.blocked = nil
+ case chanStateEmpty:
+ // Easy case. No available sender or receiver.
+ ch.state = chanStateClosed
+ }
}
diff --git a/src/runtime/scheduler.go b/src/runtime/scheduler.go
index 52301c186..9f21bc4c8 100644
--- a/src/runtime/scheduler.go
+++ b/src/runtime/scheduler.go
@@ -51,10 +51,11 @@ func makeGoroutine(*uint8) *uint8
// State/promise of a task. Internally represented as:
//
-// {i8* next, i32/i64 data}
+// {i8* next, i1 commaOk, i32/i64 data}
type taskState struct {
- next *coroutine
- data uint
+ next *coroutine
+ commaOk bool // 'comma-ok' flag for channel receive operation
+ data uint
}
// Queues used by the scheduler.
diff --git a/testdata/channel.go b/testdata/channel.go
new file mode 100644
index 000000000..3777b27a0
--- /dev/null
+++ b/testdata/channel.go
@@ -0,0 +1,95 @@
+package main
+
+import "time"
+
+func main() {
+ ch := make(chan int)
+ println("len, cap of channel:", len(ch), cap(ch))
+ go sender(ch)
+
+ n, ok := <-ch
+ println("recv from open channel:", n, ok)
+
+ for n := range ch {
+ if n == 6 {
+ time.Sleep(time.Microsecond)
+ }
+ println("received num:", n)
+ }
+
+ n, ok = <-ch
+ println("recv from closed channel:", n, ok)
+
+ // Test multi-sender.
+ ch = make(chan int)
+ go fastsender(ch)
+ go fastsender(ch)
+ go fastsender(ch)
+ slowreceiver(ch)
+
+ // Test multi-receiver.
+ ch = make(chan int)
+ go fastreceiver(ch)
+ go fastreceiver(ch)
+ go fastreceiver(ch)
+ slowsender(ch)
+
+ // Test iterator style channel.
+ ch = make(chan int)
+ go iterator(ch, 100)
+ sum := 0
+ for i := range ch {
+ sum += i
+ }
+ println("sum(100):", sum)
+
+ // Allow goroutines to exit.
+ time.Sleep(time.Microsecond)
+}
+
+func sender(ch chan int) {
+ for i := 1; i <= 8; i++ {
+ if i == 4 {
+ time.Sleep(time.Microsecond)
+ println("slept")
+ }
+ ch <- i
+ }
+ close(ch)
+}
+
+func fastsender(ch chan int) {
+ ch <- 10
+ ch <- 11
+}
+
+func slowreceiver(ch chan int) {
+ for i := 0; i < 6; i++ {
+ n := <-ch
+ println("got n:", n)
+ time.Sleep(time.Microsecond)
+ }
+}
+
+func slowsender(ch chan int) {
+ for n := 0; n < 6; n++ {
+ time.Sleep(time.Microsecond)
+ ch <- 12 + n
+ }
+}
+
+func fastreceiver(ch chan int) {
+ sum := 0
+ for i := 0; i < 2; i++ {
+ n := <-ch
+ sum += n
+ }
+ println("sum:", sum)
+}
+
+func iterator(ch chan int, top int) {
+ for i := 0; i < top; i++ {
+ ch <- i
+ }
+ close(ch)
+}
diff --git a/testdata/channel.txt b/testdata/channel.txt
new file mode 100644
index 000000000..eaffd7330
--- /dev/null
+++ b/testdata/channel.txt
@@ -0,0 +1,21 @@
+len, cap of channel: 0 0
+recv from open channel: 1 true
+received num: 2
+received num: 3
+slept
+received num: 4
+received num: 5
+received num: 6
+received num: 7
+received num: 8
+recv from closed channel: 0 false
+got n: 10
+got n: 11
+got n: 10
+got n: 11
+got n: 10
+got n: 11
+sum: 25
+sum: 29
+sum: 33
+sum(100): 4950