From d5a4b068dd9bf72a0e1b6448583ebad609ed72c1 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Fri, 17 Sep 2021 20:53:44 +0000 Subject: Redo handling of sregs --- ptx/lib/zluda_ptx_impl.bc | Bin 31940 -> 33884 bytes ptx/lib/zluda_ptx_impl.cl | 16 + ptx/src/test/spirv_run/lanemask_lt.spvtxt | 69 ++-- ptx/src/test/spirv_run/ntid.spvtxt | 67 ++-- ptx/src/test/spirv_run/stateful_ld_st_ntid.spvtxt | 134 ++++---- .../spirv_run/stateful_ld_st_ntid_chain.spvtxt | 138 ++++---- .../test/spirv_run/stateful_ld_st_ntid_sub.spvtxt | 154 +++++---- ptx/src/translate.rs | 379 +++++++++++---------- 8 files changed, 511 insertions(+), 446 deletions(-) diff --git a/ptx/lib/zluda_ptx_impl.bc b/ptx/lib/zluda_ptx_impl.bc index 7aa12c8..e2f956d 100644 Binary files a/ptx/lib/zluda_ptx_impl.bc and b/ptx/lib/zluda_ptx_impl.bc differ diff --git a/ptx/lib/zluda_ptx_impl.cl b/ptx/lib/zluda_ptx_impl.cl index d439795..0870fb5 100644 --- a/ptx/lib/zluda_ptx_impl.cl +++ b/ptx/lib/zluda_ptx_impl.cl @@ -297,6 +297,22 @@ atomic_add(atom_acq_rel_sys_shared_add_f64, memory_order_acq_rel, memory_order_a return (uint)__builtin_amdgcn_uicmp(1, 0, 33); } + uint FUNC(sreg_tid)(uchar dim) { + return (uint)get_local_id(dim); + } + + uint FUNC(sreg_ntid)(uchar dim) { + return (uint)get_local_size(dim); + } + + uint FUNC(sreg_ctaid)(uchar dim) { + return (uint)get_group_id(dim); + } + + uint FUNC(sreg_nctaid)(uchar dim) { + return (uint)get_num_groups(dim); + } + uint FUNC(sreg_clock)() { return (uint)__builtin_amdgcn_s_memtime(); } diff --git a/ptx/src/test/spirv_run/lanemask_lt.spvtxt b/ptx/src/test/spirv_run/lanemask_lt.spvtxt index 0753c95..3de53ce 100644 --- a/ptx/src/test/spirv_run/lanemask_lt.spvtxt +++ b/ptx/src/test/spirv_run/lanemask_lt.spvtxt @@ -7,39 +7,64 @@ OpCapability Int64 OpCapability Float16 OpCapability Float64 - %18 = OpExtInstImport "OpenCL.std" + %40 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "activemask" + OpEntryPoint Kernel %1 "lanemask_lt" OpExecutionMode %1 ContractionOff - OpDecorate %15 LinkageAttributes "__zluda_ptx_impl__activemask" Import + OpDecorate %11 LinkageAttributes "__zluda_ptx_impl__sreg_lanemask_lt" Import %void = OpTypeVoid %uint = OpTypeInt 32 0 - %21 = OpTypeFunction %uint + %43 = OpTypeFunction %uint %ulong = OpTypeInt 64 0 - %23 = OpTypeFunction %void %ulong %ulong + %45 = OpTypeFunction %void %ulong %ulong %_ptr_Function_ulong = OpTypePointer Function %ulong %_ptr_Function_uint = OpTypePointer Function %uint %_ptr_Generic_uint = OpTypePointer Generic %uint - %15 = OpFunction %uint None %21 + %uint_1 = OpConstant %uint 1 + %11 = OpFunction %uint None %43 OpFunctionEnd - %1 = OpFunction %void None %23 - %6 = OpFunctionParameter %ulong - %7 = OpFunctionParameter %ulong - %14 = OpLabel + %1 = OpFunction %void None %45 + %13 = OpFunctionParameter %ulong + %14 = OpFunctionParameter %ulong + %38 = OpLabel %2 = OpVariable %_ptr_Function_ulong Function %3 = OpVariable %_ptr_Function_ulong Function %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_uint Function - OpStore %2 %6 - OpStore %3 %7 - %8 = OpLoad %ulong %3 Aligned 8 - OpStore %4 %8 - %9 = OpFunctionCall %uint %15 - OpStore %5 %9 - %10 = OpLoad %ulong %4 - %11 = OpLoad %uint %5 - %12 = OpConvertUToPtr %_ptr_Generic_uint %10 - %13 = OpCopyObject %uint %11 - OpStore %12 %13 Aligned 4 + %5 = OpVariable %_ptr_Function_ulong Function + %6 = OpVariable %_ptr_Function_uint Function + %7 = OpVariable %_ptr_Function_uint Function + %8 = OpVariable %_ptr_Function_uint Function + OpStore %2 %13 + OpStore %3 %14 + %15 = OpLoad %ulong %2 Aligned 8 + OpStore %4 %15 + %16 = OpLoad %ulong %3 Aligned 8 + OpStore %5 %16 + %18 = OpLoad %ulong %4 + %29 = OpConvertUToPtr %_ptr_Generic_uint %18 + %28 = OpLoad %uint %29 Aligned 4 + %17 = OpCopyObject %uint %28 + OpStore %6 %17 + %20 = OpLoad %uint %6 + %31 = OpCopyObject %uint %20 + %30 = OpIAdd %uint %31 %uint_1 + %19 = OpCopyObject %uint %30 + OpStore %7 %19 + %10 = OpFunctionCall %uint %11 + %32 = OpCopyObject %uint %10 + %21 = OpCopyObject %uint %32 + OpStore %8 %21 + %23 = OpLoad %uint %7 + %24 = OpLoad %uint %8 + %34 = OpCopyObject %uint %23 + %35 = OpCopyObject %uint %24 + %33 = OpIAdd %uint %34 %35 + %22 = OpCopyObject %uint %33 + OpStore %7 %22 + %25 = OpLoad %ulong %5 + %26 = OpLoad %uint %7 + %36 = OpConvertUToPtr %_ptr_Generic_uint %25 + %37 = OpCopyObject %uint %26 + OpStore %36 %37 Aligned 4 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/ntid.spvtxt b/ptx/src/test/spirv_run/ntid.spvtxt index e5f343c..6754ce4 100644 --- a/ptx/src/test/spirv_run/ntid.spvtxt +++ b/ptx/src/test/spirv_run/ntid.spvtxt @@ -7,55 +7,54 @@ OpCapability Int64 OpCapability Float16 OpCapability Float64 - %31 = OpExtInstImport "OpenCL.std" + %30 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL OpEntryPoint Kernel %1 "ntid" OpExecutionMode %1 ContractionOff - OpDecorate %24 LinkageAttributes "get_local_size" Import + OpDecorate %11 LinkageAttributes "__zluda_ptx_impl__sreg_ntid" Import %void = OpTypeVoid - %ulong = OpTypeInt 64 0 %uint = OpTypeInt 32 0 - %35 = OpTypeFunction %ulong %uint + %uchar = OpTypeInt 8 0 + %34 = OpTypeFunction %uint %uchar + %ulong = OpTypeInt 64 0 %36 = OpTypeFunction %void %ulong %ulong %_ptr_Function_ulong = OpTypePointer Function %ulong %_ptr_Function_uint = OpTypePointer Function %uint %_ptr_Generic_uint = OpTypePointer Generic %uint - %uint_0 = OpConstant %uint 0 - %24 = OpFunction %ulong None %35 - %26 = OpFunctionParameter %uint + %uchar_0 = OpConstant %uchar 0 + %11 = OpFunction %uint None %34 + %13 = OpFunctionParameter %uchar OpFunctionEnd %1 = OpFunction %void None %36 - %9 = OpFunctionParameter %ulong - %10 = OpFunctionParameter %ulong - %29 = OpLabel + %14 = OpFunctionParameter %ulong + %15 = OpFunctionParameter %ulong + %28 = OpLabel %2 = OpVariable %_ptr_Function_ulong Function %3 = OpVariable %_ptr_Function_ulong Function %4 = OpVariable %_ptr_Function_ulong Function %5 = OpVariable %_ptr_Function_ulong Function %6 = OpVariable %_ptr_Function_uint Function %7 = OpVariable %_ptr_Function_uint Function - OpStore %2 %9 - OpStore %3 %10 - %11 = OpLoad %ulong %2 Aligned 8 - OpStore %4 %11 - %12 = OpLoad %ulong %3 Aligned 8 - OpStore %5 %12 - %14 = OpLoad %ulong %4 - %27 = OpConvertUToPtr %_ptr_Generic_uint %14 - %13 = OpLoad %uint %27 Aligned 4 - OpStore %6 %13 - %23 = OpFunctionCall %ulong %24 %uint_0 - %40 = OpBitcast %ulong %23 - %16 = OpUConvert %uint %40 - %15 = OpCopyObject %uint %16 - OpStore %7 %15 - %18 = OpLoad %uint %6 - %19 = OpLoad %uint %7 - %17 = OpIAdd %uint %18 %19 - OpStore %6 %17 - %20 = OpLoad %ulong %5 - %21 = OpLoad %uint %6 - %28 = OpConvertUToPtr %_ptr_Generic_uint %20 - OpStore %28 %21 Aligned 4 + OpStore %2 %14 + OpStore %3 %15 + %16 = OpLoad %ulong %2 Aligned 8 + OpStore %4 %16 + %17 = OpLoad %ulong %3 Aligned 8 + OpStore %5 %17 + %19 = OpLoad %ulong %4 + %26 = OpConvertUToPtr %_ptr_Generic_uint %19 + %18 = OpLoad %uint %26 Aligned 4 + OpStore %6 %18 + %10 = OpFunctionCall %uint %11 %uchar_0 + %20 = OpCopyObject %uint %10 + OpStore %7 %20 + %22 = OpLoad %uint %6 + %23 = OpLoad %uint %7 + %21 = OpIAdd %uint %22 %23 + OpStore %6 %21 + %24 = OpLoad %ulong %5 + %25 = OpLoad %uint %6 + %27 = OpConvertUToPtr %_ptr_Generic_uint %24 + OpStore %27 %25 Aligned 4 OpReturn - OpFunctionEnd \ No newline at end of file + OpFunctionEnd diff --git a/ptx/src/test/spirv_run/stateful_ld_st_ntid.spvtxt b/ptx/src/test/spirv_run/stateful_ld_st_ntid.spvtxt index b99fb50..e2d4db6 100644 --- a/ptx/src/test/spirv_run/stateful_ld_st_ntid.spvtxt +++ b/ptx/src/test/spirv_run/stateful_ld_st_ntid.spvtxt @@ -7,89 +7,87 @@ OpCapability Int64 OpCapability Float16 OpCapability Float64 - %57 = OpExtInstImport "OpenCL.std" + %56 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL OpEntryPoint Kernel %1 "stateful_ld_st_ntid" OpExecutionMode %1 ContractionOff - OpDecorate %44 LinkageAttributes "_Z12get_local_idj" Import + OpDecorate %12 LinkageAttributes "__zluda_ptx_impl__sreg_tid" Import %void = OpTypeVoid - %ulong = OpTypeInt 64 0 %uint = OpTypeInt 32 0 - %61 = OpTypeFunction %ulong %uint %uchar = OpTypeInt 8 0 + %60 = OpTypeFunction %uint %uchar %_ptr_CrossWorkgroup_uchar = OpTypePointer CrossWorkgroup %uchar - %64 = OpTypeFunction %void %_ptr_CrossWorkgroup_uchar %_ptr_CrossWorkgroup_uchar + %62 = OpTypeFunction %void %_ptr_CrossWorkgroup_uchar %_ptr_CrossWorkgroup_uchar %_ptr_Function__ptr_CrossWorkgroup_uchar = OpTypePointer Function %_ptr_CrossWorkgroup_uchar %_ptr_Function_uint = OpTypePointer Function %uint + %ulong = OpTypeInt 64 0 %_ptr_Function_ulong = OpTypePointer Function %ulong - %uint_0 = OpConstant %uint 0 + %uchar_0 = OpConstant %uchar 0 %_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong - %44 = OpFunction %ulong None %61 - %46 = OpFunctionParameter %uint + %12 = OpFunction %uint None %60 + %14 = OpFunctionParameter %uchar OpFunctionEnd - %1 = OpFunction %void None %64 - %20 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar - %21 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar - %55 = OpLabel - %12 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function - %13 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function - %10 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function - %11 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %1 = OpFunction %void None %62 + %25 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar + %26 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar + %54 = OpLabel + %17 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %18 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %15 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %16 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function %6 = OpVariable %_ptr_Function_uint Function %7 = OpVariable %_ptr_Function_ulong Function %8 = OpVariable %_ptr_Function_ulong Function - OpStore %12 %20 - OpStore %13 %21 - %48 = OpBitcast %_ptr_Function_ulong %12 - %47 = OpLoad %ulong %48 Aligned 8 - %14 = OpCopyObject %ulong %47 - %22 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %14 - OpStore %10 %22 - %50 = OpBitcast %_ptr_Function_ulong %13 - %49 = OpLoad %ulong %50 Aligned 8 - %15 = OpCopyObject %ulong %49 - %23 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %15 - OpStore %11 %23 - %24 = OpLoad %_ptr_CrossWorkgroup_uchar %10 - %17 = OpConvertPtrToU %ulong %24 - %16 = OpCopyObject %ulong %17 - %25 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %16 - OpStore %10 %25 - %26 = OpLoad %_ptr_CrossWorkgroup_uchar %11 - %19 = OpConvertPtrToU %ulong %26 - %18 = OpCopyObject %ulong %19 - %27 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %18 - OpStore %11 %27 - %43 = OpFunctionCall %ulong %44 %uint_0 - %68 = OpBitcast %ulong %43 - %29 = OpUConvert %uint %68 - %28 = OpCopyObject %uint %29 - OpStore %6 %28 - %31 = OpLoad %uint %6 - %69 = OpBitcast %uint %31 - %30 = OpUConvert %ulong %69 - OpStore %7 %30 - %33 = OpLoad %_ptr_CrossWorkgroup_uchar %10 - %34 = OpLoad %ulong %7 - %51 = OpCopyObject %ulong %34 - %70 = OpBitcast %_ptr_CrossWorkgroup_uchar %33 + OpStore %17 %25 + OpStore %18 %26 + %47 = OpBitcast %_ptr_Function_ulong %17 + %46 = OpLoad %ulong %47 Aligned 8 + %19 = OpCopyObject %ulong %46 + %27 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %19 + OpStore %15 %27 + %49 = OpBitcast %_ptr_Function_ulong %18 + %48 = OpLoad %ulong %49 Aligned 8 + %20 = OpCopyObject %ulong %48 + %28 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %20 + OpStore %16 %28 + %29 = OpLoad %_ptr_CrossWorkgroup_uchar %15 + %22 = OpConvertPtrToU %ulong %29 + %21 = OpCopyObject %ulong %22 + %30 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %21 + OpStore %15 %30 + %31 = OpLoad %_ptr_CrossWorkgroup_uchar %16 + %24 = OpConvertPtrToU %ulong %31 + %23 = OpCopyObject %ulong %24 + %32 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %23 + OpStore %16 %32 + %11 = OpFunctionCall %uint %12 %uchar_0 + %33 = OpCopyObject %uint %11 + OpStore %6 %33 + %35 = OpLoad %uint %6 + %67 = OpBitcast %uint %35 + %34 = OpUConvert %ulong %67 + OpStore %7 %34 + %37 = OpLoad %_ptr_CrossWorkgroup_uchar %15 + %38 = OpLoad %ulong %7 + %50 = OpCopyObject %ulong %38 + %68 = OpBitcast %_ptr_CrossWorkgroup_uchar %37 + %69 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %68 %50 + %36 = OpBitcast %_ptr_CrossWorkgroup_uchar %69 + OpStore %15 %36 + %40 = OpLoad %_ptr_CrossWorkgroup_uchar %16 + %41 = OpLoad %ulong %7 + %51 = OpCopyObject %ulong %41 + %70 = OpBitcast %_ptr_CrossWorkgroup_uchar %40 %71 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %70 %51 - %32 = OpBitcast %_ptr_CrossWorkgroup_uchar %71 - OpStore %10 %32 - %36 = OpLoad %_ptr_CrossWorkgroup_uchar %11 - %37 = OpLoad %ulong %7 - %52 = OpCopyObject %ulong %37 - %72 = OpBitcast %_ptr_CrossWorkgroup_uchar %36 - %73 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %72 %52 - %35 = OpBitcast %_ptr_CrossWorkgroup_uchar %73 - OpStore %11 %35 - %39 = OpLoad %_ptr_CrossWorkgroup_uchar %10 - %53 = OpBitcast %_ptr_CrossWorkgroup_ulong %39 - %38 = OpLoad %ulong %53 Aligned 8 - OpStore %8 %38 - %40 = OpLoad %_ptr_CrossWorkgroup_uchar %11 - %41 = OpLoad %ulong %8 - %54 = OpBitcast %_ptr_CrossWorkgroup_ulong %40 - OpStore %54 %41 Aligned 8 + %39 = OpBitcast %_ptr_CrossWorkgroup_uchar %71 + OpStore %16 %39 + %43 = OpLoad %_ptr_CrossWorkgroup_uchar %15 + %52 = OpBitcast %_ptr_CrossWorkgroup_ulong %43 + %42 = OpLoad %ulong %52 Aligned 8 + OpStore %8 %42 + %44 = OpLoad %_ptr_CrossWorkgroup_uchar %16 + %45 = OpLoad %ulong %8 + %53 = OpBitcast %_ptr_CrossWorkgroup_ulong %44 + OpStore %53 %45 Aligned 8 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/stateful_ld_st_ntid_chain.spvtxt b/ptx/src/test/spirv_run/stateful_ld_st_ntid_chain.spvtxt index 0239632..5da0ef3 100644 --- a/ptx/src/test/spirv_run/stateful_ld_st_ntid_chain.spvtxt +++ b/ptx/src/test/spirv_run/stateful_ld_st_ntid_chain.spvtxt @@ -7,93 +7,91 @@ OpCapability Int64 OpCapability Float16 OpCapability Float64 - %65 = OpExtInstImport "OpenCL.std" + %64 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL OpEntryPoint Kernel %1 "stateful_ld_st_ntid_chain" OpExecutionMode %1 ContractionOff - OpDecorate %52 LinkageAttributes "_Z12get_local_idj" Import + OpDecorate %16 LinkageAttributes "__zluda_ptx_impl__sreg_tid" Import %void = OpTypeVoid - %ulong = OpTypeInt 64 0 %uint = OpTypeInt 32 0 - %69 = OpTypeFunction %ulong %uint %uchar = OpTypeInt 8 0 + %68 = OpTypeFunction %uint %uchar %_ptr_CrossWorkgroup_uchar = OpTypePointer CrossWorkgroup %uchar - %72 = OpTypeFunction %void %_ptr_CrossWorkgroup_uchar %_ptr_CrossWorkgroup_uchar + %70 = OpTypeFunction %void %_ptr_CrossWorkgroup_uchar %_ptr_CrossWorkgroup_uchar %_ptr_Function__ptr_CrossWorkgroup_uchar = OpTypePointer Function %_ptr_CrossWorkgroup_uchar %_ptr_Function_uint = OpTypePointer Function %uint + %ulong = OpTypeInt 64 0 %_ptr_Function_ulong = OpTypePointer Function %ulong - %uint_0 = OpConstant %uint 0 + %uchar_0 = OpConstant %uchar 0 %_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong - %52 = OpFunction %ulong None %69 - %54 = OpFunctionParameter %uint + %16 = OpFunction %uint None %68 + %18 = OpFunctionParameter %uchar OpFunctionEnd - %1 = OpFunction %void None %72 - %28 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar - %29 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar - %63 = OpLabel + %1 = OpFunction %void None %70 + %33 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar + %34 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar + %62 = OpLabel + %25 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %26 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %19 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function %20 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function %21 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function - %14 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function - %15 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function - %16 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function - %17 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function - %18 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function - %19 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %22 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %23 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %24 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function %10 = OpVariable %_ptr_Function_uint Function %11 = OpVariable %_ptr_Function_ulong Function %12 = OpVariable %_ptr_Function_ulong Function - OpStore %20 %28 - OpStore %21 %29 - %56 = OpBitcast %_ptr_Function_ulong %20 - %55 = OpLoad %ulong %56 Aligned 8 - %22 = OpCopyObject %ulong %55 - %30 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %22 - OpStore %14 %30 - %58 = OpBitcast %_ptr_Function_ulong %21 - %57 = OpLoad %ulong %58 Aligned 8 - %23 = OpCopyObject %ulong %57 - %31 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %23 - OpStore %17 %31 - %32 = OpLoad %_ptr_CrossWorkgroup_uchar %14 - %25 = OpConvertPtrToU %ulong %32 - %24 = OpCopyObject %ulong %25 - %33 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %24 - OpStore %15 %33 - %34 = OpLoad %_ptr_CrossWorkgroup_uchar %17 - %27 = OpConvertPtrToU %ulong %34 - %26 = OpCopyObject %ulong %27 - %35 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %26 - OpStore %18 %35 - %51 = OpFunctionCall %ulong %52 %uint_0 - %76 = OpBitcast %ulong %51 - %37 = OpUConvert %uint %76 - %36 = OpCopyObject %uint %37 - OpStore %10 %36 - %39 = OpLoad %uint %10 - %77 = OpBitcast %uint %39 - %38 = OpUConvert %ulong %77 - OpStore %11 %38 - %41 = OpLoad %_ptr_CrossWorkgroup_uchar %15 - %42 = OpLoad %ulong %11 - %59 = OpCopyObject %ulong %42 - %78 = OpBitcast %_ptr_CrossWorkgroup_uchar %41 + OpStore %25 %33 + OpStore %26 %34 + %55 = OpBitcast %_ptr_Function_ulong %25 + %54 = OpLoad %ulong %55 Aligned 8 + %27 = OpCopyObject %ulong %54 + %35 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %27 + OpStore %19 %35 + %57 = OpBitcast %_ptr_Function_ulong %26 + %56 = OpLoad %ulong %57 Aligned 8 + %28 = OpCopyObject %ulong %56 + %36 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %28 + OpStore %22 %36 + %37 = OpLoad %_ptr_CrossWorkgroup_uchar %19 + %30 = OpConvertPtrToU %ulong %37 + %29 = OpCopyObject %ulong %30 + %38 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %29 + OpStore %20 %38 + %39 = OpLoad %_ptr_CrossWorkgroup_uchar %22 + %32 = OpConvertPtrToU %ulong %39 + %31 = OpCopyObject %ulong %32 + %40 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %31 + OpStore %23 %40 + %15 = OpFunctionCall %uint %16 %uchar_0 + %41 = OpCopyObject %uint %15 + OpStore %10 %41 + %43 = OpLoad %uint %10 + %75 = OpBitcast %uint %43 + %42 = OpUConvert %ulong %75 + OpStore %11 %42 + %45 = OpLoad %_ptr_CrossWorkgroup_uchar %20 + %46 = OpLoad %ulong %11 + %58 = OpCopyObject %ulong %46 + %76 = OpBitcast %_ptr_CrossWorkgroup_uchar %45 + %77 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %76 %58 + %44 = OpBitcast %_ptr_CrossWorkgroup_uchar %77 + OpStore %21 %44 + %48 = OpLoad %_ptr_CrossWorkgroup_uchar %23 + %49 = OpLoad %ulong %11 + %59 = OpCopyObject %ulong %49 + %78 = OpBitcast %_ptr_CrossWorkgroup_uchar %48 %79 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %78 %59 - %40 = OpBitcast %_ptr_CrossWorkgroup_uchar %79 - OpStore %16 %40 - %44 = OpLoad %_ptr_CrossWorkgroup_uchar %18 - %45 = OpLoad %ulong %11 - %60 = OpCopyObject %ulong %45 - %80 = OpBitcast %_ptr_CrossWorkgroup_uchar %44 - %81 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %80 %60 - %43 = OpBitcast %_ptr_CrossWorkgroup_uchar %81 - OpStore %19 %43 - %47 = OpLoad %_ptr_CrossWorkgroup_uchar %16 - %61 = OpBitcast %_ptr_CrossWorkgroup_ulong %47 - %46 = OpLoad %ulong %61 Aligned 8 - OpStore %12 %46 - %48 = OpLoad %_ptr_CrossWorkgroup_uchar %19 - %49 = OpLoad %ulong %12 - %62 = OpBitcast %_ptr_CrossWorkgroup_ulong %48 - OpStore %62 %49 Aligned 8 + %47 = OpBitcast %_ptr_CrossWorkgroup_uchar %79 + OpStore %24 %47 + %51 = OpLoad %_ptr_CrossWorkgroup_uchar %21 + %60 = OpBitcast %_ptr_CrossWorkgroup_ulong %51 + %50 = OpLoad %ulong %60 Aligned 8 + OpStore %12 %50 + %52 = OpLoad %_ptr_CrossWorkgroup_uchar %24 + %53 = OpLoad %ulong %12 + %61 = OpBitcast %_ptr_CrossWorkgroup_ulong %52 + OpStore %61 %53 Aligned 8 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/stateful_ld_st_ntid_sub.spvtxt b/ptx/src/test/spirv_run/stateful_ld_st_ntid_sub.spvtxt index 987e205..0ef5d28 100644 --- a/ptx/src/test/spirv_run/stateful_ld_st_ntid_sub.spvtxt +++ b/ptx/src/test/spirv_run/stateful_ld_st_ntid_sub.spvtxt @@ -7,103 +7,101 @@ OpCapability Int64 OpCapability Float16 OpCapability Float64 - %71 = OpExtInstImport "OpenCL.std" + %70 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL OpEntryPoint Kernel %1 "stateful_ld_st_ntid_sub" OpExecutionMode %1 ContractionOff - OpDecorate %54 LinkageAttributes "_Z12get_local_idj" Import + OpDecorate %16 LinkageAttributes "__zluda_ptx_impl__sreg_tid" Import %void = OpTypeVoid - %ulong = OpTypeInt 64 0 %uint = OpTypeInt 32 0 - %75 = OpTypeFunction %ulong %uint %uchar = OpTypeInt 8 0 + %74 = OpTypeFunction %uint %uchar %_ptr_CrossWorkgroup_uchar = OpTypePointer CrossWorkgroup %uchar - %78 = OpTypeFunction %void %_ptr_CrossWorkgroup_uchar %_ptr_CrossWorkgroup_uchar + %76 = OpTypeFunction %void %_ptr_CrossWorkgroup_uchar %_ptr_CrossWorkgroup_uchar %_ptr_Function__ptr_CrossWorkgroup_uchar = OpTypePointer Function %_ptr_CrossWorkgroup_uchar %_ptr_Function_uint = OpTypePointer Function %uint + %ulong = OpTypeInt 64 0 %_ptr_Function_ulong = OpTypePointer Function %ulong - %uint_0 = OpConstant %uint 0 + %uchar_0 = OpConstant %uchar 0 %ulong_0 = OpConstant %ulong 0 %_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong %ulong_0_0 = OpConstant %ulong 0 - %54 = OpFunction %ulong None %75 - %56 = OpFunctionParameter %uint + %16 = OpFunction %uint None %74 + %18 = OpFunctionParameter %uchar OpFunctionEnd - %1 = OpFunction %void None %78 - %30 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar - %31 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar - %69 = OpLabel + %1 = OpFunction %void None %76 + %35 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar + %36 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar + %68 = OpLabel + %25 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %26 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %19 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function %20 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function %21 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function - %14 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function - %15 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function - %16 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function - %17 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function - %18 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function - %19 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %22 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %23 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %24 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function %10 = OpVariable %_ptr_Function_uint Function %11 = OpVariable %_ptr_Function_ulong Function %12 = OpVariable %_ptr_Function_ulong Function - OpStore %20 %30 - OpStore %21 %31 - %62 = OpBitcast %_ptr_Function_ulong %20 - %61 = OpLoad %ulong %62 Aligned 8 - %22 = OpCopyObject %ulong %61 - %32 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %22 - OpStore %14 %32 - %64 = OpBitcast %_ptr_Function_ulong %21 - %63 = OpLoad %ulong %64 Aligned 8 - %23 = OpCopyObject %ulong %63 - %33 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %23 - OpStore %17 %33 - %34 = OpLoad %_ptr_CrossWorkgroup_uchar %14 - %25 = OpConvertPtrToU %ulong %34 - %24 = OpCopyObject %ulong %25 - %35 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %24 - OpStore %15 %35 - %36 = OpLoad %_ptr_CrossWorkgroup_uchar %17 - %27 = OpConvertPtrToU %ulong %36 - %26 = OpCopyObject %ulong %27 - %37 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %26 - OpStore %18 %37 - %53 = OpFunctionCall %ulong %54 %uint_0 - %82 = OpBitcast %ulong %53 - %39 = OpUConvert %uint %82 - %38 = OpCopyObject %uint %39 - OpStore %10 %38 - %41 = OpLoad %uint %10 - %83 = OpBitcast %uint %41 - %40 = OpUConvert %ulong %83 - OpStore %11 %40 - %42 = OpLoad %ulong %11 - %65 = OpCopyObject %ulong %42 - %28 = OpSNegate %ulong %65 - %44 = OpLoad %_ptr_CrossWorkgroup_uchar %15 - %84 = OpBitcast %_ptr_CrossWorkgroup_uchar %44 - %85 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %84 %28 - %43 = OpBitcast %_ptr_CrossWorkgroup_uchar %85 - OpStore %16 %43 - %45 = OpLoad %ulong %11 - %66 = OpCopyObject %ulong %45 - %29 = OpSNegate %ulong %66 - %47 = OpLoad %_ptr_CrossWorkgroup_uchar %18 - %86 = OpBitcast %_ptr_CrossWorkgroup_uchar %47 - %87 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %86 %29 - %46 = OpBitcast %_ptr_CrossWorkgroup_uchar %87 - OpStore %19 %46 - %49 = OpLoad %_ptr_CrossWorkgroup_uchar %16 - %67 = OpBitcast %_ptr_CrossWorkgroup_ulong %49 + OpStore %25 %35 + OpStore %26 %36 + %61 = OpBitcast %_ptr_Function_ulong %25 + %60 = OpLoad %ulong %61 Aligned 8 + %27 = OpCopyObject %ulong %60 + %37 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %27 + OpStore %19 %37 + %63 = OpBitcast %_ptr_Function_ulong %26 + %62 = OpLoad %ulong %63 Aligned 8 + %28 = OpCopyObject %ulong %62 + %38 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %28 + OpStore %22 %38 + %39 = OpLoad %_ptr_CrossWorkgroup_uchar %19 + %30 = OpConvertPtrToU %ulong %39 + %29 = OpCopyObject %ulong %30 + %40 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %29 + OpStore %20 %40 + %41 = OpLoad %_ptr_CrossWorkgroup_uchar %22 + %32 = OpConvertPtrToU %ulong %41 + %31 = OpCopyObject %ulong %32 + %42 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %31 + OpStore %23 %42 + %15 = OpFunctionCall %uint %16 %uchar_0 + %43 = OpCopyObject %uint %15 + OpStore %10 %43 + %45 = OpLoad %uint %10 + %81 = OpBitcast %uint %45 + %44 = OpUConvert %ulong %81 + OpStore %11 %44 + %46 = OpLoad %ulong %11 + %64 = OpCopyObject %ulong %46 + %33 = OpSNegate %ulong %64 + %48 = OpLoad %_ptr_CrossWorkgroup_uchar %20 + %82 = OpBitcast %_ptr_CrossWorkgroup_uchar %48 + %83 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %82 %33 + %47 = OpBitcast %_ptr_CrossWorkgroup_uchar %83 + OpStore %21 %47 + %49 = OpLoad %ulong %11 + %65 = OpCopyObject %ulong %49 + %34 = OpSNegate %ulong %65 + %51 = OpLoad %_ptr_CrossWorkgroup_uchar %23 + %84 = OpBitcast %_ptr_CrossWorkgroup_uchar %51 + %85 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %84 %34 + %50 = OpBitcast %_ptr_CrossWorkgroup_uchar %85 + OpStore %24 %50 + %53 = OpLoad %_ptr_CrossWorkgroup_uchar %21 + %66 = OpBitcast %_ptr_CrossWorkgroup_ulong %53 + %87 = OpBitcast %_ptr_CrossWorkgroup_uchar %66 + %88 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %87 %ulong_0 + %57 = OpBitcast %_ptr_CrossWorkgroup_ulong %88 + %52 = OpLoad %ulong %57 Aligned 8 + OpStore %12 %52 + %54 = OpLoad %_ptr_CrossWorkgroup_uchar %24 + %55 = OpLoad %ulong %12 + %67 = OpBitcast %_ptr_CrossWorkgroup_ulong %54 %89 = OpBitcast %_ptr_CrossWorkgroup_uchar %67 - %90 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %89 %ulong_0 - %58 = OpBitcast %_ptr_CrossWorkgroup_ulong %90 - %48 = OpLoad %ulong %58 Aligned 8 - OpStore %12 %48 - %50 = OpLoad %_ptr_CrossWorkgroup_uchar %19 - %51 = OpLoad %ulong %12 - %68 = OpBitcast %_ptr_CrossWorkgroup_ulong %50 - %91 = OpBitcast %_ptr_CrossWorkgroup_uchar %68 - %92 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %91 %ulong_0_0 - %60 = OpBitcast %_ptr_CrossWorkgroup_ulong %92 - OpStore %60 %51 Aligned 8 + %90 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %89 %ulong_0_0 + %59 = OpBitcast %_ptr_CrossWorkgroup_ulong %90 + OpStore %59 %55 Aligned 8 OpReturn OpFunctionEnd diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 15dcdd1..8135c51 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -10,8 +10,6 @@ use rspirv::binary::{Assemble, Disassemble}; static ZLUDA_PTX_IMPL_INTEL: &'static [u8] = include_bytes!("../lib/zluda_ptx_impl.spv"); static ZLUDA_PTX_IMPL_AMD: &'static [u8] = include_bytes!("../lib/zluda_ptx_impl.bc"); const ZLUDA_PTX_PREFIX: &'static str = "__zluda_ptx_impl__"; -const ZLUDA_PTX_PREFIX_SREG_CLOCK: &'static str = "__zluda_ptx_impl__sreg_clock"; -const ZLUDA_PTX_PREFIX_SREG_LANEMASK_LT: &'static str = "__zluda_ptx_impl__sreg_lanemask_lt"; quick_error! { #[derive(Debug)] @@ -426,8 +424,8 @@ pub struct KernelInfo { pub uses_shared_mem: bool, } -pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result { - let mut id_defs = GlobalStringIdResolver::new(1); +pub fn to_spirv_module<'input>(ast: ast::Module<'input>) -> Result { + let mut id_defs = GlobalStringIdResolver::<'input>::new(1); let mut ptx_impl_imports = HashMap::new(); let directives = ast .directives @@ -1135,9 +1133,9 @@ fn emit_memory_model(builder: &mut dr::Builder) { ); } -fn translate_directive<'input>( - id_defs: &mut GlobalStringIdResolver<'input>, - ptx_impl_imports: &mut HashMap>, +fn translate_directive<'input, 'a>( + id_defs: &'a mut GlobalStringIdResolver<'input>, + ptx_impl_imports: &'a mut HashMap>, d: ast::Directive<'input, ast::ParsedArgParams<'input>>, ) -> Result>, TranslateError> { Ok(match d { @@ -1157,11 +1155,11 @@ fn translate_directive<'input>( }) } -fn translate_function<'a>( - id_defs: &mut GlobalStringIdResolver<'a>, - ptx_impl_imports: &mut HashMap>, - f: ast::ParsedFunction<'a>, -) -> Result>, TranslateError> { +fn translate_function<'input, 'a>( + id_defs: &'a mut GlobalStringIdResolver<'input>, + ptx_impl_imports: &'a mut HashMap>, + f: ast::ParsedFunction<'input>, +) -> Result>, TranslateError> { let import_as = match &f.func_directive { ast::MethodDeclaration { name: ast::MethodName::Func("__assertfail"), @@ -1206,7 +1204,7 @@ fn rename_fn_params<'a, 'b>( } fn to_ssa<'input, 'b>( - ptx_impl_imports: &mut HashMap, + ptx_impl_imports: &'b mut HashMap>, mut id_defs: FnStringIdResolver<'input, 'b>, fn_defs: GlobalFnDeclResolver<'input, 'b>, func_decl: Rc>>, @@ -1231,6 +1229,8 @@ fn to_ssa<'input, 'b>( let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs)?; let typed_statements = convert_to_typed_statements(unadorned_statements, &fn_defs, &mut numeric_id_defs)?; + let typed_statements = + fix_special_registers2(ptx_impl_imports, typed_statements, &mut numeric_id_defs)?; let (func_decl, typed_statements) = convert_to_stateful_memory_access(func_decl, typed_statements, &mut numeric_id_defs)?; let ssa_statements = insert_mem_ssa_statements( @@ -1238,8 +1238,6 @@ fn to_ssa<'input, 'b>( &mut numeric_id_defs, &mut (*func_decl).borrow_mut(), )?; - let ssa_statements = - fix_special_registers(ptx_impl_imports, ssa_statements, &mut numeric_id_defs)?; let mut numeric_id_defs = numeric_id_defs.finish(); let expanded_statements = expand_arguments(ssa_statements, &mut numeric_id_defs)?; let expanded_statements = @@ -1257,90 +1255,147 @@ fn to_ssa<'input, 'b>( }) } -fn fix_special_registers( - ptx_impl_imports: &mut HashMap, +fn fix_special_registers2<'a, 'b, 'input>( + ptx_impl_imports: &'a mut HashMap>, typed_statements: Vec, - numeric_id_defs: &mut NumericIdResolver, + numeric_id_defs: &'a mut NumericIdResolver<'b>, ) -> Result, TranslateError> { - let mut result = Vec::with_capacity(typed_statements.len()); + let result = Vec::with_capacity(typed_statements.len()); + let mut sreg_sresolver = SpecialRegisterResolver { + ptx_impl_imports, + numeric_id_defs, + result, + }; for s in typed_statements { match s { - Statement::LoadVar( - details - @ - LoadVarDetails { - member_index: Some((_, Some(_))), - .. - }, - ) => { - let index = details.member_index.unwrap().0; - let sreg = numeric_id_defs - .special_registers - .get(details.arg.src) - .ok_or_else(|| error_unreachable())?; - let (ocl_name, ocl_type) = sreg.get_opencl_fn_type(); - let index_constant = numeric_id_defs.register_intermediate(Some(( - ast::Type::Scalar(ast::ScalarType::U32), - ast::StateSpace::Reg, - ))); - result.push(Statement::Constant(ConstantDefinition { - dst: index_constant, - typ: ast::ScalarType::U32, - value: ast::ImmediateValue::U64(index as u64), - })); - let fn_result = numeric_id_defs.register_intermediate(Some(( - ast::Type::Scalar(ocl_type), - ast::StateSpace::Reg, - ))); - let return_arguments = - vec![(fn_result, ast::Type::Scalar(ocl_type), ast::StateSpace::Reg)]; - let input_arguments = vec![( - TypedOperand::Reg(index_constant), - ast::Type::Scalar(ast::ScalarType::U32), - ast::StateSpace::Reg, - )]; - let fn_call = register_external_fn_call( - numeric_id_defs, - ptx_impl_imports, - ocl_name.to_string(), - return_arguments.iter().map(|(_, typ, space)| (typ, *space)), - input_arguments.iter().map(|(_, typ, space)| (typ, *space)), - )?; - result.push(Statement::Call(ResolvedCall { - uniform: false, - return_arguments, - name: fn_call, - input_arguments, - })); - result.push(Statement::Conversion(ImplicitConversion { - src: fn_result, - dst: details.arg.dst, - from_type: ast::Type::Scalar(ocl_type), - from_space: ast::StateSpace::Reg, - to_type: ast::Type::Scalar(ast::ScalarType::U32), - to_space: ast::StateSpace::Reg, - kind: ConversionKind::Default, - })); + Statement::Call(details) => { + let new_statement = details.visit(&mut sreg_sresolver)?; + sreg_sresolver.result.push(new_statement); + } + Statement::Instruction(details) => { + let new_statement = details.visit(&mut sreg_sresolver)?; + sreg_sresolver.result.push(new_statement); } - s => result.push(s), + Statement::Conditional(details) => { + let new_statement = details.visit(&mut sreg_sresolver)?; + sreg_sresolver.result.push(new_statement); + } + Statement::Conversion(details) => { + let new_statement = details.visit(&mut sreg_sresolver)?; + sreg_sresolver.result.push(new_statement); + } + Statement::PtrAccess(details) => { + let new_statement = details.visit(&mut sreg_sresolver)?; + sreg_sresolver.result.push(new_statement); + } + Statement::RepackVector(details) => { + let new_statement = details.visit(&mut sreg_sresolver)?; + sreg_sresolver.result.push(new_statement); + } + s @ Statement::Variable(_) + | s @ Statement::Label(_) + | s @ Statement::FunctionPointer(_) => sreg_sresolver.result.push(s), + _ => return Err(error_unreachable()), } } - Ok(result) + Ok(sreg_sresolver.result) } -fn get_sreg_id_scalar_type( - numeric_id_defs: &mut NumericIdResolver, - sreg: PtxSpecialRegister, -) -> Option<(spirv::Word, ast::ScalarType, u8)> { - match sreg.normalized_sreg_and_type() { - Some((normalized_sreg, typ, vec_width)) => Some(( - numeric_id_defs - .special_registers - .get_or_add(numeric_id_defs.current_id, normalized_sreg), - typ, - vec_width, - )), - None => None, +struct SpecialRegisterResolver<'a, 'b, 'input> { + ptx_impl_imports: &'a mut HashMap>, + numeric_id_defs: &'a mut NumericIdResolver<'b>, + result: Vec, +} + +impl<'a, 'b, 'input> SpecialRegisterResolver<'a, 'b, 'input> { + fn replace_sreg( + &mut self, + desc: ArgumentDescriptor, + vector_index: Option, + ) -> Result { + if let Some(sreg) = self.numeric_id_defs.special_registers.get(desc.op) { + if desc.is_dst { + return Err(TranslateError::MismatchedType); + } + let input_arguments = match (vector_index, sreg.get_function_input_type()) { + (Some(idx), Some(inp_type)) => { + if inp_type != ast::ScalarType::U8 { + return Err(TranslateError::Unreachable); + } + let constant = self.numeric_id_defs.register_intermediate(Some(( + ast::Type::Scalar(inp_type), + ast::StateSpace::Reg, + ))); + self.result.push(Statement::Constant(ConstantDefinition { + dst: constant, + typ: inp_type, + value: ast::ImmediateValue::U64(idx as u64), + })); + vec![( + TypedOperand::Reg(constant), + ast::Type::Scalar(inp_type), + ast::StateSpace::Reg, + )] + } + (None, None) => Vec::new(), + _ => return Err(TranslateError::MismatchedType), + }; + let ocl_fn_name = [ZLUDA_PTX_PREFIX, sreg.get_unprefixed_function_name()].concat(); + let return_type = sreg.get_function_return_type(); + let fn_result = self.numeric_id_defs.register_intermediate(Some(( + ast::Type::Scalar(return_type), + ast::StateSpace::Reg, + ))); + let return_arguments = vec![( + fn_result, + ast::Type::Scalar(return_type), + ast::StateSpace::Reg, + )]; + let fn_call = register_external_fn_call( + self.numeric_id_defs, + self.ptx_impl_imports, + ocl_fn_name.to_string(), + return_arguments.iter().map(|(_, typ, space)| (typ, *space)), + input_arguments.iter().map(|(_, typ, space)| (typ, *space)), + )?; + self.result.push(Statement::Call(ResolvedCall { + uniform: false, + return_arguments, + name: fn_call, + input_arguments, + })); + Ok(fn_result) + } else { + Ok(desc.op) + } + } +} + +impl<'a, 'b, 'input> ArgumentMapVisitor + for SpecialRegisterResolver<'a, 'b, 'input> +{ + fn id( + &mut self, + desc: ArgumentDescriptor, + _: Option<(&ast::Type, ast::StateSpace)>, + ) -> Result { + self.replace_sreg(desc, None) + } + + fn operand( + &mut self, + desc: ArgumentDescriptor, + typ: &ast::Type, + state_space: ast::StateSpace, + ) -> Result { + Ok(match desc.op { + TypedOperand::Reg(reg) => TypedOperand::Reg(self.replace_sreg(desc.new_op(reg), None)?), + op @ TypedOperand::RegOffset(_, _) => op, + op @ TypedOperand::Imm(_) => op, + TypedOperand::VecMember(reg, idx) => { + TypedOperand::VecMember(self.replace_sreg(desc.new_op(reg), Some(idx))?, idx) + } + }) } } @@ -1968,22 +2023,8 @@ fn insert_mem_ssa_statements<'a, 'b>( } inst => insert_mem_ssa_statement_default(id_def, &mut result, inst)?, }, - Statement::Conditional(mut bra) => { - let generated_id = id_def.register_intermediate(Some(( - ast::Type::Scalar(ast::ScalarType::Pred), - ast::StateSpace::Reg, - ))); - result.push(Statement::LoadVar(LoadVarDetails { - arg: Arg2 { - dst: generated_id, - src: bra.predicate, - }, - state_space: ast::StateSpace::Reg, - typ: ast::Type::Scalar(ast::ScalarType::Pred), - member_index: None, - })); - bra.predicate = generated_id; - result.push(Statement::Conditional(bra)); + Statement::Conditional(bra) => { + insert_mem_ssa_statement_default(id_def, &mut result, bra)? } Statement::Conversion(conv) => { insert_mem_ssa_statement_default(id_def, &mut result, conv)? @@ -1997,7 +2038,9 @@ fn insert_mem_ssa_statements<'a, 'b>( Statement::FunctionPointer(func_ptr) => { insert_mem_ssa_statement_default(id_def, &mut result, func_ptr)? } - s @ Statement::Variable(_) | s @ Statement::Label(_) => result.push(s), + s @ Statement::Variable(_) | s @ Statement::Label(_) | s @ Statement::Constant(..) => { + result.push(s) + } _ => return Err(error_unreachable()), } } @@ -4539,6 +4582,7 @@ fn convert_to_stateful_memory_access<'a, 'input>( match statement { l @ Statement::Label(_) => result.push(l), c @ Statement::Conditional(_) => result.push(c), + c @ Statement::Constant(..) => result.push(c), Statement::Variable(var) => { if !remapped_ids.contains_key(&var.name) { result.push(Statement::Variable(var)); @@ -4791,13 +4835,9 @@ fn is_64_bit_integer(id_defs: &NumericIdResolver, id: spirv::Word) -> bool { #[derive(Ord, PartialOrd, Eq, PartialEq, Hash, Copy, Clone)] enum PtxSpecialRegister { Tid, - Tid64, Ntid, - Ntid64, Ctaid, - Ctaid64, Nctaid, - Nctaid64, Clock, LanemaskLt, } @@ -4817,71 +4857,43 @@ impl PtxSpecialRegister { fn get_type(self) -> ast::Type { match self { - PtxSpecialRegister::Tid => ast::Type::Vector(ast::ScalarType::U32, 4), - PtxSpecialRegister::Tid64 => ast::Type::Vector(ast::ScalarType::U64, 3), - PtxSpecialRegister::Ntid => ast::Type::Vector(ast::ScalarType::U32, 4), - PtxSpecialRegister::Ntid64 => ast::Type::Vector(ast::ScalarType::U64, 3), - PtxSpecialRegister::Ctaid => ast::Type::Vector(ast::ScalarType::U32, 4), - PtxSpecialRegister::Ctaid64 => ast::Type::Vector(ast::ScalarType::U64, 3), - PtxSpecialRegister::Nctaid => ast::Type::Vector(ast::ScalarType::U32, 4), - PtxSpecialRegister::Nctaid64 => ast::Type::Vector(ast::ScalarType::U64, 3), - PtxSpecialRegister::Clock => ast::Type::Scalar(ast::ScalarType::U32), - PtxSpecialRegister::LanemaskLt => ast::Type::Scalar(ast::ScalarType::U32), + PtxSpecialRegister::Tid + | PtxSpecialRegister::Ntid + | PtxSpecialRegister::Ctaid + | PtxSpecialRegister::Nctaid => ast::Type::Vector(self.get_function_return_type(), 4), + _ => ast::Type::Scalar(self.get_function_return_type()), } } - fn get_scalar_type(self) -> ast::ScalarType { + fn get_function_return_type(self) -> ast::ScalarType { match self { - PtxSpecialRegister::Tid - | PtxSpecialRegister::Ntid - | PtxSpecialRegister::Ctaid - | PtxSpecialRegister::Nctaid - | PtxSpecialRegister::Clock - | PtxSpecialRegister::LanemaskLt => ast::ScalarType::U32, - PtxSpecialRegister::Tid64 - | PtxSpecialRegister::Ntid64 - | PtxSpecialRegister::Ctaid64 - | PtxSpecialRegister::Nctaid64 => ast::ScalarType::U64, + PtxSpecialRegister::Tid => ast::ScalarType::U32, + PtxSpecialRegister::Ntid => ast::ScalarType::U32, + PtxSpecialRegister::Ctaid => ast::ScalarType::U32, + PtxSpecialRegister::Nctaid => ast::ScalarType::U32, + PtxSpecialRegister::Clock => ast::ScalarType::U32, + PtxSpecialRegister::LanemaskLt => ast::ScalarType::U32, } } - fn get_opencl_fn_type(self) -> (&'static str, ast::ScalarType) { + fn get_function_input_type(self) -> Option { match self { - PtxSpecialRegister::Tid | PtxSpecialRegister::Tid64 => { - ("_Z12get_local_idj", ast::ScalarType::U64) - } - PtxSpecialRegister::Ntid | PtxSpecialRegister::Ntid64 => { - ("_Z14get_local_sizej", ast::ScalarType::U64) - } - PtxSpecialRegister::Ctaid | PtxSpecialRegister::Ctaid64 => { - ("_Z12get_group_idj", ast::ScalarType::U64) - } - PtxSpecialRegister::Nctaid | PtxSpecialRegister::Nctaid64 => { - ("_Z14get_num_groupsj", ast::ScalarType::U64) - } - PtxSpecialRegister::Clock => (ZLUDA_PTX_PREFIX_SREG_CLOCK, ast::ScalarType::U32), - PtxSpecialRegister::LanemaskLt => { - (ZLUDA_PTX_PREFIX_SREG_LANEMASK_LT, ast::ScalarType::U32) - } + PtxSpecialRegister::Tid + | PtxSpecialRegister::Ntid + | PtxSpecialRegister::Ctaid + | PtxSpecialRegister::Nctaid => Some(ast::ScalarType::U8), + PtxSpecialRegister::Clock | PtxSpecialRegister::LanemaskLt => None, } } - fn normalized_sreg_and_type(self) -> Option<(PtxSpecialRegister, ast::ScalarType, u8)> { + fn get_unprefixed_function_name(self) -> &'static str { match self { - PtxSpecialRegister::Tid => Some((PtxSpecialRegister::Tid64, ast::ScalarType::U64, 3)), - PtxSpecialRegister::Ntid => Some((PtxSpecialRegister::Ntid64, ast::ScalarType::U64, 3)), - PtxSpecialRegister::Ctaid => { - Some((PtxSpecialRegister::Ctaid64, ast::ScalarType::U64, 3)) - } - PtxSpecialRegister::Nctaid => { - Some((PtxSpecialRegister::Nctaid64, ast::ScalarType::U64, 3)) - } - PtxSpecialRegister::Tid64 - | PtxSpecialRegister::Ntid64 - | PtxSpecialRegister::Ctaid64 - | PtxSpecialRegister::Nctaid64 - | PtxSpecialRegister::Clock => None, - PtxSpecialRegister::LanemaskLt => None, + PtxSpecialRegister::Tid => "sreg_tid", + PtxSpecialRegister::Ntid => "sreg_ntid", + PtxSpecialRegister::Ctaid => "sreg_ctaid", + PtxSpecialRegister::Nctaid => "sreg_nctaid", + PtxSpecialRegister::Clock => "sreg_clock", + PtxSpecialRegister::LanemaskLt => "sreg_lanemask_lt", } } } @@ -4899,16 +4911,6 @@ impl SpecialRegistersMap { } } - fn builtins<'a>(&'a self) -> impl Iterator + 'a { - self.reg_to_id.iter().filter_map(|(sreg, id)| { - if sreg.normalized_sreg_and_type().is_none() { - Some((*sreg, *id)) - } else { - None - } - }) - } - fn interface(&self) -> Vec { return Vec::new(); /* @@ -6416,6 +6418,35 @@ struct BrachCondition { if_false: spirv::Word, } +impl, To: ArgParamsEx> Visitable + for BrachCondition +{ + fn visit( + self, + visitor: &mut impl ArgumentMapVisitor, + ) -> Result, To>, TranslateError> { + let predicate = visitor.id( + ArgumentDescriptor { + op: self.predicate, + is_dst: false, + is_memory_access: false, + non_default_implicit_conversion: None, + }, + Some(( + &ast::Type::Scalar(ast::ScalarType::Pred), + ast::StateSpace::Reg, + )), + )?; + let if_true = self.if_true; + let if_false = self.if_false; + Ok(Statement::Conditional(BrachCondition { + predicate, + if_true, + if_false, + })) + } +} + #[derive(Clone)] struct ImplicitConversion { src: spirv::Word, -- cgit v1.2.3