aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--compiler/channel.go151
-rw-r--r--compiler/compiler.go27
-rw-r--r--src/runtime/chan.go74
-rw-r--r--testdata/channel.go57
-rw-r--r--testdata/channel.txt7
5 files changed, 291 insertions, 25 deletions
diff --git a/compiler/channel.go b/compiler/channel.go
index 6265d7f63..144297e84 100644
--- a/compiler/channel.go
+++ b/compiler/channel.go
@@ -85,3 +85,154 @@ func (c *Compiler) emitChanClose(frame *Frame, param ssa.Value) {
ch := c.getValue(frame, param)
c.createRuntimeCall("chanClose", []llvm.Value{ch}, "")
}
+
+// emitSelect emits all IR necessary for a select statements. That's a
+// non-trivial amount of code because select is very complex to implement.
+func (c *Compiler) emitSelect(frame *Frame, expr *ssa.Select) llvm.Value {
+ if len(expr.States) == 0 {
+ // Shortcuts for some simple selects.
+ llvmType := c.getLLVMType(expr.Type())
+ if expr.Blocking {
+ // Blocks forever:
+ // select {}
+ c.createRuntimeCall("deadlockStub", nil, "")
+ return llvm.Undef(llvmType)
+ } else {
+ // No-op:
+ // select {
+ // default:
+ // }
+ retval := llvm.Undef(llvmType)
+ retval = c.builder.CreateInsertValue(retval, llvm.ConstInt(c.intType, 0xffffffffffffffff, true), 0, "")
+ return retval // {-1, false}
+ }
+ }
+
+ // This code create a (stack-allocated) slice containing all the select
+ // cases and then calls runtime.chanSelect to perform the actual select
+ // statement.
+ // Simple selects (blocking and with just one case) are already transformed
+ // into regular chan operations during SSA construction so we don't have to
+ // optimize such small selects.
+
+ // Go through all the cases. Create the selectStates slice and and
+ // determine the receive buffer size and alignment.
+ recvbufSize := uint64(0)
+ recvbufAlign := 0
+ hasReceives := false
+ var selectStates []llvm.Value
+ chanSelectStateType := c.getLLVMRuntimeType("chanSelectState")
+ for _, state := range expr.States {
+ ch := c.getValue(frame, state.Chan)
+ selectState := c.getZeroValue(chanSelectStateType)
+ selectState = c.builder.CreateInsertValue(selectState, ch, 0, "")
+ switch state.Dir {
+ case types.RecvOnly:
+ // Make sure the receive buffer is big enough and has the correct alignment.
+ llvmType := c.getLLVMType(state.Chan.Type().(*types.Chan).Elem())
+ if size := c.targetData.TypeAllocSize(llvmType); size > recvbufSize {
+ recvbufSize = size
+ }
+ if align := c.targetData.ABITypeAlignment(llvmType); align > recvbufAlign {
+ recvbufAlign = align
+ }
+ hasReceives = true
+ case types.SendOnly:
+ // Store this value in an alloca and put a pointer to this alloca
+ // in the send state.
+ sendValue := c.getValue(frame, state.Send)
+ alloca := c.createEntryBlockAlloca(sendValue.Type(), "select.send.value")
+ c.builder.CreateStore(sendValue, alloca)
+ ptr := c.builder.CreateBitCast(alloca, c.i8ptrType, "")
+ selectState = c.builder.CreateInsertValue(selectState, ptr, 1, "")
+ default:
+ panic("unreachable")
+ }
+ selectStates = append(selectStates, selectState)
+ }
+
+ // Create a receive buffer, where the received value will be stored.
+ recvbuf := llvm.Undef(c.i8ptrType)
+ if hasReceives {
+ allocaType := llvm.ArrayType(c.ctx.Int8Type(), int(recvbufSize))
+ recvbufAlloca := c.builder.CreateAlloca(allocaType, "select.recvbuf.alloca")
+ recvbufAlloca.SetAlignment(recvbufAlign)
+ recvbuf = c.builder.CreateGEP(recvbufAlloca, []llvm.Value{
+ llvm.ConstInt(c.ctx.Int32Type(), 0, false),
+ llvm.ConstInt(c.ctx.Int32Type(), 0, false),
+ }, "select.recvbuf")
+ }
+
+ // Create the states slice (allocated on the stack).
+ statesAllocaType := llvm.ArrayType(chanSelectStateType, len(selectStates))
+ statesAlloca := c.builder.CreateAlloca(statesAllocaType, "select.states.alloca")
+ for i, state := range selectStates {
+ // Set each slice element to the appropriate channel.
+ gep := c.builder.CreateGEP(statesAlloca, []llvm.Value{
+ llvm.ConstInt(c.ctx.Int32Type(), 0, false),
+ llvm.ConstInt(c.ctx.Int32Type(), uint64(i), false),
+ }, "")
+ c.builder.CreateStore(state, gep)
+ }
+ statesPtr := c.builder.CreateGEP(statesAlloca, []llvm.Value{
+ llvm.ConstInt(c.ctx.Int32Type(), 0, false),
+ llvm.ConstInt(c.ctx.Int32Type(), 0, false),
+ }, "select.states")
+ statesLen := llvm.ConstInt(c.uintptrType, uint64(len(selectStates)), false)
+
+ // Convert the 'blocking' flag on this select into a LLVM value.
+ blockingInt := uint64(0)
+ if expr.Blocking {
+ blockingInt = 1
+ }
+ blockingValue := llvm.ConstInt(c.ctx.Int1Type(), blockingInt, false)
+
+ // Do the select in the runtime.
+ results := c.createRuntimeCall("chanSelect", []llvm.Value{
+ recvbuf,
+ statesPtr, statesLen, statesLen, // []chanSelectState
+ blockingValue,
+ }, "")
+
+ // The result value does not include all the possible received values,
+ // because we can't load them in advance. Instead, the *ssa.Extract
+ // instruction will treat a *ssa.Select specially and load it there inline.
+ // Store the receive alloca in a sidetable until we hit this extract
+ // instruction.
+ if frame.selectRecvBuf == nil {
+ frame.selectRecvBuf = make(map[*ssa.Select]llvm.Value)
+ }
+ frame.selectRecvBuf[expr] = recvbuf
+
+ return results
+}
+
+// getChanSelectResult returns the special values from a *ssa.Extract expression
+// when extracting a value from a select statement (*ssa.Select). Because
+// *ssa.Select cannot load all values in advance, it does this later in the
+// *ssa.Extract expression.
+func (c *Compiler) getChanSelectResult(frame *Frame, expr *ssa.Extract) llvm.Value {
+ if expr.Index == 0 {
+ // index
+ value := c.getValue(frame, expr.Tuple)
+ index := c.builder.CreateExtractValue(value, expr.Index, "")
+ if index.Type().IntTypeWidth() < c.intType.IntTypeWidth() {
+ index = c.builder.CreateSExt(index, c.intType, "")
+ }
+ return index
+ } else if expr.Index == 1 {
+ // comma-ok
+ value := c.getValue(frame, expr.Tuple)
+ return c.builder.CreateExtractValue(value, expr.Index, "")
+ } else {
+ // Select statements are (index, ok, ...) where ... is a number of
+ // received values, depending on how many receive statements there
+ // are. They are all combined into one alloca (because only one
+ // receive can proceed at a time) so we'll get that alloca, bitcast
+ // it to the correct type, and dereference it.
+ recvbuf := frame.selectRecvBuf[expr.Tuple.(*ssa.Select)]
+ typ := llvm.PointerType(c.getLLVMType(expr.Type()), 0)
+ ptr := c.builder.CreateBitCast(recvbuf, typ, "")
+ return c.builder.CreateLoad(ptr, "")
+ }
+}
diff --git a/compiler/compiler.go b/compiler/compiler.go
index f27b8adee..dd7e3d5eb 100644
--- a/compiler/compiler.go
+++ b/compiler/compiler.go
@@ -84,6 +84,7 @@ type Frame struct {
deferFuncs map[*ir.Function]int
deferInvokeFuncs map[string]int
deferClosureFuncs map[*ir.Function]int
+ selectRecvBuf map[*ssa.Select]llvm.Value
}
type Phi struct {
@@ -1445,9 +1446,11 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) {
x := c.getValue(frame, expr.X)
return c.parseConvert(expr.X.Type(), expr.Type(), x, expr.Pos())
case *ssa.Extract:
+ if _, ok := expr.Tuple.(*ssa.Select); ok {
+ return c.getChanSelectResult(frame, expr), nil
+ }
value := c.getValue(frame, expr.Tuple)
- result := c.builder.CreateExtractValue(value, expr.Index, "")
- return result, nil
+ return c.builder.CreateExtractValue(value, expr.Index, ""), nil
case *ssa.Field:
value := c.getValue(frame, expr.X)
if s := expr.X.Type().Underlying().(*types.Struct); s.NumFields() > 2 && s.Field(0).Name() == "C union" {
@@ -1696,25 +1699,7 @@ func (c *Compiler) parseExpr(frame *Frame, expr ssa.Value) (llvm.Value, error) {
c.builder.CreateStore(c.getZeroValue(iteratorType), it)
return it, nil
case *ssa.Select:
- if len(expr.States) == 0 {
- // Shortcuts for some simple selects.
- llvmType := c.getLLVMType(expr.Type())
- if expr.Blocking {
- // Blocks forever:
- // select {}
- c.createRuntimeCall("deadlockStub", nil, "")
- return llvm.Undef(llvmType), nil
- } else {
- // No-op:
- // select {
- // default:
- // }
- retval := llvm.Undef(llvmType)
- retval = c.builder.CreateInsertValue(retval, llvm.ConstInt(c.intType, 0xffffffffffffffff, true), 0, "")
- return retval, nil // {-1, false}
- }
- }
- return llvm.Undef(c.getLLVMType(expr.Type())), c.makeError(expr.Pos(), "unimplemented: "+expr.String())
+ return c.emitSelect(frame, expr), nil
case *ssa.Slice:
if expr.Max != nil {
return llvm.Value{}, c.makeError(expr.Pos(), "todo: full slice expressions (with max): "+expr.Type().String())
diff --git a/src/runtime/chan.go b/src/runtime/chan.go
index d9851df4c..578adf294 100644
--- a/src/runtime/chan.go
+++ b/src/runtime/chan.go
@@ -29,17 +29,27 @@ import (
type channel struct {
elementSize uint16 // the size of one value in this channel
- state uint8
+ state chanState
blocked *coroutine
}
+type chanState uint8
+
const (
- chanStateEmpty = iota
+ chanStateEmpty chanState = iota
chanStateRecv
chanStateSend
chanStateClosed
)
+// chanSelectState is a single channel operation (send/recv) in a select
+// statement. The value pointer is either nil (for receives) or points to the
+// value to send (for sends).
+type chanSelectState struct {
+ ch *channel
+ value unsafe.Pointer
+}
+
func deadlockStub()
// chanSend sends a single value over the channel. If this operation can
@@ -144,3 +154,63 @@ func chanClose(ch *channel) {
ch.state = chanStateClosed
}
}
+
+// chanSelect is the runtime implementation of the select statement. This is
+// perhaps the most complicated statement in the Go spec. It returns the
+// selected index and the 'comma-ok' value.
+//
+// TODO: do this in a round-robin fashion (as specified in the Go spec) instead
+// of picking the first one that can proceed.
+func chanSelect(recvbuf unsafe.Pointer, states []chanSelectState, blocking bool) (uintptr, bool) {
+ // See whether we can receive from one of the channels.
+ for i, state := range states {
+ if state.ch == nil {
+ // A nil channel blocks forever, so don't consider it here.
+ continue
+ }
+ if state.value == nil {
+ // A receive operation.
+ switch state.ch.state {
+ case chanStateSend:
+ // We can receive immediately.
+ sender := state.ch.blocked
+ senderPromise := sender.promise()
+ memcpy(recvbuf, senderPromise.ptr, uintptr(state.ch.elementSize))
+ state.ch.blocked = senderPromise.next
+ senderPromise.next = nil
+ activateTask(sender)
+ if state.ch.blocked == nil {
+ state.ch.state = chanStateEmpty
+ }
+ return uintptr(i), true // commaOk = true
+ case chanStateClosed:
+ // Receive the zero value.
+ memzero(recvbuf, uintptr(state.ch.elementSize))
+ return uintptr(i), false // commaOk = false
+ }
+ } else {
+ // A send operation: state.value is not nil.
+ switch state.ch.state {
+ case chanStateRecv:
+ receiver := state.ch.blocked
+ receiverPromise := receiver.promise()
+ memcpy(receiverPromise.ptr, state.value, uintptr(state.ch.elementSize))
+ receiverPromise.data = 1 // commaOk = true
+ state.ch.blocked = receiverPromise.next
+ receiverPromise.next = nil
+ activateTask(receiver)
+ if state.ch.blocked == nil {
+ state.ch.state = chanStateEmpty
+ }
+ return uintptr(i), false
+ case chanStateClosed:
+ runtimePanic("send on closed channel")
+ }
+ }
+ }
+
+ if !blocking {
+ return ^uintptr(0), false
+ }
+ panic("unimplemented: blocking select")
+}
diff --git a/testdata/channel.go b/testdata/channel.go
index cdf3fac6e..db5b3f297 100644
--- a/testdata/channel.go
+++ b/testdata/channel.go
@@ -48,10 +48,63 @@ func main() {
}
println("sum(100):", sum)
- // Test select
+ // Test simple selects.
go selectDeadlock()
go selectNoOp()
+ // Test select with a single send operation (transformed into chan send).
+ ch = make(chan int)
+ go fastreceiver(ch)
+ select {
+ case ch <- 5:
+ println("select one sent")
+ }
+ close(ch)
+
+ // Test select with a single recv operation (transformed into chan recv).
+ select {
+ case n := <-ch:
+ println("select one n:", n)
+ }
+
+ // Test select recv with channel that has one entry.
+ ch = make(chan int)
+ go func(ch chan int) {
+ ch <- 55
+ }(ch)
+ time.Sleep(time.Millisecond)
+ select {
+ case make(chan int) <- 3:
+ println("unreachable")
+ case n := <-ch:
+ println("select n from chan:", n)
+ case n := <-make(chan int):
+ println("unreachable:", n)
+ }
+
+ // Test select recv with closed channel.
+ close(ch)
+ select {
+ case make(chan int) <- 3:
+ println("unreachable")
+ case n := <-ch:
+ println("select n from closed chan:", n)
+ case n := <-make(chan int):
+ println("unreachable:", n)
+ }
+
+ // Test select send.
+ ch = make(chan int)
+ go fastreceiver(ch)
+ time.Sleep(time.Millisecond)
+ select {
+ case ch <- 235:
+ println("select send")
+ case n := <-make(chan int):
+ println("unreachable:", n)
+ }
+ close(ch)
+
// Allow goroutines to exit.
time.Sleep(time.Microsecond)
}
@@ -68,7 +121,7 @@ func sender(ch chan int) {
}
func sendComplex(ch chan complex128) {
- ch <- 7+10.5i
+ ch <- 7 + 10.5i
}
func fastsender(ch chan int) {
diff --git a/testdata/channel.txt b/testdata/channel.txt
index 9daef1be5..2355bd53c 100644
--- a/testdata/channel.txt
+++ b/testdata/channel.txt
@@ -23,3 +23,10 @@ sum(100): 4950
deadlocking
select no-op
after no-op
+select one sent
+sum: 5
+select one n: 0
+select n from chan: 55
+select n from closed chan: 0
+select send
+sum: 235