aboutsummaryrefslogtreecommitdiffhomepage
path: root/interp
diff options
context:
space:
mode:
Diffstat (limited to 'interp')
-rw-r--r--interp/interpreter.go84
-rw-r--r--interp/memory.go11
-rw-r--r--interp/testdata/consteval.ll15
-rw-r--r--interp/testdata/consteval.out.ll1
4 files changed, 73 insertions, 38 deletions
diff --git a/interp/interpreter.go b/interp/interpreter.go
index c438b4b61..83fd2cd9a 100644
--- a/interp/interpreter.go
+++ b/interp/interpreter.go
@@ -723,46 +723,9 @@ func (r *runner) run(fn *function, params []value, parentMem *memoryView, indent
locals[inst.localIndex] = newagg
case llvm.ICmp:
predicate := llvm.IntPredicate(operands[2].(literalValue).value.(uint8))
- var result bool
lhs := operands[0]
rhs := operands[1]
- switch predicate {
- case llvm.IntEQ, llvm.IntNE:
- lhsPointer, lhsErr := lhs.asPointer(r)
- rhsPointer, rhsErr := rhs.asPointer(r)
- if (lhsErr == nil) != (rhsErr == nil) {
- // Fast path: only one is a pointer, so they can't be equal.
- result = false
- } else if lhsErr == nil {
- // Both must be nil, so both are pointers.
- // Compare them directly.
- result = lhsPointer.equal(rhsPointer)
- } else {
- // Fall back to generic comparison.
- result = lhs.asRawValue(r).equal(rhs.asRawValue(r))
- }
- if predicate == llvm.IntNE {
- result = !result
- }
- case llvm.IntUGT:
- result = lhs.Uint() > rhs.Uint()
- case llvm.IntUGE:
- result = lhs.Uint() >= rhs.Uint()
- case llvm.IntULT:
- result = lhs.Uint() < rhs.Uint()
- case llvm.IntULE:
- result = lhs.Uint() <= rhs.Uint()
- case llvm.IntSGT:
- result = lhs.Int() > rhs.Int()
- case llvm.IntSGE:
- result = lhs.Int() >= rhs.Int()
- case llvm.IntSLT:
- result = lhs.Int() < rhs.Int()
- case llvm.IntSLE:
- result = lhs.Int() <= rhs.Int()
- default:
- return nil, mem, r.errorAt(inst, errors.New("interp: unsupported icmp"))
- }
+ result := r.interpretICmp(lhs, rhs, predicate)
if result {
locals[inst.localIndex] = literalValue{uint8(1)}
} else {
@@ -948,6 +911,51 @@ func (r *runner) run(fn *function, params []value, parentMem *memoryView, indent
return nil, mem, r.errorAt(bb.instructions[len(bb.instructions)-1], errors.New("interp: reached end of basic block without terminator"))
}
+// Interpret an icmp instruction. Doesn't have side effects, only returns the
+// output of the comparison.
+func (r *runner) interpretICmp(lhs, rhs value, predicate llvm.IntPredicate) bool {
+ switch predicate {
+ case llvm.IntEQ, llvm.IntNE:
+ var result bool
+ lhsPointer, lhsErr := lhs.asPointer(r)
+ rhsPointer, rhsErr := rhs.asPointer(r)
+ if (lhsErr == nil) != (rhsErr == nil) {
+ // Fast path: only one is a pointer, so they can't be equal.
+ result = false
+ } else if lhsErr == nil {
+ // Both must be nil, so both are pointers.
+ // Compare them directly.
+ result = lhsPointer.equal(rhsPointer)
+ } else {
+ // Fall back to generic comparison.
+ result = lhs.asRawValue(r).equal(rhs.asRawValue(r))
+ }
+ if predicate == llvm.IntNE {
+ result = !result
+ }
+ return result
+ case llvm.IntUGT:
+ return lhs.Uint() > rhs.Uint()
+ case llvm.IntUGE:
+ return lhs.Uint() >= rhs.Uint()
+ case llvm.IntULT:
+ return lhs.Uint() < rhs.Uint()
+ case llvm.IntULE:
+ return lhs.Uint() <= rhs.Uint()
+ case llvm.IntSGT:
+ return lhs.Int() > rhs.Int()
+ case llvm.IntSGE:
+ return lhs.Int() >= rhs.Int()
+ case llvm.IntSLT:
+ return lhs.Int() < rhs.Int()
+ case llvm.IntSLE:
+ return lhs.Int() <= rhs.Int()
+ default:
+ // _should_ be unreachable, until LLVM adds new icmp operands (unlikely)
+ panic("interp: unsupported icmp")
+ }
+}
+
func (r *runner) runAtRuntime(fn *function, inst instruction, locals []value, mem *memoryView, indent string) *Error {
numOperands := inst.llvmInst.OperandsCount()
operands := make([]llvm.Value, numOperands)
diff --git a/interp/memory.go b/interp/memory.go
index 82ab716d4..248825f64 100644
--- a/interp/memory.go
+++ b/interp/memory.go
@@ -1031,6 +1031,17 @@ func (v *rawValue) set(llvmValue llvm.Value, r *runner) {
for i := uint32(0); i < ptrSize; i++ {
v.buf[i] = ptrValue.pointer
}
+ case llvm.ICmp:
+ size := r.targetData.TypeAllocSize(llvmValue.Operand(0).Type())
+ lhs := newRawValue(uint32(size))
+ rhs := newRawValue(uint32(size))
+ lhs.set(llvmValue.Operand(0), r)
+ rhs.set(llvmValue.Operand(1), r)
+ if r.interpretICmp(lhs, rhs, llvmValue.IntPredicate()) {
+ v.buf[0] = 1 // result is true
+ } else {
+ v.buf[0] = 0 // result is false
+ }
default:
llvmValue.Dump()
println()
diff --git a/interp/testdata/consteval.ll b/interp/testdata/consteval.ll
index 9afb9ff7e..d0c0e3b66 100644
--- a/interp/testdata/consteval.ll
+++ b/interp/testdata/consteval.ll
@@ -3,6 +3,7 @@ target triple = "x86_64--linux"
@intToPtrResult = global i8 0
@ptrToIntResult = global i8 0
+@icmpResult = global i8 0
@someArray = internal global {i16, i8, i8} zeroinitializer
@someArrayPointer = global i8* zeroinitializer
@@ -15,6 +16,7 @@ define internal void @main.init() {
call void @testIntToPtr()
call void @testPtrToInt()
call void @testConstGEP()
+ call void @testICmp()
ret void
}
@@ -48,3 +50,16 @@ define internal void @testConstGEP() {
store i8* getelementptr inbounds (i8, i8* bitcast ({i16, i8, i8}* @someArray to i8*), i32 2), i8** @someArrayPointer
ret void
}
+
+define internal void @testICmp() {
+ br i1 icmp eq (i64 ptrtoint (i8* @ptrToIntResult to i64), i64 0), label %equal, label %unequal
+equal:
+ ; should not be reached
+ store i8 1, i8* @icmpResult
+ ret void
+unequal:
+ ; should be reached
+ store i8 2, i8* @icmpResult
+ ret void
+ ret void
+}
diff --git a/interp/testdata/consteval.out.ll b/interp/testdata/consteval.out.ll
index 5fac449e4..08d74c857 100644
--- a/interp/testdata/consteval.out.ll
+++ b/interp/testdata/consteval.out.ll
@@ -3,6 +3,7 @@ target triple = "x86_64--linux"
@intToPtrResult = local_unnamed_addr global i8 2
@ptrToIntResult = local_unnamed_addr global i8 2
+@icmpResult = local_unnamed_addr global i8 2
@someArray = internal global { i16, i8, i8 } zeroinitializer
@someArrayPointer = local_unnamed_addr global i8* getelementptr inbounds ({ i16, i8, i8 }, { i16, i8, i8 }* @someArray, i64 0, i32 1)