aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2021-09-26 01:24:14 +0200
committerAndrzej Janik <[email protected]>2021-09-26 01:24:14 +0200
commitc23be576e86282c1c4673164cb9e92845cd0517e (patch)
tree5b8180d74f63c3fbc3221e1501aa856ec433b863
parent370c0bd09ef5b49e327368fb1899c1692bb8eff4 (diff)
downloadZLUDA-c23be576e86282c1c4673164cb9e92845cd0517e.tar.gz
ZLUDA-c23be576e86282c1c4673164cb9e92845cd0517e.zip
Finish fixing shared memory pass
-rw-r--r--ptx/src/test/spirv_run/atom_add.spvtxt37
-rw-r--r--ptx/src/test/spirv_run/atom_add_float.spvtxt41
-rw-r--r--ptx/src/test/spirv_run/call.spvtxt4
-rw-r--r--ptx/src/test/spirv_run/extern_shared.spvtxt24
-rw-r--r--ptx/src/test/spirv_run/extern_shared_call.spvtxt39
-rw-r--r--ptx/src/test/spirv_run/mod.rs1
-rw-r--r--ptx/src/test/spirv_run/reg_local.spvtxt19
-rw-r--r--ptx/src/test/spirv_run/shared_ptr_32.spvtxt23
-rw-r--r--ptx/src/test/spirv_run/shared_ptr_take_address.spvtxt16
-rw-r--r--ptx/src/test/spirv_run/shared_unify_extern.spvtxt141
-rw-r--r--ptx/src/test/spirv_run/shared_unify_private.ptx32
-rw-r--r--ptx/src/test/spirv_run/shared_unify_private.spvtxt84
-rw-r--r--ptx/src/test/spirv_run/shared_variable.spvtxt21
-rw-r--r--ptx/src/test/spirv_run/verify.py2
-rw-r--r--ptx/src/translate.rs248
15 files changed, 500 insertions, 232 deletions
diff --git a/ptx/src/test/spirv_run/atom_add.spvtxt b/ptx/src/test/spirv_run/atom_add.spvtxt
index b4de00a..3609247 100644
--- a/ptx/src/test/spirv_run/atom_add.spvtxt
+++ b/ptx/src/test/spirv_run/atom_add.spvtxt
@@ -7,29 +7,33 @@
OpCapability Int64
OpCapability Float16
OpCapability Float64
- %38 = OpExtInstImport "OpenCL.std"
+ OpCapability DenormFlushToZero
+ %42 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
- OpEntryPoint Kernel %1 "atom_add" %4
- OpDecorate %4 Alignment 4
+ OpEntryPoint Kernel %1 "atom_add" %38
+ OpExecutionMode %1 ContractionOff
%void = OpTypeVoid
%uint = OpTypeInt 32 0
%uchar = OpTypeInt 8 0
%uint_1024 = OpConstant %uint 1024
%_arr_uchar_uint_1024 = OpTypeArray %uchar %uint_1024
%_ptr_Workgroup__arr_uchar_uint_1024 = OpTypePointer Workgroup %_arr_uchar_uint_1024
- %4 = OpVariable %_ptr_Workgroup__arr_uchar_uint_1024 Workgroup
+ %38 = OpVariable %_ptr_Workgroup__arr_uchar_uint_1024 Workgroup
%ulong = OpTypeInt 64 0
- %46 = OpTypeFunction %void %ulong %ulong
+ %50 = OpTypeFunction %void %ulong %ulong
%_ptr_Function_ulong = OpTypePointer Function %ulong
%_ptr_Function_uint = OpTypePointer Function %uint
%_ptr_Generic_uint = OpTypePointer Generic %uint
%ulong_4 = OpConstant %ulong 4
%_ptr_Generic_uchar = OpTypePointer Generic %uchar
+%uint_1024_0 = OpConstant %uint 1024
%_ptr_Workgroup_uint = OpTypePointer Workgroup %uint
+%uint_1024_1 = OpConstant %uint 1024
%uint_1 = OpConstant %uint 1
%uint_0 = OpConstant %uint 0
+%uint_1024_2 = OpConstant %uint 1024
%ulong_4_0 = OpConstant %ulong 4
- %1 = OpFunction %void None %46
+ %1 = OpFunction %void None %50
%9 = OpFunctionParameter %ulong
%10 = OpFunctionParameter %ulong
%36 = OpLabel
@@ -51,19 +55,22 @@
OpStore %7 %13
%16 = OpLoad %ulong %5
%30 = OpConvertUToPtr %_ptr_Generic_uint %16
- %51 = OpBitcast %_ptr_Generic_uchar %30
- %52 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %51 %ulong_4
- %26 = OpBitcast %_ptr_Generic_uint %52
+ %55 = OpBitcast %_ptr_Generic_uchar %30
+ %56 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %55 %ulong_4
+ %26 = OpBitcast %_ptr_Generic_uint %56
%15 = OpLoad %uint %26 Aligned 4
OpStore %8 %15
%17 = OpLoad %uint %7
- %31 = OpBitcast %_ptr_Workgroup_uint %4
+ %39 = OpBitcast %_ptr_Workgroup__arr_uchar_uint_1024 %38
+ %31 = OpBitcast %_ptr_Workgroup_uint %39
OpStore %31 %17 Aligned 4
%19 = OpLoad %uint %8
- %32 = OpBitcast %_ptr_Workgroup_uint %4
+ %40 = OpBitcast %_ptr_Workgroup__arr_uchar_uint_1024 %38
+ %32 = OpBitcast %_ptr_Workgroup_uint %40
%18 = OpAtomicIAdd %uint %32 %uint_1 %uint_0 %19
OpStore %7 %18
- %33 = OpBitcast %_ptr_Workgroup_uint %4
+ %41 = OpBitcast %_ptr_Workgroup__arr_uchar_uint_1024 %38
+ %33 = OpBitcast %_ptr_Workgroup_uint %41
%20 = OpLoad %uint %33 Aligned 4
OpStore %8 %20
%21 = OpLoad %ulong %6
@@ -73,9 +80,9 @@
%23 = OpLoad %ulong %6
%24 = OpLoad %uint %8
%35 = OpConvertUToPtr %_ptr_Generic_uint %23
- %56 = OpBitcast %_ptr_Generic_uchar %35
- %57 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %56 %ulong_4_0
- %28 = OpBitcast %_ptr_Generic_uint %57
+ %63 = OpBitcast %_ptr_Generic_uchar %35
+ %64 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %63 %ulong_4_0
+ %28 = OpBitcast %_ptr_Generic_uint %64
OpStore %28 %24 Aligned 4
OpReturn
OpFunctionEnd
diff --git a/ptx/src/test/spirv_run/atom_add_float.spvtxt b/ptx/src/test/spirv_run/atom_add_float.spvtxt
index 7d25632..9533d83 100644
--- a/ptx/src/test/spirv_run/atom_add_float.spvtxt
+++ b/ptx/src/test/spirv_run/atom_add_float.spvtxt
@@ -7,34 +7,38 @@
OpCapability Int64
OpCapability Float16
OpCapability Float64
- %42 = OpExtInstImport "OpenCL.std"
+ OpCapability DenormFlushToZero
+ %46 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
- OpEntryPoint Kernel %1 "atom_add_float" %4
+ OpEntryPoint Kernel %1 "atom_add_float" %42
+ OpExecutionMode %1 ContractionOff
OpDecorate %37 LinkageAttributes "__zluda_ptx_impl__atom_relaxed_gpu_shared_add_f32" Import
- OpDecorate %4 Alignment 4
%void = OpTypeVoid
%float = OpTypeFloat 32
%_ptr_Workgroup_float = OpTypePointer Workgroup %float
- %46 = OpTypeFunction %float %_ptr_Workgroup_float %float
+ %50 = OpTypeFunction %float %_ptr_Workgroup_float %float
%uint = OpTypeInt 32 0
%uchar = OpTypeInt 8 0
%uint_1024 = OpConstant %uint 1024
%_arr_uchar_uint_1024 = OpTypeArray %uchar %uint_1024
%_ptr_Workgroup__arr_uchar_uint_1024 = OpTypePointer Workgroup %_arr_uchar_uint_1024
- %4 = OpVariable %_ptr_Workgroup__arr_uchar_uint_1024 Workgroup
+ %42 = OpVariable %_ptr_Workgroup__arr_uchar_uint_1024 Workgroup
%ulong = OpTypeInt 64 0
- %53 = OpTypeFunction %void %ulong %ulong
+ %57 = OpTypeFunction %void %ulong %ulong
%_ptr_Function_ulong = OpTypePointer Function %ulong
%_ptr_Function_float = OpTypePointer Function %float
%_ptr_Generic_float = OpTypePointer Generic %float
%ulong_4 = OpConstant %ulong 4
%_ptr_Generic_uchar = OpTypePointer Generic %uchar
+%uint_1024_0 = OpConstant %uint 1024
+%uint_1024_1 = OpConstant %uint 1024
+%uint_1024_2 = OpConstant %uint 1024
%ulong_4_0 = OpConstant %ulong 4
- %37 = OpFunction %float None %46
+ %37 = OpFunction %float None %50
%39 = OpFunctionParameter %_ptr_Workgroup_float
%40 = OpFunctionParameter %float
OpFunctionEnd
- %1 = OpFunction %void None %53
+ %1 = OpFunction %void None %57
%9 = OpFunctionParameter %ulong
%10 = OpFunctionParameter %ulong
%36 = OpLabel
@@ -56,19 +60,22 @@
OpStore %7 %13
%16 = OpLoad %ulong %5
%30 = OpConvertUToPtr %_ptr_Generic_float %16
- %58 = OpBitcast %_ptr_Generic_uchar %30
- %59 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %58 %ulong_4
- %26 = OpBitcast %_ptr_Generic_float %59
+ %62 = OpBitcast %_ptr_Generic_uchar %30
+ %63 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %62 %ulong_4
+ %26 = OpBitcast %_ptr_Generic_float %63
%15 = OpLoad %float %26 Aligned 4
OpStore %8 %15
%17 = OpLoad %float %7
- %31 = OpBitcast %_ptr_Workgroup_float %4
+ %43 = OpBitcast %_ptr_Workgroup__arr_uchar_uint_1024 %42
+ %31 = OpBitcast %_ptr_Workgroup_float %43
OpStore %31 %17 Aligned 4
%19 = OpLoad %float %8
- %32 = OpBitcast %_ptr_Workgroup_float %4
+ %44 = OpBitcast %_ptr_Workgroup__arr_uchar_uint_1024 %42
+ %32 = OpBitcast %_ptr_Workgroup_float %44
%18 = OpFunctionCall %float %37 %32 %19
OpStore %7 %18
- %33 = OpBitcast %_ptr_Workgroup_float %4
+ %45 = OpBitcast %_ptr_Workgroup__arr_uchar_uint_1024 %42
+ %33 = OpBitcast %_ptr_Workgroup_float %45
%20 = OpLoad %float %33 Aligned 4
OpStore %8 %20
%21 = OpLoad %ulong %6
@@ -78,9 +85,9 @@
%23 = OpLoad %ulong %6
%24 = OpLoad %float %8
%35 = OpConvertUToPtr %_ptr_Generic_float %23
- %60 = OpBitcast %_ptr_Generic_uchar %35
- %61 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %60 %ulong_4_0
- %28 = OpBitcast %_ptr_Generic_float %61
+ %67 = OpBitcast %_ptr_Generic_uchar %35
+ %68 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %67 %ulong_4_0
+ %28 = OpBitcast %_ptr_Generic_float %68
OpStore %28 %24 Aligned 4
OpReturn
OpFunctionEnd
diff --git a/ptx/src/test/spirv_run/call.spvtxt b/ptx/src/test/spirv_run/call.spvtxt
index 6929b1e..c29984e 100644
--- a/ptx/src/test/spirv_run/call.spvtxt
+++ b/ptx/src/test/spirv_run/call.spvtxt
@@ -7,9 +7,13 @@
OpCapability Int64
OpCapability Float16
OpCapability Float64
+ OpCapability DenormFlushToZero
%37 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %4 "call"
+ OpExecutionMode %4 ContractionOff
+ OpDecorate %4 LinkageAttributes "call" Export
+ OpDecorate %1 LinkageAttributes "incr" Export
%void = OpTypeVoid
%ulong = OpTypeInt 64 0
%40 = OpTypeFunction %void %ulong %ulong
diff --git a/ptx/src/test/spirv_run/extern_shared.spvtxt b/ptx/src/test/spirv_run/extern_shared.spvtxt
index ed1c489..82d86ae 100644
--- a/ptx/src/test/spirv_run/extern_shared.spvtxt
+++ b/ptx/src/test/spirv_run/extern_shared.spvtxt
@@ -7,21 +7,23 @@
OpCapability Int64
OpCapability Float16
OpCapability Float64
- %24 = OpExtInstImport "OpenCL.std"
+ OpCapability DenormFlushToZero
+ %27 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
- OpEntryPoint Kernel %2 "extern_shared" %1
+ OpEntryPoint Kernel %2 "extern_shared" %24
OpExecutionMode %2 ContractionOff
- OpDecorate %1 LinkageAttributes "shared_mem" Import
%void = OpTypeVoid
- %uint = OpTypeInt 32 0
-%_ptr_Workgroup_uint = OpTypePointer Workgroup %uint
- %1 = OpVariable %_ptr_Workgroup_uint Workgroup
+ %uchar = OpTypeInt 8 0
+%_ptr_Workgroup_uchar = OpTypePointer Workgroup %uchar
+ %24 = OpVariable %_ptr_Workgroup_uchar Workgroup
%ulong = OpTypeInt 64 0
- %29 = OpTypeFunction %void %ulong %ulong
+ %32 = OpTypeFunction %void %ulong %ulong
%_ptr_Function_ulong = OpTypePointer Function %ulong
%_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong
+ %uint = OpTypeInt 32 0
+%_ptr_Workgroup_uint = OpTypePointer Workgroup %uint
%_ptr_Workgroup_ulong = OpTypePointer Workgroup %ulong
- %2 = OpFunction %void None %29
+ %2 = OpFunction %void None %32
%8 = OpFunctionParameter %ulong
%9 = OpFunctionParameter %ulong
%22 = OpLabel
@@ -41,9 +43,11 @@
%12 = OpLoad %ulong %18 Aligned 8
OpStore %7 %12
%14 = OpLoad %ulong %7
- %19 = OpBitcast %_ptr_Workgroup_ulong %1
+ %25 = OpBitcast %_ptr_Workgroup_uint %24
+ %19 = OpBitcast %_ptr_Workgroup_ulong %25
OpStore %19 %14 Aligned 8
- %20 = OpBitcast %_ptr_Workgroup_ulong %1
+ %26 = OpBitcast %_ptr_Workgroup_uint %24
+ %20 = OpBitcast %_ptr_Workgroup_ulong %26
%15 = OpLoad %ulong %20 Aligned 8
OpStore %7 %15
%16 = OpLoad %ulong %6
diff --git a/ptx/src/test/spirv_run/extern_shared_call.spvtxt b/ptx/src/test/spirv_run/extern_shared_call.spvtxt
index 941eb39..3cc78cb 100644
--- a/ptx/src/test/spirv_run/extern_shared_call.spvtxt
+++ b/ptx/src/test/spirv_run/extern_shared_call.spvtxt
@@ -7,38 +7,42 @@
OpCapability Int64
OpCapability Float16
OpCapability Float64
- %34 = OpExtInstImport "OpenCL.std"
+ OpCapability DenormFlushToZero
+ %41 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
- OpEntryPoint Kernel %12 "extern_shared_call" %1
+ OpEntryPoint Kernel %12 "extern_shared_call" %37
OpExecutionMode %12 ContractionOff
- OpDecorate %1 Alignment 4
- OpDecorate %1 LinkageAttributes "shared_mem" Import
%void = OpTypeVoid
- %uint = OpTypeInt 32 0
-%_ptr_Workgroup_uint = OpTypePointer Workgroup %uint
- %1 = OpVariable %_ptr_Workgroup_uint Workgroup
- %38 = OpTypeFunction %void
+ %uchar = OpTypeInt 8 0
+%_ptr_Workgroup_uchar = OpTypePointer Workgroup %uchar
+ %45 = OpTypeFunction %void %_ptr_Workgroup_uchar
%ulong = OpTypeInt 64 0
%_ptr_Function_ulong = OpTypePointer Function %ulong
+ %uint = OpTypeInt 32 0
+%_ptr_Workgroup_uint = OpTypePointer Workgroup %uint
%_ptr_Workgroup_ulong = OpTypePointer Workgroup %ulong
%ulong_2 = OpConstant %ulong 2
- %42 = OpTypeFunction %void %ulong %ulong
+ %37 = OpVariable %_ptr_Workgroup_uchar Workgroup
+ %51 = OpTypeFunction %void %ulong %ulong
%_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong
- %2 = OpFunction %void None %38
+ %2 = OpFunction %void None %45
+ %34 = OpFunctionParameter %_ptr_Workgroup_uchar
%11 = OpLabel
%3 = OpVariable %_ptr_Function_ulong Function
- %9 = OpBitcast %_ptr_Workgroup_ulong %1
+ %35 = OpBitcast %_ptr_Workgroup_uint %34
+ %9 = OpBitcast %_ptr_Workgroup_ulong %35
%4 = OpLoad %ulong %9 Aligned 8
OpStore %3 %4
%6 = OpLoad %ulong %3
%5 = OpIAdd %ulong %6 %ulong_2
OpStore %3 %5
%7 = OpLoad %ulong %3
- %10 = OpBitcast %_ptr_Workgroup_ulong %1
+ %36 = OpBitcast %_ptr_Workgroup_uint %34
+ %10 = OpBitcast %_ptr_Workgroup_ulong %36
OpStore %10 %7 Aligned 8
OpReturn
OpFunctionEnd
- %12 = OpFunction %void None %42
+ %12 = OpFunction %void None %51
%18 = OpFunctionParameter %ulong
%19 = OpFunctionParameter %ulong
%32 = OpLabel
@@ -58,10 +62,13 @@
%22 = OpLoad %ulong %28 Aligned 8
OpStore %17 %22
%24 = OpLoad %ulong %17
- %29 = OpBitcast %_ptr_Workgroup_ulong %1
+ %38 = OpBitcast %_ptr_Workgroup_uint %37
+ %29 = OpBitcast %_ptr_Workgroup_ulong %38
OpStore %29 %24 Aligned 8
- %44 = OpFunctionCall %void %2
- %30 = OpBitcast %_ptr_Workgroup_ulong %1
+ %39 = OpBitcast %_ptr_Workgroup_uchar %37
+ %53 = OpFunctionCall %void %2 %39
+ %40 = OpBitcast %_ptr_Workgroup_uint %37
+ %30 = OpBitcast %_ptr_Workgroup_ulong %40
%25 = OpLoad %ulong %30 Aligned 8
OpStore %17 %25
%26 = OpLoad %ulong %16
diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs
index b7fd386..6c073f3 100644
--- a/ptx/src/test/spirv_run/mod.rs
+++ b/ptx/src/test/spirv_run/mod.rs
@@ -222,6 +222,7 @@ test_ptx!(prmt, [0x70c507d6u32, 0x6fbd4b5cu32], [0x6fbdd65cu32]);
test_ptx!(activemask, [0u32], [1u32]);
test_ptx!(membar, [152731u32], [152731u32]);
test_ptx!(shared_unify_extern, [7681u64], [15362u64]);
+test_ptx!(shared_unify_private, [67153u64], [134306u64]);
test_ptx!(func_ptr);
test_ptx!(lanemask_lt);
diff --git a/ptx/src/test/spirv_run/reg_local.spvtxt b/ptx/src/test/spirv_run/reg_local.spvtxt
index 4a69450..ddb6a9e 100644
--- a/ptx/src/test/spirv_run/reg_local.spvtxt
+++ b/ptx/src/test/spirv_run/reg_local.spvtxt
@@ -7,6 +7,7 @@
OpCapability Int64
OpCapability Float16
OpCapability Float64
+ OpCapability DenormFlushToZero
%34 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %1 "reg_local"
@@ -51,22 +52,24 @@
OpStore %7 %12
%14 = OpLoad %ulong %7
%19 = OpIAdd %ulong %14 %ulong_1
- %26 = OpPtrCastToGeneric %_ptr_Generic_ulong %4
+ %46 = OpBitcast %_ptr_Function_ulong %4
+ %26 = OpPtrCastToGeneric %_ptr_Generic_ulong %46
%27 = OpCopyObject %ulong %19
OpStore %26 %27 Aligned 8
- %28 = OpPtrCastToGeneric %_ptr_Generic_ulong %4
- %47 = OpBitcast %_ptr_Generic_uchar %28
- %48 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %47 %ulong_0
- %21 = OpBitcast %_ptr_Generic_ulong %48
+ %47 = OpBitcast %_ptr_Function_ulong %4
+ %28 = OpPtrCastToGeneric %_ptr_Generic_ulong %47
+ %49 = OpBitcast %_ptr_Generic_uchar %28
+ %50 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %49 %ulong_0
+ %21 = OpBitcast %_ptr_Generic_ulong %50
%29 = OpLoad %ulong %21 Aligned 8
%15 = OpCopyObject %ulong %29
OpStore %7 %15
%16 = OpLoad %ulong %6
%17 = OpLoad %ulong %7
%30 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %16
- %50 = OpBitcast %_ptr_CrossWorkgroup_uchar %30
- %51 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %50 %ulong_0_0
- %23 = OpBitcast %_ptr_CrossWorkgroup_ulong %51
+ %52 = OpBitcast %_ptr_CrossWorkgroup_uchar %30
+ %53 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %52 %ulong_0_0
+ %23 = OpBitcast %_ptr_CrossWorkgroup_ulong %53
%31 = OpCopyObject %ulong %17
OpStore %23 %31 Aligned 8
OpReturn
diff --git a/ptx/src/test/spirv_run/shared_ptr_32.spvtxt b/ptx/src/test/spirv_run/shared_ptr_32.spvtxt
index 1b2e3dd..020c15b 100644
--- a/ptx/src/test/spirv_run/shared_ptr_32.spvtxt
+++ b/ptx/src/test/spirv_run/shared_ptr_32.spvtxt
@@ -7,26 +7,28 @@
OpCapability Int64
OpCapability Float16
OpCapability Float64
- %32 = OpExtInstImport "OpenCL.std"
+ OpCapability DenormFlushToZero
+ %34 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
- OpEntryPoint Kernel %1 "shared_ptr_32" %4
- OpDecorate %4 Alignment 4
+ OpEntryPoint Kernel %1 "shared_ptr_32" %32
+ OpExecutionMode %1 ContractionOff
%void = OpTypeVoid
%uint = OpTypeInt 32 0
%uchar = OpTypeInt 8 0
%uint_128 = OpConstant %uint 128
%_arr_uchar_uint_128 = OpTypeArray %uchar %uint_128
%_ptr_Workgroup__arr_uchar_uint_128 = OpTypePointer Workgroup %_arr_uchar_uint_128
- %4 = OpVariable %_ptr_Workgroup__arr_uchar_uint_128 Workgroup
+ %32 = OpVariable %_ptr_Workgroup__arr_uchar_uint_128 Workgroup
%ulong = OpTypeInt 64 0
- %40 = OpTypeFunction %void %ulong %ulong
+ %42 = OpTypeFunction %void %ulong %ulong
%_ptr_Function_ulong = OpTypePointer Function %ulong
%_ptr_Function_uint = OpTypePointer Function %uint
+ %uint_128_0 = OpConstant %uint 128
%_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong
%_ptr_Workgroup_ulong = OpTypePointer Workgroup %ulong
%ulong_0 = OpConstant %ulong 0
%_ptr_Workgroup_uchar = OpTypePointer Workgroup %uchar
- %1 = OpFunction %void None %40
+ %1 = OpFunction %void None %42
%10 = OpFunctionParameter %ulong
%11 = OpFunctionParameter %ulong
%30 = OpLabel
@@ -43,7 +45,8 @@
OpStore %5 %12
%13 = OpLoad %ulong %3 Aligned 8
OpStore %6 %13
- %25 = OpConvertPtrToU %uint %4
+ %33 = OpBitcast %_ptr_Workgroup__arr_uchar_uint_128 %32
+ %25 = OpConvertPtrToU %uint %33
%14 = OpCopyObject %uint %25
OpStore %7 %14
%16 = OpLoad %ulong %5
@@ -56,9 +59,9 @@
OpStore %27 %18 Aligned 8
%20 = OpLoad %uint %7
%28 = OpConvertUToPtr %_ptr_Workgroup_ulong %20
- %46 = OpBitcast %_ptr_Workgroup_uchar %28
- %47 = OpInBoundsPtrAccessChain %_ptr_Workgroup_uchar %46 %ulong_0
- %24 = OpBitcast %_ptr_Workgroup_ulong %47
+ %49 = OpBitcast %_ptr_Workgroup_uchar %28
+ %50 = OpInBoundsPtrAccessChain %_ptr_Workgroup_uchar %49 %ulong_0
+ %24 = OpBitcast %_ptr_Workgroup_ulong %50
%19 = OpLoad %ulong %24 Aligned 8
OpStore %9 %19
%21 = OpLoad %ulong %6
diff --git a/ptx/src/test/spirv_run/shared_ptr_take_address.spvtxt b/ptx/src/test/spirv_run/shared_ptr_take_address.spvtxt
index 3ebe810..90e04f3 100644
--- a/ptx/src/test/spirv_run/shared_ptr_take_address.spvtxt
+++ b/ptx/src/test/spirv_run/shared_ptr_take_address.spvtxt
@@ -7,22 +7,21 @@
OpCapability Int64
OpCapability Float16
OpCapability Float64
- %30 = OpExtInstImport "OpenCL.std"
+ OpCapability DenormFlushToZero
+ %32 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
- OpEntryPoint Kernel %2 "shared_ptr_take_address" %1
+ OpEntryPoint Kernel %2 "shared_ptr_take_address" %30
OpExecutionMode %2 ContractionOff
- OpDecorate %1 Alignment 4
- OpDecorate %1 LinkageAttributes "shared_mem" Import
%void = OpTypeVoid
%uchar = OpTypeInt 8 0
%_ptr_Workgroup_uchar = OpTypePointer Workgroup %uchar
- %1 = OpVariable %_ptr_Workgroup_uchar Workgroup
+ %30 = OpVariable %_ptr_Workgroup_uchar Workgroup
%ulong = OpTypeInt 64 0
- %35 = OpTypeFunction %void %ulong %ulong
+ %37 = OpTypeFunction %void %ulong %ulong
%_ptr_Function_ulong = OpTypePointer Function %ulong
%_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong
%_ptr_Workgroup_ulong = OpTypePointer Workgroup %ulong
- %2 = OpFunction %void None %35
+ %2 = OpFunction %void None %37
%10 = OpFunctionParameter %ulong
%11 = OpFunctionParameter %ulong
%28 = OpLabel
@@ -39,7 +38,8 @@
OpStore %5 %12
%13 = OpLoad %ulong %4 Aligned 8
OpStore %6 %13
- %23 = OpConvertPtrToU %ulong %1
+ %31 = OpBitcast %_ptr_Workgroup_uchar %30
+ %23 = OpConvertPtrToU %ulong %31
%14 = OpCopyObject %ulong %23
OpStore %7 %14
%16 = OpLoad %ulong %5
diff --git a/ptx/src/test/spirv_run/shared_unify_extern.spvtxt b/ptx/src/test/spirv_run/shared_unify_extern.spvtxt
index 9b62045..2dd2056 100644
--- a/ptx/src/test/spirv_run/shared_unify_extern.spvtxt
+++ b/ptx/src/test/spirv_run/shared_unify_extern.spvtxt
@@ -1,62 +1,79 @@
- OpCapability GenericPointer
- OpCapability Linkage
- OpCapability Addresses
- OpCapability Kernel
- OpCapability Int8
- OpCapability Int16
- OpCapability Int64
- OpCapability Float16
- OpCapability Float64
- %30 = OpExtInstImport "OpenCL.std"
- OpMemoryModel Physical64 OpenCL
- OpEntryPoint Kernel %2 "shared_ptr_take_address" %1
- OpExecutionMode %2 ContractionOff
- OpDecorate %1 Alignment 4
- OpDecorate %1 LinkageAttributes "shared_mem" Import
- %void = OpTypeVoid
- %uchar = OpTypeInt 8 0
-%_ptr_Workgroup_uchar = OpTypePointer Workgroup %uchar
- %1 = OpVariable %_ptr_Workgroup_uchar Workgroup
- %ulong = OpTypeInt 64 0
- %35 = OpTypeFunction %void %ulong %ulong
-%_ptr_Function_ulong = OpTypePointer Function %ulong
-%_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong
-%_ptr_Workgroup_ulong = OpTypePointer Workgroup %ulong
- %2 = OpFunction %void None %35
- %10 = OpFunctionParameter %ulong
- %11 = OpFunctionParameter %ulong
- %28 = OpLabel
- %3 = OpVariable %_ptr_Function_ulong Function
- %4 = OpVariable %_ptr_Function_ulong Function
- %5 = OpVariable %_ptr_Function_ulong Function
- %6 = OpVariable %_ptr_Function_ulong Function
- %7 = OpVariable %_ptr_Function_ulong Function
- %8 = OpVariable %_ptr_Function_ulong Function
- %9 = OpVariable %_ptr_Function_ulong Function
- OpStore %3 %10
- OpStore %4 %11
- %12 = OpLoad %ulong %3 Aligned 8
- OpStore %5 %12
- %13 = OpLoad %ulong %4 Aligned 8
- OpStore %6 %13
- %23 = OpConvertPtrToU %ulong %1
- %14 = OpCopyObject %ulong %23
- OpStore %7 %14
- %16 = OpLoad %ulong %5
- %24 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %16
- %15 = OpLoad %ulong %24 Aligned 8
- OpStore %8 %15
- %17 = OpLoad %ulong %7
- %18 = OpLoad %ulong %8
- %25 = OpConvertUToPtr %_ptr_Workgroup_ulong %17
- OpStore %25 %18 Aligned 8
- %20 = OpLoad %ulong %7
- %26 = OpConvertUToPtr %_ptr_Workgroup_ulong %20
- %19 = OpLoad %ulong %26 Aligned 8
- OpStore %9 %19
- %21 = OpLoad %ulong %6
- %22 = OpLoad %ulong %9
- %27 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %21
- OpStore %27 %22 Aligned 8
- OpReturn
- OpFunctionEnd
+ OpCapability GenericPointer
+ OpCapability Linkage
+ OpCapability Addresses
+ OpCapability Kernel
+ OpCapability Int8
+ OpCapability Int16
+ OpCapability Int64
+ OpCapability Float16
+ OpCapability Float64
+ OpCapability DenormFlushToZero
+ %41 = OpExtInstImport "OpenCL.std"
+ OpMemoryModel Physical64 OpenCL
+ OpEntryPoint Kernel %10 "shared_unify_extern" %38
+ OpExecutionMode %10 ContractionOff
+ %void = OpTypeVoid
+ %ulong = OpTypeInt 64 0
+ %uchar = OpTypeInt 8 0
+%_ptr_Workgroup_uchar = OpTypePointer Workgroup %uchar
+ %46 = OpTypeFunction %ulong %_ptr_Workgroup_uchar
+%_ptr_Function_ulong = OpTypePointer Function %ulong
+ %uint = OpTypeInt 32 0
+ %uint_4 = OpConstant %uint 4
+%_arr_uint_uint_4 = OpTypeArray %uint %uint_4
+%_ptr_Workgroup__arr_uint_uint_4 = OpTypePointer Workgroup %_arr_uint_uint_4
+%_ptr_Workgroup_ulong = OpTypePointer Workgroup %ulong
+ %38 = OpVariable %_ptr_Workgroup_uchar Workgroup
+ %53 = OpTypeFunction %void %ulong %ulong
+%_ptr_Generic_ulong = OpTypePointer Generic %ulong
+%_ptr_Workgroup_uint = OpTypePointer Workgroup %uint
+ %3 = OpFunction %ulong None %46
+ %36 = OpFunctionParameter %_ptr_Workgroup_uchar
+ %9 = OpLabel
+ %4 = OpVariable %_ptr_Function_ulong Function
+ %37 = OpBitcast %_ptr_Workgroup__arr_uint_uint_4 %36
+ %8 = OpBitcast %_ptr_Workgroup_ulong %37
+ %7 = OpLoad %ulong %8 Aligned 8
+ %5 = OpCopyObject %ulong %7
+ OpStore %4 %5
+ %6 = OpLoad %ulong %4
+ OpReturnValue %6
+ OpFunctionEnd
+ %10 = OpFunction %void None %53
+ %17 = OpFunctionParameter %ulong
+ %18 = OpFunctionParameter %ulong
+ %34 = OpLabel
+ %11 = OpVariable %_ptr_Function_ulong Function
+ %12 = OpVariable %_ptr_Function_ulong Function
+ %13 = OpVariable %_ptr_Function_ulong Function
+ %14 = OpVariable %_ptr_Function_ulong Function
+ %15 = OpVariable %_ptr_Function_ulong Function
+ %16 = OpVariable %_ptr_Function_ulong Function
+ OpStore %11 %17
+ OpStore %12 %18
+ %19 = OpLoad %ulong %11 Aligned 8
+ OpStore %13 %19
+ %20 = OpLoad %ulong %12 Aligned 8
+ OpStore %14 %20
+ %22 = OpLoad %ulong %13
+ %30 = OpConvertUToPtr %_ptr_Generic_ulong %22
+ %21 = OpLoad %ulong %30 Aligned 8
+ OpStore %15 %21
+ %23 = OpLoad %ulong %15
+ %39 = OpBitcast %_ptr_Workgroup_uint %38
+ %31 = OpBitcast %_ptr_Workgroup_ulong %39
+ OpStore %31 %23 Aligned 8
+ %40 = OpBitcast %_ptr_Workgroup_uchar %38
+ %32 = OpFunctionCall %ulong %3 %40
+ %24 = OpCopyObject %ulong %32
+ OpStore %16 %24
+ %26 = OpLoad %ulong %16
+ %27 = OpLoad %ulong %15
+ %25 = OpIAdd %ulong %26 %27
+ OpStore %16 %25
+ %28 = OpLoad %ulong %14
+ %29 = OpLoad %ulong %16
+ %33 = OpConvertUToPtr %_ptr_Generic_ulong %28
+ OpStore %33 %29 Aligned 8
+ OpReturn
+ OpFunctionEnd
diff --git a/ptx/src/test/spirv_run/shared_unify_private.ptx b/ptx/src/test/spirv_run/shared_unify_private.ptx
new file mode 100644
index 0000000..fd31357
--- /dev/null
+++ b/ptx/src/test/spirv_run/shared_unify_private.ptx
@@ -0,0 +1,32 @@
+.version 6.5
+.target sm_30
+.address_size 64
+
+.func (.reg .b64 out) load_from_shared()
+{
+ .shared .b32 shared_mod[4];
+ ld.shared.u64 out, [shared_mod];
+ ret;
+}
+
+.visible .entry shared_unify_private(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .shared .b32 shared_ex[2];
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .u64 temp1;
+ .reg .u64 temp2;
+
+ ld.param.u64 in_addr, [input];
+ ld.param.u64 out_addr, [output];
+
+ ld.u64 temp1, [in_addr];
+ st.shared.u64 [shared_ex], temp1;
+ call (temp2), load_from_shared;
+ add.u64 temp2, temp2, temp1;
+ st.u64 [out_addr], temp2;
+ ret;
+}
diff --git a/ptx/src/test/spirv_run/shared_unify_private.spvtxt b/ptx/src/test/spirv_run/shared_unify_private.spvtxt
new file mode 100644
index 0000000..69bf018
--- /dev/null
+++ b/ptx/src/test/spirv_run/shared_unify_private.spvtxt
@@ -0,0 +1,84 @@
+ OpCapability GenericPointer
+ OpCapability Linkage
+ OpCapability Addresses
+ OpCapability Kernel
+ OpCapability Int8
+ OpCapability Int16
+ OpCapability Int64
+ OpCapability Float16
+ OpCapability Float64
+ OpCapability DenormFlushToZero
+ %41 = OpExtInstImport "OpenCL.std"
+ OpMemoryModel Physical64 OpenCL
+ OpEntryPoint Kernel %9 "shared_unify_private" %38
+ OpExecutionMode %9 ContractionOff
+ %void = OpTypeVoid
+ %ulong = OpTypeInt 64 0
+ %uchar = OpTypeInt 8 0
+%_ptr_Workgroup_uchar = OpTypePointer Workgroup %uchar
+ %46 = OpTypeFunction %ulong %_ptr_Workgroup_uchar
+%_ptr_Function_ulong = OpTypePointer Function %ulong
+ %uint = OpTypeInt 32 0
+ %uint_4 = OpConstant %uint 4
+%_arr_uint_uint_4 = OpTypeArray %uint %uint_4
+%_ptr_Workgroup__arr_uint_uint_4 = OpTypePointer Workgroup %_arr_uint_uint_4
+%_ptr_Workgroup_ulong = OpTypePointer Workgroup %ulong
+ %uint_16 = OpConstant %uint 16
+%_arr_uchar_uint_16 = OpTypeArray %uchar %uint_16
+%_ptr_Workgroup__arr_uchar_uint_16 = OpTypePointer Workgroup %_arr_uchar_uint_16
+ %38 = OpVariable %_ptr_Workgroup__arr_uchar_uint_16 Workgroup
+ %56 = OpTypeFunction %void %ulong %ulong
+%_ptr_Generic_ulong = OpTypePointer Generic %ulong
+ %uint_2 = OpConstant %uint 2
+%_arr_uint_uint_2 = OpTypeArray %uint %uint_2
+%_ptr_Workgroup__arr_uint_uint_2 = OpTypePointer Workgroup %_arr_uint_uint_2
+ %1 = OpFunction %ulong None %46
+ %36 = OpFunctionParameter %_ptr_Workgroup_uchar
+ %8 = OpLabel
+ %2 = OpVariable %_ptr_Function_ulong Function
+ %37 = OpBitcast %_ptr_Workgroup__arr_uint_uint_4 %36
+ %7 = OpBitcast %_ptr_Workgroup_ulong %37
+ %6 = OpLoad %ulong %7 Aligned 8
+ %4 = OpCopyObject %ulong %6
+ OpStore %2 %4
+ %5 = OpLoad %ulong %2
+ OpReturnValue %5
+ OpFunctionEnd
+ %9 = OpFunction %void None %56
+ %17 = OpFunctionParameter %ulong
+ %18 = OpFunctionParameter %ulong
+ %34 = OpLabel
+ %10 = OpVariable %_ptr_Function_ulong Function
+ %11 = OpVariable %_ptr_Function_ulong Function
+ %13 = OpVariable %_ptr_Function_ulong Function
+ %14 = OpVariable %_ptr_Function_ulong Function
+ %15 = OpVariable %_ptr_Function_ulong Function
+ %16 = OpVariable %_ptr_Function_ulong Function
+ OpStore %10 %17
+ OpStore %11 %18
+ %19 = OpLoad %ulong %10 Aligned 8
+ OpStore %13 %19
+ %20 = OpLoad %ulong %11 Aligned 8
+ OpStore %14 %20
+ %22 = OpLoad %ulong %13
+ %30 = OpConvertUToPtr %_ptr_Generic_ulong %22
+ %21 = OpLoad %ulong %30 Aligned 8
+ OpStore %15 %21
+ %23 = OpLoad %ulong %15
+ %39 = OpBitcast %_ptr_Workgroup__arr_uint_uint_2 %38
+ %31 = OpBitcast %_ptr_Workgroup_ulong %39
+ OpStore %31 %23 Aligned 8
+ %40 = OpBitcast %_ptr_Workgroup_uchar %38
+ %32 = OpFunctionCall %ulong %1 %40
+ %24 = OpCopyObject %ulong %32
+ OpStore %16 %24
+ %26 = OpLoad %ulong %16
+ %27 = OpLoad %ulong %15
+ %25 = OpIAdd %ulong %26 %27
+ OpStore %16 %25
+ %28 = OpLoad %ulong %14
+ %29 = OpLoad %ulong %16
+ %33 = OpConvertUToPtr %_ptr_Generic_ulong %28
+ OpStore %33 %29 Aligned 8
+ OpReturn
+ OpFunctionEnd
diff --git a/ptx/src/test/spirv_run/shared_variable.spvtxt b/ptx/src/test/spirv_run/shared_variable.spvtxt
index 49278a8..7a97c0e 100644
--- a/ptx/src/test/spirv_run/shared_variable.spvtxt
+++ b/ptx/src/test/spirv_run/shared_variable.spvtxt
@@ -7,23 +7,26 @@
OpCapability Int64
OpCapability Float16
OpCapability Float64
- %25 = OpExtInstImport "OpenCL.std"
+ OpCapability DenormFlushToZero
+ %28 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
- OpEntryPoint Kernel %1 "shared_variable" %4
- OpDecorate %4 Alignment 4
+ OpEntryPoint Kernel %1 "shared_variable" %25
+ OpExecutionMode %1 ContractionOff
%void = OpTypeVoid
%uint = OpTypeInt 32 0
%uchar = OpTypeInt 8 0
%uint_128 = OpConstant %uint 128
%_arr_uchar_uint_128 = OpTypeArray %uchar %uint_128
%_ptr_Workgroup__arr_uchar_uint_128 = OpTypePointer Workgroup %_arr_uchar_uint_128
- %4 = OpVariable %_ptr_Workgroup__arr_uchar_uint_128 Workgroup
+ %25 = OpVariable %_ptr_Workgroup__arr_uchar_uint_128 Workgroup
%ulong = OpTypeInt 64 0
- %33 = OpTypeFunction %void %ulong %ulong
+ %36 = OpTypeFunction %void %ulong %ulong
%_ptr_Function_ulong = OpTypePointer Function %ulong
%_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong
+ %uint_128_0 = OpConstant %uint 128
%_ptr_Workgroup_ulong = OpTypePointer Workgroup %ulong
- %1 = OpFunction %void None %33
+ %uint_128_1 = OpConstant %uint 128
+ %1 = OpFunction %void None %36
%9 = OpFunctionParameter %ulong
%10 = OpFunctionParameter %ulong
%23 = OpLabel
@@ -44,9 +47,11 @@
%13 = OpLoad %ulong %19 Aligned 8
OpStore %7 %13
%15 = OpLoad %ulong %7
- %20 = OpBitcast %_ptr_Workgroup_ulong %4
+ %26 = OpBitcast %_ptr_Workgroup__arr_uchar_uint_128 %25
+ %20 = OpBitcast %_ptr_Workgroup_ulong %26
OpStore %20 %15 Aligned 8
- %21 = OpBitcast %_ptr_Workgroup_ulong %4
+ %27 = OpBitcast %_ptr_Workgroup__arr_uchar_uint_128 %25
+ %21 = OpBitcast %_ptr_Workgroup_ulong %27
%16 = OpLoad %ulong %21 Aligned 8
OpStore %8 %16
%17 = OpLoad %ulong %6
diff --git a/ptx/src/test/spirv_run/verify.py b/ptx/src/test/spirv_run/verify.py
index dbfab00..4ef6465 100644
--- a/ptx/src/test/spirv_run/verify.py
+++ b/ptx/src/test/spirv_run/verify.py
@@ -1,7 +1,7 @@
import os, sys, subprocess
def main(path):
- dirs = os.listdir(path)
+ dirs = sorted(os.listdir(path))
for file in dirs:
if not file.endswith(".spvtxt"):
continue
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
index e96cdc2..3f27522 100644
--- a/ptx/src/translate.rs
+++ b/ptx/src/translate.rs
@@ -434,6 +434,7 @@ pub fn to_spirv_module<'input>(ast: ast::Module<'input>) -> Result<Module, Trans
translate_directive(&mut id_defs, &mut ptx_impl_imports, directive).transpose()
})
.collect::<Result<Vec<_>, _>>()?;
+ let directives = hoist_function_globals(directives);
let must_link_ptx_impl = ptx_impl_imports.len() > 0;
let mut directives = ptx_impl_imports
.into_iter()
@@ -458,6 +459,7 @@ pub fn to_spirv_module<'input>(ast: ast::Module<'input>) -> Result<Module, Trans
let mut kernel_info = HashMap::new();
let (build_options, should_flush_denorms) =
emit_denorm_build_string(&call_map, &denorm_information);
+ let (directives, globals_use_map) = get_globals_use_map(directives);
emit_directives(
&mut builder,
&mut map,
@@ -465,6 +467,7 @@ pub fn to_spirv_module<'input>(ast: ast::Module<'input>) -> Result<Module, Trans
opencl_id,
should_flush_denorms,
&call_map,
+ globals_use_map,
directives,
&mut kernel_info,
)?;
@@ -481,6 +484,79 @@ pub fn to_spirv_module<'input>(ast: ast::Module<'input>) -> Result<Module, Trans
})
}
+fn get_globals_use_map<'input>(
+ directives: Vec<Directive<'input>>,
+) -> (
+ Vec<Directive<'input>>,
+ HashMap<ast::MethodName<'input, spirv::Word>, HashSet<spirv::Word>>,
+) {
+ let mut known_globals = HashSet::new();
+ for directive in directives.iter() {
+ match directive {
+ Directive::Variable(_, ast::Variable { name, .. }) => {
+ known_globals.insert(*name);
+ }
+ Directive::Method(..) => {}
+ }
+ }
+ let mut symbol_uses_map = HashMap::new();
+ let directives = directives
+ .into_iter()
+ .map(|directive| match directive {
+ Directive::Variable(..) | Directive::Method(Function { body: None, .. }) => directive,
+ Directive::Method(Function {
+ func_decl,
+ body: Some(mut statements),
+ globals,
+ import_as,
+ tuning,
+ linkage,
+ }) => {
+ let method_name = func_decl.borrow().name;
+ statements = statements
+ .into_iter()
+ .map(|statement| {
+ statement.map_id(&mut |symbol, _| {
+ if known_globals.contains(&symbol) {
+ multi_hash_map_append(&mut symbol_uses_map, method_name, symbol);
+ }
+ symbol
+ })
+ })
+ .collect::<Vec<_>>();
+ Directive::Method(Function {
+ func_decl,
+ body: Some(statements),
+ globals,
+ import_as,
+ tuning,
+ linkage,
+ })
+ }
+ })
+ .collect::<Vec<_>>();
+ (directives, symbol_uses_map)
+}
+
+fn hoist_function_globals(directives: Vec<Directive>) -> Vec<Directive> {
+ let mut result = Vec::with_capacity(directives.len());
+ for directive in directives {
+ match directive {
+ Directive::Method(method) => {
+ for variable in method.globals {
+ result.push(Directive::Variable(ast::LinkingDirective::NONE, variable));
+ }
+ result.push(Directive::Method(Function {
+ globals: Vec::new(),
+ ..method
+ }))
+ }
+ _ => result.push(directive),
+ }
+ }
+ result
+}
+
// TODO: remove this once we have pef-function support for denorms
fn emit_denorm_build_string<'input>(
call_map: &HashMap<&str, HashSet<u32>>,
@@ -531,6 +607,7 @@ fn emit_directives<'input>(
opencl_id: spirv::Word,
should_flush_denorms: bool,
call_map: &HashMap<&'input str, HashSet<spirv::Word>>,
+ globals_use_map: HashMap<ast::MethodName<'input, spirv::Word>, HashSet<spirv::Word>>,
directives: Vec<Directive<'input>>,
kernel_info: &mut HashMap<String, KernelInfo>,
) -> Result<(), TranslateError> {
@@ -559,10 +636,9 @@ fn emit_directives<'input>(
builder,
map,
&id_defs,
- &f.globals,
&*func_decl,
call_map,
- &directives,
+ &globals_use_map,
kernel_info,
)?;
if func_decl.name.is_kernel() {
@@ -626,16 +702,17 @@ fn emit_function_linkage<'input>(
if f.linkage == ast::LinkingDirective::NONE {
return Ok(());
};
- let linking_name = f.import_as.as_deref().map_or_else(
- || match f.func_decl.borrow().name {
- ast::MethodName::Kernel(kernel_name) => Ok(kernel_name),
- ast::MethodName::Func(fn_id) => match id_defs.reverse_variables.get(&fn_id) {
+ let linking_name = match f.func_decl.borrow().name {
+ // According to SPIR-V rules linkage attributes are invalid on kernels
+ ast::MethodName::Kernel(..) => return Ok(()),
+ ast::MethodName::Func(fn_id) => f.import_as.as_deref().map_or_else(
+ || match id_defs.reverse_variables.get(&fn_id) {
Some(fn_name) => Ok(fn_name),
None => Err(error_unknown_symbol()),
},
- },
- Result::Ok,
- )?;
+ Result::Ok,
+ )?,
+ };
emit_linking_decoration(builder, id_defs, Some(linking_name), fn_name, f.linkage);
Ok(())
}
@@ -712,7 +789,7 @@ fn multi_hash_map_append<
entry.get_mut().extend(iter::once(value));
}
hash_map::Entry::Vacant(entry) => {
- entry.insert(Default::default());
+ entry.insert(Default::default()).extend(iter::once(value));
}
}
}
@@ -857,6 +934,9 @@ fn convert_dynamic_shared_memory_usage<'input>(
linkage,
}));
}
+ // Existing .shared globals are now unused, they were replaced by kernel-specific globals
+ Directive::Variable(_, ast::Variable { name, .. })
+ if globals_shared.contains_key(&name) => {}
directive => result.push(directive),
}
}
@@ -873,34 +953,35 @@ fn insert_arguments_remap_statements(
globals_shared: &HashMap<u32, (GlobalSharedSize, ast::Type)>,
statements: Vec<Statement<ast::Instruction<ExpandedArgParams>, ExpandedArgParams>>,
) -> Vec<Statement<ast::Instruction<ExpandedArgParams>, ExpandedArgParams>> {
- let shared_id_param = match method_name {
+ let (shared_id_param, shared_id_type) = match method_name {
ast::MethodName::Kernel(kernel_name) => {
let globals_shared_size = match kernels_to_global_shared.get(kernel_name) {
Some(s) => *s,
None => return statements,
};
let shared_id_param = new_id();
+ func_decl_ref.shared_mem = Some(shared_id_param);
let (linkage, type_) = match globals_shared_size {
GlobalSharedSize::ExternUnsized => (
ast::LinkingDirective::EXTERN,
- ast::Type::Array(ast::ScalarType::U8, Vec::new()),
+ ast::Type::Array(ast::ScalarType::B8, Vec::new()),
),
GlobalSharedSize::Sized(size) => (
ast::LinkingDirective::NONE,
- ast::Type::Array(ast::ScalarType::U8, vec![size as u32]),
+ ast::Type::Array(ast::ScalarType::B8, vec![size as u32]),
),
};
result.push(Directive::Variable(
linkage,
ast::Variable {
align: None,
- v_type: type_,
+ v_type: type_.clone(),
state_space: ast::StateSpace::Shared,
name: shared_id_param,
array_init: Vec::new(),
},
));
- shared_id_param
+ (shared_id_param, Some(type_))
}
ast::MethodName::Func(function_name) => {
if !functions_to_global_shared.contains(&function_name) {
@@ -914,7 +995,7 @@ fn insert_arguments_remap_statements(
name: shared_id_param,
array_init: Vec::new(),
});
- shared_id_param
+ (shared_id_param, None)
}
};
replace_uses_of_shared_memory(
@@ -922,6 +1003,7 @@ fn insert_arguments_remap_statements(
globals_shared,
functions_to_global_shared,
shared_id_param,
+ shared_id_type,
statements,
)
}
@@ -948,6 +1030,7 @@ fn replace_uses_of_shared_memory<'a>(
extern_shared_decls: &HashMap<spirv::Word, (GlobalSharedSize, ast::Type)>,
methods_using_extern_shared: &HashSet<spirv::Word>,
shared_id_param: spirv::Word,
+ shared_id_type: Option<ast::Type>,
statements: Vec<ExpandedStatement>,
) -> Vec<ExpandedStatement> {
let mut result = Vec::with_capacity(statements.len());
@@ -958,6 +1041,22 @@ fn replace_uses_of_shared_memory<'a>(
// because there's simply no way to pass shared ptr
// without converting it to .b64 first
if methods_using_extern_shared.contains(&call.name) {
+ let shared_id_param = match shared_id_type {
+ Some(ref global_shared_defined_type) => {
+ let dst = new_id();
+ result.push(Statement::Conversion(ImplicitConversion {
+ src: shared_id_param,
+ dst,
+ from_type: global_shared_defined_type.clone(),
+ to_type: ast::Type::Scalar(ast::ScalarType::B8),
+ from_space: ast::StateSpace::Shared,
+ to_space: ast::StateSpace::Shared,
+ kind: ConversionKind::PtrToPtr,
+ }));
+ dst
+ }
+ None => shared_id_param,
+ };
call.input_arguments.push((
shared_id_param,
ast::Type::Scalar(ast::ScalarType::B8),
@@ -998,10 +1097,7 @@ fn replace_uses_of_shared_memory<'a>(
// * If it's a kernel -> size of .shared globals in use (direct or indirect)
// * If it's a function -> does it use .shared global (directly or indirectly)
fn resolve_indirect_uses_of_globals_shared<'input>(
- mut methods_use_of_globals_shared: HashMap<
- ast::MethodName<'input, spirv::Word>,
- GlobalSharedSize,
- >,
+ methods_use_of_globals_shared: HashMap<ast::MethodName<'input, spirv::Word>, GlobalSharedSize>,
kernels_methods_call_map: &HashMap<&'input str, HashSet<spirv::Word>>,
) -> (HashMap<&'input str, GlobalSharedSize>, HashSet<spirv::Word>) {
let mut kernel_use = HashMap::new();
@@ -1116,14 +1212,13 @@ fn compute_denorm_information<'input>(
.collect()
}
-fn emit_function_header<'a>(
+fn emit_function_header<'input>(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
- defined_globals: &GlobalStringIdResolver<'a>,
- synthetic_globals: &[ast::Variable<spirv::Word>],
- func_decl: &ast::MethodDeclaration<'a, spirv::Word>,
- call_map: &HashMap<&'a str, HashSet<spirv::Word>>,
- direcitves: &[Directive],
+ defined_globals: &GlobalStringIdResolver<'input>,
+ func_decl: &ast::MethodDeclaration<'input, spirv::Word>,
+ call_map: &HashMap<&'input str, HashSet<spirv::Word>>,
+ globals_use_map: &HashMap<ast::MethodName<'input, spirv::Word>, HashSet<spirv::Word>>,
kernel_info: &mut HashMap<String, KernelInfo>,
) -> Result<spirv::Word, TranslateError> {
if let ast::MethodName::Kernel(name) = func_decl.name {
@@ -1154,38 +1249,28 @@ fn emit_function_header<'a>(
let fn_id = match func_decl.name {
ast::MethodName::Kernel(name) => {
let fn_id = defined_globals.get_id(name)?;
- let mut global_variables = defined_globals
- .variables_type_check
- .iter()
- .filter_map(|(k, t)| t.as_ref().map(|_| *k))
- .collect::<Vec<_>>();
- let mut interface = defined_globals.special_registers.interface();
- for ast::Variable { name, .. } in synthetic_globals {
- interface.push(*name);
- }
- let empty_hash_set = HashSet::new();
- let child_fns = call_map.get(name).unwrap_or(&empty_hash_set);
- for directive in direcitves {
- match directive {
- Directive::Method(Function {
- func_decl, globals, ..
- }) => {
- match (**func_decl).borrow().name {
- ast::MethodName::Func(name) => {
- if child_fns.contains(&name) {
- for var in globals {
- interface.push(var.name);
- }
- }
- }
- ast::MethodName::Kernel(_) => {}
- };
- }
- _ => {}
- }
- }
- global_variables.append(&mut interface);
- builder.entry_point(spirv::ExecutionModel::Kernel, fn_id, name, global_variables);
+ let interface = globals_use_map
+ .get(&ast::MethodName::Kernel(name))
+ .into_iter()
+ .flatten()
+ .copied()
+ .chain({
+ call_map
+ .get(name)
+ .into_iter()
+ .flat_map(|subfunctions| {
+ subfunctions.iter().flat_map(|subfunction| {
+ globals_use_map
+ .get(&ast::MethodName::Func(*subfunction))
+ .into_iter()
+ .flatten()
+ .copied()
+ })
+ })
+ .into_iter()
+ })
+ .collect::<Vec<spirv::Word>>();
+ builder.entry_point(spirv::ExecutionModel::Kernel, fn_id, name, interface);
fn_id
}
ast::MethodName::Func(name) => name,
@@ -3571,7 +3656,8 @@ fn emit_variable<'input>(
[dr::Operand::LiteralInt32(align)].iter().cloned(),
);
}
- if var.state_space != ast::StateSpace::Shared || !linking.contains(ast::LinkingDirective::EXTERN)
+ if var.state_space != ast::StateSpace::Shared
+ || !linking.contains(ast::LinkingDirective::EXTERN)
{
emit_linking_decoration(builder, id_defs, None, var.name, linking);
}
@@ -4328,11 +4414,35 @@ fn emit_implicit_conversion(
);
if cv.to_space == ast::StateSpace::Generic && cv.from_space != ast::StateSpace::Generic
{
- builder.ptr_cast_to_generic(result_type, Some(cv.dst), cv.src)?;
+ let src = if cv.from_type != cv.to_type {
+ let temp_type = map.get_or_add(
+ builder,
+ SpirvType::Pointer(
+ Box::new(SpirvType::new(cv.to_type.clone())),
+ cv.from_space.to_spirv(),
+ ),
+ );
+ builder.bitcast(temp_type, None, cv.src)?
+ } else {
+ cv.src
+ };
+ builder.ptr_cast_to_generic(result_type, Some(cv.dst), src)?;
} else if cv.from_space == ast::StateSpace::Generic
&& cv.to_space != ast::StateSpace::Generic
{
- builder.generic_cast_to_ptr(result_type, Some(cv.dst), cv.src)?;
+ let src = if cv.from_type != cv.to_type {
+ let temp_type = map.get_or_add(
+ builder,
+ SpirvType::Pointer(
+ Box::new(SpirvType::new(cv.to_type.clone())),
+ cv.from_space.to_spirv(),
+ ),
+ );
+ builder.bitcast(temp_type, None, cv.src)?
+ } else {
+ cv.src
+ };
+ builder.generic_cast_to_ptr(result_type, Some(cv.dst), src)?;
} else {
builder.bitcast(result_type, Some(cv.dst), cv.src)?;
}
@@ -5035,22 +5145,6 @@ impl SpecialRegistersMap {
}
}
- fn interface(&self) -> Vec<spirv::Word> {
- return Vec::new();
- /*
- self.reg_to_id
- .iter()
- .filter_map(|(sreg, id)| {
- if sreg.normalized_sreg_and_type().is_none() {
- Some(*id)
- } else {
- None
- }
- })
- .collect::<Vec<_>>()
- */
- }
-
fn get(&self, id: spirv::Word) -> Option<PtxSpecialRegister> {
self.id_to_reg.get(&id).copied()
}