diff options
author | Damian Gryski <[email protected]> | 2023-02-25 13:40:08 -0800 |
---|---|---|
committer | GitHub <[email protected]> | 2023-02-25 22:40:08 +0100 |
commit | 476621736c95ce65a75a27ee2749e9ff2b296608 (patch) | |
tree | 30b9098ecf6e65efd90a270a3151204f8737b118 | |
parent | 7b44fcd865370d5ecf7471b04ad1b4e013e63708 (diff) | |
download | tinygo-476621736c95ce65a75a27ee2749e9ff2b296608.tar.gz tinygo-476621736c95ce65a75a27ee2749e9ff2b296608.zip |
compiler: zero struct padding during map operations
Fixes #3358
-rw-r--r-- | compiler/compiler_test.go | 1 | ||||
-rw-r--r-- | compiler/map.go | 79 | ||||
-rw-r--r-- | compiler/testdata/zeromap.go | 37 | ||||
-rw-r--r-- | compiler/testdata/zeromap.ll | 170 |
4 files changed, 286 insertions, 1 deletions
diff --git a/compiler/compiler_test.go b/compiler/compiler_test.go index f8221a08e..74e213dc6 100644 --- a/compiler/compiler_test.go +++ b/compiler/compiler_test.go @@ -49,6 +49,7 @@ func TestCompiler(t *testing.T) { {"goroutine.go", "cortex-m-qemu", "tasks"}, {"channel.go", "", ""}, {"gc.go", "", ""}, + {"zeromap.go", "", ""}, } if goMinor >= 20 { tests = append(tests, testCase{"go1.20.go", "", ""}) diff --git a/compiler/map.go b/compiler/map.go index 9d162bfc0..632ff2470 100644 --- a/compiler/map.go +++ b/compiler/map.go @@ -89,6 +89,7 @@ func (b *builder) createMapLookup(keyType, valueType types.Type, m, key llvm.Val // growth. mapKeyAlloca, mapKeyPtr, mapKeySize := b.createTemporaryAlloca(key.Type(), "hashmap.key") b.CreateStore(key, mapKeyAlloca) + b.zeroUndefBytes(b.getLLVMType(keyType), mapKeyAlloca) // Fetch the value from the hashmap. params := []llvm.Value{m, mapKeyPtr, mapValuePtr, mapValueSize} commaOkValue = b.createRuntimeCall("hashmapBinaryGet", params, "") @@ -133,6 +134,7 @@ func (b *builder) createMapUpdate(keyType types.Type, m, key, value llvm.Value, // key can be compared with runtime.memequal keyAlloca, keyPtr, keySize := b.createTemporaryAlloca(key.Type(), "hashmap.key") b.CreateStore(key, keyAlloca) + b.zeroUndefBytes(b.getLLVMType(keyType), keyAlloca) params := []llvm.Value{m, keyPtr, valuePtr} b.createRuntimeCall("hashmapBinarySet", params, "") b.emitLifetimeEnd(keyPtr, keySize) @@ -161,6 +163,7 @@ func (b *builder) createMapDelete(keyType types.Type, m, key llvm.Value, pos tok } else if hashmapIsBinaryKey(keyType) { keyAlloca, keyPtr, keySize := b.createTemporaryAlloca(key.Type(), "hashmap.key") b.CreateStore(key, keyAlloca) + b.zeroUndefBytes(b.getLLVMType(keyType), keyAlloca) params := []llvm.Value{m, keyPtr} b.createRuntimeCall("hashmapBinaryDelete", params, "") b.emitLifetimeEnd(keyPtr, keySize) @@ -240,7 +243,8 @@ func (b *builder) createMapIteratorNext(rangeVal ssa.Value, llvmRangeVal, it llv } // Returns true if this key type does not contain strings, interfaces etc., so -// can be compared with runtime.memequal. +// can be compared with runtime.memequal. Note that padding bytes are undef +// and can alter two "equal" structs being equal when compared with memequal. func hashmapIsBinaryKey(keyType types.Type) bool { switch keyType := keyType.(type) { case *types.Basic: @@ -263,3 +267,76 @@ func hashmapIsBinaryKey(keyType types.Type) bool { return false } } + +func (b *builder) zeroUndefBytes(llvmType llvm.Type, ptr llvm.Value) error { + // We know that hashmapIsBinaryKey is true, so we only have to handle those types that can show up there. + // To zero all undefined bytes, we iterate over all the fields in the type. For each element, compute the + // offset of that element. If it's Basic type, there are no internal padding bytes. For compound types, we recurse to ensure + // we handle nested types. Next, we determine if there are any padding bytes before the next + // element and zero those as well. + + zero := llvm.ConstInt(b.ctx.Int32Type(), 0, false) + + switch llvmType.TypeKind() { + case llvm.IntegerTypeKind: + // no padding bytes + return nil + case llvm.PointerTypeKind: + // mo padding bytes + return nil + case llvm.ArrayTypeKind: + llvmArrayType := llvmType + llvmElemType := llvmType.ElementType() + + for i := 0; i < llvmArrayType.ArrayLength(); i++ { + idx := llvm.ConstInt(b.uintptrType, uint64(i), false) + elemPtr := b.CreateInBoundsGEP(llvmArrayType, ptr, []llvm.Value{zero, idx}, "") + + // zero any padding bytes in this element + b.zeroUndefBytes(llvmElemType, elemPtr) + } + + case llvm.StructTypeKind: + llvmStructType := llvmType + numFields := llvmStructType.StructElementTypesCount() + llvmElementTypes := llvmStructType.StructElementTypes() + + for i := 0; i < numFields; i++ { + idx := llvm.ConstInt(b.ctx.Int32Type(), uint64(i), false) + elemPtr := b.CreateInBoundsGEP(llvmStructType, ptr, []llvm.Value{zero, idx}, "") + + // zero any padding bytes in this field + llvmElemType := llvmElementTypes[i] + b.zeroUndefBytes(llvmElemType, elemPtr) + + // zero any padding bytes before the next field, if any + offset := b.targetData.ElementOffset(llvmStructType, i) + storeSize := b.targetData.TypeStoreSize(llvmElemType) + fieldEndOffset := offset + storeSize + + var nextOffset uint64 + if i < numFields-1 { + nextOffset = b.targetData.ElementOffset(llvmStructType, i+1) + } else { + // Last field? Next offset is the total size of the allcoate struct. + nextOffset = b.targetData.TypeAllocSize(llvmStructType) + } + + if fieldEndOffset != nextOffset { + n := llvm.ConstInt(b.uintptrType, nextOffset-fieldEndOffset, false) + llvmStoreSize := llvm.ConstInt(b.uintptrType, storeSize, false) + gepPtr := elemPtr + if gepPtr.Type() != b.i8ptrType { + gepPtr = b.CreateBitCast(gepPtr, b.i8ptrType, "") // LLVM 14 + } + paddingStart := b.CreateInBoundsGEP(b.ctx.Int8Type(), gepPtr, []llvm.Value{llvmStoreSize}, "") + if paddingStart.Type() != b.i8ptrType { + paddingStart = b.CreateBitCast(paddingStart, b.i8ptrType, "") // LLVM 14 + } + b.createRuntimeCall("memzero", []llvm.Value{paddingStart, n}, "") + } + } + } + + return nil +} diff --git a/compiler/testdata/zeromap.go b/compiler/testdata/zeromap.go new file mode 100644 index 000000000..6cf9f611b --- /dev/null +++ b/compiler/testdata/zeromap.go @@ -0,0 +1,37 @@ +package main + +type hasPadding struct { + b1 bool + i int + b2 bool +} + +type nestedPadding struct { + b bool + hasPadding + i int +} + +//go:noinline +func testZeroGet(m map[hasPadding]int, s hasPadding) int { + return m[s] +} + +//go:noinline +func testZeroSet(m map[hasPadding]int, s hasPadding) { + m[s] = 5 +} + +//go:noinline +func testZeroArrayGet(m map[[2]hasPadding]int, s [2]hasPadding) int { + return m[s] +} + +//go:noinline +func testZeroArraySet(m map[[2]hasPadding]int, s [2]hasPadding) { + m[s] = 5 +} + +func main() { + +} diff --git a/compiler/testdata/zeromap.ll b/compiler/testdata/zeromap.ll new file mode 100644 index 000000000..a04ad242f --- /dev/null +++ b/compiler/testdata/zeromap.ll @@ -0,0 +1,170 @@ +; ModuleID = 'zeromap.go' +source_filename = "zeromap.go" +target datalayout = "e-m:e-p:32:32-p10:8:8-p20:8:8-i64:64-n32:64-S128-ni:1:10:20" +target triple = "wasm32-unknown-wasi" + +%main.hasPadding = type { i1, i32, i1 } + +declare noalias nonnull ptr @runtime.alloc(i32, ptr, ptr) #0 + +declare void @runtime.trackPointer(ptr nocapture readonly, ptr, ptr) #0 + +; Function Attrs: nounwind +define hidden void @main.init(ptr %context) unnamed_addr #1 { +entry: + ret void +} + +; Function Attrs: noinline nounwind +define hidden i32 @main.testZeroGet(ptr dereferenceable_or_null(40) %m, i1 %s.b1, i32 %s.i, i1 %s.b2, ptr %context) unnamed_addr #2 { +entry: + %hashmap.key = alloca %main.hasPadding, align 8 + %hashmap.value = alloca i32, align 4 + %s = alloca %main.hasPadding, align 8 + %0 = insertvalue %main.hasPadding zeroinitializer, i1 %s.b1, 0 + %1 = insertvalue %main.hasPadding %0, i32 %s.i, 1 + %2 = insertvalue %main.hasPadding %1, i1 %s.b2, 2 + %stackalloc = alloca i8, align 1 + store %main.hasPadding zeroinitializer, ptr %s, align 8 + call void @runtime.trackPointer(ptr nonnull %s, ptr nonnull %stackalloc, ptr undef) #4 + store %main.hasPadding %2, ptr %s, align 8 + call void @llvm.lifetime.start.p0(i64 4, ptr nonnull %hashmap.value) + call void @llvm.lifetime.start.p0(i64 12, ptr nonnull %hashmap.key) + store %main.hasPadding %2, ptr %hashmap.key, align 8 + %3 = getelementptr inbounds i8, ptr %hashmap.key, i32 1 + call void @runtime.memzero(ptr nonnull %3, i32 3, ptr undef) #4 + %4 = getelementptr inbounds i8, ptr %hashmap.key, i32 9 + call void @runtime.memzero(ptr nonnull %4, i32 3, ptr undef) #4 + %5 = call i1 @runtime.hashmapBinaryGet(ptr %m, ptr nonnull %hashmap.key, ptr nonnull %hashmap.value, i32 4, ptr undef) #4 + call void @llvm.lifetime.end.p0(i64 12, ptr nonnull %hashmap.key) + %6 = load i32, ptr %hashmap.value, align 4 + call void @llvm.lifetime.end.p0(i64 4, ptr nonnull %hashmap.value) + ret i32 %6 +} + +; Function Attrs: argmemonly nocallback nofree nosync nounwind willreturn +declare void @llvm.lifetime.start.p0(i64 immarg, ptr nocapture) #3 + +declare void @runtime.memzero(ptr, i32, ptr) #0 + +declare i1 @runtime.hashmapBinaryGet(ptr dereferenceable_or_null(40), ptr, ptr, i32, ptr) #0 + +; Function Attrs: argmemonly nocallback nofree nosync nounwind willreturn +declare void @llvm.lifetime.end.p0(i64 immarg, ptr nocapture) #3 + +; Function Attrs: noinline nounwind +define hidden void @main.testZeroSet(ptr dereferenceable_or_null(40) %m, i1 %s.b1, i32 %s.i, i1 %s.b2, ptr %context) unnamed_addr #2 { +entry: + %hashmap.key = alloca %main.hasPadding, align 8 + %hashmap.value = alloca i32, align 4 + %s = alloca %main.hasPadding, align 8 + %0 = insertvalue %main.hasPadding zeroinitializer, i1 %s.b1, 0 + %1 = insertvalue %main.hasPadding %0, i32 %s.i, 1 + %2 = insertvalue %main.hasPadding %1, i1 %s.b2, 2 + %stackalloc = alloca i8, align 1 + store %main.hasPadding zeroinitializer, ptr %s, align 8 + call void @runtime.trackPointer(ptr nonnull %s, ptr nonnull %stackalloc, ptr undef) #4 + store %main.hasPadding %2, ptr %s, align 8 + call void @llvm.lifetime.start.p0(i64 4, ptr nonnull %hashmap.value) + store i32 5, ptr %hashmap.value, align 4 + call void @llvm.lifetime.start.p0(i64 12, ptr nonnull %hashmap.key) + store %main.hasPadding %2, ptr %hashmap.key, align 8 + %3 = getelementptr inbounds i8, ptr %hashmap.key, i32 1 + call void @runtime.memzero(ptr nonnull %3, i32 3, ptr undef) #4 + %4 = getelementptr inbounds i8, ptr %hashmap.key, i32 9 + call void @runtime.memzero(ptr nonnull %4, i32 3, ptr undef) #4 + call void @runtime.hashmapBinarySet(ptr %m, ptr nonnull %hashmap.key, ptr nonnull %hashmap.value, ptr undef) #4 + call void @llvm.lifetime.end.p0(i64 12, ptr nonnull %hashmap.key) + call void @llvm.lifetime.end.p0(i64 4, ptr nonnull %hashmap.value) + ret void +} + +declare void @runtime.hashmapBinarySet(ptr dereferenceable_or_null(40), ptr, ptr, ptr) #0 + +; Function Attrs: noinline nounwind +define hidden i32 @main.testZeroArrayGet(ptr dereferenceable_or_null(40) %m, [2 x %main.hasPadding] %s, ptr %context) unnamed_addr #2 { +entry: + %hashmap.key = alloca [2 x %main.hasPadding], align 8 + %hashmap.value = alloca i32, align 4 + %s1 = alloca [2 x %main.hasPadding], align 8 + %stackalloc = alloca i8, align 1 + store %main.hasPadding zeroinitializer, ptr %s1, align 8 + %s1.repack2 = getelementptr inbounds [2 x %main.hasPadding], ptr %s1, i32 0, i32 1 + store %main.hasPadding zeroinitializer, ptr %s1.repack2, align 4 + call void @runtime.trackPointer(ptr nonnull %s1, ptr nonnull %stackalloc, ptr undef) #4 + %s.elt = extractvalue [2 x %main.hasPadding] %s, 0 + store %main.hasPadding %s.elt, ptr %s1, align 8 + %s1.repack3 = getelementptr inbounds [2 x %main.hasPadding], ptr %s1, i32 0, i32 1 + %s.elt4 = extractvalue [2 x %main.hasPadding] %s, 1 + store %main.hasPadding %s.elt4, ptr %s1.repack3, align 4 + call void @llvm.lifetime.start.p0(i64 4, ptr nonnull %hashmap.value) + call void @llvm.lifetime.start.p0(i64 24, ptr nonnull %hashmap.key) + %s.elt7 = extractvalue [2 x %main.hasPadding] %s, 0 + store %main.hasPadding %s.elt7, ptr %hashmap.key, align 8 + %hashmap.key.repack8 = getelementptr inbounds [2 x %main.hasPadding], ptr %hashmap.key, i32 0, i32 1 + %s.elt9 = extractvalue [2 x %main.hasPadding] %s, 1 + store %main.hasPadding %s.elt9, ptr %hashmap.key.repack8, align 4 + %0 = getelementptr inbounds i8, ptr %hashmap.key, i32 1 + call void @runtime.memzero(ptr nonnull %0, i32 3, ptr undef) #4 + %1 = getelementptr inbounds i8, ptr %hashmap.key, i32 9 + call void @runtime.memzero(ptr nonnull %1, i32 3, ptr undef) #4 + %2 = getelementptr inbounds i8, ptr %hashmap.key, i32 13 + call void @runtime.memzero(ptr nonnull %2, i32 3, ptr undef) #4 + %3 = getelementptr inbounds i8, ptr %hashmap.key, i32 21 + call void @runtime.memzero(ptr nonnull %3, i32 3, ptr undef) #4 + %4 = call i1 @runtime.hashmapBinaryGet(ptr %m, ptr nonnull %hashmap.key, ptr nonnull %hashmap.value, i32 4, ptr undef) #4 + call void @llvm.lifetime.end.p0(i64 24, ptr nonnull %hashmap.key) + %5 = load i32, ptr %hashmap.value, align 4 + call void @llvm.lifetime.end.p0(i64 4, ptr nonnull %hashmap.value) + ret i32 %5 +} + +; Function Attrs: noinline nounwind +define hidden void @main.testZeroArraySet(ptr dereferenceable_or_null(40) %m, [2 x %main.hasPadding] %s, ptr %context) unnamed_addr #2 { +entry: + %hashmap.key = alloca [2 x %main.hasPadding], align 8 + %hashmap.value = alloca i32, align 4 + %s1 = alloca [2 x %main.hasPadding], align 8 + %stackalloc = alloca i8, align 1 + store %main.hasPadding zeroinitializer, ptr %s1, align 8 + %s1.repack2 = getelementptr inbounds [2 x %main.hasPadding], ptr %s1, i32 0, i32 1 + store %main.hasPadding zeroinitializer, ptr %s1.repack2, align 4 + call void @runtime.trackPointer(ptr nonnull %s1, ptr nonnull %stackalloc, ptr undef) #4 + %s.elt = extractvalue [2 x %main.hasPadding] %s, 0 + store %main.hasPadding %s.elt, ptr %s1, align 8 + %s1.repack3 = getelementptr inbounds [2 x %main.hasPadding], ptr %s1, i32 0, i32 1 + %s.elt4 = extractvalue [2 x %main.hasPadding] %s, 1 + store %main.hasPadding %s.elt4, ptr %s1.repack3, align 4 + call void @llvm.lifetime.start.p0(i64 4, ptr nonnull %hashmap.value) + store i32 5, ptr %hashmap.value, align 4 + call void @llvm.lifetime.start.p0(i64 24, ptr nonnull %hashmap.key) + %s.elt7 = extractvalue [2 x %main.hasPadding] %s, 0 + store %main.hasPadding %s.elt7, ptr %hashmap.key, align 8 + %hashmap.key.repack8 = getelementptr inbounds [2 x %main.hasPadding], ptr %hashmap.key, i32 0, i32 1 + %s.elt9 = extractvalue [2 x %main.hasPadding] %s, 1 + store %main.hasPadding %s.elt9, ptr %hashmap.key.repack8, align 4 + %0 = getelementptr inbounds i8, ptr %hashmap.key, i32 1 + call void @runtime.memzero(ptr nonnull %0, i32 3, ptr undef) #4 + %1 = getelementptr inbounds i8, ptr %hashmap.key, i32 9 + call void @runtime.memzero(ptr nonnull %1, i32 3, ptr undef) #4 + %2 = getelementptr inbounds i8, ptr %hashmap.key, i32 13 + call void @runtime.memzero(ptr nonnull %2, i32 3, ptr undef) #4 + %3 = getelementptr inbounds i8, ptr %hashmap.key, i32 21 + call void @runtime.memzero(ptr nonnull %3, i32 3, ptr undef) #4 + call void @runtime.hashmapBinarySet(ptr %m, ptr nonnull %hashmap.key, ptr nonnull %hashmap.value, ptr undef) #4 + call void @llvm.lifetime.end.p0(i64 24, ptr nonnull %hashmap.key) + call void @llvm.lifetime.end.p0(i64 4, ptr nonnull %hashmap.value) + ret void +} + +; Function Attrs: nounwind +define hidden void @main.main(ptr %context) unnamed_addr #1 { +entry: + ret void +} + +attributes #0 = { "target-features"="+bulk-memory,+nontrapping-fptoint,+sign-ext" } +attributes #1 = { nounwind "target-features"="+bulk-memory,+nontrapping-fptoint,+sign-ext" } +attributes #2 = { noinline nounwind "target-features"="+bulk-memory,+nontrapping-fptoint,+sign-ext" } +attributes #3 = { argmemonly nocallback nofree nosync nounwind willreturn } +attributes #4 = { nounwind } |