From c23be576e86282c1c4673164cb9e92845cd0517e Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Sun, 26 Sep 2021 01:24:14 +0200 Subject: Finish fixing shared memory pass --- ptx/src/test/spirv_run/atom_add.spvtxt | 37 +-- ptx/src/test/spirv_run/atom_add_float.spvtxt | 41 ++-- ptx/src/test/spirv_run/call.spvtxt | 4 + ptx/src/test/spirv_run/extern_shared.spvtxt | 24 +- ptx/src/test/spirv_run/extern_shared_call.spvtxt | 39 ++-- ptx/src/test/spirv_run/mod.rs | 1 + ptx/src/test/spirv_run/reg_local.spvtxt | 19 +- ptx/src/test/spirv_run/shared_ptr_32.spvtxt | 23 +- .../test/spirv_run/shared_ptr_take_address.spvtxt | 16 +- ptx/src/test/spirv_run/shared_unify_extern.spvtxt | 141 ++++++------ ptx/src/test/spirv_run/shared_unify_private.ptx | 32 +++ ptx/src/test/spirv_run/shared_unify_private.spvtxt | 84 +++++++ ptx/src/test/spirv_run/shared_variable.spvtxt | 21 +- ptx/src/test/spirv_run/verify.py | 2 +- ptx/src/translate.rs | 248 ++++++++++++++------- 15 files changed, 500 insertions(+), 232 deletions(-) create mode 100644 ptx/src/test/spirv_run/shared_unify_private.ptx create mode 100644 ptx/src/test/spirv_run/shared_unify_private.spvtxt diff --git a/ptx/src/test/spirv_run/atom_add.spvtxt b/ptx/src/test/spirv_run/atom_add.spvtxt index b4de00a..3609247 100644 --- a/ptx/src/test/spirv_run/atom_add.spvtxt +++ b/ptx/src/test/spirv_run/atom_add.spvtxt @@ -7,29 +7,33 @@ OpCapability Int64 OpCapability Float16 OpCapability Float64 - %38 = OpExtInstImport "OpenCL.std" + OpCapability DenormFlushToZero + %42 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "atom_add" %4 - OpDecorate %4 Alignment 4 + OpEntryPoint Kernel %1 "atom_add" %38 + OpExecutionMode %1 ContractionOff %void = OpTypeVoid %uint = OpTypeInt 32 0 %uchar = OpTypeInt 8 0 %uint_1024 = OpConstant %uint 1024 %_arr_uchar_uint_1024 = OpTypeArray %uchar %uint_1024 %_ptr_Workgroup__arr_uchar_uint_1024 = OpTypePointer Workgroup %_arr_uchar_uint_1024 - %4 = OpVariable %_ptr_Workgroup__arr_uchar_uint_1024 Workgroup + %38 = OpVariable %_ptr_Workgroup__arr_uchar_uint_1024 Workgroup %ulong = OpTypeInt 64 0 - %46 = OpTypeFunction %void %ulong %ulong + %50 = OpTypeFunction %void %ulong %ulong %_ptr_Function_ulong = OpTypePointer Function %ulong %_ptr_Function_uint = OpTypePointer Function %uint %_ptr_Generic_uint = OpTypePointer Generic %uint %ulong_4 = OpConstant %ulong 4 %_ptr_Generic_uchar = OpTypePointer Generic %uchar +%uint_1024_0 = OpConstant %uint 1024 %_ptr_Workgroup_uint = OpTypePointer Workgroup %uint +%uint_1024_1 = OpConstant %uint 1024 %uint_1 = OpConstant %uint 1 %uint_0 = OpConstant %uint 0 +%uint_1024_2 = OpConstant %uint 1024 %ulong_4_0 = OpConstant %ulong 4 - %1 = OpFunction %void None %46 + %1 = OpFunction %void None %50 %9 = OpFunctionParameter %ulong %10 = OpFunctionParameter %ulong %36 = OpLabel @@ -51,19 +55,22 @@ OpStore %7 %13 %16 = OpLoad %ulong %5 %30 = OpConvertUToPtr %_ptr_Generic_uint %16 - %51 = OpBitcast %_ptr_Generic_uchar %30 - %52 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %51 %ulong_4 - %26 = OpBitcast %_ptr_Generic_uint %52 + %55 = OpBitcast %_ptr_Generic_uchar %30 + %56 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %55 %ulong_4 + %26 = OpBitcast %_ptr_Generic_uint %56 %15 = OpLoad %uint %26 Aligned 4 OpStore %8 %15 %17 = OpLoad %uint %7 - %31 = OpBitcast %_ptr_Workgroup_uint %4 + %39 = OpBitcast %_ptr_Workgroup__arr_uchar_uint_1024 %38 + %31 = OpBitcast %_ptr_Workgroup_uint %39 OpStore %31 %17 Aligned 4 %19 = OpLoad %uint %8 - %32 = OpBitcast %_ptr_Workgroup_uint %4 + %40 = OpBitcast %_ptr_Workgroup__arr_uchar_uint_1024 %38 + %32 = OpBitcast %_ptr_Workgroup_uint %40 %18 = OpAtomicIAdd %uint %32 %uint_1 %uint_0 %19 OpStore %7 %18 - %33 = OpBitcast %_ptr_Workgroup_uint %4 + %41 = OpBitcast %_ptr_Workgroup__arr_uchar_uint_1024 %38 + %33 = OpBitcast %_ptr_Workgroup_uint %41 %20 = OpLoad %uint %33 Aligned 4 OpStore %8 %20 %21 = OpLoad %ulong %6 @@ -73,9 +80,9 @@ %23 = OpLoad %ulong %6 %24 = OpLoad %uint %8 %35 = OpConvertUToPtr %_ptr_Generic_uint %23 - %56 = OpBitcast %_ptr_Generic_uchar %35 - %57 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %56 %ulong_4_0 - %28 = OpBitcast %_ptr_Generic_uint %57 + %63 = OpBitcast %_ptr_Generic_uchar %35 + %64 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %63 %ulong_4_0 + %28 = OpBitcast %_ptr_Generic_uint %64 OpStore %28 %24 Aligned 4 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/atom_add_float.spvtxt b/ptx/src/test/spirv_run/atom_add_float.spvtxt index 7d25632..9533d83 100644 --- a/ptx/src/test/spirv_run/atom_add_float.spvtxt +++ b/ptx/src/test/spirv_run/atom_add_float.spvtxt @@ -7,34 +7,38 @@ OpCapability Int64 OpCapability Float16 OpCapability Float64 - %42 = OpExtInstImport "OpenCL.std" + OpCapability DenormFlushToZero + %46 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "atom_add_float" %4 + OpEntryPoint Kernel %1 "atom_add_float" %42 + OpExecutionMode %1 ContractionOff OpDecorate %37 LinkageAttributes "__zluda_ptx_impl__atom_relaxed_gpu_shared_add_f32" Import - OpDecorate %4 Alignment 4 %void = OpTypeVoid %float = OpTypeFloat 32 %_ptr_Workgroup_float = OpTypePointer Workgroup %float - %46 = OpTypeFunction %float %_ptr_Workgroup_float %float + %50 = OpTypeFunction %float %_ptr_Workgroup_float %float %uint = OpTypeInt 32 0 %uchar = OpTypeInt 8 0 %uint_1024 = OpConstant %uint 1024 %_arr_uchar_uint_1024 = OpTypeArray %uchar %uint_1024 %_ptr_Workgroup__arr_uchar_uint_1024 = OpTypePointer Workgroup %_arr_uchar_uint_1024 - %4 = OpVariable %_ptr_Workgroup__arr_uchar_uint_1024 Workgroup + %42 = OpVariable %_ptr_Workgroup__arr_uchar_uint_1024 Workgroup %ulong = OpTypeInt 64 0 - %53 = OpTypeFunction %void %ulong %ulong + %57 = OpTypeFunction %void %ulong %ulong %_ptr_Function_ulong = OpTypePointer Function %ulong %_ptr_Function_float = OpTypePointer Function %float %_ptr_Generic_float = OpTypePointer Generic %float %ulong_4 = OpConstant %ulong 4 %_ptr_Generic_uchar = OpTypePointer Generic %uchar +%uint_1024_0 = OpConstant %uint 1024 +%uint_1024_1 = OpConstant %uint 1024 +%uint_1024_2 = OpConstant %uint 1024 %ulong_4_0 = OpConstant %ulong 4 - %37 = OpFunction %float None %46 + %37 = OpFunction %float None %50 %39 = OpFunctionParameter %_ptr_Workgroup_float %40 = OpFunctionParameter %float OpFunctionEnd - %1 = OpFunction %void None %53 + %1 = OpFunction %void None %57 %9 = OpFunctionParameter %ulong %10 = OpFunctionParameter %ulong %36 = OpLabel @@ -56,19 +60,22 @@ OpStore %7 %13 %16 = OpLoad %ulong %5 %30 = OpConvertUToPtr %_ptr_Generic_float %16 - %58 = OpBitcast %_ptr_Generic_uchar %30 - %59 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %58 %ulong_4 - %26 = OpBitcast %_ptr_Generic_float %59 + %62 = OpBitcast %_ptr_Generic_uchar %30 + %63 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %62 %ulong_4 + %26 = OpBitcast %_ptr_Generic_float %63 %15 = OpLoad %float %26 Aligned 4 OpStore %8 %15 %17 = OpLoad %float %7 - %31 = OpBitcast %_ptr_Workgroup_float %4 + %43 = OpBitcast %_ptr_Workgroup__arr_uchar_uint_1024 %42 + %31 = OpBitcast %_ptr_Workgroup_float %43 OpStore %31 %17 Aligned 4 %19 = OpLoad %float %8 - %32 = OpBitcast %_ptr_Workgroup_float %4 + %44 = OpBitcast %_ptr_Workgroup__arr_uchar_uint_1024 %42 + %32 = OpBitcast %_ptr_Workgroup_float %44 %18 = OpFunctionCall %float %37 %32 %19 OpStore %7 %18 - %33 = OpBitcast %_ptr_Workgroup_float %4 + %45 = OpBitcast %_ptr_Workgroup__arr_uchar_uint_1024 %42 + %33 = OpBitcast %_ptr_Workgroup_float %45 %20 = OpLoad %float %33 Aligned 4 OpStore %8 %20 %21 = OpLoad %ulong %6 @@ -78,9 +85,9 @@ %23 = OpLoad %ulong %6 %24 = OpLoad %float %8 %35 = OpConvertUToPtr %_ptr_Generic_float %23 - %60 = OpBitcast %_ptr_Generic_uchar %35 - %61 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %60 %ulong_4_0 - %28 = OpBitcast %_ptr_Generic_float %61 + %67 = OpBitcast %_ptr_Generic_uchar %35 + %68 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %67 %ulong_4_0 + %28 = OpBitcast %_ptr_Generic_float %68 OpStore %28 %24 Aligned 4 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/call.spvtxt b/ptx/src/test/spirv_run/call.spvtxt index 6929b1e..c29984e 100644 --- a/ptx/src/test/spirv_run/call.spvtxt +++ b/ptx/src/test/spirv_run/call.spvtxt @@ -7,9 +7,13 @@ OpCapability Int64 OpCapability Float16 OpCapability Float64 + OpCapability DenormFlushToZero %37 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL OpEntryPoint Kernel %4 "call" + OpExecutionMode %4 ContractionOff + OpDecorate %4 LinkageAttributes "call" Export + OpDecorate %1 LinkageAttributes "incr" Export %void = OpTypeVoid %ulong = OpTypeInt 64 0 %40 = OpTypeFunction %void %ulong %ulong diff --git a/ptx/src/test/spirv_run/extern_shared.spvtxt b/ptx/src/test/spirv_run/extern_shared.spvtxt index ed1c489..82d86ae 100644 --- a/ptx/src/test/spirv_run/extern_shared.spvtxt +++ b/ptx/src/test/spirv_run/extern_shared.spvtxt @@ -7,21 +7,23 @@ OpCapability Int64 OpCapability Float16 OpCapability Float64 - %24 = OpExtInstImport "OpenCL.std" + OpCapability DenormFlushToZero + %27 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %2 "extern_shared" %1 + OpEntryPoint Kernel %2 "extern_shared" %24 OpExecutionMode %2 ContractionOff - OpDecorate %1 LinkageAttributes "shared_mem" Import %void = OpTypeVoid - %uint = OpTypeInt 32 0 -%_ptr_Workgroup_uint = OpTypePointer Workgroup %uint - %1 = OpVariable %_ptr_Workgroup_uint Workgroup + %uchar = OpTypeInt 8 0 +%_ptr_Workgroup_uchar = OpTypePointer Workgroup %uchar + %24 = OpVariable %_ptr_Workgroup_uchar Workgroup %ulong = OpTypeInt 64 0 - %29 = OpTypeFunction %void %ulong %ulong + %32 = OpTypeFunction %void %ulong %ulong %_ptr_Function_ulong = OpTypePointer Function %ulong %_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong + %uint = OpTypeInt 32 0 +%_ptr_Workgroup_uint = OpTypePointer Workgroup %uint %_ptr_Workgroup_ulong = OpTypePointer Workgroup %ulong - %2 = OpFunction %void None %29 + %2 = OpFunction %void None %32 %8 = OpFunctionParameter %ulong %9 = OpFunctionParameter %ulong %22 = OpLabel @@ -41,9 +43,11 @@ %12 = OpLoad %ulong %18 Aligned 8 OpStore %7 %12 %14 = OpLoad %ulong %7 - %19 = OpBitcast %_ptr_Workgroup_ulong %1 + %25 = OpBitcast %_ptr_Workgroup_uint %24 + %19 = OpBitcast %_ptr_Workgroup_ulong %25 OpStore %19 %14 Aligned 8 - %20 = OpBitcast %_ptr_Workgroup_ulong %1 + %26 = OpBitcast %_ptr_Workgroup_uint %24 + %20 = OpBitcast %_ptr_Workgroup_ulong %26 %15 = OpLoad %ulong %20 Aligned 8 OpStore %7 %15 %16 = OpLoad %ulong %6 diff --git a/ptx/src/test/spirv_run/extern_shared_call.spvtxt b/ptx/src/test/spirv_run/extern_shared_call.spvtxt index 941eb39..3cc78cb 100644 --- a/ptx/src/test/spirv_run/extern_shared_call.spvtxt +++ b/ptx/src/test/spirv_run/extern_shared_call.spvtxt @@ -7,38 +7,42 @@ OpCapability Int64 OpCapability Float16 OpCapability Float64 - %34 = OpExtInstImport "OpenCL.std" + OpCapability DenormFlushToZero + %41 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %12 "extern_shared_call" %1 + OpEntryPoint Kernel %12 "extern_shared_call" %37 OpExecutionMode %12 ContractionOff - OpDecorate %1 Alignment 4 - OpDecorate %1 LinkageAttributes "shared_mem" Import %void = OpTypeVoid - %uint = OpTypeInt 32 0 -%_ptr_Workgroup_uint = OpTypePointer Workgroup %uint - %1 = OpVariable %_ptr_Workgroup_uint Workgroup - %38 = OpTypeFunction %void + %uchar = OpTypeInt 8 0 +%_ptr_Workgroup_uchar = OpTypePointer Workgroup %uchar + %45 = OpTypeFunction %void %_ptr_Workgroup_uchar %ulong = OpTypeInt 64 0 %_ptr_Function_ulong = OpTypePointer Function %ulong + %uint = OpTypeInt 32 0 +%_ptr_Workgroup_uint = OpTypePointer Workgroup %uint %_ptr_Workgroup_ulong = OpTypePointer Workgroup %ulong %ulong_2 = OpConstant %ulong 2 - %42 = OpTypeFunction %void %ulong %ulong + %37 = OpVariable %_ptr_Workgroup_uchar Workgroup + %51 = OpTypeFunction %void %ulong %ulong %_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong - %2 = OpFunction %void None %38 + %2 = OpFunction %void None %45 + %34 = OpFunctionParameter %_ptr_Workgroup_uchar %11 = OpLabel %3 = OpVariable %_ptr_Function_ulong Function - %9 = OpBitcast %_ptr_Workgroup_ulong %1 + %35 = OpBitcast %_ptr_Workgroup_uint %34 + %9 = OpBitcast %_ptr_Workgroup_ulong %35 %4 = OpLoad %ulong %9 Aligned 8 OpStore %3 %4 %6 = OpLoad %ulong %3 %5 = OpIAdd %ulong %6 %ulong_2 OpStore %3 %5 %7 = OpLoad %ulong %3 - %10 = OpBitcast %_ptr_Workgroup_ulong %1 + %36 = OpBitcast %_ptr_Workgroup_uint %34 + %10 = OpBitcast %_ptr_Workgroup_ulong %36 OpStore %10 %7 Aligned 8 OpReturn OpFunctionEnd - %12 = OpFunction %void None %42 + %12 = OpFunction %void None %51 %18 = OpFunctionParameter %ulong %19 = OpFunctionParameter %ulong %32 = OpLabel @@ -58,10 +62,13 @@ %22 = OpLoad %ulong %28 Aligned 8 OpStore %17 %22 %24 = OpLoad %ulong %17 - %29 = OpBitcast %_ptr_Workgroup_ulong %1 + %38 = OpBitcast %_ptr_Workgroup_uint %37 + %29 = OpBitcast %_ptr_Workgroup_ulong %38 OpStore %29 %24 Aligned 8 - %44 = OpFunctionCall %void %2 - %30 = OpBitcast %_ptr_Workgroup_ulong %1 + %39 = OpBitcast %_ptr_Workgroup_uchar %37 + %53 = OpFunctionCall %void %2 %39 + %40 = OpBitcast %_ptr_Workgroup_uint %37 + %30 = OpBitcast %_ptr_Workgroup_ulong %40 %25 = OpLoad %ulong %30 Aligned 8 OpStore %17 %25 %26 = OpLoad %ulong %16 diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index b7fd386..6c073f3 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -222,6 +222,7 @@ test_ptx!(prmt, [0x70c507d6u32, 0x6fbd4b5cu32], [0x6fbdd65cu32]); test_ptx!(activemask, [0u32], [1u32]); test_ptx!(membar, [152731u32], [152731u32]); test_ptx!(shared_unify_extern, [7681u64], [15362u64]); +test_ptx!(shared_unify_private, [67153u64], [134306u64]); test_ptx!(func_ptr); test_ptx!(lanemask_lt); diff --git a/ptx/src/test/spirv_run/reg_local.spvtxt b/ptx/src/test/spirv_run/reg_local.spvtxt index 4a69450..ddb6a9e 100644 --- a/ptx/src/test/spirv_run/reg_local.spvtxt +++ b/ptx/src/test/spirv_run/reg_local.spvtxt @@ -7,6 +7,7 @@ OpCapability Int64 OpCapability Float16 OpCapability Float64 + OpCapability DenormFlushToZero %34 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL OpEntryPoint Kernel %1 "reg_local" @@ -51,22 +52,24 @@ OpStore %7 %12 %14 = OpLoad %ulong %7 %19 = OpIAdd %ulong %14 %ulong_1 - %26 = OpPtrCastToGeneric %_ptr_Generic_ulong %4 + %46 = OpBitcast %_ptr_Function_ulong %4 + %26 = OpPtrCastToGeneric %_ptr_Generic_ulong %46 %27 = OpCopyObject %ulong %19 OpStore %26 %27 Aligned 8 - %28 = OpPtrCastToGeneric %_ptr_Generic_ulong %4 - %47 = OpBitcast %_ptr_Generic_uchar %28 - %48 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %47 %ulong_0 - %21 = OpBitcast %_ptr_Generic_ulong %48 + %47 = OpBitcast %_ptr_Function_ulong %4 + %28 = OpPtrCastToGeneric %_ptr_Generic_ulong %47 + %49 = OpBitcast %_ptr_Generic_uchar %28 + %50 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %49 %ulong_0 + %21 = OpBitcast %_ptr_Generic_ulong %50 %29 = OpLoad %ulong %21 Aligned 8 %15 = OpCopyObject %ulong %29 OpStore %7 %15 %16 = OpLoad %ulong %6 %17 = OpLoad %ulong %7 %30 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %16 - %50 = OpBitcast %_ptr_CrossWorkgroup_uchar %30 - %51 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %50 %ulong_0_0 - %23 = OpBitcast %_ptr_CrossWorkgroup_ulong %51 + %52 = OpBitcast %_ptr_CrossWorkgroup_uchar %30 + %53 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %52 %ulong_0_0 + %23 = OpBitcast %_ptr_CrossWorkgroup_ulong %53 %31 = OpCopyObject %ulong %17 OpStore %23 %31 Aligned 8 OpReturn diff --git a/ptx/src/test/spirv_run/shared_ptr_32.spvtxt b/ptx/src/test/spirv_run/shared_ptr_32.spvtxt index 1b2e3dd..020c15b 100644 --- a/ptx/src/test/spirv_run/shared_ptr_32.spvtxt +++ b/ptx/src/test/spirv_run/shared_ptr_32.spvtxt @@ -7,26 +7,28 @@ OpCapability Int64 OpCapability Float16 OpCapability Float64 - %32 = OpExtInstImport "OpenCL.std" + OpCapability DenormFlushToZero + %34 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "shared_ptr_32" %4 - OpDecorate %4 Alignment 4 + OpEntryPoint Kernel %1 "shared_ptr_32" %32 + OpExecutionMode %1 ContractionOff %void = OpTypeVoid %uint = OpTypeInt 32 0 %uchar = OpTypeInt 8 0 %uint_128 = OpConstant %uint 128 %_arr_uchar_uint_128 = OpTypeArray %uchar %uint_128 %_ptr_Workgroup__arr_uchar_uint_128 = OpTypePointer Workgroup %_arr_uchar_uint_128 - %4 = OpVariable %_ptr_Workgroup__arr_uchar_uint_128 Workgroup + %32 = OpVariable %_ptr_Workgroup__arr_uchar_uint_128 Workgroup %ulong = OpTypeInt 64 0 - %40 = OpTypeFunction %void %ulong %ulong + %42 = OpTypeFunction %void %ulong %ulong %_ptr_Function_ulong = OpTypePointer Function %ulong %_ptr_Function_uint = OpTypePointer Function %uint + %uint_128_0 = OpConstant %uint 128 %_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong %_ptr_Workgroup_ulong = OpTypePointer Workgroup %ulong %ulong_0 = OpConstant %ulong 0 %_ptr_Workgroup_uchar = OpTypePointer Workgroup %uchar - %1 = OpFunction %void None %40 + %1 = OpFunction %void None %42 %10 = OpFunctionParameter %ulong %11 = OpFunctionParameter %ulong %30 = OpLabel @@ -43,7 +45,8 @@ OpStore %5 %12 %13 = OpLoad %ulong %3 Aligned 8 OpStore %6 %13 - %25 = OpConvertPtrToU %uint %4 + %33 = OpBitcast %_ptr_Workgroup__arr_uchar_uint_128 %32 + %25 = OpConvertPtrToU %uint %33 %14 = OpCopyObject %uint %25 OpStore %7 %14 %16 = OpLoad %ulong %5 @@ -56,9 +59,9 @@ OpStore %27 %18 Aligned 8 %20 = OpLoad %uint %7 %28 = OpConvertUToPtr %_ptr_Workgroup_ulong %20 - %46 = OpBitcast %_ptr_Workgroup_uchar %28 - %47 = OpInBoundsPtrAccessChain %_ptr_Workgroup_uchar %46 %ulong_0 - %24 = OpBitcast %_ptr_Workgroup_ulong %47 + %49 = OpBitcast %_ptr_Workgroup_uchar %28 + %50 = OpInBoundsPtrAccessChain %_ptr_Workgroup_uchar %49 %ulong_0 + %24 = OpBitcast %_ptr_Workgroup_ulong %50 %19 = OpLoad %ulong %24 Aligned 8 OpStore %9 %19 %21 = OpLoad %ulong %6 diff --git a/ptx/src/test/spirv_run/shared_ptr_take_address.spvtxt b/ptx/src/test/spirv_run/shared_ptr_take_address.spvtxt index 3ebe810..90e04f3 100644 --- a/ptx/src/test/spirv_run/shared_ptr_take_address.spvtxt +++ b/ptx/src/test/spirv_run/shared_ptr_take_address.spvtxt @@ -7,22 +7,21 @@ OpCapability Int64 OpCapability Float16 OpCapability Float64 - %30 = OpExtInstImport "OpenCL.std" + OpCapability DenormFlushToZero + %32 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %2 "shared_ptr_take_address" %1 + OpEntryPoint Kernel %2 "shared_ptr_take_address" %30 OpExecutionMode %2 ContractionOff - OpDecorate %1 Alignment 4 - OpDecorate %1 LinkageAttributes "shared_mem" Import %void = OpTypeVoid %uchar = OpTypeInt 8 0 %_ptr_Workgroup_uchar = OpTypePointer Workgroup %uchar - %1 = OpVariable %_ptr_Workgroup_uchar Workgroup + %30 = OpVariable %_ptr_Workgroup_uchar Workgroup %ulong = OpTypeInt 64 0 - %35 = OpTypeFunction %void %ulong %ulong + %37 = OpTypeFunction %void %ulong %ulong %_ptr_Function_ulong = OpTypePointer Function %ulong %_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong %_ptr_Workgroup_ulong = OpTypePointer Workgroup %ulong - %2 = OpFunction %void None %35 + %2 = OpFunction %void None %37 %10 = OpFunctionParameter %ulong %11 = OpFunctionParameter %ulong %28 = OpLabel @@ -39,7 +38,8 @@ OpStore %5 %12 %13 = OpLoad %ulong %4 Aligned 8 OpStore %6 %13 - %23 = OpConvertPtrToU %ulong %1 + %31 = OpBitcast %_ptr_Workgroup_uchar %30 + %23 = OpConvertPtrToU %ulong %31 %14 = OpCopyObject %ulong %23 OpStore %7 %14 %16 = OpLoad %ulong %5 diff --git a/ptx/src/test/spirv_run/shared_unify_extern.spvtxt b/ptx/src/test/spirv_run/shared_unify_extern.spvtxt index 9b62045..2dd2056 100644 --- a/ptx/src/test/spirv_run/shared_unify_extern.spvtxt +++ b/ptx/src/test/spirv_run/shared_unify_extern.spvtxt @@ -1,62 +1,79 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %30 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %2 "shared_ptr_take_address" %1 - OpExecutionMode %2 ContractionOff - OpDecorate %1 Alignment 4 - OpDecorate %1 LinkageAttributes "shared_mem" Import - %void = OpTypeVoid - %uchar = OpTypeInt 8 0 -%_ptr_Workgroup_uchar = OpTypePointer Workgroup %uchar - %1 = OpVariable %_ptr_Workgroup_uchar Workgroup - %ulong = OpTypeInt 64 0 - %35 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong -%_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong -%_ptr_Workgroup_ulong = OpTypePointer Workgroup %ulong - %2 = OpFunction %void None %35 - %10 = OpFunctionParameter %ulong - %11 = OpFunctionParameter %ulong - %28 = OpLabel - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_ulong Function - %7 = OpVariable %_ptr_Function_ulong Function - %8 = OpVariable %_ptr_Function_ulong Function - %9 = OpVariable %_ptr_Function_ulong Function - OpStore %3 %10 - OpStore %4 %11 - %12 = OpLoad %ulong %3 Aligned 8 - OpStore %5 %12 - %13 = OpLoad %ulong %4 Aligned 8 - OpStore %6 %13 - %23 = OpConvertPtrToU %ulong %1 - %14 = OpCopyObject %ulong %23 - OpStore %7 %14 - %16 = OpLoad %ulong %5 - %24 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %16 - %15 = OpLoad %ulong %24 Aligned 8 - OpStore %8 %15 - %17 = OpLoad %ulong %7 - %18 = OpLoad %ulong %8 - %25 = OpConvertUToPtr %_ptr_Workgroup_ulong %17 - OpStore %25 %18 Aligned 8 - %20 = OpLoad %ulong %7 - %26 = OpConvertUToPtr %_ptr_Workgroup_ulong %20 - %19 = OpLoad %ulong %26 Aligned 8 - OpStore %9 %19 - %21 = OpLoad %ulong %6 - %22 = OpLoad %ulong %9 - %27 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %21 - OpStore %27 %22 Aligned 8 - OpReturn - OpFunctionEnd + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int8 + OpCapability Int16 + OpCapability Int64 + OpCapability Float16 + OpCapability Float64 + OpCapability DenormFlushToZero + %41 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %10 "shared_unify_extern" %38 + OpExecutionMode %10 ContractionOff + %void = OpTypeVoid + %ulong = OpTypeInt 64 0 + %uchar = OpTypeInt 8 0 +%_ptr_Workgroup_uchar = OpTypePointer Workgroup %uchar + %46 = OpTypeFunction %ulong %_ptr_Workgroup_uchar +%_ptr_Function_ulong = OpTypePointer Function %ulong + %uint = OpTypeInt 32 0 + %uint_4 = OpConstant %uint 4 +%_arr_uint_uint_4 = OpTypeArray %uint %uint_4 +%_ptr_Workgroup__arr_uint_uint_4 = OpTypePointer Workgroup %_arr_uint_uint_4 +%_ptr_Workgroup_ulong = OpTypePointer Workgroup %ulong + %38 = OpVariable %_ptr_Workgroup_uchar Workgroup + %53 = OpTypeFunction %void %ulong %ulong +%_ptr_Generic_ulong = OpTypePointer Generic %ulong +%_ptr_Workgroup_uint = OpTypePointer Workgroup %uint + %3 = OpFunction %ulong None %46 + %36 = OpFunctionParameter %_ptr_Workgroup_uchar + %9 = OpLabel + %4 = OpVariable %_ptr_Function_ulong Function + %37 = OpBitcast %_ptr_Workgroup__arr_uint_uint_4 %36 + %8 = OpBitcast %_ptr_Workgroup_ulong %37 + %7 = OpLoad %ulong %8 Aligned 8 + %5 = OpCopyObject %ulong %7 + OpStore %4 %5 + %6 = OpLoad %ulong %4 + OpReturnValue %6 + OpFunctionEnd + %10 = OpFunction %void None %53 + %17 = OpFunctionParameter %ulong + %18 = OpFunctionParameter %ulong + %34 = OpLabel + %11 = OpVariable %_ptr_Function_ulong Function + %12 = OpVariable %_ptr_Function_ulong Function + %13 = OpVariable %_ptr_Function_ulong Function + %14 = OpVariable %_ptr_Function_ulong Function + %15 = OpVariable %_ptr_Function_ulong Function + %16 = OpVariable %_ptr_Function_ulong Function + OpStore %11 %17 + OpStore %12 %18 + %19 = OpLoad %ulong %11 Aligned 8 + OpStore %13 %19 + %20 = OpLoad %ulong %12 Aligned 8 + OpStore %14 %20 + %22 = OpLoad %ulong %13 + %30 = OpConvertUToPtr %_ptr_Generic_ulong %22 + %21 = OpLoad %ulong %30 Aligned 8 + OpStore %15 %21 + %23 = OpLoad %ulong %15 + %39 = OpBitcast %_ptr_Workgroup_uint %38 + %31 = OpBitcast %_ptr_Workgroup_ulong %39 + OpStore %31 %23 Aligned 8 + %40 = OpBitcast %_ptr_Workgroup_uchar %38 + %32 = OpFunctionCall %ulong %3 %40 + %24 = OpCopyObject %ulong %32 + OpStore %16 %24 + %26 = OpLoad %ulong %16 + %27 = OpLoad %ulong %15 + %25 = OpIAdd %ulong %26 %27 + OpStore %16 %25 + %28 = OpLoad %ulong %14 + %29 = OpLoad %ulong %16 + %33 = OpConvertUToPtr %_ptr_Generic_ulong %28 + OpStore %33 %29 Aligned 8 + OpReturn + OpFunctionEnd diff --git a/ptx/src/test/spirv_run/shared_unify_private.ptx b/ptx/src/test/spirv_run/shared_unify_private.ptx new file mode 100644 index 0000000..fd31357 --- /dev/null +++ b/ptx/src/test/spirv_run/shared_unify_private.ptx @@ -0,0 +1,32 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.func (.reg .b64 out) load_from_shared() +{ + .shared .b32 shared_mod[4]; + ld.shared.u64 out, [shared_mod]; + ret; +} + +.visible .entry shared_unify_private( + .param .u64 input, + .param .u64 output +) +{ + .shared .b32 shared_ex[2]; + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .u64 temp1; + .reg .u64 temp2; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.u64 temp1, [in_addr]; + st.shared.u64 [shared_ex], temp1; + call (temp2), load_from_shared; + add.u64 temp2, temp2, temp1; + st.u64 [out_addr], temp2; + ret; +} diff --git a/ptx/src/test/spirv_run/shared_unify_private.spvtxt b/ptx/src/test/spirv_run/shared_unify_private.spvtxt new file mode 100644 index 0000000..69bf018 --- /dev/null +++ b/ptx/src/test/spirv_run/shared_unify_private.spvtxt @@ -0,0 +1,84 @@ + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int8 + OpCapability Int16 + OpCapability Int64 + OpCapability Float16 + OpCapability Float64 + OpCapability DenormFlushToZero + %41 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %9 "shared_unify_private" %38 + OpExecutionMode %9 ContractionOff + %void = OpTypeVoid + %ulong = OpTypeInt 64 0 + %uchar = OpTypeInt 8 0 +%_ptr_Workgroup_uchar = OpTypePointer Workgroup %uchar + %46 = OpTypeFunction %ulong %_ptr_Workgroup_uchar +%_ptr_Function_ulong = OpTypePointer Function %ulong + %uint = OpTypeInt 32 0 + %uint_4 = OpConstant %uint 4 +%_arr_uint_uint_4 = OpTypeArray %uint %uint_4 +%_ptr_Workgroup__arr_uint_uint_4 = OpTypePointer Workgroup %_arr_uint_uint_4 +%_ptr_Workgroup_ulong = OpTypePointer Workgroup %ulong + %uint_16 = OpConstant %uint 16 +%_arr_uchar_uint_16 = OpTypeArray %uchar %uint_16 +%_ptr_Workgroup__arr_uchar_uint_16 = OpTypePointer Workgroup %_arr_uchar_uint_16 + %38 = OpVariable %_ptr_Workgroup__arr_uchar_uint_16 Workgroup + %56 = OpTypeFunction %void %ulong %ulong +%_ptr_Generic_ulong = OpTypePointer Generic %ulong + %uint_2 = OpConstant %uint 2 +%_arr_uint_uint_2 = OpTypeArray %uint %uint_2 +%_ptr_Workgroup__arr_uint_uint_2 = OpTypePointer Workgroup %_arr_uint_uint_2 + %1 = OpFunction %ulong None %46 + %36 = OpFunctionParameter %_ptr_Workgroup_uchar + %8 = OpLabel + %2 = OpVariable %_ptr_Function_ulong Function + %37 = OpBitcast %_ptr_Workgroup__arr_uint_uint_4 %36 + %7 = OpBitcast %_ptr_Workgroup_ulong %37 + %6 = OpLoad %ulong %7 Aligned 8 + %4 = OpCopyObject %ulong %6 + OpStore %2 %4 + %5 = OpLoad %ulong %2 + OpReturnValue %5 + OpFunctionEnd + %9 = OpFunction %void None %56 + %17 = OpFunctionParameter %ulong + %18 = OpFunctionParameter %ulong + %34 = OpLabel + %10 = OpVariable %_ptr_Function_ulong Function + %11 = OpVariable %_ptr_Function_ulong Function + %13 = OpVariable %_ptr_Function_ulong Function + %14 = OpVariable %_ptr_Function_ulong Function + %15 = OpVariable %_ptr_Function_ulong Function + %16 = OpVariable %_ptr_Function_ulong Function + OpStore %10 %17 + OpStore %11 %18 + %19 = OpLoad %ulong %10 Aligned 8 + OpStore %13 %19 + %20 = OpLoad %ulong %11 Aligned 8 + OpStore %14 %20 + %22 = OpLoad %ulong %13 + %30 = OpConvertUToPtr %_ptr_Generic_ulong %22 + %21 = OpLoad %ulong %30 Aligned 8 + OpStore %15 %21 + %23 = OpLoad %ulong %15 + %39 = OpBitcast %_ptr_Workgroup__arr_uint_uint_2 %38 + %31 = OpBitcast %_ptr_Workgroup_ulong %39 + OpStore %31 %23 Aligned 8 + %40 = OpBitcast %_ptr_Workgroup_uchar %38 + %32 = OpFunctionCall %ulong %1 %40 + %24 = OpCopyObject %ulong %32 + OpStore %16 %24 + %26 = OpLoad %ulong %16 + %27 = OpLoad %ulong %15 + %25 = OpIAdd %ulong %26 %27 + OpStore %16 %25 + %28 = OpLoad %ulong %14 + %29 = OpLoad %ulong %16 + %33 = OpConvertUToPtr %_ptr_Generic_ulong %28 + OpStore %33 %29 Aligned 8 + OpReturn + OpFunctionEnd diff --git a/ptx/src/test/spirv_run/shared_variable.spvtxt b/ptx/src/test/spirv_run/shared_variable.spvtxt index 49278a8..7a97c0e 100644 --- a/ptx/src/test/spirv_run/shared_variable.spvtxt +++ b/ptx/src/test/spirv_run/shared_variable.spvtxt @@ -7,23 +7,26 @@ OpCapability Int64 OpCapability Float16 OpCapability Float64 - %25 = OpExtInstImport "OpenCL.std" + OpCapability DenormFlushToZero + %28 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "shared_variable" %4 - OpDecorate %4 Alignment 4 + OpEntryPoint Kernel %1 "shared_variable" %25 + OpExecutionMode %1 ContractionOff %void = OpTypeVoid %uint = OpTypeInt 32 0 %uchar = OpTypeInt 8 0 %uint_128 = OpConstant %uint 128 %_arr_uchar_uint_128 = OpTypeArray %uchar %uint_128 %_ptr_Workgroup__arr_uchar_uint_128 = OpTypePointer Workgroup %_arr_uchar_uint_128 - %4 = OpVariable %_ptr_Workgroup__arr_uchar_uint_128 Workgroup + %25 = OpVariable %_ptr_Workgroup__arr_uchar_uint_128 Workgroup %ulong = OpTypeInt 64 0 - %33 = OpTypeFunction %void %ulong %ulong + %36 = OpTypeFunction %void %ulong %ulong %_ptr_Function_ulong = OpTypePointer Function %ulong %_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong + %uint_128_0 = OpConstant %uint 128 %_ptr_Workgroup_ulong = OpTypePointer Workgroup %ulong - %1 = OpFunction %void None %33 + %uint_128_1 = OpConstant %uint 128 + %1 = OpFunction %void None %36 %9 = OpFunctionParameter %ulong %10 = OpFunctionParameter %ulong %23 = OpLabel @@ -44,9 +47,11 @@ %13 = OpLoad %ulong %19 Aligned 8 OpStore %7 %13 %15 = OpLoad %ulong %7 - %20 = OpBitcast %_ptr_Workgroup_ulong %4 + %26 = OpBitcast %_ptr_Workgroup__arr_uchar_uint_128 %25 + %20 = OpBitcast %_ptr_Workgroup_ulong %26 OpStore %20 %15 Aligned 8 - %21 = OpBitcast %_ptr_Workgroup_ulong %4 + %27 = OpBitcast %_ptr_Workgroup__arr_uchar_uint_128 %25 + %21 = OpBitcast %_ptr_Workgroup_ulong %27 %16 = OpLoad %ulong %21 Aligned 8 OpStore %8 %16 %17 = OpLoad %ulong %6 diff --git a/ptx/src/test/spirv_run/verify.py b/ptx/src/test/spirv_run/verify.py index dbfab00..4ef6465 100644 --- a/ptx/src/test/spirv_run/verify.py +++ b/ptx/src/test/spirv_run/verify.py @@ -1,7 +1,7 @@ import os, sys, subprocess def main(path): - dirs = os.listdir(path) + dirs = sorted(os.listdir(path)) for file in dirs: if not file.endswith(".spvtxt"): continue diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index e96cdc2..3f27522 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -434,6 +434,7 @@ pub fn to_spirv_module<'input>(ast: ast::Module<'input>) -> Result, _>>()?; + let directives = hoist_function_globals(directives); let must_link_ptx_impl = ptx_impl_imports.len() > 0; let mut directives = ptx_impl_imports .into_iter() @@ -458,6 +459,7 @@ pub fn to_spirv_module<'input>(ast: ast::Module<'input>) -> Result(ast: ast::Module<'input>) -> Result(ast: ast::Module<'input>) -> Result( + directives: Vec>, +) -> ( + Vec>, + HashMap, HashSet>, +) { + let mut known_globals = HashSet::new(); + for directive in directives.iter() { + match directive { + Directive::Variable(_, ast::Variable { name, .. }) => { + known_globals.insert(*name); + } + Directive::Method(..) => {} + } + } + let mut symbol_uses_map = HashMap::new(); + let directives = directives + .into_iter() + .map(|directive| match directive { + Directive::Variable(..) | Directive::Method(Function { body: None, .. }) => directive, + Directive::Method(Function { + func_decl, + body: Some(mut statements), + globals, + import_as, + tuning, + linkage, + }) => { + let method_name = func_decl.borrow().name; + statements = statements + .into_iter() + .map(|statement| { + statement.map_id(&mut |symbol, _| { + if known_globals.contains(&symbol) { + multi_hash_map_append(&mut symbol_uses_map, method_name, symbol); + } + symbol + }) + }) + .collect::>(); + Directive::Method(Function { + func_decl, + body: Some(statements), + globals, + import_as, + tuning, + linkage, + }) + } + }) + .collect::>(); + (directives, symbol_uses_map) +} + +fn hoist_function_globals(directives: Vec) -> Vec { + let mut result = Vec::with_capacity(directives.len()); + for directive in directives { + match directive { + Directive::Method(method) => { + for variable in method.globals { + result.push(Directive::Variable(ast::LinkingDirective::NONE, variable)); + } + result.push(Directive::Method(Function { + globals: Vec::new(), + ..method + })) + } + _ => result.push(directive), + } + } + result +} + // TODO: remove this once we have pef-function support for denorms fn emit_denorm_build_string<'input>( call_map: &HashMap<&str, HashSet>, @@ -531,6 +607,7 @@ fn emit_directives<'input>( opencl_id: spirv::Word, should_flush_denorms: bool, call_map: &HashMap<&'input str, HashSet>, + globals_use_map: HashMap, HashSet>, directives: Vec>, kernel_info: &mut HashMap, ) -> Result<(), TranslateError> { @@ -559,10 +636,9 @@ fn emit_directives<'input>( builder, map, &id_defs, - &f.globals, &*func_decl, call_map, - &directives, + &globals_use_map, kernel_info, )?; if func_decl.name.is_kernel() { @@ -626,16 +702,17 @@ fn emit_function_linkage<'input>( if f.linkage == ast::LinkingDirective::NONE { return Ok(()); }; - let linking_name = f.import_as.as_deref().map_or_else( - || match f.func_decl.borrow().name { - ast::MethodName::Kernel(kernel_name) => Ok(kernel_name), - ast::MethodName::Func(fn_id) => match id_defs.reverse_variables.get(&fn_id) { + let linking_name = match f.func_decl.borrow().name { + // According to SPIR-V rules linkage attributes are invalid on kernels + ast::MethodName::Kernel(..) => return Ok(()), + ast::MethodName::Func(fn_id) => f.import_as.as_deref().map_or_else( + || match id_defs.reverse_variables.get(&fn_id) { Some(fn_name) => Ok(fn_name), None => Err(error_unknown_symbol()), }, - }, - Result::Ok, - )?; + Result::Ok, + )?, + }; emit_linking_decoration(builder, id_defs, Some(linking_name), fn_name, f.linkage); Ok(()) } @@ -712,7 +789,7 @@ fn multi_hash_map_append< entry.get_mut().extend(iter::once(value)); } hash_map::Entry::Vacant(entry) => { - entry.insert(Default::default()); + entry.insert(Default::default()).extend(iter::once(value)); } } } @@ -857,6 +934,9 @@ fn convert_dynamic_shared_memory_usage<'input>( linkage, })); } + // Existing .shared globals are now unused, they were replaced by kernel-specific globals + Directive::Variable(_, ast::Variable { name, .. }) + if globals_shared.contains_key(&name) => {} directive => result.push(directive), } } @@ -873,34 +953,35 @@ fn insert_arguments_remap_statements( globals_shared: &HashMap, statements: Vec, ExpandedArgParams>>, ) -> Vec, ExpandedArgParams>> { - let shared_id_param = match method_name { + let (shared_id_param, shared_id_type) = match method_name { ast::MethodName::Kernel(kernel_name) => { let globals_shared_size = match kernels_to_global_shared.get(kernel_name) { Some(s) => *s, None => return statements, }; let shared_id_param = new_id(); + func_decl_ref.shared_mem = Some(shared_id_param); let (linkage, type_) = match globals_shared_size { GlobalSharedSize::ExternUnsized => ( ast::LinkingDirective::EXTERN, - ast::Type::Array(ast::ScalarType::U8, Vec::new()), + ast::Type::Array(ast::ScalarType::B8, Vec::new()), ), GlobalSharedSize::Sized(size) => ( ast::LinkingDirective::NONE, - ast::Type::Array(ast::ScalarType::U8, vec![size as u32]), + ast::Type::Array(ast::ScalarType::B8, vec![size as u32]), ), }; result.push(Directive::Variable( linkage, ast::Variable { align: None, - v_type: type_, + v_type: type_.clone(), state_space: ast::StateSpace::Shared, name: shared_id_param, array_init: Vec::new(), }, )); - shared_id_param + (shared_id_param, Some(type_)) } ast::MethodName::Func(function_name) => { if !functions_to_global_shared.contains(&function_name) { @@ -914,7 +995,7 @@ fn insert_arguments_remap_statements( name: shared_id_param, array_init: Vec::new(), }); - shared_id_param + (shared_id_param, None) } }; replace_uses_of_shared_memory( @@ -922,6 +1003,7 @@ fn insert_arguments_remap_statements( globals_shared, functions_to_global_shared, shared_id_param, + shared_id_type, statements, ) } @@ -948,6 +1030,7 @@ fn replace_uses_of_shared_memory<'a>( extern_shared_decls: &HashMap, methods_using_extern_shared: &HashSet, shared_id_param: spirv::Word, + shared_id_type: Option, statements: Vec, ) -> Vec { let mut result = Vec::with_capacity(statements.len()); @@ -958,6 +1041,22 @@ fn replace_uses_of_shared_memory<'a>( // because there's simply no way to pass shared ptr // without converting it to .b64 first if methods_using_extern_shared.contains(&call.name) { + let shared_id_param = match shared_id_type { + Some(ref global_shared_defined_type) => { + let dst = new_id(); + result.push(Statement::Conversion(ImplicitConversion { + src: shared_id_param, + dst, + from_type: global_shared_defined_type.clone(), + to_type: ast::Type::Scalar(ast::ScalarType::B8), + from_space: ast::StateSpace::Shared, + to_space: ast::StateSpace::Shared, + kind: ConversionKind::PtrToPtr, + })); + dst + } + None => shared_id_param, + }; call.input_arguments.push(( shared_id_param, ast::Type::Scalar(ast::ScalarType::B8), @@ -998,10 +1097,7 @@ fn replace_uses_of_shared_memory<'a>( // * If it's a kernel -> size of .shared globals in use (direct or indirect) // * If it's a function -> does it use .shared global (directly or indirectly) fn resolve_indirect_uses_of_globals_shared<'input>( - mut methods_use_of_globals_shared: HashMap< - ast::MethodName<'input, spirv::Word>, - GlobalSharedSize, - >, + methods_use_of_globals_shared: HashMap, GlobalSharedSize>, kernels_methods_call_map: &HashMap<&'input str, HashSet>, ) -> (HashMap<&'input str, GlobalSharedSize>, HashSet) { let mut kernel_use = HashMap::new(); @@ -1116,14 +1212,13 @@ fn compute_denorm_information<'input>( .collect() } -fn emit_function_header<'a>( +fn emit_function_header<'input>( builder: &mut dr::Builder, map: &mut TypeWordMap, - defined_globals: &GlobalStringIdResolver<'a>, - synthetic_globals: &[ast::Variable], - func_decl: &ast::MethodDeclaration<'a, spirv::Word>, - call_map: &HashMap<&'a str, HashSet>, - direcitves: &[Directive], + defined_globals: &GlobalStringIdResolver<'input>, + func_decl: &ast::MethodDeclaration<'input, spirv::Word>, + call_map: &HashMap<&'input str, HashSet>, + globals_use_map: &HashMap, HashSet>, kernel_info: &mut HashMap, ) -> Result { if let ast::MethodName::Kernel(name) = func_decl.name { @@ -1154,38 +1249,28 @@ fn emit_function_header<'a>( let fn_id = match func_decl.name { ast::MethodName::Kernel(name) => { let fn_id = defined_globals.get_id(name)?; - let mut global_variables = defined_globals - .variables_type_check - .iter() - .filter_map(|(k, t)| t.as_ref().map(|_| *k)) - .collect::>(); - let mut interface = defined_globals.special_registers.interface(); - for ast::Variable { name, .. } in synthetic_globals { - interface.push(*name); - } - let empty_hash_set = HashSet::new(); - let child_fns = call_map.get(name).unwrap_or(&empty_hash_set); - for directive in direcitves { - match directive { - Directive::Method(Function { - func_decl, globals, .. - }) => { - match (**func_decl).borrow().name { - ast::MethodName::Func(name) => { - if child_fns.contains(&name) { - for var in globals { - interface.push(var.name); - } - } - } - ast::MethodName::Kernel(_) => {} - }; - } - _ => {} - } - } - global_variables.append(&mut interface); - builder.entry_point(spirv::ExecutionModel::Kernel, fn_id, name, global_variables); + let interface = globals_use_map + .get(&ast::MethodName::Kernel(name)) + .into_iter() + .flatten() + .copied() + .chain({ + call_map + .get(name) + .into_iter() + .flat_map(|subfunctions| { + subfunctions.iter().flat_map(|subfunction| { + globals_use_map + .get(&ast::MethodName::Func(*subfunction)) + .into_iter() + .flatten() + .copied() + }) + }) + .into_iter() + }) + .collect::>(); + builder.entry_point(spirv::ExecutionModel::Kernel, fn_id, name, interface); fn_id } ast::MethodName::Func(name) => name, @@ -3571,7 +3656,8 @@ fn emit_variable<'input>( [dr::Operand::LiteralInt32(align)].iter().cloned(), ); } - if var.state_space != ast::StateSpace::Shared || !linking.contains(ast::LinkingDirective::EXTERN) + if var.state_space != ast::StateSpace::Shared + || !linking.contains(ast::LinkingDirective::EXTERN) { emit_linking_decoration(builder, id_defs, None, var.name, linking); } @@ -4328,11 +4414,35 @@ fn emit_implicit_conversion( ); if cv.to_space == ast::StateSpace::Generic && cv.from_space != ast::StateSpace::Generic { - builder.ptr_cast_to_generic(result_type, Some(cv.dst), cv.src)?; + let src = if cv.from_type != cv.to_type { + let temp_type = map.get_or_add( + builder, + SpirvType::Pointer( + Box::new(SpirvType::new(cv.to_type.clone())), + cv.from_space.to_spirv(), + ), + ); + builder.bitcast(temp_type, None, cv.src)? + } else { + cv.src + }; + builder.ptr_cast_to_generic(result_type, Some(cv.dst), src)?; } else if cv.from_space == ast::StateSpace::Generic && cv.to_space != ast::StateSpace::Generic { - builder.generic_cast_to_ptr(result_type, Some(cv.dst), cv.src)?; + let src = if cv.from_type != cv.to_type { + let temp_type = map.get_or_add( + builder, + SpirvType::Pointer( + Box::new(SpirvType::new(cv.to_type.clone())), + cv.from_space.to_spirv(), + ), + ); + builder.bitcast(temp_type, None, cv.src)? + } else { + cv.src + }; + builder.generic_cast_to_ptr(result_type, Some(cv.dst), src)?; } else { builder.bitcast(result_type, Some(cv.dst), cv.src)?; } @@ -5035,22 +5145,6 @@ impl SpecialRegistersMap { } } - fn interface(&self) -> Vec { - return Vec::new(); - /* - self.reg_to_id - .iter() - .filter_map(|(sreg, id)| { - if sreg.normalized_sreg_and_type().is_none() { - Some(*id) - } else { - None - } - }) - .collect::>() - */ - } - fn get(&self, id: spirv::Word) -> Option { self.id_to_reg.get(&id).copied() } -- cgit v1.2.3