diff options
Diffstat (limited to 'interp')
-rw-r--r-- | interp/interpreter.go | 84 | ||||
-rw-r--r-- | interp/memory.go | 11 | ||||
-rw-r--r-- | interp/testdata/consteval.ll | 15 | ||||
-rw-r--r-- | interp/testdata/consteval.out.ll | 1 |
4 files changed, 73 insertions, 38 deletions
diff --git a/interp/interpreter.go b/interp/interpreter.go index c438b4b61..83fd2cd9a 100644 --- a/interp/interpreter.go +++ b/interp/interpreter.go @@ -723,46 +723,9 @@ func (r *runner) run(fn *function, params []value, parentMem *memoryView, indent locals[inst.localIndex] = newagg case llvm.ICmp: predicate := llvm.IntPredicate(operands[2].(literalValue).value.(uint8)) - var result bool lhs := operands[0] rhs := operands[1] - switch predicate { - case llvm.IntEQ, llvm.IntNE: - lhsPointer, lhsErr := lhs.asPointer(r) - rhsPointer, rhsErr := rhs.asPointer(r) - if (lhsErr == nil) != (rhsErr == nil) { - // Fast path: only one is a pointer, so they can't be equal. - result = false - } else if lhsErr == nil { - // Both must be nil, so both are pointers. - // Compare them directly. - result = lhsPointer.equal(rhsPointer) - } else { - // Fall back to generic comparison. - result = lhs.asRawValue(r).equal(rhs.asRawValue(r)) - } - if predicate == llvm.IntNE { - result = !result - } - case llvm.IntUGT: - result = lhs.Uint() > rhs.Uint() - case llvm.IntUGE: - result = lhs.Uint() >= rhs.Uint() - case llvm.IntULT: - result = lhs.Uint() < rhs.Uint() - case llvm.IntULE: - result = lhs.Uint() <= rhs.Uint() - case llvm.IntSGT: - result = lhs.Int() > rhs.Int() - case llvm.IntSGE: - result = lhs.Int() >= rhs.Int() - case llvm.IntSLT: - result = lhs.Int() < rhs.Int() - case llvm.IntSLE: - result = lhs.Int() <= rhs.Int() - default: - return nil, mem, r.errorAt(inst, errors.New("interp: unsupported icmp")) - } + result := r.interpretICmp(lhs, rhs, predicate) if result { locals[inst.localIndex] = literalValue{uint8(1)} } else { @@ -948,6 +911,51 @@ func (r *runner) run(fn *function, params []value, parentMem *memoryView, indent return nil, mem, r.errorAt(bb.instructions[len(bb.instructions)-1], errors.New("interp: reached end of basic block without terminator")) } +// Interpret an icmp instruction. Doesn't have side effects, only returns the +// output of the comparison. +func (r *runner) interpretICmp(lhs, rhs value, predicate llvm.IntPredicate) bool { + switch predicate { + case llvm.IntEQ, llvm.IntNE: + var result bool + lhsPointer, lhsErr := lhs.asPointer(r) + rhsPointer, rhsErr := rhs.asPointer(r) + if (lhsErr == nil) != (rhsErr == nil) { + // Fast path: only one is a pointer, so they can't be equal. + result = false + } else if lhsErr == nil { + // Both must be nil, so both are pointers. + // Compare them directly. + result = lhsPointer.equal(rhsPointer) + } else { + // Fall back to generic comparison. + result = lhs.asRawValue(r).equal(rhs.asRawValue(r)) + } + if predicate == llvm.IntNE { + result = !result + } + return result + case llvm.IntUGT: + return lhs.Uint() > rhs.Uint() + case llvm.IntUGE: + return lhs.Uint() >= rhs.Uint() + case llvm.IntULT: + return lhs.Uint() < rhs.Uint() + case llvm.IntULE: + return lhs.Uint() <= rhs.Uint() + case llvm.IntSGT: + return lhs.Int() > rhs.Int() + case llvm.IntSGE: + return lhs.Int() >= rhs.Int() + case llvm.IntSLT: + return lhs.Int() < rhs.Int() + case llvm.IntSLE: + return lhs.Int() <= rhs.Int() + default: + // _should_ be unreachable, until LLVM adds new icmp operands (unlikely) + panic("interp: unsupported icmp") + } +} + func (r *runner) runAtRuntime(fn *function, inst instruction, locals []value, mem *memoryView, indent string) *Error { numOperands := inst.llvmInst.OperandsCount() operands := make([]llvm.Value, numOperands) diff --git a/interp/memory.go b/interp/memory.go index 82ab716d4..248825f64 100644 --- a/interp/memory.go +++ b/interp/memory.go @@ -1031,6 +1031,17 @@ func (v *rawValue) set(llvmValue llvm.Value, r *runner) { for i := uint32(0); i < ptrSize; i++ { v.buf[i] = ptrValue.pointer } + case llvm.ICmp: + size := r.targetData.TypeAllocSize(llvmValue.Operand(0).Type()) + lhs := newRawValue(uint32(size)) + rhs := newRawValue(uint32(size)) + lhs.set(llvmValue.Operand(0), r) + rhs.set(llvmValue.Operand(1), r) + if r.interpretICmp(lhs, rhs, llvmValue.IntPredicate()) { + v.buf[0] = 1 // result is true + } else { + v.buf[0] = 0 // result is false + } default: llvmValue.Dump() println() diff --git a/interp/testdata/consteval.ll b/interp/testdata/consteval.ll index 9afb9ff7e..d0c0e3b66 100644 --- a/interp/testdata/consteval.ll +++ b/interp/testdata/consteval.ll @@ -3,6 +3,7 @@ target triple = "x86_64--linux" @intToPtrResult = global i8 0 @ptrToIntResult = global i8 0 +@icmpResult = global i8 0 @someArray = internal global {i16, i8, i8} zeroinitializer @someArrayPointer = global i8* zeroinitializer @@ -15,6 +16,7 @@ define internal void @main.init() { call void @testIntToPtr() call void @testPtrToInt() call void @testConstGEP() + call void @testICmp() ret void } @@ -48,3 +50,16 @@ define internal void @testConstGEP() { store i8* getelementptr inbounds (i8, i8* bitcast ({i16, i8, i8}* @someArray to i8*), i32 2), i8** @someArrayPointer ret void } + +define internal void @testICmp() { + br i1 icmp eq (i64 ptrtoint (i8* @ptrToIntResult to i64), i64 0), label %equal, label %unequal +equal: + ; should not be reached + store i8 1, i8* @icmpResult + ret void +unequal: + ; should be reached + store i8 2, i8* @icmpResult + ret void + ret void +} diff --git a/interp/testdata/consteval.out.ll b/interp/testdata/consteval.out.ll index 5fac449e4..08d74c857 100644 --- a/interp/testdata/consteval.out.ll +++ b/interp/testdata/consteval.out.ll @@ -3,6 +3,7 @@ target triple = "x86_64--linux" @intToPtrResult = local_unnamed_addr global i8 2 @ptrToIntResult = local_unnamed_addr global i8 2 +@icmpResult = local_unnamed_addr global i8 2 @someArray = internal global { i16, i8, i8 } zeroinitializer @someArrayPointer = local_unnamed_addr global i8* getelementptr inbounds ({ i16, i8, i8 }, { i16, i8, i8 }* @someArray, i64 0, i32 1) |