From 0172dc58e52f2ac1e4d01951002a94a69b3589d0 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Wed, 29 Sep 2021 02:24:32 +0200 Subject: Redo shared memory transformation --- ptx/src/test/spirv_run/atom_add.spvtxt | 37 +- ptx/src/test/spirv_run/atom_add_float.spvtxt | 41 +- ptx/src/test/spirv_run/extern_shared.spvtxt | 24 +- ptx/src/test/spirv_run/extern_shared_call.spvtxt | 40 +- ptx/src/test/spirv_run/ld_st_implicit.spvtxt | 39 +- ptx/src/test/spirv_run/mod.rs | 3 +- ptx/src/test/spirv_run/shared_ptr_32.spvtxt | 23 +- .../test/spirv_run/shared_ptr_take_address.spvtxt | 16 +- ptx/src/test/spirv_run/shared_unify_extern.ptx | 22 +- ptx/src/test/spirv_run/shared_unify_extern.spvtxt | 153 +++++--- ptx/src/test/spirv_run/shared_unify_private.ptx | 32 -- ptx/src/test/spirv_run/shared_unify_private.spvtxt | 84 ---- ptx/src/test/spirv_run/shared_variable.spvtxt | 21 +- ptx/src/translate.rs | 424 ++++++++++----------- 14 files changed, 439 insertions(+), 520 deletions(-) delete mode 100644 ptx/src/test/spirv_run/shared_unify_private.ptx delete mode 100644 ptx/src/test/spirv_run/shared_unify_private.spvtxt diff --git a/ptx/src/test/spirv_run/atom_add.spvtxt b/ptx/src/test/spirv_run/atom_add.spvtxt index 3609247..987fdef 100644 --- a/ptx/src/test/spirv_run/atom_add.spvtxt +++ b/ptx/src/test/spirv_run/atom_add.spvtxt @@ -8,32 +8,32 @@ OpCapability Float16 OpCapability Float64 OpCapability DenormFlushToZero - %42 = OpExtInstImport "OpenCL.std" + OpExtension "SPV_KHR_float_controls" + OpExtension "SPV_KHR_no_integer_wrap_decoration" + %38 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "atom_add" %38 + OpEntryPoint Kernel %1 "atom_add" %4 OpExecutionMode %1 ContractionOff + OpDecorate %4 Alignment 4 %void = OpTypeVoid %uint = OpTypeInt 32 0 %uchar = OpTypeInt 8 0 %uint_1024 = OpConstant %uint 1024 %_arr_uchar_uint_1024 = OpTypeArray %uchar %uint_1024 %_ptr_Workgroup__arr_uchar_uint_1024 = OpTypePointer Workgroup %_arr_uchar_uint_1024 - %38 = OpVariable %_ptr_Workgroup__arr_uchar_uint_1024 Workgroup + %4 = OpVariable %_ptr_Workgroup__arr_uchar_uint_1024 Workgroup %ulong = OpTypeInt 64 0 - %50 = OpTypeFunction %void %ulong %ulong + %46 = OpTypeFunction %void %ulong %ulong %_ptr_Function_ulong = OpTypePointer Function %ulong %_ptr_Function_uint = OpTypePointer Function %uint %_ptr_Generic_uint = OpTypePointer Generic %uint %ulong_4 = OpConstant %ulong 4 %_ptr_Generic_uchar = OpTypePointer Generic %uchar -%uint_1024_0 = OpConstant %uint 1024 %_ptr_Workgroup_uint = OpTypePointer Workgroup %uint -%uint_1024_1 = OpConstant %uint 1024 %uint_1 = OpConstant %uint 1 %uint_0 = OpConstant %uint 0 -%uint_1024_2 = OpConstant %uint 1024 %ulong_4_0 = OpConstant %ulong 4 - %1 = OpFunction %void None %50 + %1 = OpFunction %void None %46 %9 = OpFunctionParameter %ulong %10 = OpFunctionParameter %ulong %36 = OpLabel @@ -55,22 +55,19 @@ OpStore %7 %13 %16 = OpLoad %ulong %5 %30 = OpConvertUToPtr %_ptr_Generic_uint %16 - %55 = OpBitcast %_ptr_Generic_uchar %30 - %56 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %55 %ulong_4 - %26 = OpBitcast %_ptr_Generic_uint %56 + %51 = OpBitcast %_ptr_Generic_uchar %30 + %52 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %51 %ulong_4 + %26 = OpBitcast %_ptr_Generic_uint %52 %15 = OpLoad %uint %26 Aligned 4 OpStore %8 %15 %17 = OpLoad %uint %7 - %39 = OpBitcast %_ptr_Workgroup__arr_uchar_uint_1024 %38 - %31 = OpBitcast %_ptr_Workgroup_uint %39 + %31 = OpBitcast %_ptr_Workgroup_uint %4 OpStore %31 %17 Aligned 4 %19 = OpLoad %uint %8 - %40 = OpBitcast %_ptr_Workgroup__arr_uchar_uint_1024 %38 - %32 = OpBitcast %_ptr_Workgroup_uint %40 + %32 = OpBitcast %_ptr_Workgroup_uint %4 %18 = OpAtomicIAdd %uint %32 %uint_1 %uint_0 %19 OpStore %7 %18 - %41 = OpBitcast %_ptr_Workgroup__arr_uchar_uint_1024 %38 - %33 = OpBitcast %_ptr_Workgroup_uint %41 + %33 = OpBitcast %_ptr_Workgroup_uint %4 %20 = OpLoad %uint %33 Aligned 4 OpStore %8 %20 %21 = OpLoad %ulong %6 @@ -80,9 +77,9 @@ %23 = OpLoad %ulong %6 %24 = OpLoad %uint %8 %35 = OpConvertUToPtr %_ptr_Generic_uint %23 - %63 = OpBitcast %_ptr_Generic_uchar %35 - %64 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %63 %ulong_4_0 - %28 = OpBitcast %_ptr_Generic_uint %64 + %56 = OpBitcast %_ptr_Generic_uchar %35 + %57 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %56 %ulong_4_0 + %28 = OpBitcast %_ptr_Generic_uint %57 OpStore %28 %24 Aligned 4 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/atom_add_float.spvtxt b/ptx/src/test/spirv_run/atom_add_float.spvtxt index 9533d83..067c347 100644 --- a/ptx/src/test/spirv_run/atom_add_float.spvtxt +++ b/ptx/src/test/spirv_run/atom_add_float.spvtxt @@ -8,37 +8,37 @@ OpCapability Float16 OpCapability Float64 OpCapability DenormFlushToZero - %46 = OpExtInstImport "OpenCL.std" + OpExtension "SPV_KHR_float_controls" + OpExtension "SPV_KHR_no_integer_wrap_decoration" + %42 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "atom_add_float" %42 + OpEntryPoint Kernel %1 "atom_add_float" %4 OpExecutionMode %1 ContractionOff OpDecorate %37 LinkageAttributes "__zluda_ptx_impl__atom_relaxed_gpu_shared_add_f32" Import + OpDecorate %4 Alignment 4 %void = OpTypeVoid %float = OpTypeFloat 32 %_ptr_Workgroup_float = OpTypePointer Workgroup %float - %50 = OpTypeFunction %float %_ptr_Workgroup_float %float + %46 = OpTypeFunction %float %_ptr_Workgroup_float %float %uint = OpTypeInt 32 0 %uchar = OpTypeInt 8 0 %uint_1024 = OpConstant %uint 1024 %_arr_uchar_uint_1024 = OpTypeArray %uchar %uint_1024 %_ptr_Workgroup__arr_uchar_uint_1024 = OpTypePointer Workgroup %_arr_uchar_uint_1024 - %42 = OpVariable %_ptr_Workgroup__arr_uchar_uint_1024 Workgroup + %4 = OpVariable %_ptr_Workgroup__arr_uchar_uint_1024 Workgroup %ulong = OpTypeInt 64 0 - %57 = OpTypeFunction %void %ulong %ulong + %53 = OpTypeFunction %void %ulong %ulong %_ptr_Function_ulong = OpTypePointer Function %ulong %_ptr_Function_float = OpTypePointer Function %float %_ptr_Generic_float = OpTypePointer Generic %float %ulong_4 = OpConstant %ulong 4 %_ptr_Generic_uchar = OpTypePointer Generic %uchar -%uint_1024_0 = OpConstant %uint 1024 -%uint_1024_1 = OpConstant %uint 1024 -%uint_1024_2 = OpConstant %uint 1024 %ulong_4_0 = OpConstant %ulong 4 - %37 = OpFunction %float None %50 + %37 = OpFunction %float None %46 %39 = OpFunctionParameter %_ptr_Workgroup_float %40 = OpFunctionParameter %float OpFunctionEnd - %1 = OpFunction %void None %57 + %1 = OpFunction %void None %53 %9 = OpFunctionParameter %ulong %10 = OpFunctionParameter %ulong %36 = OpLabel @@ -60,22 +60,19 @@ OpStore %7 %13 %16 = OpLoad %ulong %5 %30 = OpConvertUToPtr %_ptr_Generic_float %16 - %62 = OpBitcast %_ptr_Generic_uchar %30 - %63 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %62 %ulong_4 - %26 = OpBitcast %_ptr_Generic_float %63 + %58 = OpBitcast %_ptr_Generic_uchar %30 + %59 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %58 %ulong_4 + %26 = OpBitcast %_ptr_Generic_float %59 %15 = OpLoad %float %26 Aligned 4 OpStore %8 %15 %17 = OpLoad %float %7 - %43 = OpBitcast %_ptr_Workgroup__arr_uchar_uint_1024 %42 - %31 = OpBitcast %_ptr_Workgroup_float %43 + %31 = OpBitcast %_ptr_Workgroup_float %4 OpStore %31 %17 Aligned 4 %19 = OpLoad %float %8 - %44 = OpBitcast %_ptr_Workgroup__arr_uchar_uint_1024 %42 - %32 = OpBitcast %_ptr_Workgroup_float %44 + %32 = OpBitcast %_ptr_Workgroup_float %4 %18 = OpFunctionCall %float %37 %32 %19 OpStore %7 %18 - %45 = OpBitcast %_ptr_Workgroup__arr_uchar_uint_1024 %42 - %33 = OpBitcast %_ptr_Workgroup_float %45 + %33 = OpBitcast %_ptr_Workgroup_float %4 %20 = OpLoad %float %33 Aligned 4 OpStore %8 %20 %21 = OpLoad %ulong %6 @@ -85,9 +82,9 @@ %23 = OpLoad %ulong %6 %24 = OpLoad %float %8 %35 = OpConvertUToPtr %_ptr_Generic_float %23 - %67 = OpBitcast %_ptr_Generic_uchar %35 - %68 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %67 %ulong_4_0 - %28 = OpBitcast %_ptr_Generic_float %68 + %60 = OpBitcast %_ptr_Generic_uchar %35 + %61 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %60 %ulong_4_0 + %28 = OpBitcast %_ptr_Generic_float %61 OpStore %28 %24 Aligned 4 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/extern_shared.spvtxt b/ptx/src/test/spirv_run/extern_shared.spvtxt index 82d86ae..025cd81 100644 --- a/ptx/src/test/spirv_run/extern_shared.spvtxt +++ b/ptx/src/test/spirv_run/extern_shared.spvtxt @@ -8,22 +8,22 @@ OpCapability Float16 OpCapability Float64 OpCapability DenormFlushToZero - %27 = OpExtInstImport "OpenCL.std" + OpExtension "SPV_KHR_float_controls" + OpExtension "SPV_KHR_no_integer_wrap_decoration" + %24 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %2 "extern_shared" %24 + OpEntryPoint Kernel %2 "extern_shared" %1 OpExecutionMode %2 ContractionOff %void = OpTypeVoid - %uchar = OpTypeInt 8 0 -%_ptr_Workgroup_uchar = OpTypePointer Workgroup %uchar - %24 = OpVariable %_ptr_Workgroup_uchar Workgroup + %uint = OpTypeInt 32 0 +%_ptr_Workgroup_uint = OpTypePointer Workgroup %uint + %1 = OpVariable %_ptr_Workgroup_uint Workgroup %ulong = OpTypeInt 64 0 - %32 = OpTypeFunction %void %ulong %ulong + %29 = OpTypeFunction %void %ulong %ulong %_ptr_Function_ulong = OpTypePointer Function %ulong %_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong - %uint = OpTypeInt 32 0 -%_ptr_Workgroup_uint = OpTypePointer Workgroup %uint %_ptr_Workgroup_ulong = OpTypePointer Workgroup %ulong - %2 = OpFunction %void None %32 + %2 = OpFunction %void None %29 %8 = OpFunctionParameter %ulong %9 = OpFunctionParameter %ulong %22 = OpLabel @@ -43,11 +43,9 @@ %12 = OpLoad %ulong %18 Aligned 8 OpStore %7 %12 %14 = OpLoad %ulong %7 - %25 = OpBitcast %_ptr_Workgroup_uint %24 - %19 = OpBitcast %_ptr_Workgroup_ulong %25 + %19 = OpBitcast %_ptr_Workgroup_ulong %1 OpStore %19 %14 Aligned 8 - %26 = OpBitcast %_ptr_Workgroup_uint %24 - %20 = OpBitcast %_ptr_Workgroup_ulong %26 + %20 = OpBitcast %_ptr_Workgroup_ulong %1 %15 = OpLoad %ulong %20 Aligned 8 OpStore %7 %15 %16 = OpLoad %ulong %6 diff --git a/ptx/src/test/spirv_run/extern_shared_call.spvtxt b/ptx/src/test/spirv_run/extern_shared_call.spvtxt index 3cc78cb..bf1dccd 100644 --- a/ptx/src/test/spirv_run/extern_shared_call.spvtxt +++ b/ptx/src/test/spirv_run/extern_shared_call.spvtxt @@ -8,41 +8,40 @@ OpCapability Float16 OpCapability Float64 OpCapability DenormFlushToZero - %41 = OpExtInstImport "OpenCL.std" + OpExtension "SPV_KHR_float_controls" + OpExtension "SPV_KHR_no_integer_wrap_decoration" + %35 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %12 "extern_shared_call" %37 + OpEntryPoint Kernel %12 "extern_shared_call" %1 OpExecutionMode %12 ContractionOff + OpDecorate %1 Alignment 4 %void = OpTypeVoid - %uchar = OpTypeInt 8 0 -%_ptr_Workgroup_uchar = OpTypePointer Workgroup %uchar - %45 = OpTypeFunction %void %_ptr_Workgroup_uchar - %ulong = OpTypeInt 64 0 -%_ptr_Function_ulong = OpTypePointer Function %ulong %uint = OpTypeInt 32 0 %_ptr_Workgroup_uint = OpTypePointer Workgroup %uint + %1 = OpVariable %_ptr_Workgroup_uint Workgroup + %39 = OpTypeFunction %void %_ptr_Workgroup_uint + %ulong = OpTypeInt 64 0 +%_ptr_Function_ulong = OpTypePointer Function %ulong %_ptr_Workgroup_ulong = OpTypePointer Workgroup %ulong %ulong_2 = OpConstant %ulong 2 - %37 = OpVariable %_ptr_Workgroup_uchar Workgroup - %51 = OpTypeFunction %void %ulong %ulong + %43 = OpTypeFunction %void %ulong %ulong %_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong - %2 = OpFunction %void None %45 - %34 = OpFunctionParameter %_ptr_Workgroup_uchar + %2 = OpFunction %void None %39 + %34 = OpFunctionParameter %_ptr_Workgroup_uint %11 = OpLabel %3 = OpVariable %_ptr_Function_ulong Function - %35 = OpBitcast %_ptr_Workgroup_uint %34 - %9 = OpBitcast %_ptr_Workgroup_ulong %35 + %9 = OpBitcast %_ptr_Workgroup_ulong %34 %4 = OpLoad %ulong %9 Aligned 8 OpStore %3 %4 %6 = OpLoad %ulong %3 %5 = OpIAdd %ulong %6 %ulong_2 OpStore %3 %5 %7 = OpLoad %ulong %3 - %36 = OpBitcast %_ptr_Workgroup_uint %34 - %10 = OpBitcast %_ptr_Workgroup_ulong %36 + %10 = OpBitcast %_ptr_Workgroup_ulong %34 OpStore %10 %7 Aligned 8 OpReturn OpFunctionEnd - %12 = OpFunction %void None %51 + %12 = OpFunction %void None %43 %18 = OpFunctionParameter %ulong %19 = OpFunctionParameter %ulong %32 = OpLabel @@ -62,13 +61,10 @@ %22 = OpLoad %ulong %28 Aligned 8 OpStore %17 %22 %24 = OpLoad %ulong %17 - %38 = OpBitcast %_ptr_Workgroup_uint %37 - %29 = OpBitcast %_ptr_Workgroup_ulong %38 + %29 = OpBitcast %_ptr_Workgroup_ulong %1 OpStore %29 %24 Aligned 8 - %39 = OpBitcast %_ptr_Workgroup_uchar %37 - %53 = OpFunctionCall %void %2 %39 - %40 = OpBitcast %_ptr_Workgroup_uint %37 - %30 = OpBitcast %_ptr_Workgroup_ulong %40 + %45 = OpFunctionCall %void %2 %1 + %30 = OpBitcast %_ptr_Workgroup_ulong %1 %25 = OpLoad %ulong %30 Aligned 8 OpStore %17 %25 %26 = OpLoad %ulong %16 diff --git a/ptx/src/test/spirv_run/ld_st_implicit.spvtxt b/ptx/src/test/spirv_run/ld_st_implicit.spvtxt index 29f46f9..9c0e508 100644 --- a/ptx/src/test/spirv_run/ld_st_implicit.spvtxt +++ b/ptx/src/test/spirv_run/ld_st_implicit.spvtxt @@ -7,20 +7,25 @@ OpCapability Int64 OpCapability Float16 OpCapability Float64 - %21 = OpExtInstImport "OpenCL.std" + OpCapability DenormFlushToZero + OpExtension "SPV_KHR_float_controls" + OpExtension "SPV_KHR_no_integer_wrap_decoration" + %23 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL OpEntryPoint Kernel %1 "ld_st_implicit" + OpExecutionMode %1 ContractionOff %void = OpTypeVoid %ulong = OpTypeInt 64 0 - %24 = OpTypeFunction %void %ulong %ulong + %26 = OpTypeFunction %void %ulong %ulong %_ptr_Function_ulong = OpTypePointer Function %ulong +%ulong_81985529216486895 = OpConstant %ulong 81985529216486895 %float = OpTypeFloat 32 %_ptr_CrossWorkgroup_float = OpTypePointer CrossWorkgroup %float %uint = OpTypeInt 32 0 - %1 = OpFunction %void None %24 + %1 = OpFunction %void None %26 %7 = OpFunctionParameter %ulong %8 = OpFunctionParameter %ulong - %19 = OpLabel + %21 = OpLabel %2 = OpVariable %_ptr_Function_ulong Function %3 = OpVariable %_ptr_Function_ulong Function %4 = OpVariable %_ptr_Function_ulong Function @@ -32,18 +37,20 @@ OpStore %4 %9 %10 = OpLoad %ulong %3 Aligned 8 OpStore %5 %10 - %12 = OpLoad %ulong %4 - %16 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %12 - %15 = OpLoad %float %16 Aligned 4 - %29 = OpBitcast %uint %15 - %11 = OpUConvert %ulong %29 + %11 = OpCopyObject %ulong %ulong_81985529216486895 OpStore %6 %11 - %13 = OpLoad %ulong %5 - %14 = OpLoad %ulong %6 - %17 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %13 - %30 = OpBitcast %ulong %14 - %31 = OpUConvert %uint %30 - %18 = OpBitcast %float %31 - OpStore %17 %18 Aligned 4 + %13 = OpLoad %ulong %4 + %18 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %13 + %17 = OpLoad %float %18 Aligned 4 + %31 = OpBitcast %uint %17 + %12 = OpUConvert %ulong %31 + OpStore %6 %12 + %14 = OpLoad %ulong %5 + %15 = OpLoad %ulong %6 + %19 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %14 + %32 = OpBitcast %ulong %15 + %33 = OpUConvert %uint %32 + %20 = OpBitcast %float %33 + OpStore %19 %20 Aligned 4 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index 512b6cf..dfc252d 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -218,8 +218,7 @@ test_ptx!(cvt_f64_f32, [0.125f32], [0.125f64]); test_ptx!(prmt, [0x70c507d6u32, 0x6fbd4b5cu32], [0x6fbdd65cu32]); test_ptx!(activemask, [0u32], [1u32]); test_ptx!(membar, [152731u32], [152731u32]); -test_ptx!(shared_unify_extern, [7681u64], [15362u64]); -test_ptx!(shared_unify_private, [67153u64], [134306u64]); +test_ptx!(shared_unify_extern, [7681u64, 7682u64], [15363u64]); test_ptx!(assertfail); test_ptx!(func_ptr); diff --git a/ptx/src/test/spirv_run/shared_ptr_32.spvtxt b/ptx/src/test/spirv_run/shared_ptr_32.spvtxt index 020c15b..787a71c 100644 --- a/ptx/src/test/spirv_run/shared_ptr_32.spvtxt +++ b/ptx/src/test/spirv_run/shared_ptr_32.spvtxt @@ -8,27 +8,29 @@ OpCapability Float16 OpCapability Float64 OpCapability DenormFlushToZero - %34 = OpExtInstImport "OpenCL.std" + OpExtension "SPV_KHR_float_controls" + OpExtension "SPV_KHR_no_integer_wrap_decoration" + %32 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "shared_ptr_32" %32 + OpEntryPoint Kernel %1 "shared_ptr_32" %4 OpExecutionMode %1 ContractionOff + OpDecorate %4 Alignment 4 %void = OpTypeVoid %uint = OpTypeInt 32 0 %uchar = OpTypeInt 8 0 %uint_128 = OpConstant %uint 128 %_arr_uchar_uint_128 = OpTypeArray %uchar %uint_128 %_ptr_Workgroup__arr_uchar_uint_128 = OpTypePointer Workgroup %_arr_uchar_uint_128 - %32 = OpVariable %_ptr_Workgroup__arr_uchar_uint_128 Workgroup + %4 = OpVariable %_ptr_Workgroup__arr_uchar_uint_128 Workgroup %ulong = OpTypeInt 64 0 - %42 = OpTypeFunction %void %ulong %ulong + %40 = OpTypeFunction %void %ulong %ulong %_ptr_Function_ulong = OpTypePointer Function %ulong %_ptr_Function_uint = OpTypePointer Function %uint - %uint_128_0 = OpConstant %uint 128 %_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong %_ptr_Workgroup_ulong = OpTypePointer Workgroup %ulong %ulong_0 = OpConstant %ulong 0 %_ptr_Workgroup_uchar = OpTypePointer Workgroup %uchar - %1 = OpFunction %void None %42 + %1 = OpFunction %void None %40 %10 = OpFunctionParameter %ulong %11 = OpFunctionParameter %ulong %30 = OpLabel @@ -45,8 +47,7 @@ OpStore %5 %12 %13 = OpLoad %ulong %3 Aligned 8 OpStore %6 %13 - %33 = OpBitcast %_ptr_Workgroup__arr_uchar_uint_128 %32 - %25 = OpConvertPtrToU %uint %33 + %25 = OpConvertPtrToU %uint %4 %14 = OpCopyObject %uint %25 OpStore %7 %14 %16 = OpLoad %ulong %5 @@ -59,9 +60,9 @@ OpStore %27 %18 Aligned 8 %20 = OpLoad %uint %7 %28 = OpConvertUToPtr %_ptr_Workgroup_ulong %20 - %49 = OpBitcast %_ptr_Workgroup_uchar %28 - %50 = OpInBoundsPtrAccessChain %_ptr_Workgroup_uchar %49 %ulong_0 - %24 = OpBitcast %_ptr_Workgroup_ulong %50 + %46 = OpBitcast %_ptr_Workgroup_uchar %28 + %47 = OpInBoundsPtrAccessChain %_ptr_Workgroup_uchar %46 %ulong_0 + %24 = OpBitcast %_ptr_Workgroup_ulong %47 %19 = OpLoad %ulong %24 Aligned 8 OpStore %9 %19 %21 = OpLoad %ulong %6 diff --git a/ptx/src/test/spirv_run/shared_ptr_take_address.spvtxt b/ptx/src/test/spirv_run/shared_ptr_take_address.spvtxt index 90e04f3..14926ef 100644 --- a/ptx/src/test/spirv_run/shared_ptr_take_address.spvtxt +++ b/ptx/src/test/spirv_run/shared_ptr_take_address.spvtxt @@ -8,20 +8,23 @@ OpCapability Float16 OpCapability Float64 OpCapability DenormFlushToZero - %32 = OpExtInstImport "OpenCL.std" + OpExtension "SPV_KHR_float_controls" + OpExtension "SPV_KHR_no_integer_wrap_decoration" + %30 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %2 "shared_ptr_take_address" %30 + OpEntryPoint Kernel %2 "shared_ptr_take_address" %1 OpExecutionMode %2 ContractionOff + OpDecorate %1 Alignment 4 %void = OpTypeVoid %uchar = OpTypeInt 8 0 %_ptr_Workgroup_uchar = OpTypePointer Workgroup %uchar - %30 = OpVariable %_ptr_Workgroup_uchar Workgroup + %1 = OpVariable %_ptr_Workgroup_uchar Workgroup %ulong = OpTypeInt 64 0 - %37 = OpTypeFunction %void %ulong %ulong + %35 = OpTypeFunction %void %ulong %ulong %_ptr_Function_ulong = OpTypePointer Function %ulong %_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong %_ptr_Workgroup_ulong = OpTypePointer Workgroup %ulong - %2 = OpFunction %void None %37 + %2 = OpFunction %void None %35 %10 = OpFunctionParameter %ulong %11 = OpFunctionParameter %ulong %28 = OpLabel @@ -38,8 +41,7 @@ OpStore %5 %12 %13 = OpLoad %ulong %4 Aligned 8 OpStore %6 %13 - %31 = OpBitcast %_ptr_Workgroup_uchar %30 - %23 = OpConvertPtrToU %ulong %31 + %23 = OpConvertPtrToU %ulong %1 %14 = OpCopyObject %ulong %23 OpStore %7 %14 %16 = OpLoad %ulong %5 diff --git a/ptx/src/test/spirv_run/shared_unify_extern.ptx b/ptx/src/test/spirv_run/shared_unify_extern.ptx index 8b406b2..0416e2c 100644 --- a/ptx/src/test/spirv_run/shared_unify_extern.ptx +++ b/ptx/src/test/spirv_run/shared_unify_extern.ptx @@ -5,10 +5,20 @@ .extern .shared .b32 shared_ex[]; .shared .b32 shared_mod[4]; +.func (.reg .b64 out) add() +{ + .reg .u64 temp1; + .reg .u64 temp2; + ld.shared.u64 temp1, [shared_mod]; + ld.shared.u64 temp2, [shared_ex]; + add.u64 out, temp2, temp1; + ret; +} -.func (.reg .b64 out) load_from_shared() +.func (.reg .b64 out) set_shared_temp1(.reg .b64 temp1) { - ld.shared.u64 out, [shared_mod]; + st.shared.u64 [shared_ex], temp1; + call (out), add; ret; } @@ -25,10 +35,10 @@ ld.param.u64 in_addr, [input]; ld.param.u64 out_addr, [output]; - ld.u64 temp1, [in_addr]; - st.shared.u64 [shared_ex], temp1; - call (temp2), load_from_shared; - add.u64 temp2, temp2, temp1; + ld.global.u64 temp1, [in_addr]; + ld.global.u64 temp2, [in_addr+8]; + st.shared.u64 [shared_mod], temp2; + call (temp2), set_shared_temp1, (temp1); st.u64 [out_addr], temp2; ret; } diff --git a/ptx/src/test/spirv_run/shared_unify_extern.spvtxt b/ptx/src/test/spirv_run/shared_unify_extern.spvtxt index 2dd2056..90fc156 100644 --- a/ptx/src/test/spirv_run/shared_unify_extern.spvtxt +++ b/ptx/src/test/spirv_run/shared_unify_extern.spvtxt @@ -8,72 +8,111 @@ OpCapability Float16 OpCapability Float64 OpCapability DenormFlushToZero - %41 = OpExtInstImport "OpenCL.std" + OpExtension "SPV_KHR_float_controls" + OpExtension "SPV_KHR_no_integer_wrap_decoration" + %61 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %10 "shared_unify_extern" %38 - OpExecutionMode %10 ContractionOff + OpEntryPoint Kernel %27 "shared_unify_extern" %1 %2 + OpExecutionMode %27 ContractionOff %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %uchar = OpTypeInt 8 0 -%_ptr_Workgroup_uchar = OpTypePointer Workgroup %uchar - %46 = OpTypeFunction %ulong %_ptr_Workgroup_uchar -%_ptr_Function_ulong = OpTypePointer Function %ulong %uint = OpTypeInt 32 0 +%_ptr_Workgroup_uint = OpTypePointer Workgroup %uint + %1 = OpVariable %_ptr_Workgroup_uint Workgroup %uint_4 = OpConstant %uint 4 %_arr_uint_uint_4 = OpTypeArray %uint %uint_4 %_ptr_Workgroup__arr_uint_uint_4 = OpTypePointer Workgroup %_arr_uint_uint_4 + %2 = OpVariable %_ptr_Workgroup__arr_uint_uint_4 Workgroup + %ulong = OpTypeInt 64 0 + %uint_4_0 = OpConstant %uint 4 + %70 = OpTypeFunction %ulong %_ptr_Workgroup_uint %_ptr_Workgroup__arr_uint_uint_4 + %uint_4_1 = OpConstant %uint 4 +%_ptr_Function_ulong = OpTypePointer Function %ulong %_ptr_Workgroup_ulong = OpTypePointer Workgroup %ulong - %38 = OpVariable %_ptr_Workgroup_uchar Workgroup - %53 = OpTypeFunction %void %ulong %ulong + %uint_4_2 = OpConstant %uint 4 + %75 = OpTypeFunction %ulong %ulong %_ptr_Workgroup_uint %_ptr_Workgroup__arr_uint_uint_4 + %uint_4_3 = OpConstant %uint 4 + %77 = OpTypeFunction %void %ulong %ulong +%_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong + %ulong_8 = OpConstant %ulong 8 + %uchar = OpTypeInt 8 0 +%_ptr_CrossWorkgroup_uchar = OpTypePointer CrossWorkgroup %uchar %_ptr_Generic_ulong = OpTypePointer Generic %ulong -%_ptr_Workgroup_uint = OpTypePointer Workgroup %uint - %3 = OpFunction %ulong None %46 - %36 = OpFunctionParameter %_ptr_Workgroup_uchar - %9 = OpLabel + %3 = OpFunction %ulong None %70 + %57 = OpFunctionParameter %_ptr_Workgroup_uint + %58 = OpFunctionParameter %_ptr_Workgroup__arr_uint_uint_4 + %16 = OpLabel %4 = OpVariable %_ptr_Function_ulong Function - %37 = OpBitcast %_ptr_Workgroup__arr_uint_uint_4 %36 - %8 = OpBitcast %_ptr_Workgroup_ulong %37 - %7 = OpLoad %ulong %8 Aligned 8 - %5 = OpCopyObject %ulong %7 - OpStore %4 %5 - %6 = OpLoad %ulong %4 - OpReturnValue %6 + %5 = OpVariable %_ptr_Function_ulong Function + %6 = OpVariable %_ptr_Function_ulong Function + %13 = OpBitcast %_ptr_Workgroup_ulong %58 + %7 = OpLoad %ulong %13 Aligned 8 + OpStore %5 %7 + %14 = OpBitcast %_ptr_Workgroup_ulong %57 + %8 = OpLoad %ulong %14 Aligned 8 + OpStore %6 %8 + %10 = OpLoad %ulong %6 + %11 = OpLoad %ulong %5 + %15 = OpIAdd %ulong %10 %11 + %9 = OpCopyObject %ulong %15 + OpStore %4 %9 + %12 = OpLoad %ulong %4 + OpReturnValue %12 + OpFunctionEnd + %17 = OpFunction %ulong None %75 + %20 = OpFunctionParameter %ulong + %59 = OpFunctionParameter %_ptr_Workgroup_uint + %60 = OpFunctionParameter %_ptr_Workgroup__arr_uint_uint_4 + %26 = OpLabel + %19 = OpVariable %_ptr_Function_ulong Function + %18 = OpVariable %_ptr_Function_ulong Function + OpStore %19 %20 + %21 = OpLoad %ulong %19 + %24 = OpBitcast %_ptr_Workgroup_ulong %59 + %25 = OpCopyObject %ulong %21 + OpStore %24 %25 Aligned 8 + %22 = OpFunctionCall %ulong %3 %59 %60 + OpStore %18 %22 + %23 = OpLoad %ulong %18 + OpReturnValue %23 OpFunctionEnd - %10 = OpFunction %void None %53 - %17 = OpFunctionParameter %ulong - %18 = OpFunctionParameter %ulong - %34 = OpLabel - %11 = OpVariable %_ptr_Function_ulong Function - %12 = OpVariable %_ptr_Function_ulong Function - %13 = OpVariable %_ptr_Function_ulong Function - %14 = OpVariable %_ptr_Function_ulong Function - %15 = OpVariable %_ptr_Function_ulong Function - %16 = OpVariable %_ptr_Function_ulong Function - OpStore %11 %17 - OpStore %12 %18 - %19 = OpLoad %ulong %11 Aligned 8 - OpStore %13 %19 - %20 = OpLoad %ulong %12 Aligned 8 - OpStore %14 %20 - %22 = OpLoad %ulong %13 - %30 = OpConvertUToPtr %_ptr_Generic_ulong %22 - %21 = OpLoad %ulong %30 Aligned 8 - OpStore %15 %21 - %23 = OpLoad %ulong %15 - %39 = OpBitcast %_ptr_Workgroup_uint %38 - %31 = OpBitcast %_ptr_Workgroup_ulong %39 - OpStore %31 %23 Aligned 8 - %40 = OpBitcast %_ptr_Workgroup_uchar %38 - %32 = OpFunctionCall %ulong %3 %40 - %24 = OpCopyObject %ulong %32 - OpStore %16 %24 - %26 = OpLoad %ulong %16 - %27 = OpLoad %ulong %15 - %25 = OpIAdd %ulong %26 %27 - OpStore %16 %25 - %28 = OpLoad %ulong %14 - %29 = OpLoad %ulong %16 - %33 = OpConvertUToPtr %_ptr_Generic_ulong %28 - OpStore %33 %29 Aligned 8 + %27 = OpFunction %void None %77 + %34 = OpFunctionParameter %ulong + %35 = OpFunctionParameter %ulong + %55 = OpLabel + %28 = OpVariable %_ptr_Function_ulong Function + %29 = OpVariable %_ptr_Function_ulong Function + %30 = OpVariable %_ptr_Function_ulong Function + %31 = OpVariable %_ptr_Function_ulong Function + %32 = OpVariable %_ptr_Function_ulong Function + %33 = OpVariable %_ptr_Function_ulong Function + OpStore %28 %34 + OpStore %29 %35 + %36 = OpLoad %ulong %28 Aligned 8 + OpStore %30 %36 + %37 = OpLoad %ulong %29 Aligned 8 + OpStore %31 %37 + %39 = OpLoad %ulong %30 + %49 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %39 + %38 = OpLoad %ulong %49 Aligned 8 + OpStore %32 %38 + %41 = OpLoad %ulong %30 + %50 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %41 + %81 = OpBitcast %_ptr_CrossWorkgroup_uchar %50 + %82 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %81 %ulong_8 + %48 = OpBitcast %_ptr_CrossWorkgroup_ulong %82 + %40 = OpLoad %ulong %48 Aligned 8 + OpStore %33 %40 + %42 = OpLoad %ulong %33 + %51 = OpBitcast %_ptr_Workgroup_ulong %2 + OpStore %51 %42 Aligned 8 + %44 = OpLoad %ulong %32 + %53 = OpCopyObject %ulong %44 + %52 = OpFunctionCall %ulong %17 %53 %1 %2 + %43 = OpCopyObject %ulong %52 + OpStore %33 %43 + %45 = OpLoad %ulong %31 + %46 = OpLoad %ulong %33 + %54 = OpConvertUToPtr %_ptr_Generic_ulong %45 + OpStore %54 %46 Aligned 8 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/shared_unify_private.ptx b/ptx/src/test/spirv_run/shared_unify_private.ptx deleted file mode 100644 index fd31357..0000000 --- a/ptx/src/test/spirv_run/shared_unify_private.ptx +++ /dev/null @@ -1,32 +0,0 @@ -.version 6.5 -.target sm_30 -.address_size 64 - -.func (.reg .b64 out) load_from_shared() -{ - .shared .b32 shared_mod[4]; - ld.shared.u64 out, [shared_mod]; - ret; -} - -.visible .entry shared_unify_private( - .param .u64 input, - .param .u64 output -) -{ - .shared .b32 shared_ex[2]; - .reg .u64 in_addr; - .reg .u64 out_addr; - .reg .u64 temp1; - .reg .u64 temp2; - - ld.param.u64 in_addr, [input]; - ld.param.u64 out_addr, [output]; - - ld.u64 temp1, [in_addr]; - st.shared.u64 [shared_ex], temp1; - call (temp2), load_from_shared; - add.u64 temp2, temp2, temp1; - st.u64 [out_addr], temp2; - ret; -} diff --git a/ptx/src/test/spirv_run/shared_unify_private.spvtxt b/ptx/src/test/spirv_run/shared_unify_private.spvtxt deleted file mode 100644 index 69bf018..0000000 --- a/ptx/src/test/spirv_run/shared_unify_private.spvtxt +++ /dev/null @@ -1,84 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - OpCapability DenormFlushToZero - %41 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %9 "shared_unify_private" %38 - OpExecutionMode %9 ContractionOff - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %uchar = OpTypeInt 8 0 -%_ptr_Workgroup_uchar = OpTypePointer Workgroup %uchar - %46 = OpTypeFunction %ulong %_ptr_Workgroup_uchar -%_ptr_Function_ulong = OpTypePointer Function %ulong - %uint = OpTypeInt 32 0 - %uint_4 = OpConstant %uint 4 -%_arr_uint_uint_4 = OpTypeArray %uint %uint_4 -%_ptr_Workgroup__arr_uint_uint_4 = OpTypePointer Workgroup %_arr_uint_uint_4 -%_ptr_Workgroup_ulong = OpTypePointer Workgroup %ulong - %uint_16 = OpConstant %uint 16 -%_arr_uchar_uint_16 = OpTypeArray %uchar %uint_16 -%_ptr_Workgroup__arr_uchar_uint_16 = OpTypePointer Workgroup %_arr_uchar_uint_16 - %38 = OpVariable %_ptr_Workgroup__arr_uchar_uint_16 Workgroup - %56 = OpTypeFunction %void %ulong %ulong -%_ptr_Generic_ulong = OpTypePointer Generic %ulong - %uint_2 = OpConstant %uint 2 -%_arr_uint_uint_2 = OpTypeArray %uint %uint_2 -%_ptr_Workgroup__arr_uint_uint_2 = OpTypePointer Workgroup %_arr_uint_uint_2 - %1 = OpFunction %ulong None %46 - %36 = OpFunctionParameter %_ptr_Workgroup_uchar - %8 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %37 = OpBitcast %_ptr_Workgroup__arr_uint_uint_4 %36 - %7 = OpBitcast %_ptr_Workgroup_ulong %37 - %6 = OpLoad %ulong %7 Aligned 8 - %4 = OpCopyObject %ulong %6 - OpStore %2 %4 - %5 = OpLoad %ulong %2 - OpReturnValue %5 - OpFunctionEnd - %9 = OpFunction %void None %56 - %17 = OpFunctionParameter %ulong - %18 = OpFunctionParameter %ulong - %34 = OpLabel - %10 = OpVariable %_ptr_Function_ulong Function - %11 = OpVariable %_ptr_Function_ulong Function - %13 = OpVariable %_ptr_Function_ulong Function - %14 = OpVariable %_ptr_Function_ulong Function - %15 = OpVariable %_ptr_Function_ulong Function - %16 = OpVariable %_ptr_Function_ulong Function - OpStore %10 %17 - OpStore %11 %18 - %19 = OpLoad %ulong %10 Aligned 8 - OpStore %13 %19 - %20 = OpLoad %ulong %11 Aligned 8 - OpStore %14 %20 - %22 = OpLoad %ulong %13 - %30 = OpConvertUToPtr %_ptr_Generic_ulong %22 - %21 = OpLoad %ulong %30 Aligned 8 - OpStore %15 %21 - %23 = OpLoad %ulong %15 - %39 = OpBitcast %_ptr_Workgroup__arr_uint_uint_2 %38 - %31 = OpBitcast %_ptr_Workgroup_ulong %39 - OpStore %31 %23 Aligned 8 - %40 = OpBitcast %_ptr_Workgroup_uchar %38 - %32 = OpFunctionCall %ulong %1 %40 - %24 = OpCopyObject %ulong %32 - OpStore %16 %24 - %26 = OpLoad %ulong %16 - %27 = OpLoad %ulong %15 - %25 = OpIAdd %ulong %26 %27 - OpStore %16 %25 - %28 = OpLoad %ulong %14 - %29 = OpLoad %ulong %16 - %33 = OpConvertUToPtr %_ptr_Generic_ulong %28 - OpStore %33 %29 Aligned 8 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/shared_variable.spvtxt b/ptx/src/test/spirv_run/shared_variable.spvtxt index 7a97c0e..fbbfe4a 100644 --- a/ptx/src/test/spirv_run/shared_variable.spvtxt +++ b/ptx/src/test/spirv_run/shared_variable.spvtxt @@ -8,25 +8,26 @@ OpCapability Float16 OpCapability Float64 OpCapability DenormFlushToZero - %28 = OpExtInstImport "OpenCL.std" + OpExtension "SPV_KHR_float_controls" + OpExtension "SPV_KHR_no_integer_wrap_decoration" + %25 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "shared_variable" %25 + OpEntryPoint Kernel %1 "shared_variable" %4 OpExecutionMode %1 ContractionOff + OpDecorate %4 Alignment 4 %void = OpTypeVoid %uint = OpTypeInt 32 0 %uchar = OpTypeInt 8 0 %uint_128 = OpConstant %uint 128 %_arr_uchar_uint_128 = OpTypeArray %uchar %uint_128 %_ptr_Workgroup__arr_uchar_uint_128 = OpTypePointer Workgroup %_arr_uchar_uint_128 - %25 = OpVariable %_ptr_Workgroup__arr_uchar_uint_128 Workgroup + %4 = OpVariable %_ptr_Workgroup__arr_uchar_uint_128 Workgroup %ulong = OpTypeInt 64 0 - %36 = OpTypeFunction %void %ulong %ulong + %33 = OpTypeFunction %void %ulong %ulong %_ptr_Function_ulong = OpTypePointer Function %ulong %_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong - %uint_128_0 = OpConstant %uint 128 %_ptr_Workgroup_ulong = OpTypePointer Workgroup %ulong - %uint_128_1 = OpConstant %uint 128 - %1 = OpFunction %void None %36 + %1 = OpFunction %void None %33 %9 = OpFunctionParameter %ulong %10 = OpFunctionParameter %ulong %23 = OpLabel @@ -47,11 +48,9 @@ %13 = OpLoad %ulong %19 Aligned 8 OpStore %7 %13 %15 = OpLoad %ulong %7 - %26 = OpBitcast %_ptr_Workgroup__arr_uchar_uint_128 %25 - %20 = OpBitcast %_ptr_Workgroup_ulong %26 + %20 = OpBitcast %_ptr_Workgroup_ulong %4 OpStore %20 %15 Aligned 8 - %27 = OpBitcast %_ptr_Workgroup__arr_uchar_uint_128 %25 - %21 = OpBitcast %_ptr_Workgroup_ulong %27 + %21 = OpBitcast %_ptr_Workgroup_ulong %4 %16 = OpLoad %ulong %21 Aligned 8 OpStore %8 %16 %17 = OpLoad %ulong %6 diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 4265d33..165997e 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -2,7 +2,7 @@ use crate::ast; use half::f16; use rspirv::dr; use std::cell::RefCell; -use std::collections::{hash_map, HashMap, HashSet}; +use std::collections::{hash_map, BTreeMap, HashMap, HashSet}; use std::{borrow::Cow, collections::BTreeSet, ffi::CString, hash::Hash, iter, mem, rc::Rc}; use rspirv::binary::{Assemble, Disassemble}; @@ -443,7 +443,7 @@ pub fn to_spirv_module<'input>(ast: ast::Module<'input>) -> Result>(); let mut builder = dr::Builder::new(); builder.reserve_ids(id_defs.current_id()); - let call_map = get_kernels_call_map(&directives); + let call_map = MethodsCallMap::new(&directives); let mut directives = convert_dynamic_shared_memory_usage(directives, &call_map, &mut || builder.id()); normalize_variable_decls(&mut directives); @@ -559,7 +559,7 @@ fn hoist_function_globals(directives: Vec) -> Vec { // TODO: remove this once we have pef-function support for denorms fn emit_denorm_build_string<'input>( - call_map: &HashMap<&str, HashSet>, + call_map: &MethodsCallMap, denorm_information: &HashMap< ast::MethodName<'input, spirv::Word>, HashMap, @@ -580,7 +580,7 @@ fn emit_denorm_build_string<'input>( }) .collect::>(); let mut flush_over_preserve = 0; - for (kernel, children) in call_map { + for (kernel, children) in call_map.kernels() { flush_over_preserve += *denorm_counts .get(&ast::MethodName::Kernel(kernel)) .unwrap_or(&0); @@ -606,7 +606,7 @@ fn emit_directives<'input>( id_defs: &GlobalStringIdResolver<'input>, opencl_id: spirv::Word, should_flush_denorms: bool, - call_map: &HashMap<&'input str, HashSet>, + call_map: &MethodsCallMap<'input>, globals_use_map: HashMap, HashSet>, directives: Vec>, kernel_info: &mut HashMap, @@ -717,61 +717,89 @@ fn emit_function_linkage<'input>( Ok(()) } -fn get_kernels_call_map<'input>( - module: &[Directive<'input>], -) -> HashMap<&'input str, HashSet> { - let mut directly_called_by = HashMap::new(); - for directive in module { - match directive { - Directive::Method(Function { - func_decl, - body: Some(statements), - .. - }) => { - let call_key: ast::MethodName<_> = (**func_decl).borrow().name; - if let hash_map::Entry::Vacant(entry) = directly_called_by.entry(call_key) { - entry.insert(Vec::new()); - } - for statement in statements { - match statement { - Statement::Call(call) => { - multi_hash_map_append(&mut directly_called_by, call_key, call.name); +struct MethodsCallMap<'input> { + map: HashMap, HashSet>, +} + +impl<'input> MethodsCallMap<'input> { + fn new(module: &[Directive<'input>]) -> Self { + let mut directly_called_by = HashMap::new(); + for directive in module { + match directive { + Directive::Method(Function { + func_decl, + body: Some(statements), + .. + }) => { + let call_key: ast::MethodName<_> = (**func_decl).borrow().name; + if let hash_map::Entry::Vacant(entry) = directly_called_by.entry(call_key) { + entry.insert(Vec::new()); + } + for statement in statements { + match statement { + Statement::Call(call) => { + multi_hash_map_append(&mut directly_called_by, call_key, call.name); + } + _ => {} } - _ => {} } } + _ => {} } - _ => {} } + let mut result = HashMap::new(); + for (&method_key, children) in directly_called_by.iter() { + let mut visited = HashSet::new(); + for child in children { + Self::add_call_map_single(&directly_called_by, &mut visited, *child); + } + result.insert(method_key, visited); + } + MethodsCallMap { map: result } } - let mut result = HashMap::new(); - for (method_key, children) in directly_called_by.iter() { - match method_key { - ast::MethodName::Kernel(name) => { - let mut visited = HashSet::new(); - for child in children { - add_call_map_single(&directly_called_by, &mut visited, *child); - } - result.insert(*name, visited); + + fn add_call_map_single( + directly_called_by: &HashMap, Vec>, + visited: &mut HashSet, + current: spirv::Word, + ) { + if !visited.insert(current) { + return; + } + if let Some(children) = directly_called_by.get(&ast::MethodName::Func(current)) { + for child in children { + Self::add_call_map_single(directly_called_by, visited, *child); } - ast::MethodName::Func(_) => {} } } - result -} -fn add_call_map_single<'input>( - directly_called_by: &HashMap, Vec>, - visited: &mut HashSet, - current: spirv::Word, -) { - if !visited.insert(current) { - return; + fn get_kernel_children(&self, name: &'input str) -> impl Iterator { + self.map + .get(&ast::MethodName::Kernel(name)) + .into_iter() + .flatten() } - if let Some(children) = directly_called_by.get(&ast::MethodName::Func(current)) { - for child in children { - add_call_map_single(directly_called_by, visited, *child); - } + + fn kernels(&self) -> impl Iterator)> { + self.map + .iter() + .filter_map(|(method, children)| match method { + ast::MethodName::Kernel(kernel) => Some((*kernel, children)), + ast::MethodName::Func(..) => None, + }) + } + + fn visit_callees( + &self, + method: ast::MethodName<'input, spirv::Word>, + f: impl FnMut(spirv::Word), + ) { + self.map + .get(&method) + .into_iter() + .flatten() + .copied() + .for_each(f); } } @@ -820,14 +848,14 @@ fn multi_hash_map_append< */ fn convert_dynamic_shared_memory_usage<'input>( module: Vec>, - kernels_methods_call_map: &HashMap<&'input str, HashSet>, + kernels_methods_call_map: &MethodsCallMap<'input>, new_id: &mut impl FnMut() -> spirv::Word, ) -> Vec> { let mut globals_shared = HashMap::new(); for dir in module.iter() { match dir { Directive::Variable( - linking, + _, ast::Variable { state_space: ast::StateSpace::Shared, name, @@ -835,12 +863,7 @@ fn convert_dynamic_shared_memory_usage<'input>( .. }, ) => { - let size = if linking.contains(ast::LinkingDirective::EXTERN) { - GlobalSharedSize::ExternUnsized - } else { - GlobalSharedSize::Sized((*v_type).size_of()) - }; - globals_shared.insert(*name, (size, v_type.clone())); + globals_shared.insert(*name, v_type.clone()); } _ => {} } @@ -848,7 +871,7 @@ fn convert_dynamic_shared_memory_usage<'input>( if globals_shared.len() == 0 { return module; } - let mut methods_to_globals_shared_direct_only_use = HashMap::<_, GlobalSharedSize>::new(); + let mut methods_to_directly_used_shared_globals = HashMap::<_, HashSet>::new(); let module = module .into_iter() .map(|directive| match directive { @@ -865,16 +888,11 @@ fn convert_dynamic_shared_memory_usage<'input>( .into_iter() .map(|statement| { statement.map_id(&mut |id, _| { - if let Some((size, _)) = globals_shared.get(&id) { - match methods_to_globals_shared_direct_only_use.entry(call_key) { - hash_map::Entry::Occupied(mut e) => { - let original_size = *e.get(); - e.insert(original_size.fold(*size)); - } - hash_map::Entry::Vacant(mut e) => { - e.insert(*size); - } - } + if let Some(type_) = globals_shared.get(&id) { + methods_to_directly_used_shared_globals + .entry(call_key) + .or_insert_with(HashSet::new) + .insert(id); } id }) @@ -894,13 +912,12 @@ fn convert_dynamic_shared_memory_usage<'input>( .collect::>(); // If there's a chain `kernel` -> `fn1` -> `fn2`, where only `fn2` uses extern shared, // make sure it gets propagated to `fn1` and `kernel` - let (kernels_to_global_shared, functions_to_global_shared) = - resolve_indirect_uses_of_globals_shared( - methods_to_globals_shared_direct_only_use, - kernels_methods_call_map, - ); + let methods_to_indirectly_used_shared_globals = resolve_indirect_uses_of_globals_shared( + methods_to_directly_used_shared_globals, + kernels_methods_call_map, + ); // now visit every method declaration and inject those additional arguments - let mut result = Vec::with_capacity(module.len()); + let mut directives = Vec::with_capacity(module.len()); for directive in module.into_iter() { match directive { Directive::Method(Function { @@ -915,17 +932,17 @@ fn convert_dynamic_shared_memory_usage<'input>( let func_decl_ref = &mut (*func_decl).borrow_mut(); let method_name = func_decl_ref.name; insert_arguments_remap_statements( - method_name, - &kernels_to_global_shared, new_id, - &mut result, - &functions_to_global_shared, - func_decl_ref, + kernels_methods_call_map, &globals_shared, + &methods_to_indirectly_used_shared_globals, + method_name, + &mut directives, + func_decl_ref, statements, ) }; - result.push(Directive::Method(Function { + directives.push(Directive::Method(Function { func_decl, globals, body: Some(statements), @@ -934,77 +951,79 @@ fn convert_dynamic_shared_memory_usage<'input>( linkage, })); } - // Existing .shared globals are now unused, they were replaced by kernel-specific globals - Directive::Variable(_, ast::Variable { name, .. }) - if globals_shared.contains_key(&name) => {} - directive => result.push(directive), + directive => directives.push(directive), } } - result + directives } -fn insert_arguments_remap_statements( - method_name: ast::MethodName, - kernels_to_global_shared: &HashMap<&str, GlobalSharedSize>, +fn insert_arguments_remap_statements<'input>( new_id: &mut impl FnMut() -> u32, + kernels_methods_call_map: &MethodsCallMap<'input>, + globals_shared: &HashMap, + methods_to_indirectly_used_shared_globals: &HashMap< + ast::MethodName<'input, spirv::Word>, + BTreeSet, + >, + method_name: ast::MethodName, result: &mut Vec, - functions_to_global_shared: &HashSet, func_decl_ref: &mut std::cell::RefMut>, - globals_shared: &HashMap, statements: Vec, ExpandedArgParams>>, ) -> Vec, ExpandedArgParams>> { - let (shared_id_param, shared_id_type) = match method_name { - ast::MethodName::Kernel(kernel_name) => { - let globals_shared_size = match kernels_to_global_shared.get(kernel_name) { - Some(s) => *s, - None => return statements, - }; - let shared_id_param = new_id(); - func_decl_ref.shared_mem = Some(shared_id_param); - let (linkage, type_) = match globals_shared_size { - GlobalSharedSize::ExternUnsized => ( - ast::LinkingDirective::EXTERN, - ast::Type::Array(ast::ScalarType::B8, Vec::new()), - ), - GlobalSharedSize::Sized(size) => ( - ast::LinkingDirective::NONE, - ast::Type::Array(ast::ScalarType::B8, vec![size as u32]), - ), - }; - result.push(Directive::Variable( - linkage, - ast::Variable { - align: None, - v_type: type_.clone(), - state_space: ast::StateSpace::Shared, - name: shared_id_param, - array_init: Vec::new(), - }, - )); - (shared_id_param, Some(type_)) - } - ast::MethodName::Func(function_name) => { - if !functions_to_global_shared.contains(&function_name) { - return statements; + let remapped_globals_in_method = + if let Some(method_globals) = methods_to_indirectly_used_shared_globals.get(&method_name) { + match method_name { + ast::MethodName::Func(..) => { + let remapped_globals = method_globals + .iter() + .map(|global| { + ( + *global, + ( + new_id(), + globals_shared + .get(&global) + .unwrap_or_else(|| todo!()) + .clone(), + ), + ) + }) + .collect::>(); + for (_, (new_shared_global_id, shared_global_type)) in remapped_globals.iter() { + func_decl_ref.input_arguments.push(ast::Variable { + align: None, + v_type: shared_global_type.clone(), + state_space: ast::StateSpace::Shared, + name: *new_shared_global_id, + array_init: Vec::new(), + }); + } + remapped_globals + } + ast::MethodName::Kernel(..) => method_globals + .iter() + .map(|global| { + ( + *global, + ( + *global, + globals_shared + .get(&global) + .unwrap_or_else(|| todo!()) + .clone(), + ), + ) + }) + .collect::>(), } - let shared_id_param = new_id(); - func_decl_ref.input_arguments.push(ast::Variable { - align: None, - v_type: ast::Type::Pointer(ast::ScalarType::B8, ast::StateSpace::Shared), - state_space: ast::StateSpace::Reg, - name: shared_id_param, - array_init: Vec::new(), - }); - (shared_id_param, None) - } - }; + } else { + return statements; + }; replace_uses_of_shared_memory( new_id, - globals_shared, - functions_to_global_shared, - shared_id_param, - shared_id_type, + methods_to_indirectly_used_shared_globals, statements, + remapped_globals_in_method, ) } @@ -1025,13 +1044,14 @@ impl GlobalSharedSize { } } -fn replace_uses_of_shared_memory<'a>( +fn replace_uses_of_shared_memory<'input>( new_id: &mut impl FnMut() -> spirv::Word, - extern_shared_decls: &HashMap, - methods_using_extern_shared: &HashSet, - shared_id_param: spirv::Word, - shared_id_type: Option, + methods_to_indirectly_used_shared_globals: &HashMap< + ast::MethodName<'input, spirv::Word>, + BTreeSet, + >, statements: Vec, + remapped_globals_in_method: BTreeMap, ) -> Vec { let mut result = Vec::with_capacity(statements.len()); for statement in statements { @@ -1040,48 +1060,26 @@ fn replace_uses_of_shared_memory<'a>( // We can safely skip checking call arguments, // because there's simply no way to pass shared ptr // without converting it to .b64 first - if methods_using_extern_shared.contains(&call.name) { - let shared_id_param = match shared_id_type { - Some(ref global_shared_defined_type) => { - let dst = new_id(); - result.push(Statement::Conversion(ImplicitConversion { - src: shared_id_param, - dst, - from_type: global_shared_defined_type.clone(), - to_type: ast::Type::Scalar(ast::ScalarType::B8), - from_space: ast::StateSpace::Shared, - to_space: ast::StateSpace::Shared, - kind: ConversionKind::PtrToPtr, - })); - dst - } - None => shared_id_param, - }; - call.input_arguments.push(( - shared_id_param, - ast::Type::Scalar(ast::ScalarType::B8), - ast::StateSpace::Shared, - )); + if let Some(shared_globals_used_by_callee) = + methods_to_indirectly_used_shared_globals.get(&ast::MethodName::Func(call.name)) + { + for &shared_global_used_by_callee in shared_globals_used_by_callee { + let (remapped_shared_id, type_) = remapped_globals_in_method + .get(&shared_global_used_by_callee) + .unwrap_or_else(|| todo!()); + call.input_arguments.push(( + *remapped_shared_id, + type_.clone(), + ast::StateSpace::Shared, + )); + } } result.push(Statement::Call(call)) } statement => { let new_statement = statement.map_id(&mut |id, _| { - if let Some((_, type_)) = extern_shared_decls.get(&id) { - if *type_ == ast::Type::Scalar(ast::ScalarType::B8) { - return shared_id_param; - } - let replacement_id = new_id(); - result.push(Statement::Conversion(ImplicitConversion { - src: shared_id_param, - dst: replacement_id, - from_type: ast::Type::Scalar(ast::ScalarType::B8), - from_space: ast::StateSpace::Shared, - to_type: type_.clone(), - to_space: ast::StateSpace::Shared, - kind: ConversionKind::PtrToPtr, - })); - replacement_id + if let Some((remapped_shared_id, _)) = remapped_globals_in_method.get(&id) { + *remapped_shared_id } else { id } @@ -1097,33 +1095,27 @@ fn replace_uses_of_shared_memory<'a>( // * If it's a kernel -> size of .shared globals in use (direct or indirect) // * If it's a function -> does it use .shared global (directly or indirectly) fn resolve_indirect_uses_of_globals_shared<'input>( - methods_use_of_globals_shared: HashMap, GlobalSharedSize>, - kernels_methods_call_map: &HashMap<&'input str, HashSet>, -) -> (HashMap<&'input str, GlobalSharedSize>, HashSet) { - let mut kernel_use = HashMap::new(); - let mut functions_using_global = HashSet::new(); - let empty = HashSet::new(); - for (method, globals) in methods_use_of_globals_shared.iter() { - match method { - ast::MethodName::Kernel(kernel_name) => { - let mut size = *globals; - for &called_subfunction in - kernels_methods_call_map.get(kernel_name).unwrap_or(&empty) - { - if let Some(new_size) = methods_use_of_globals_shared - .get(&ast::MethodName::Func(called_subfunction)) - { - size = size.fold(*new_size); - } - } - kernel_use.insert(*kernel_name, size); - } - ast::MethodName::Func(fn_id) => { - functions_using_global.insert(*fn_id); - } - } + methods_use_of_globals_shared: HashMap< + ast::MethodName<'input, spirv::Word>, + HashSet, + >, + kernels_methods_call_map: &MethodsCallMap<'input>, +) -> HashMap, BTreeSet> { + let mut result = HashMap::new(); + for (method, direct_globals) in methods_use_of_globals_shared.iter() { + let mut indirect_globals = direct_globals.iter().copied().collect::>(); + kernels_methods_call_map.visit_callees(*method, |func| { + indirect_globals.extend( + methods_use_of_globals_shared + .get(&ast::MethodName::Func(func)) + .into_iter() + .flatten() + .copied(), + ); + }); + result.insert(*method, indirect_globals); } - (kernel_use, functions_using_global) + result } type DenormCountMap = HashMap; @@ -1217,7 +1209,7 @@ fn emit_function_header<'input>( map: &mut TypeWordMap, defined_globals: &GlobalStringIdResolver<'input>, func_decl: &ast::MethodDeclaration<'input, spirv::Word>, - call_map: &HashMap<&'input str, HashSet>, + call_map: &MethodsCallMap<'input>, globals_use_map: &HashMap, HashSet>, kernel_info: &mut HashMap, ) -> Result { @@ -1256,16 +1248,14 @@ fn emit_function_header<'input>( .copied() .chain({ call_map - .get(name) - .into_iter() - .flat_map(|subfunctions| { - subfunctions.iter().flat_map(|subfunction| { - globals_use_map - .get(&ast::MethodName::Func(*subfunction)) - .into_iter() - .flatten() - .copied() - }) + .get_kernel_children(name) + .copied() + .flat_map(|subfunction| { + globals_use_map + .get(&ast::MethodName::Func(subfunction)) + .into_iter() + .flatten() + .copied() }) .into_iter() }) -- cgit v1.2.3