aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorDamian Gryski <[email protected]>2023-02-25 13:40:08 -0800
committerGitHub <[email protected]>2023-02-25 22:40:08 +0100
commit476621736c95ce65a75a27ee2749e9ff2b296608 (patch)
tree30b9098ecf6e65efd90a270a3151204f8737b118
parent7b44fcd865370d5ecf7471b04ad1b4e013e63708 (diff)
downloadtinygo-476621736c95ce65a75a27ee2749e9ff2b296608.tar.gz
tinygo-476621736c95ce65a75a27ee2749e9ff2b296608.zip
compiler: zero struct padding during map operations
Fixes #3358
-rw-r--r--compiler/compiler_test.go1
-rw-r--r--compiler/map.go79
-rw-r--r--compiler/testdata/zeromap.go37
-rw-r--r--compiler/testdata/zeromap.ll170
4 files changed, 286 insertions, 1 deletions
diff --git a/compiler/compiler_test.go b/compiler/compiler_test.go
index f8221a08e..74e213dc6 100644
--- a/compiler/compiler_test.go
+++ b/compiler/compiler_test.go
@@ -49,6 +49,7 @@ func TestCompiler(t *testing.T) {
{"goroutine.go", "cortex-m-qemu", "tasks"},
{"channel.go", "", ""},
{"gc.go", "", ""},
+ {"zeromap.go", "", ""},
}
if goMinor >= 20 {
tests = append(tests, testCase{"go1.20.go", "", ""})
diff --git a/compiler/map.go b/compiler/map.go
index 9d162bfc0..632ff2470 100644
--- a/compiler/map.go
+++ b/compiler/map.go
@@ -89,6 +89,7 @@ func (b *builder) createMapLookup(keyType, valueType types.Type, m, key llvm.Val
// growth.
mapKeyAlloca, mapKeyPtr, mapKeySize := b.createTemporaryAlloca(key.Type(), "hashmap.key")
b.CreateStore(key, mapKeyAlloca)
+ b.zeroUndefBytes(b.getLLVMType(keyType), mapKeyAlloca)
// Fetch the value from the hashmap.
params := []llvm.Value{m, mapKeyPtr, mapValuePtr, mapValueSize}
commaOkValue = b.createRuntimeCall("hashmapBinaryGet", params, "")
@@ -133,6 +134,7 @@ func (b *builder) createMapUpdate(keyType types.Type, m, key, value llvm.Value,
// key can be compared with runtime.memequal
keyAlloca, keyPtr, keySize := b.createTemporaryAlloca(key.Type(), "hashmap.key")
b.CreateStore(key, keyAlloca)
+ b.zeroUndefBytes(b.getLLVMType(keyType), keyAlloca)
params := []llvm.Value{m, keyPtr, valuePtr}
b.createRuntimeCall("hashmapBinarySet", params, "")
b.emitLifetimeEnd(keyPtr, keySize)
@@ -161,6 +163,7 @@ func (b *builder) createMapDelete(keyType types.Type, m, key llvm.Value, pos tok
} else if hashmapIsBinaryKey(keyType) {
keyAlloca, keyPtr, keySize := b.createTemporaryAlloca(key.Type(), "hashmap.key")
b.CreateStore(key, keyAlloca)
+ b.zeroUndefBytes(b.getLLVMType(keyType), keyAlloca)
params := []llvm.Value{m, keyPtr}
b.createRuntimeCall("hashmapBinaryDelete", params, "")
b.emitLifetimeEnd(keyPtr, keySize)
@@ -240,7 +243,8 @@ func (b *builder) createMapIteratorNext(rangeVal ssa.Value, llvmRangeVal, it llv
}
// Returns true if this key type does not contain strings, interfaces etc., so
-// can be compared with runtime.memequal.
+// can be compared with runtime.memequal. Note that padding bytes are undef
+// and can alter two "equal" structs being equal when compared with memequal.
func hashmapIsBinaryKey(keyType types.Type) bool {
switch keyType := keyType.(type) {
case *types.Basic:
@@ -263,3 +267,76 @@ func hashmapIsBinaryKey(keyType types.Type) bool {
return false
}
}
+
+func (b *builder) zeroUndefBytes(llvmType llvm.Type, ptr llvm.Value) error {
+ // We know that hashmapIsBinaryKey is true, so we only have to handle those types that can show up there.
+ // To zero all undefined bytes, we iterate over all the fields in the type. For each element, compute the
+ // offset of that element. If it's Basic type, there are no internal padding bytes. For compound types, we recurse to ensure
+ // we handle nested types. Next, we determine if there are any padding bytes before the next
+ // element and zero those as well.
+
+ zero := llvm.ConstInt(b.ctx.Int32Type(), 0, false)
+
+ switch llvmType.TypeKind() {
+ case llvm.IntegerTypeKind:
+ // no padding bytes
+ return nil
+ case llvm.PointerTypeKind:
+ // mo padding bytes
+ return nil
+ case llvm.ArrayTypeKind:
+ llvmArrayType := llvmType
+ llvmElemType := llvmType.ElementType()
+
+ for i := 0; i < llvmArrayType.ArrayLength(); i++ {
+ idx := llvm.ConstInt(b.uintptrType, uint64(i), false)
+ elemPtr := b.CreateInBoundsGEP(llvmArrayType, ptr, []llvm.Value{zero, idx}, "")
+
+ // zero any padding bytes in this element
+ b.zeroUndefBytes(llvmElemType, elemPtr)
+ }
+
+ case llvm.StructTypeKind:
+ llvmStructType := llvmType
+ numFields := llvmStructType.StructElementTypesCount()
+ llvmElementTypes := llvmStructType.StructElementTypes()
+
+ for i := 0; i < numFields; i++ {
+ idx := llvm.ConstInt(b.ctx.Int32Type(), uint64(i), false)
+ elemPtr := b.CreateInBoundsGEP(llvmStructType, ptr, []llvm.Value{zero, idx}, "")
+
+ // zero any padding bytes in this field
+ llvmElemType := llvmElementTypes[i]
+ b.zeroUndefBytes(llvmElemType, elemPtr)
+
+ // zero any padding bytes before the next field, if any
+ offset := b.targetData.ElementOffset(llvmStructType, i)
+ storeSize := b.targetData.TypeStoreSize(llvmElemType)
+ fieldEndOffset := offset + storeSize
+
+ var nextOffset uint64
+ if i < numFields-1 {
+ nextOffset = b.targetData.ElementOffset(llvmStructType, i+1)
+ } else {
+ // Last field? Next offset is the total size of the allcoate struct.
+ nextOffset = b.targetData.TypeAllocSize(llvmStructType)
+ }
+
+ if fieldEndOffset != nextOffset {
+ n := llvm.ConstInt(b.uintptrType, nextOffset-fieldEndOffset, false)
+ llvmStoreSize := llvm.ConstInt(b.uintptrType, storeSize, false)
+ gepPtr := elemPtr
+ if gepPtr.Type() != b.i8ptrType {
+ gepPtr = b.CreateBitCast(gepPtr, b.i8ptrType, "") // LLVM 14
+ }
+ paddingStart := b.CreateInBoundsGEP(b.ctx.Int8Type(), gepPtr, []llvm.Value{llvmStoreSize}, "")
+ if paddingStart.Type() != b.i8ptrType {
+ paddingStart = b.CreateBitCast(paddingStart, b.i8ptrType, "") // LLVM 14
+ }
+ b.createRuntimeCall("memzero", []llvm.Value{paddingStart, n}, "")
+ }
+ }
+ }
+
+ return nil
+}
diff --git a/compiler/testdata/zeromap.go b/compiler/testdata/zeromap.go
new file mode 100644
index 000000000..6cf9f611b
--- /dev/null
+++ b/compiler/testdata/zeromap.go
@@ -0,0 +1,37 @@
+package main
+
+type hasPadding struct {
+ b1 bool
+ i int
+ b2 bool
+}
+
+type nestedPadding struct {
+ b bool
+ hasPadding
+ i int
+}
+
+//go:noinline
+func testZeroGet(m map[hasPadding]int, s hasPadding) int {
+ return m[s]
+}
+
+//go:noinline
+func testZeroSet(m map[hasPadding]int, s hasPadding) {
+ m[s] = 5
+}
+
+//go:noinline
+func testZeroArrayGet(m map[[2]hasPadding]int, s [2]hasPadding) int {
+ return m[s]
+}
+
+//go:noinline
+func testZeroArraySet(m map[[2]hasPadding]int, s [2]hasPadding) {
+ m[s] = 5
+}
+
+func main() {
+
+}
diff --git a/compiler/testdata/zeromap.ll b/compiler/testdata/zeromap.ll
new file mode 100644
index 000000000..a04ad242f
--- /dev/null
+++ b/compiler/testdata/zeromap.ll
@@ -0,0 +1,170 @@
+; ModuleID = 'zeromap.go'
+source_filename = "zeromap.go"
+target datalayout = "e-m:e-p:32:32-p10:8:8-p20:8:8-i64:64-n32:64-S128-ni:1:10:20"
+target triple = "wasm32-unknown-wasi"
+
+%main.hasPadding = type { i1, i32, i1 }
+
+declare noalias nonnull ptr @runtime.alloc(i32, ptr, ptr) #0
+
+declare void @runtime.trackPointer(ptr nocapture readonly, ptr, ptr) #0
+
+; Function Attrs: nounwind
+define hidden void @main.init(ptr %context) unnamed_addr #1 {
+entry:
+ ret void
+}
+
+; Function Attrs: noinline nounwind
+define hidden i32 @main.testZeroGet(ptr dereferenceable_or_null(40) %m, i1 %s.b1, i32 %s.i, i1 %s.b2, ptr %context) unnamed_addr #2 {
+entry:
+ %hashmap.key = alloca %main.hasPadding, align 8
+ %hashmap.value = alloca i32, align 4
+ %s = alloca %main.hasPadding, align 8
+ %0 = insertvalue %main.hasPadding zeroinitializer, i1 %s.b1, 0
+ %1 = insertvalue %main.hasPadding %0, i32 %s.i, 1
+ %2 = insertvalue %main.hasPadding %1, i1 %s.b2, 2
+ %stackalloc = alloca i8, align 1
+ store %main.hasPadding zeroinitializer, ptr %s, align 8
+ call void @runtime.trackPointer(ptr nonnull %s, ptr nonnull %stackalloc, ptr undef) #4
+ store %main.hasPadding %2, ptr %s, align 8
+ call void @llvm.lifetime.start.p0(i64 4, ptr nonnull %hashmap.value)
+ call void @llvm.lifetime.start.p0(i64 12, ptr nonnull %hashmap.key)
+ store %main.hasPadding %2, ptr %hashmap.key, align 8
+ %3 = getelementptr inbounds i8, ptr %hashmap.key, i32 1
+ call void @runtime.memzero(ptr nonnull %3, i32 3, ptr undef) #4
+ %4 = getelementptr inbounds i8, ptr %hashmap.key, i32 9
+ call void @runtime.memzero(ptr nonnull %4, i32 3, ptr undef) #4
+ %5 = call i1 @runtime.hashmapBinaryGet(ptr %m, ptr nonnull %hashmap.key, ptr nonnull %hashmap.value, i32 4, ptr undef) #4
+ call void @llvm.lifetime.end.p0(i64 12, ptr nonnull %hashmap.key)
+ %6 = load i32, ptr %hashmap.value, align 4
+ call void @llvm.lifetime.end.p0(i64 4, ptr nonnull %hashmap.value)
+ ret i32 %6
+}
+
+; Function Attrs: argmemonly nocallback nofree nosync nounwind willreturn
+declare void @llvm.lifetime.start.p0(i64 immarg, ptr nocapture) #3
+
+declare void @runtime.memzero(ptr, i32, ptr) #0
+
+declare i1 @runtime.hashmapBinaryGet(ptr dereferenceable_or_null(40), ptr, ptr, i32, ptr) #0
+
+; Function Attrs: argmemonly nocallback nofree nosync nounwind willreturn
+declare void @llvm.lifetime.end.p0(i64 immarg, ptr nocapture) #3
+
+; Function Attrs: noinline nounwind
+define hidden void @main.testZeroSet(ptr dereferenceable_or_null(40) %m, i1 %s.b1, i32 %s.i, i1 %s.b2, ptr %context) unnamed_addr #2 {
+entry:
+ %hashmap.key = alloca %main.hasPadding, align 8
+ %hashmap.value = alloca i32, align 4
+ %s = alloca %main.hasPadding, align 8
+ %0 = insertvalue %main.hasPadding zeroinitializer, i1 %s.b1, 0
+ %1 = insertvalue %main.hasPadding %0, i32 %s.i, 1
+ %2 = insertvalue %main.hasPadding %1, i1 %s.b2, 2
+ %stackalloc = alloca i8, align 1
+ store %main.hasPadding zeroinitializer, ptr %s, align 8
+ call void @runtime.trackPointer(ptr nonnull %s, ptr nonnull %stackalloc, ptr undef) #4
+ store %main.hasPadding %2, ptr %s, align 8
+ call void @llvm.lifetime.start.p0(i64 4, ptr nonnull %hashmap.value)
+ store i32 5, ptr %hashmap.value, align 4
+ call void @llvm.lifetime.start.p0(i64 12, ptr nonnull %hashmap.key)
+ store %main.hasPadding %2, ptr %hashmap.key, align 8
+ %3 = getelementptr inbounds i8, ptr %hashmap.key, i32 1
+ call void @runtime.memzero(ptr nonnull %3, i32 3, ptr undef) #4
+ %4 = getelementptr inbounds i8, ptr %hashmap.key, i32 9
+ call void @runtime.memzero(ptr nonnull %4, i32 3, ptr undef) #4
+ call void @runtime.hashmapBinarySet(ptr %m, ptr nonnull %hashmap.key, ptr nonnull %hashmap.value, ptr undef) #4
+ call void @llvm.lifetime.end.p0(i64 12, ptr nonnull %hashmap.key)
+ call void @llvm.lifetime.end.p0(i64 4, ptr nonnull %hashmap.value)
+ ret void
+}
+
+declare void @runtime.hashmapBinarySet(ptr dereferenceable_or_null(40), ptr, ptr, ptr) #0
+
+; Function Attrs: noinline nounwind
+define hidden i32 @main.testZeroArrayGet(ptr dereferenceable_or_null(40) %m, [2 x %main.hasPadding] %s, ptr %context) unnamed_addr #2 {
+entry:
+ %hashmap.key = alloca [2 x %main.hasPadding], align 8
+ %hashmap.value = alloca i32, align 4
+ %s1 = alloca [2 x %main.hasPadding], align 8
+ %stackalloc = alloca i8, align 1
+ store %main.hasPadding zeroinitializer, ptr %s1, align 8
+ %s1.repack2 = getelementptr inbounds [2 x %main.hasPadding], ptr %s1, i32 0, i32 1
+ store %main.hasPadding zeroinitializer, ptr %s1.repack2, align 4
+ call void @runtime.trackPointer(ptr nonnull %s1, ptr nonnull %stackalloc, ptr undef) #4
+ %s.elt = extractvalue [2 x %main.hasPadding] %s, 0
+ store %main.hasPadding %s.elt, ptr %s1, align 8
+ %s1.repack3 = getelementptr inbounds [2 x %main.hasPadding], ptr %s1, i32 0, i32 1
+ %s.elt4 = extractvalue [2 x %main.hasPadding] %s, 1
+ store %main.hasPadding %s.elt4, ptr %s1.repack3, align 4
+ call void @llvm.lifetime.start.p0(i64 4, ptr nonnull %hashmap.value)
+ call void @llvm.lifetime.start.p0(i64 24, ptr nonnull %hashmap.key)
+ %s.elt7 = extractvalue [2 x %main.hasPadding] %s, 0
+ store %main.hasPadding %s.elt7, ptr %hashmap.key, align 8
+ %hashmap.key.repack8 = getelementptr inbounds [2 x %main.hasPadding], ptr %hashmap.key, i32 0, i32 1
+ %s.elt9 = extractvalue [2 x %main.hasPadding] %s, 1
+ store %main.hasPadding %s.elt9, ptr %hashmap.key.repack8, align 4
+ %0 = getelementptr inbounds i8, ptr %hashmap.key, i32 1
+ call void @runtime.memzero(ptr nonnull %0, i32 3, ptr undef) #4
+ %1 = getelementptr inbounds i8, ptr %hashmap.key, i32 9
+ call void @runtime.memzero(ptr nonnull %1, i32 3, ptr undef) #4
+ %2 = getelementptr inbounds i8, ptr %hashmap.key, i32 13
+ call void @runtime.memzero(ptr nonnull %2, i32 3, ptr undef) #4
+ %3 = getelementptr inbounds i8, ptr %hashmap.key, i32 21
+ call void @runtime.memzero(ptr nonnull %3, i32 3, ptr undef) #4
+ %4 = call i1 @runtime.hashmapBinaryGet(ptr %m, ptr nonnull %hashmap.key, ptr nonnull %hashmap.value, i32 4, ptr undef) #4
+ call void @llvm.lifetime.end.p0(i64 24, ptr nonnull %hashmap.key)
+ %5 = load i32, ptr %hashmap.value, align 4
+ call void @llvm.lifetime.end.p0(i64 4, ptr nonnull %hashmap.value)
+ ret i32 %5
+}
+
+; Function Attrs: noinline nounwind
+define hidden void @main.testZeroArraySet(ptr dereferenceable_or_null(40) %m, [2 x %main.hasPadding] %s, ptr %context) unnamed_addr #2 {
+entry:
+ %hashmap.key = alloca [2 x %main.hasPadding], align 8
+ %hashmap.value = alloca i32, align 4
+ %s1 = alloca [2 x %main.hasPadding], align 8
+ %stackalloc = alloca i8, align 1
+ store %main.hasPadding zeroinitializer, ptr %s1, align 8
+ %s1.repack2 = getelementptr inbounds [2 x %main.hasPadding], ptr %s1, i32 0, i32 1
+ store %main.hasPadding zeroinitializer, ptr %s1.repack2, align 4
+ call void @runtime.trackPointer(ptr nonnull %s1, ptr nonnull %stackalloc, ptr undef) #4
+ %s.elt = extractvalue [2 x %main.hasPadding] %s, 0
+ store %main.hasPadding %s.elt, ptr %s1, align 8
+ %s1.repack3 = getelementptr inbounds [2 x %main.hasPadding], ptr %s1, i32 0, i32 1
+ %s.elt4 = extractvalue [2 x %main.hasPadding] %s, 1
+ store %main.hasPadding %s.elt4, ptr %s1.repack3, align 4
+ call void @llvm.lifetime.start.p0(i64 4, ptr nonnull %hashmap.value)
+ store i32 5, ptr %hashmap.value, align 4
+ call void @llvm.lifetime.start.p0(i64 24, ptr nonnull %hashmap.key)
+ %s.elt7 = extractvalue [2 x %main.hasPadding] %s, 0
+ store %main.hasPadding %s.elt7, ptr %hashmap.key, align 8
+ %hashmap.key.repack8 = getelementptr inbounds [2 x %main.hasPadding], ptr %hashmap.key, i32 0, i32 1
+ %s.elt9 = extractvalue [2 x %main.hasPadding] %s, 1
+ store %main.hasPadding %s.elt9, ptr %hashmap.key.repack8, align 4
+ %0 = getelementptr inbounds i8, ptr %hashmap.key, i32 1
+ call void @runtime.memzero(ptr nonnull %0, i32 3, ptr undef) #4
+ %1 = getelementptr inbounds i8, ptr %hashmap.key, i32 9
+ call void @runtime.memzero(ptr nonnull %1, i32 3, ptr undef) #4
+ %2 = getelementptr inbounds i8, ptr %hashmap.key, i32 13
+ call void @runtime.memzero(ptr nonnull %2, i32 3, ptr undef) #4
+ %3 = getelementptr inbounds i8, ptr %hashmap.key, i32 21
+ call void @runtime.memzero(ptr nonnull %3, i32 3, ptr undef) #4
+ call void @runtime.hashmapBinarySet(ptr %m, ptr nonnull %hashmap.key, ptr nonnull %hashmap.value, ptr undef) #4
+ call void @llvm.lifetime.end.p0(i64 24, ptr nonnull %hashmap.key)
+ call void @llvm.lifetime.end.p0(i64 4, ptr nonnull %hashmap.value)
+ ret void
+}
+
+; Function Attrs: nounwind
+define hidden void @main.main(ptr %context) unnamed_addr #1 {
+entry:
+ ret void
+}
+
+attributes #0 = { "target-features"="+bulk-memory,+nontrapping-fptoint,+sign-ext" }
+attributes #1 = { nounwind "target-features"="+bulk-memory,+nontrapping-fptoint,+sign-ext" }
+attributes #2 = { noinline nounwind "target-features"="+bulk-memory,+nontrapping-fptoint,+sign-ext" }
+attributes #3 = { argmemonly nocallback nofree nosync nounwind willreturn }
+attributes #4 = { nounwind }